Skip to content

Commit e90a294

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Gemma 4 local tokenizer support
PiperOrigin-RevId: 929258534
1 parent 176f247 commit e90a294

6 files changed

Lines changed: 177 additions & 35 deletions

File tree

google/genai/_local_tokenizer_loader.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
import hashlib
1919
import os
2020
import tempfile
21-
from typing import Optional, cast
21+
from typing import Any, Optional, cast
2222
import uuid
2323

2424
import requests # type: ignore
2525
import sentencepiece as spm
2626
from sentencepiece import sentencepiece_model_pb2
27+
from transformers import AutoProcessor
2728

2829

2930
# Source of truth: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
@@ -47,21 +48,26 @@
4748
"gemini-3-pro-preview": "gemma3",
4849
}
4950

51+
# https://github.com/google/gemma_pytorch stop supporting gemma 4 moving forward.
52+
_GEMINI_MODELS_TO_HUGGINGFACE_TOKENIZER_NAMES = {
53+
"gemini-3.5-flash": "gemma4",
54+
"gemini-3.1-flash-lite": "gemma4",
55+
"gemini-3.1-pro-preview": "gemma4",
56+
"gemini-4-flash-preview": "gemma4",
57+
}
58+
59+
GEMMA_TOKENIZER_TO_MODEL_NAMES = {
60+
"gemma4": "google/gemma-4-E4B-it",
61+
}
62+
5063

5164
@dataclasses.dataclass(frozen=True)
5265
class _TokenizerConfig:
5366
model_url: str
5467
model_hash: str
5568

5669

