Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions test/test_pre_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
#include <nlohmann/json.hpp>
#include <re2/re2.h>

#include <atomic>
#include <thread>
#include <vector>

// Local
#include <pytorch/tokenizers/pre_tokenizer.h>

Expand Down Expand Up @@ -388,6 +392,83 @@ TEST_F(PreTokenizerConfigTest, SplitWithUnsupportedBehavior) {
std::runtime_error);
}

// Regex cache (unicode_regex_split) ///////////////////////////////////////////
// The ByteLevel pre-tokenizer drives unicode_regex_split, which caches the
// compiled std::regex/std::wregex per pattern. This test guards that cache:
// (1) results are deterministic / behavior-neutral, including across the full
// Unicode White_Space class and near-miss non-whitespace codepoints, and
// (2) it is thread-safe (the tokenizer pool calls it concurrently).
class RegexCacheTest : public ::testing::Test {};

TEST_F(RegexCacheTest, ByteLevelDeterministicAndThreadSafe) {
ByteLevelPreTokenizer ptok(/*add_prefix_space=*/false);

const std::vector<std::string> corpus = {
"Hello World",
" leading and multiple spaces ",
"tabs\t\tand\nnewlines\r\n",
"code: { return x; } // trailing ",
// Unicode White_Space codepoints (must be treated as whitespace).
"a\xC2\x85"
"b", // U+0085 NEL
"a\x0B"
"b\x0C"
"c", // U+000B VT, U+000C FF
"a\xE1\x9A\x80"
"b", // U+1680 OGHAM SPACE MARK
"a\xE2\x80\x80\xE2\x80\x8A"
"b", // U+2000 .. U+200A
"a\xE2\x80\xAF"
"b\xE2\x81\x9F"
"c", // U+202F NNBSP, U+205F MMSP
"a\xE3\x80\x80"
"b", // U+3000 ideographic space
// Near-miss NON-whitespace codepoints (must NOT be treated as ws).
"a\xE2\x80\x8B"
"b", // U+200B ZERO WIDTH SPACE
"a\xE1\xA0\x8E"
"b", // U+180E MONGOLIAN VOWEL SEPARATOR
"a\xEF\xBB\xBF"
"b", // U+FEFF BOM / ZWNBSP
// Unicode letters and a representative SID prompt fragment.
"caf\xC3\xA9 \xE4\xBD\xA0\xE5\xA5\xBD",
"history: i0:<1326-617-1617> i1:<197-296-385> next:",
};

// Single-threaded reference.
std::vector<std::vector<std::string>> ref;
ref.reserve(corpus.size());
for (const auto& s : corpus) {
ref.push_back(ptok.pre_tokenize(s));
EXPECT_FALSE(ref.back().empty());
}

// Hammer the shared pre-tokenizer (and the static regex cache inside
// unicode_regex_split) from many threads; every result must match the
// reference. A recompile race or a wrong cache lookup would surface as a
// mismatch (or a crash under TSAN).
constexpr int kThreads = 16;
constexpr int kIters = 200;
std::atomic<int> mismatches{0};
std::vector<std::thread> threads;
threads.reserve(kThreads);
for (int t = 0; t < kThreads; ++t) {
threads.emplace_back([&]() {
for (int iter = 0; iter < kIters; ++iter) {
for (size_t i = 0; i < corpus.size(); ++i) {
if (ptok.pre_tokenize(corpus[i]) != ref[i]) {
mismatches.fetch_add(1, std::memory_order_relaxed);
}
}
}
});
}
for (auto& th : threads) {
th.join();
}
EXPECT_EQ(0, mismatches.load());
}

