diff --git a/test/test_pre_tokenizer.cpp b/test/test_pre_tokenizer.cpp index 257e9d82..69416f67 100644 --- a/test/test_pre_tokenizer.cpp +++ b/test/test_pre_tokenizer.cpp @@ -11,6 +11,10 @@ #include #include +#include +#include +#include + // Local #include @@ -388,6 +392,83 @@ TEST_F(PreTokenizerConfigTest, SplitWithUnsupportedBehavior) { std::runtime_error); } +// Regex cache (unicode_regex_split) /////////////////////////////////////////// +// The ByteLevel pre-tokenizer drives unicode_regex_split, which caches the +// compiled std::regex/std::wregex per pattern. This test guards that cache: +// (1) results are deterministic / behavior-neutral, including across the full +// Unicode White_Space class and near-miss non-whitespace codepoints, and +// (2) it is thread-safe (the tokenizer pool calls it concurrently). +class RegexCacheTest : public ::testing::Test {}; + +TEST_F(RegexCacheTest, ByteLevelDeterministicAndThreadSafe) { + ByteLevelPreTokenizer ptok(/*add_prefix_space=*/false); + + const std::vector corpus = { + "Hello World", + " leading and multiple spaces ", + "tabs\t\tand\nnewlines\r\n", + "code: { return x; } // trailing ", + // Unicode White_Space codepoints (must be treated as whitespace). + "a\xC2\x85" + "b", // U+0085 NEL + "a\x0B" + "b\x0C" + "c", // U+000B VT, U+000C FF + "a\xE1\x9A\x80" + "b", // U+1680 OGHAM SPACE MARK + "a\xE2\x80\x80\xE2\x80\x8A" + "b", // U+2000 .. U+200A + "a\xE2\x80\xAF" + "b\xE2\x81\x9F" + "c", // U+202F NNBSP, U+205F MMSP + "a\xE3\x80\x80" + "b", // U+3000 ideographic space + // Near-miss NON-whitespace codepoints (must NOT be treated as ws). + "a\xE2\x80\x8B" + "b", // U+200B ZERO WIDTH SPACE + "a\xE1\xA0\x8E" + "b", // U+180E MONGOLIAN VOWEL SEPARATOR + "a\xEF\xBB\xBF" + "b", // U+FEFF BOM / ZWNBSP + // Unicode letters and a representative SID prompt fragment. + "caf\xC3\xA9 \xE4\xBD\xA0\xE5\xA5\xBD", + "history: i0:<1326-617-1617> i1:<197-296-385> next:", + }; + + // Single-threaded reference. + std::vector> ref; + ref.reserve(corpus.size()); + for (const auto& s : corpus) { + ref.push_back(ptok.pre_tokenize(s)); + EXPECT_FALSE(ref.back().empty()); + } + + // Hammer the shared pre-tokenizer (and the static regex cache inside + // unicode_regex_split) from many threads; every result must match the + // reference. A recompile race or a wrong cache lookup would surface as a + // mismatch (or a crash under TSAN). + constexpr int kThreads = 16; + constexpr int kIters = 200; + std::atomic mismatches{0}; + std::vector threads; + threads.reserve(kThreads); + for (int t = 0; t < kThreads; ++t) { + threads.emplace_back([&]() { + for (int iter = 0; iter < kIters; ++iter) { + for (size_t i = 0; i < corpus.size(); ++i) { + if (ptok.pre_tokenize(corpus[i]) != ref[i]) { + mismatches.fetch_add(1, std::memory_order_relaxed); + } + } + } + }); + } + for (auto& th : threads) { + th.join(); + } + EXPECT_EQ(0, mismatches.load()); +} + TEST_F(PreTokenizerConfigTest, SplitWithInvertTrue) { PreTokenizerConfig config; EXPECT_THROW( diff --git a/third-party/llama.cpp-unicode/src/unicode.cpp b/third-party/llama.cpp-unicode/src/unicode.cpp index 75f44ec5..169b6fa3 100644 --- a/third-party/llama.cpp-unicode/src/unicode.cpp +++ b/third-party/llama.cpp-unicode/src/unicode.cpp @@ -42,6 +42,8 @@ SOFTWARE. #include #include #include +#include +#include #include #include #include @@ -562,7 +564,25 @@ static std::vector unicode_regex_split_stl( const std::wstring& wtext, const std::wstring& regex_expr, const std::vector& offsets) { - std::wregex expr(regex_expr); + // Cache the compiled regex per pattern and reuse it. Matching on a const + // std::wregex from multiple threads is safe; the cache itself is guarded + // by a mutex. The pattern set is tiny and fixed (a handful of tokenizer + // patterns), so the cache does not grow unbounded. + static std::mutex s_mutex; + static std::unordered_map> + s_cache; + std::shared_ptr expr_ptr; + { + std::lock_guard lock(s_mutex); + auto it = s_cache.find(regex_expr); + if (it == s_cache.end()) { + expr_ptr = std::make_shared(regex_expr); + s_cache.emplace(regex_expr, expr_ptr); + } else { + expr_ptr = it->second; + } + } + const std::wregex& expr = *expr_ptr; std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve( offsets.size()); // Reserve memory for the approximate size @@ -597,7 +617,21 @@ static std::vector unicode_regex_split_stl( const std::string& text, const std::string& regex_expr, const std::vector& offsets) { - std::regex expr(regex_expr); + static std::mutex s_mutex; + static std::unordered_map> + s_cache; + std::shared_ptr expr_ptr; + { + std::lock_guard lock(s_mutex); + auto it = s_cache.find(regex_expr); + if (it == s_cache.end()) { + expr_ptr = std::make_shared(regex_expr); + s_cache.emplace(regex_expr, expr_ptr); + } else { + expr_ptr = it->second; + } + } + const std::regex& expr = *expr_ptr; std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve( offsets.size()); // Reserve memory for the approximate size @@ -943,7 +977,7 @@ std::vector unicode_regex_split( // Get canonical combining class for a codepoint using existing flags data static uint8_t get_combining_class(uint32_t cpt) { codepoint_flags flags = unicode_cpt_flags(cpt); - + // Use the existing flag system to determine combining class if (flags.is_accent_mark) { // Most combining marks have class 230, but some have different classes @@ -956,11 +990,11 @@ static uint8_t get_combining_class(uint32_t cpt) { if (cpt >= 0x06D6 && cpt <= 0x06E4) return 230; // Arabic small high marks if (cpt >= 0x06E7 && cpt <= 0x06E8) return 230; // Arabic small high marks if (cpt >= 0x06EA && cpt <= 0x06ED) return 220; // Arabic small low marks - + // Default combining class for most combining marks return 230; } - + return 0; // Non-combining character (starter) } @@ -970,7 +1004,7 @@ static void canonical_order(std::vector& cpts) { for (size_t j = i; j > 0; --j) { uint8_t cc1 = get_combining_class(cpts[j-1]); uint8_t cc2 = get_combining_class(cpts[j]); - + // Only reorder if both have non-zero combining class and are out of order if (cc1 > cc2 && cc2 != 0) { std::swap(cpts[j-1], cpts[j]); @@ -984,24 +1018,24 @@ static void canonical_order(std::vector& cpts) { // Build composition table by reverse-engineering the NFD data static std::unordered_map, uint32_t> build_composition_table() { std::unordered_map, uint32_t> composition_map; - + // Iterate through all NFD mappings to build reverse composition table for (const auto& range : unicode_ranges_nfd) { for (uint32_t cpt = range.first; cpt <= range.last; ++cpt) { uint32_t base = range.nfd; - + // For NFC, we need to figure out what combining character was removed // This is a simplified approach that works for the most common cases - + // Common diacritic mappings based on the composed character uint32_t combining = 0; - + // Determine combining character based on the composed character // This is derived from common Unicode patterns switch (cpt) { // Grave accent (0x0300) case 0x00C0: case 0x00E0: // À à - case 0x00C8: case 0x00E8: // È è + case 0x00C8: case 0x00E8: // È è case 0x00CC: case 0x00EC: // Ì ì case 0x00D2: case 0x00F2: // Ò ò case 0x00D9: case 0x00F9: // Ù ù @@ -1010,7 +1044,7 @@ static std::unordered_map, uint32_t> build_composi case 0x01D1: case 0x01D2: // Ǒ ǒ case 0x01D3: case 0x01D4: // Ǔ ǔ combining = 0x0300; break; - + // Acute accent (0x0301) case 0x00C1: case 0x00E1: // Á á case 0x00C9: case 0x00E9: // É é @@ -1019,7 +1053,7 @@ static std::unordered_map, uint32_t> build_composi case 0x00DA: case 0x00FA: // Ú ú case 0x00DD: case 0x00FD: // Ý ý combining = 0x0301; break; - + // Circumflex (0x0302) case 0x00C2: case 0x00E2: // Â â case 0x00CA: case 0x00EA: // Ê ê @@ -1027,13 +1061,13 @@ static std::unordered_map, uint32_t> build_composi case 0x00D4: case 0x00F4: // Ô ô case 0x00DB: case 0x00FB: // Û û combining = 0x0302; break; - + // Tilde (0x0303) case 0x00C3: case 0x00E3: // Ã ã case 0x00D1: case 0x00F1: // Ñ ñ case 0x00D5: case 0x00F5: // Õ õ combining = 0x0303; break; - + // Diaeresis (0x0308) case 0x00C4: case 0x00E4: // Ä ä case 0x00CB: case 0x00EB: // Ë ë @@ -1042,15 +1076,15 @@ static std::unordered_map, uint32_t> build_composi case 0x00DC: case 0x00FC: // Ü ü case 0x00FF: // ÿ combining = 0x0308; break; - + // Ring above (0x030A) case 0x00C5: case 0x00E5: // Å å combining = 0x030A; break; - + // Cedilla (0x0327) case 0x00C7: case 0x00E7: // Ç ç combining = 0x0327; break; - + default: // For other characters, try to infer from Unicode blocks if (cpt >= 0x0100 && cpt <= 0x017F) { @@ -1067,14 +1101,14 @@ static std::unordered_map, uint32_t> build_composi } break; } - + // Only add to composition table if we identified a combining character if (combining != 0) { composition_map[{base, combining}] = cpt; } } } - + return composition_map; } @@ -1086,40 +1120,40 @@ static const std::unordered_map, uint32_t>& get_co std::vector unicode_cpts_normalize_nfc( const std::vector& cpts) { - + // Step 1: Apply NFD (canonical decomposition) using existing implementation std::vector nfd_result = unicode_cpts_normalize_nfd(cpts); - + // Step 2: Apply canonical ordering canonical_order(nfd_result); - + // Step 3: Apply canonical composition const auto& composition_table = get_composition_table(); std::vector result; result.reserve(nfd_result.size()); - + size_t i = 0; while (i < nfd_result.size()) { uint32_t starter = nfd_result[i]; result.push_back(starter); - + // Only try to compose if this is a starter (combining class 0) if (get_combining_class(starter) == 0) { size_t last_starter_pos = result.size() - 1; - + // Look for composable combining marks after this starter size_t j = i + 1; while (j < nfd_result.size()) { uint32_t combining = nfd_result[j]; uint8_t cc = get_combining_class(combining); - + // If we hit another starter, stop if (cc == 0) break; - + // Try to compose with the last starter auto key = std::make_pair(result[last_starter_pos], combining); auto it = composition_table.find(key); - + if (it != composition_table.end()) { // Compose: replace starter with composed character result[last_starter_pos] = it->second; @@ -1127,7 +1161,7 @@ std::vector unicode_cpts_normalize_nfc( ++j; continue; } - + // No composition possible, add the combining character result.push_back(combining); ++j; @@ -1137,6 +1171,6 @@ std::vector unicode_cpts_normalize_nfc( ++i; } } - + return result; }