57-
# TODO: update gemma3 tokenizer
5870
_TOKENIZERS = {
59-
"gemma2": _TokenizerConfig(
60-
model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
61-
model_hash=(
62-
"61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"
63-
),
64-
),
6571
"gemma3": _TokenizerConfig(
6672
model_url="https://raw.githubusercontent.com/google/gemma_pytorch/014acb7ac4563a5f77c76d7ff98f31b568c16508/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
6773
model_hash=(
@@ -177,7 +183,7 @@ def _load_model_proto_bytes(tokenizer_name: str) -> bytes:
177183
"""Loads model proto bytes from the given tokenizer name."""
178184
if tokenizer_name not in _TOKENIZERS:
179185
raise ValueError(
180-
f"Tokenizer {tokenizer_name} is not supported."
186+
f"Tokenizer {tokenizer_name} is not supported. "
181187
f"Supported tokenizers: {list(_TOKENIZERS.keys())}"
182188
)
183189
return _load(
@@ -202,11 +208,23 @@ def get_tokenizer_name(model_name: str) -> str:
202208
return _GEMINI_MODELS_TO_TOKENIZER_NAMES[model_name]
203209
if model_name in _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys():
204210
return _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES[model_name]
211+
if model_name in _GEMINI_MODELS_TO_HUGGINGFACE_TOKENIZER_NAMES.keys():
212+
return _GEMINI_MODELS_TO_HUGGINGFACE_TOKENIZER_NAMES[model_name]
205213
raise ValueError(
206214
f"Model {model_name} is not supported. Supported models: {', '.join(_GEMINI_MODELS_TO_TOKENIZER_NAMES.keys())}, {', '.join(_GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys())}.\n" # pylint: disable=line-too-long
207215
)
208216

209217

218+
def get_huggingface_tokenizer(tokenizer_name: str) -> Any:
219+
"""Loads huggingface tokenizer from the given tokenizer name."""
220+
# Load the processor which includes the tokenizer
221+
processor = AutoProcessor.from_pretrained( # type: ignore[no-untyped-call]
222+
GEMMA_TOKENIZER_TO_MODEL_NAMES[tokenizer_name]
223+
)
224+
# Access the underlying tokenizer if needed
225+
return processor.tokenizer
226+
227+
210228
@functools.lru_cache()
211229
def get_sentencepiece(tokenizer_name: str) -> spm.SentencePieceProcessor:
212230
"""Loads sentencepiece tokenizer from the given tokenizer name."""

google/genai/local_tokenizer.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,12 @@ class LocalTokenizer:
291291

292292
def __init__(self, model_name: str):
293293
self._tokenizer_name = loader.get_tokenizer_name(model_name)
294-
self._model_proto = loader.load_model_proto(self._tokenizer_name)
295-
self._tokenizer = loader.get_sentencepiece(self._tokenizer_name)
294+
self._model_proto = None
295+
if self._tokenizer_name in loader.GEMMA_TOKENIZER_TO_MODEL_NAMES:
296+
self._tokenizer = loader.get_huggingface_tokenizer(self._tokenizer_name)
297+
else:
298+
self._model_proto = loader.load_model_proto(self._tokenizer_name)
299+
self._tokenizer = loader.get_sentencepiece(self._tokenizer_name)
296300

297301
@_common.experimental_warning(
298302
"The SDK's local tokenizer implementation is experimental and may change"
@@ -365,27 +369,46 @@ def compute_tokens(
365369
# tokens_info=[TokensInfo(token_ids=[279, 329, 1313, 2508, 13], tokens=[b' What', b' is', b' your', b' name', b'?'], role='user')]
366370
"""
367371
processed_contents = t.t_contents(contents)
372+
roles = []
373+
368374
text_accumulator = _TextsAccumulator()
369375
for content in processed_contents:
370376
text_accumulator.add_content(content)
371-
tokens_protos = self._tokenizer.EncodeAsImmutableProto(
372-
text_accumulator.get_texts()
373-
)
374-
375-
roles = []
376-
for content in processed_contents:
377377
if content.parts:
378378
for _ in content.parts:
379379
roles.append(content.role)
380380

381381
token_infos = []
382+
if self._tokenizer_name in loader.GEMMA_TOKENIZER_TO_MODEL_NAMES:
383+
# Use the HuggingFace tokenizer since gemma_pytorch is not available for
384+
# gemma 4+.
385+
token_ids = self._tokenizer.encode(list(text_accumulator.get_texts()))
386+
for token_id, role in zip(token_ids, roles):
387+
token_infos.append(
388+
types.TokensInfo(
389+
token_ids=token_id,
390+
tokens=[
391+
token.replace("_", " ")
392+
.encode("utf-8")
393+
.replace(b"\xe2\x96\x81", b" ")
394+
for token in self._tokenizer.convert_ids_to_tokens(token_id)
395+
],
396+
role=role,
397+
)
398+
)
399+
return types.ComputeTokensResult(tokens_info=token_infos)
400+
401+
tokens_protos = self._tokenizer.EncodeAsImmutableProto(
402+
text_accumulator.get_texts()
403+
)
404+
382405
for tokens_proto, role in zip(tokens_protos, roles):
383406
token_infos.append(
384407
types.TokensInfo(
385408
token_ids=[piece.id for piece in tokens_proto.pieces],
386409
tokens=[
387410
_token_str_to_bytes(
388-
piece.piece, self._model_proto.pieces[piece.id].type
411+
piece.piece, self._model_proto.pieces[piece.id].type # type: ignore[union-attr]
389412
)
390413
for piece in tokens_proto.pieces
391414
],

google/genai/tests/local_tokenizer/test_local_tokenizer.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def setUp(self):
2929
self.mock_load_model_proto = patch(
3030
'genai._local_tokenizer_loader.load_model_proto'
3131
).start()
32+
self.addCleanup(patch.stopall)
3233
self.mock_get_sentencepiece = patch(
3334
'genai._local_tokenizer_loader.get_sentencepiece'
3435
).start()
@@ -39,9 +40,6 @@ def setUp(self):
3940

4041
self.tokenizer = local_tokenizer.LocalTokenizer(model_name='gemini-3-pro-preview')
4142

42-
def tearDown(self):
43-
patch.stopall()
44-
4543
def test_count_tokens_simple_string(self):
4644
self.mock_tokenizer.encode.return_value = [[1, 2, 3]]
4745
result = self.tokenizer.count_tokens('Hello world')
@@ -341,3 +339,72 @@ def test_invalid_format(self):
341339
def test_invalid_hex_value(self):
342340
with self.assertRaisesRegex(ValueError, 'Invalid hex value'):
343341
local_tokenizer._parse_hex_byte('<0xFG>')
342+
343+
344+
class TestLocalTokenizerHuggingFace(unittest.TestCase):
345+
346+
def setUp(self):
347+
self.mock_get_huggingface_tokenizer = patch(
348+
'genai._local_tokenizer_loader.get_huggingface_tokenizer'
349+
).start()
350+
self.addCleanup(patch.stopall)
351+
352+
self.mock_tokenizer = MagicMock()
353+
self.mock_get_huggingface_tokenizer.return_value = self.mock_tokenizer
354+
355+
# gemini-3.5-flash maps to gemma4 (HuggingFace)
356+
self.tokenizer = local_tokenizer.LocalTokenizer(model_name='gemini-3.5-flash')
357+
358+
def test_count_tokens_simple_string(self):
359+
self.mock_tokenizer.encode.return_value = [[1, 2, 3]]
360+
result = self.tokenizer.count_tokens('Hello world')
361+
self.assertEqual(result.total_tokens, 3)
362+
self.mock_tokenizer.encode.assert_called_once_with(['Hello world'])
363+
364+
def test_compute_tokens_simple_string(self):
365+
self.mock_tokenizer.encode.return_value = [[1, 2, 3]]
366+
self.mock_tokenizer.convert_ids_to_tokens.return_value = ['He', 'llo', ' world']
367+
368+
result = self.tokenizer.compute_tokens('Hello world')
369+
370+
self.assertEqual(len(result.tokens_info), 1)
371+
self.assertEqual(result.tokens_info[0].token_ids, [1, 2, 3])
372+
self.assertEqual(result.tokens_info[0].tokens, [b'He', b'llo', b' world'])
373+
self.assertEqual(result.tokens_info[0].role, 'user')
374+
375+
self.mock_tokenizer.encode.assert_called_once_with(['Hello world'])
376+
self.mock_tokenizer.convert_ids_to_tokens.assert_called_once_with([1, 2, 3])
377+
378+
def test_compute_tokens_special_characters(self):
379+
self.mock_tokenizer.encode.return_value = [[1, 2]]
380+
# Use U+2581 (lower one eighth block) and underscore
381+
self.mock_tokenizer.convert_ids_to_tokens.return_value = ['_world', '\u2581hello']
382+
383+
result = self.tokenizer.compute_tokens('dummy')
384+
385+
self.assertEqual(result.tokens_info[0].tokens, [b' world', b' hello'])
386+
387+
def test_compute_tokens_with_chat_history(self):
388+
self.mock_tokenizer.encode.return_value = [[1], [2, 3]]
389+
self.mock_tokenizer.convert_ids_to_tokens.side_effect = [
390+
['Hello'],
391+
['Hi', ' there!']
392+
]
393+
history = [
394+
types.Content(role='user', parts=[types.Part(text='Hello')]),
395+
types.Content(role='model', parts=[types.Part(text='Hi there!')]),
396+
]
397+
result = self.tokenizer.compute_tokens(history)
398+
self.assertEqual(len(result.tokens_info), 2)
399+
self.assertEqual(result.tokens_info[0].token_ids, [1])
400+
self.assertEqual(result.tokens_info[0].tokens, [b'Hello'])
401+
self.assertEqual(result.tokens_info[0].role, 'user')
402+
self.assertEqual(result.tokens_info[1].token_ids, [2, 3])
403+
self.assertEqual(result.tokens_info[1].tokens, [b'Hi', b' there!'])
404+
self.assertEqual(result.tokens_info[1].role, 'model')
405+
406+
self.mock_tokenizer.encode.assert_called_once_with(['Hello', 'Hi there!'])
407+
self.mock_tokenizer.convert_ids_to_tokens.assert_has_calls([
408+
unittest.mock.call([1]),
409+
unittest.mock.call([2, 3])
410+
])

google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
]
4848
).SerializeToString()
4949

50-
GEMMA2_HASH = "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"
50+
GEMMA3_HASH = "1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c"
5151

5252

5353
class TestGetTokenizerName(unittest.TestCase):
@@ -58,6 +58,18 @@ def test_get_tokenizer_name_success(self):
5858
loader.get_tokenizer_name("gemini-2.5-pro-preview-06-05"), "gemma3"
5959
)
6060

61+
def test_get_tokenizer_name_huggingface(self):
62+
self.assertEqual(loader.get_tokenizer_name("gemini-3.5-flash"), "gemma4")
63+
self.assertEqual(
64+
loader.get_tokenizer_name("gemini-3.1-flash-lite"), "gemma4"
65+
)
66+
self.assertEqual(
67+
loader.get_tokenizer_name("gemini-3.1-pro-preview"), "gemma4"
68+
)
69+
self.assertEqual(
70+
loader.get_tokenizer_name("gemini-4-flash-preview"), "gemma4"
71+
)
72+
6173
def test_get_tokenizer_name_unsupported(self):
6274
with self.assertRaisesRegex(
6375
ValueError, "Model unsupported-model is not supported"
@@ -105,9 +117,9 @@ def test_load_model_proto_from_url(
105117
):
106118
mock_exists.return_value = False # Don't use cache
107119
self._setup_get_mock(mock_get)
108-
mock_sha256.return_value.hexdigest.return_value = GEMMA2_HASH
120+
mock_sha256.return_value.hexdigest.return_value = GEMMA3_HASH
109121

110-
proto = loader.load_model_proto("gemma2")
122+
proto = loader.load_model_proto("gemma3")
111123

112124
self.assertIsInstance(proto, sentencepiece_model_pb2.ModelProto)
113125
self.assertEqual(len(proto.pieces), 4)
@@ -128,9 +140,9 @@ def test_load_model_proto_from_cache(
128140
):
129141
mock_exists.return_value = True # Use cache
130142
mock_open_func.return_value.read.return_value = FAKE_MODEL_CONTENT
131-
mock_sha256.return_value.hexdigest.return_value = GEMMA2_HASH
143+
mock_sha256.return_value.hexdigest.return_value = GEMMA3_HASH
132144

133-
proto = loader.load_model_proto("gemma2")
145+
proto = loader.load_model_proto("gemma3")
134146

135147
self.assertIsInstance(proto, sentencepiece_model_pb2.ModelProto)
136148
mock_get.assert_not_called()
@@ -154,10 +166,10 @@ def test_load_model_proto_corrupted_cache(
154166
# First hash for corrupted cache, second for good download
155167
mock_sha256.side_effect = [
156168
MagicMock(hexdigest=MagicMock(return_value="wrong_hash")),
157-
MagicMock(hexdigest=MagicMock(return_value=GEMMA2_HASH)),
169+
MagicMock(hexdigest=MagicMock(return_value=GEMMA3_HASH)),
158170
]
159171

160-
proto = loader.load_model_proto("gemma2")
172+
proto = loader.load_model_proto("gemma3")
161173

162174
self.assertIsInstance(proto, sentencepiece_model_pb2.ModelProto)
163175
mock_remove.assert_called_once()
@@ -180,7 +192,7 @@ def test_load_model_proto_bad_hash_from_url(
180192
with self.assertRaisesRegex(
181193
ValueError, "Downloaded model file is corrupted"
182194
):
183-
loader.load_model_proto("gemma2")
195+
loader.load_model_proto("gemma3")
184196

185197
def test_load_model_proto_unsupported(self, *args):
186198
with self.assertRaisesRegex(
@@ -200,9 +212,9 @@ def test_get_sentencepiece_success(
200212
):
201213
mock_exists.return_value = False
202214
self._setup_get_mock(mock_get)
203-
mock_sha256.return_value.hexdigest.return_value = GEMMA2_HASH
215+
mock_sha256.return_value.hexdigest.return_value = GEMMA3_HASH
204216

205-
processor = loader.get_sentencepiece("gemma2")
217+
processor = loader.get_sentencepiece("gemma3")
206218

207219
self.assertIsInstance(processor, spm.SentencePieceProcessor)
208220
mock_get.assert_called_once()
@@ -225,11 +237,32 @@ def test_get_sentencepiece_caching(
225237
):
226238
mock_exists.return_value = False
227239
self._setup_get_mock(mock_get)
228-
mock_sha256.return_value.hexdigest.return_value = GEMMA2_HASH
240+
mock_sha256.return_value.hexdigest.return_value = GEMMA3_HASH
229241

230242
# Call twice
231-
loader.get_sentencepiece("gemma2")
232-
loader.get_sentencepiece("gemma2")
243+
loader.get_sentencepiece("gemma3")
244+
loader.get_sentencepiece("gemma3")
233245

234246
# Should only be loaded once due to lru_cache
235247
mock_get.assert_called_once()
248+
249+
250+
class TestGetHuggingFaceTokenizer(unittest.TestCase):
251+
252+
@patch("genai._local_tokenizer_loader.AutoProcessor")
253+
def test_get_huggingface_tokenizer_success(self, mock_auto_processor):
254+
mock_processor = MagicMock()
255+
mock_tokenizer = MagicMock()
256+
mock_processor.tokenizer = mock_tokenizer
257+
mock_auto_processor.from_pretrained.return_value = mock_processor
258+
259+
tokenizer = loader.get_huggingface_tokenizer("gemma4")
260+
261+
self.assertEqual(tokenizer, mock_tokenizer)
262+
mock_auto_processor.from_pretrained.assert_called_once_with(
263+
"google/gemma-4-E4B-it"
264+
)
265+
266+
def test_get_huggingface_tokenizer_unsupported(self):
267+
with self.assertRaises(KeyError):
268+
loader.get_huggingface_tokenizer("unsupported")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ dependencies = [
3939

4040
[project.optional-dependencies]
4141
aiohttp = ["aiohttp>=3.10.11, <4.0.0"]
42-
local-tokenizer = ["sentencepiece>=0.2.0", "protobuf"]
42+
local-tokenizer = ["sentencepiece>=0.2.0", "protobuf", "transformers"]
4343
pyopenssl = ["pyopenssl"]
4444

4545
[project.urls]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ websockets==16.0
3232
mcp>=1.14.0; python_version > '3.9'
3333
sentencepiece>=0.2.0
3434
protobuf
35+
transformers>=5.10.1

0 commit comments

Comments
 (0)