Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion minbpe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
some concessions are made for simplicity.
"""
import unicodedata
from typing import Optional

# -----------------------------------------------------------------------------
# a few helper functions useful for both BasicTokenizer and RegexTokenizer
Expand All @@ -22,21 +23,37 @@ def get_stats(ids, counts=None):
return counts


def merge(ids, pair, idx):
def merge(ids, pair, idx, stats: Optional[dict] = None):
"""
In the list of integers (ids), replace all consecutive occurrences
of pair with the new integer token idx
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
newids = []
i = 0
curr_new_idx = -1
while i < len(ids):
# if not at the very last position AND the pair matches, replace it
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
curr_new_idx += 1
if stats is not None:
if (curr_new_idx > 0): # there is a token before this pair. Updating stat
old_pair = (newids[curr_new_idx - 1], ids[i])
new_pair = (newids[curr_new_idx - 1], idx)
assert old_pair in stats
stats[old_pair] -= 1
stats[new_pair] = stats.get(new_pair, 0) + 1
if (i < (len(ids) - 2)): # there is a token after this pair. Updating stat
old_pair = (ids[i + 1], ids[i + 2])
new_pair = (idx, ids[i + 2])
assert old_pair in stats
stats[old_pair] -= 1
stats[new_pair] = stats.get(new_pair, 0) + 1
i += 2
else:
newids.append(ids[i])
curr_new_idx += 1
i += 1
return newids

Expand Down
11 changes: 6 additions & 5 deletions minbpe/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,22 @@ def train(self, text, vocab_size, verbose=False):
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
# count up the number of times every consecutive pair appears
stats = get_stats(ids)
for i in range(num_merges):
# count up the number of times every consecutive pair appears
stats = get_stats(ids)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = 256 + i
# find the pair with the highest count
pair = max(stats, key=stats.get)
# replace all occurrences of pair in ids with idx
ids = merge(ids, pair, idx)
ids = merge(ids, pair, idx, stats)
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
stats[pair] = 0

# save class variables
self.merges = merges # used in encode()
Expand Down
5 changes: 2 additions & 3 deletions minbpe/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,22 @@ def train(self, text, vocab_size, verbose=False):
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
for i in range(num_merges):
# count the number of times every consecutive pair appears
stats = {}
for chunk_ids in ids:
# passing in stats will update it in place, adding up counts
get_stats(chunk_ids, stats)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = 256 + i
# replace all occurrences of pair in ids with idx
# replace all occurrences of pair in ids with idx and update pair-count simultaneously
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
stats[pair] = 0

# save class variables
self.merges = merges # used in encode()
Expand Down