TEST_F(PreTokenizerConfigTest, SplitWithInvertTrue) {
PreTokenizerConfig config;
EXPECT_THROW(
Expand Down
96 changes: 65 additions & 31 deletions third-party/llama.cpp-unicode/src/unicode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ SOFTWARE.
#include <limits>
#include <locale>
#include <map>
#include <memory>
#include <mutex>
#include <regex>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -562,7 +564,25 @@ static std::vector<size_t> unicode_regex_split_stl(
const std::wstring& wtext,
const std::wstring& regex_expr,
const std::vector<size_t>& offsets) {
std::wregex expr(regex_expr);
// Cache the compiled regex per pattern and reuse it. Matching on a const
// std::wregex from multiple threads is safe; the cache itself is guarded
// by a mutex. The pattern set is tiny and fixed (a handful of tokenizer
// patterns), so the cache does not grow unbounded.
static std::mutex s_mutex;
static std::unordered_map<std::wstring, std::shared_ptr<const std::wregex>>
s_cache;
std::shared_ptr<const std::wregex> expr_ptr;
{
std::lock_guard<std::mutex> lock(s_mutex);
auto it = s_cache.find(regex_expr);
if (it == s_cache.end()) {
expr_ptr = std::make_shared<const std::wregex>(regex_expr);
s_cache.emplace(regex_expr, expr_ptr);
} else {
expr_ptr = it->second;
}
}
const std::wregex& expr = *expr_ptr;
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(
offsets.size()); // Reserve memory for the approximate size
Expand Down Expand Up @@ -597,7 +617,21 @@ static std::vector<size_t> unicode_regex_split_stl(
const std::string& text,
const std::string& regex_expr,
const std::vector<size_t>& offsets) {
std::regex expr(regex_expr);
static std::mutex s_mutex;
static std::unordered_map<std::string, std::shared_ptr<const std::regex>>
s_cache;
std::shared_ptr<const std::regex> expr_ptr;
{
std::lock_guard<std::mutex> lock(s_mutex);
auto it = s_cache.find(regex_expr);
if (it == s_cache.end()) {
expr_ptr = std::make_shared<const std::regex>(regex_expr);
s_cache.emplace(regex_expr, expr_ptr);
} else {
expr_ptr = it->second;
}
}
const std::regex& expr = *expr_ptr;
std::vector<size_t> bpe_offsets; // store the offset of each word
bpe_offsets.reserve(
offsets.size()); // Reserve memory for the approximate size
Expand Down Expand Up @@ -943,7 +977,7 @@ std::vector<std::string> unicode_regex_split(
// Get canonical combining class for a codepoint using existing flags data
static uint8_t get_combining_class(uint32_t cpt) {
codepoint_flags flags = unicode_cpt_flags(cpt);

// Use the existing flag system to determine combining class
if (flags.is_accent_mark) {
// Most combining marks have class 230, but some have different classes
Expand All @@ -956,11 +990,11 @@ static uint8_t get_combining_class(uint32_t cpt) {
if (cpt >= 0x06D6 && cpt <= 0x06E4) return 230; // Arabic small high marks
if (cpt >= 0x06E7 && cpt <= 0x06E8) return 230; // Arabic small high marks
if (cpt >= 0x06EA && cpt <= 0x06ED) return 220; // Arabic small low marks

// Default combining class for most combining marks
return 230;
}

return 0; // Non-combining character (starter)
}

Expand All @@ -970,7 +1004,7 @@ static void canonical_order(std::vector<uint32_t>& cpts) {
for (size_t j = i; j > 0; --j) {
uint8_t cc1 = get_combining_class(cpts[j-1]);
uint8_t cc2 = get_combining_class(cpts[j]);

// Only reorder if both have non-zero combining class and are out of order
if (cc1 > cc2 && cc2 != 0) {
std::swap(cpts[j-1], cpts[j]);
Expand All @@ -984,24 +1018,24 @@ static void canonical_order(std::vector<uint32_t>& cpts) {
// Build composition table by reverse-engineering the NFD data
static std::unordered_map<std::pair<uint32_t, uint32_t>, uint32_t> build_composition_table() {
std::unordered_map<std::pair<uint32_t, uint32_t>, uint32_t> composition_map;

// Iterate through all NFD mappings to build reverse composition table
for (const auto& range : unicode_ranges_nfd) {
for (uint32_t cpt = range.first; cpt <= range.last; ++cpt) {
uint32_t base = range.nfd;

// For NFC, we need to figure out what combining character was removed
// This is a simplified approach that works for the most common cases

// Common diacritic mappings based on the composed character
uint32_t combining = 0;

// Determine combining character based on the composed character
// This is derived from common Unicode patterns
switch (cpt) {
// Grave accent (0x0300)
case 0x00C0: case 0x00E0: // À à
case 0x00C8: case 0x00E8: // È è
case 0x00C8: case 0x00E8: // È è
case 0x00CC: case 0x00EC: // Ì ì
case 0x00D2: case 0x00F2: // Ò ò
case 0x00D9: case 0x00F9: // Ù ù
Expand All @@ -1010,7 +1044,7 @@ static std::unordered_map<std::pair<uint32_t, uint32_t>, uint32_t> build_composi
case 0x01D1: case 0x01D2: // Ǒ ǒ
case 0x01D3: case 0x01D4: // Ǔ ǔ
combining = 0x0300; break;

// Acute accent (0x0301)
case 0x00C1: case 0x00E1: // Á á
case 0x00C9: case 0x00E9: // É é
Expand All @@ -1019,21 +1053,21 @@ static std::unordered_map<std::pair<uint32_t, uint32_t>, uint32_t> build_composi
case 0x00DA: case 0x00FA: // Ú ú
case 0x00DD: case 0x00FD: // Ý ý
combining = 0x0301; break;

// Circumflex (0x0302)
case 0x00C2: case 0x00E2: // Â â
case 0x00CA: case 0x00EA: // Ê ê
case 0x00CE: case 0x00EE: // Î î
case 0x00D4: case 0x00F4: // Ô ô
case 0x00DB: case 0x00FB: // Û û
combining = 0x0302; break;

// Tilde (0x0303)
case 0x00C3: case 0x00E3: // Ã ã
case 0x00D1: case 0x00F1: // Ñ ñ
case 0x00D5: case 0x00F5: // Õ õ
combining = 0x0303; break;

// Diaeresis (0x0308)
case 0x00C4: case 0x00E4: // Ä ä
case 0x00CB: case 0x00EB: // Ë ë
Expand All @@ -1042,15 +1076,15 @@ static std::unordered_map<std::pair<uint32_t, uint32_t>, uint32_t> build_composi
case 0x00DC: case 0x00FC: // Ü ü
case 0x00FF: // ÿ
combining = 0x0308; break;

// Ring above (0x030A)
case 0x00C5: case 0x00E5: // Å å
combining = 0x030A; break;

// Cedilla (0x0327)
case 0x00C7: case 0x00E7: // Ç ç
combining = 0x0327; break;

default:
// For other characters, try to infer from Unicode blocks
if (cpt >= 0x0100 && cpt <= 0x017F) {
Expand All @@ -1067,14 +1101,14 @@ static std::unordered_map<std::pair<uint32_t, uint32_t>, uint32_t> build_composi
}
break;
}

// Only add to composition table if we identified a combining character
if (combining != 0) {
composition_map[{base, combining}] = cpt;
}
}
}

return composition_map;
}

Expand All @@ -1086,48 +1120,48 @@ static const std::unordered_map<std::pair<uint32_t, uint32_t>, uint32_t>& get_co

std::vector<uint32_t> unicode_cpts_normalize_nfc(
const std::vector<uint32_t>& cpts) {

// Step 1: Apply NFD (canonical decomposition) using existing implementation
std::vector<uint32_t> nfd_result = unicode_cpts_normalize_nfd(cpts);

// Step 2: Apply canonical ordering
canonical_order(nfd_result);

// Step 3: Apply canonical composition
const auto& composition_table = get_composition_table();
std::vector<uint32_t> result;
result.reserve(nfd_result.size());

size_t i = 0;
while (i < nfd_result.size()) {
uint32_t starter = nfd_result[i];
result.push_back(starter);

// Only try to compose if this is a starter (combining class 0)
if (get_combining_class(starter) == 0) {
size_t last_starter_pos = result.size() - 1;

// Look for composable combining marks after this starter
size_t j = i + 1;
while (j < nfd_result.size()) {
uint32_t combining = nfd_result[j];
uint8_t cc = get_combining_class(combining);

// If we hit another starter, stop
if (cc == 0) break;

// Try to compose with the last starter
auto key = std::make_pair(result[last_starter_pos], combining);
auto it = composition_table.find(key);

if (it != composition_table.end()) {
// Compose: replace starter with composed character
result[last_starter_pos] = it->second;
// Skip this combining character
++j;
continue;
}

// No composition possible, add the combining character
result.push_back(combining);
++j;
Expand All @@ -1137,6 +1171,6 @@ std::vector<uint32_t> unicode_cpts_normalize_nfc(
++i;
}
}

return result;
}
Loading