diff --git a/minbpe/base.py b/minbpe/base.py index 65cc45cf..077d836e 100644 --- a/minbpe/base.py +++ b/minbpe/base.py @@ -6,6 +6,7 @@ some concessions are made for simplicity. """ import unicodedata +from typing import Optional # ----------------------------------------------------------------------------- # a few helper functions useful for both BasicTokenizer and RegexTokenizer @@ -22,7 +23,7 @@ def get_stats(ids, counts=None): return counts -def merge(ids, pair, idx): +def merge(ids, pair, idx, stats: Optional[dict] = None): """ In the list of integers (ids), replace all consecutive occurrences of pair with the new integer token idx @@ -30,13 +31,29 @@ def merge(ids, pair, idx): """ newids = [] i = 0 + curr_new_idx = -1 while i < len(ids): # if not at the very last position AND the pair matches, replace it if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]: newids.append(idx) + curr_new_idx += 1 + if stats is not None: + if (curr_new_idx > 0): # there is a token before this pair. Updating stat + old_pair = (newids[curr_new_idx - 1], ids[i]) + new_pair = (newids[curr_new_idx - 1], idx) + assert old_pair in stats + stats[old_pair] -= 1 + stats[new_pair] = stats.get(new_pair, 0) + 1 + if (i < (len(ids) - 2)): # there is a token after this pair. Updating stat + old_pair = (ids[i + 1], ids[i + 2]) + new_pair = (idx, ids[i + 2]) + assert old_pair in stats + stats[old_pair] -= 1 + stats[new_pair] = stats.get(new_pair, 0) + 1 i += 2 else: newids.append(ids[i]) + curr_new_idx += 1 i += 1 return newids diff --git a/minbpe/basic.py b/minbpe/basic.py index 9bc5ab76..56bb431d 100644 --- a/minbpe/basic.py +++ b/minbpe/basic.py @@ -28,21 +28,22 @@ def train(self, text, vocab_size, verbose=False): # iteratively merge the most common pairs to create new tokens merges = {} # (int, int) -> int vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes + # count up the number of times every consecutive pair appears + stats = get_stats(ids) for i in range(num_merges): - # count up the number of times every consecutive pair appears - stats = get_stats(ids) - # find the pair with the highest count - pair = max(stats, key=stats.get) # mint a new token: assign it the next available id idx = 256 + i + # find the pair with the highest count + pair = max(stats, key=stats.get) # replace all occurrences of pair in ids with idx - ids = merge(ids, pair, idx) + ids = merge(ids, pair, idx, stats) # save the merge merges[pair] = idx vocab[idx] = vocab[pair[0]] + vocab[pair[1]] # prints if verbose: print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") + stats[pair] = 0 # save class variables self.merges = merges # used in encode() diff --git a/minbpe/regex.py b/minbpe/regex.py index 9ed78e43..8e45767d 100644 --- a/minbpe/regex.py +++ b/minbpe/regex.py @@ -47,16 +47,14 @@ def train(self, text, vocab_size, verbose=False): merges = {} # (int, int) -> int vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes for i in range(num_merges): - # count the number of times every consecutive pair appears stats = {} for chunk_ids in ids: - # passing in stats will update it in place, adding up counts get_stats(chunk_ids, stats) # find the pair with the highest count pair = max(stats, key=stats.get) # mint a new token: assign it the next available id idx = 256 + i - # replace all occurrences of pair in ids with idx + # replace all occurrences of pair in ids with idx and update pair-count simultaneously ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids] # save the merge merges[pair] = idx @@ -64,6 +62,7 @@ def train(self, text, vocab_size, verbose=False): # prints if verbose: print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") + stats[pair] = 0 # save class variables self.merges = merges # used in encode()