Skip to content

Commit 941d575

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Gemma 4 local tokenizer support
PiperOrigin-RevId: 929258534
1 parent c754ebf commit 941d575

6 files changed

Lines changed: 165 additions & 8 deletions

File tree

google/genai/_local_tokenizer_loader.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import requests # type: ignore
2525
import sentencepiece as spm
2626
from sentencepiece import sentencepiece_model_pb2
27+
from transformers import AutoProcessor
2728

2829

2930
# Source of truth: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
@@ -47,6 +48,18 @@
4748
"gemini-3-pro-preview": "gemma3",
4849
}
4950

51+
# https://github.com/google/gemma_pytorch stop supporting gemma 4 moving forward.
52+
_GEMINI_MODELS_TO_HUGGINGFACE_TOKENIZER_NAMES = {
53+
"gemini-3.5-flash": "gemma4",
54+
"gemini-3.1-flash-lite": "gemma4",
55+
"gemini-3.1-pro-preview": "gemma4",
56+
"gemini-4-flash-preview": "gemma4",
57+
}
58+
59+
GEMMA_TOKENIZER_TO_MODEL_NAMES = {
60+
"gemma4": "google/gemma-4-E4B-it",
61+
}
62+
5063

5164
@dataclasses.dataclass(frozen=True)
5265
class _TokenizerConfig:
@@ -202,11 +215,23 @@ def get_tokenizer_name(model_name: str) -> str:
202215
return _GEMINI_MODELS_TO_TOKENIZER_NAMES[model_name]
203216
if model_name in _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys():
204217
return _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES[model_name]
218+
if model_name in _GEMINI_MODELS_TO_HUGGINGFACE_TOKENIZER_NAMES.keys():
219+
return _GEMINI_MODELS_TO_HUGGINGFACE_TOKENIZER_NAMES[model_name]
205220
raise ValueError(
206221
f"Model {model_name} is not supported. Supported models: {', '.join(_GEMINI_MODELS_TO_TOKENIZER_NAMES.keys())}, {', '.join(_GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys())}.\n" # pylint: disable=line-too-long
207222
)
208223

209224

225+
def get_huggingface_tokenizer(tokenizer_name: str):
226+
"""Loads huggingface tokenizer from the given tokenizer name."""
227+
# Load the processor which includes the tokenizer
228+
processor = AutoProcessor.from_pretrained(
229+
GEMMA_TOKENIZER_TO_MODEL_NAMES[tokenizer_name]
230+
)
231+
# Access the underlying tokenizer if needed
232+
return processor.tokenizer
233+
234+
210235
@functools.lru_cache()
211236
def get_sentencepiece(tokenizer_name: str) -> spm.SentencePieceProcessor:
212237
"""Loads sentencepiece tokenizer from the given tokenizer name."""

google/genai/local_tokenizer.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,10 @@ class LocalTokenizer:
292292
def __init__(self, model_name: str):
293293
self._tokenizer_name = loader.get_tokenizer_name(model_name)
294294
self._model_proto = loader.load_model_proto(self._tokenizer_name)
295-
self._tokenizer = loader.get_sentencepiece(self._tokenizer_name)
295+
if self._tokenizer_name != "gemma3":
296+
self._tokenizer = loader.get_huggingface_tokenizer(self._tokenizer_name)
297+
else:
298+
self._tokenizer = loader.get_sentencepiece(self._tokenizer_name)
296299

297300
@_common.experimental_warning(
298301
"The SDK's local tokenizer implementation is experimental and may change"
@@ -365,20 +368,39 @@ def compute_tokens(
365368
# tokens_info=[TokensInfo(token_ids=[279, 329, 1313, 2508, 13], tokens=[b' What', b' is', b' your', b' name', b'?'], role='user')]
366369
"""
367370
processed_contents = t.t_contents(contents)
371+
roles = []
372+
368373
text_accumulator = _TextsAccumulator()
369374
for content in processed_contents:
370375
text_accumulator.add_content(content)
371-
tokens_protos = self._tokenizer.EncodeAsImmutableProto(
372-
text_accumulator.get_texts()
373-
)
374-
375-
roles = []
376-
for content in processed_contents:
377376
if content.parts:
378377
for _ in content.parts:
379378
roles.append(content.role)
380379

381380
token_infos = []
381+
if self._tokenizer_name != "gemma3":
382+
# Use the HuggingFace tokenizer since gemma_pytorch is not available for
383+
# gemma 4+.
384+
token_ids = self._tokenizer.encode(list(text_accumulator.get_texts()))
385+
for token_id, role in zip(token_ids, roles):
386+
token_infos.append(
387+
types.TokensInfo(
388+
token_ids=token_id,
389+
tokens=[
390+
token.replace("_", " ")
391+
.encode("utf-8")
392+
.replace(b"\xe2\x96\x81", b" ")
393+
for token in self._tokenizer.convert_ids_to_tokens(token_id)
394+
],
395+
role=role,
396+
)
397+
)
398+
return types.ComputeTokensResult(tokens_info=token_infos)
399+
400+
tokens_protos = self._tokenizer.EncodeAsImmutableProto(
401+
text_accumulator.get_texts()
402+
)
403+
382404
for tokens_proto, role in zip(tokens_protos, roles):
383405
token_infos.append(
384406
types.TokensInfo(

google/genai/tests/local_tokenizer/test_local_tokenizer.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,79 @@ def test_invalid_format(self):
341341
def test_invalid_hex_value(self):
342342
with self.assertRaisesRegex(ValueError, 'Invalid hex value'):
343343
local_tokenizer._parse_hex_byte('<0xFG>')
344+
345+
346+
class TestLocalTokenizerHuggingFace(unittest.TestCase):
347+
348+
def setUp(self):
349+
self.mock_load_model_proto = patch(
350+
'genai._local_tokenizer_loader.load_model_proto'
351+
).start()
352+
self.mock_get_huggingface_tokenizer = patch(
353+
'genai._local_tokenizer_loader.get_huggingface_tokenizer'
354+
).start()
355+
356+
self.mock_load_model_proto.return_value = MagicMock()
357+
self.mock_tokenizer = MagicMock()
358+
self.mock_get_huggingface_tokenizer.return_value = self.mock_tokenizer
359+
360+
# gemini-3.5-flash maps to gemma4 (HuggingFace)
361+
self.tokenizer = local_tokenizer.LocalTokenizer(model_name='gemini-3.5-flash')
362+
363+
def tearDown(self):
364+
patch.stopall()
365+
366+
def test_count_tokens_simple_string(self):
367+
self.mock_tokenizer.encode.return_value = [[1, 2, 3]]
368+
result = self.tokenizer.count_tokens('Hello world')
369+
self.assertEqual(result.total_tokens, 3)
370+
self.mock_tokenizer.encode.assert_called_once_with(['Hello world'])
371+
372+
def test_compute_tokens_simple_string(self):
373+
self.mock_tokenizer.encode.return_value = [[1, 2, 3]]
374+
self.mock_tokenizer.convert_ids_to_tokens.return_value = ['He', 'llo', ' world']
375+
376+
result = self.tokenizer.compute_tokens('Hello world')
377+
378+
self.assertEqual(len(result.tokens_info), 1)
379+
self.assertEqual(result.tokens_info[0].token_ids, [1, 2, 3])
380+
self.assertEqual(result.tokens_info[0].tokens, [b'He', b'llo', b' world'])
381+
self.assertEqual(result.tokens_info[0].role, 'user')
382+
383+
self.mock_tokenizer.encode.assert_called_once_with(['Hello world'])
384+
self.mock_tokenizer.convert_ids_to_tokens.assert_called_once_with([1, 2, 3])
385+
386+
def test_compute_tokens_special_characters(self):
387+
self.mock_tokenizer.encode.return_value = [[1, 2]]
388+
# Use U+2581 (lower one eighth block) and underscore
389+
self.mock_tokenizer.convert_ids_to_tokens.return_value = ['_world', '\u2581hello']
390+
391+
result = self.tokenizer.compute_tokens('dummy')
392+
393+
self.assertEqual(result.tokens_info[0].tokens, [b' world', b' hello'])
394+
395+
def test_compute_tokens_with_chat_history(self):
396+
self.mock_tokenizer.encode.return_value = [[1], [2, 3]]
397+
self.mock_tokenizer.convert_ids_to_tokens.side_effect = [
398+
['Hello'],
399+
['Hi', ' there!']
400+
]
401+
history = [
402+
types.Content(role='user', parts=[types.Part(text='Hello')]),
403+
types.Content(role='model', parts=[types.Part(text='Hi there!')]),
404+
]
405+
result = self.tokenizer.compute_tokens(history)
406+
self.assertEqual(len(result.tokens_info), 2)
407+
self.assertEqual(result.tokens_info[0].token_ids, [1])
408+
self.assertEqual(result.tokens_info[0].tokens, [b'Hello'])
409+
self.assertEqual(result.tokens_info[0].role, 'user')
410+
self.assertEqual(result.tokens_info[1].token_ids, [2, 3])
411+
self.assertEqual(result.tokens_info[1].tokens, [b'Hi', b' there!'])
412+
self.assertEqual(result.tokens_info[1].role, 'model')
413+
414+
self.mock_tokenizer.encode.assert_called_once_with(['Hello', 'Hi there!'])
415+
self.mock_tokenizer.convert_ids_to_tokens.assert_has_calls([
416+
unittest.mock.call([1]),
417+
unittest.mock.call([2, 3])
418+
])
419+

google/genai/tests/local_tokenizer/test_local_tokenizer_loader.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@ def test_get_tokenizer_name_success(self):
5858
loader.get_tokenizer_name("gemini-2.5-pro-preview-06-05"), "gemma3"
5959
)
6060

61+
def test_get_tokenizer_name_huggingface(self):
62+
self.assertEqual(loader.get_tokenizer_name("gemini-3.5-flash"), "gemma4")
63+
self.assertEqual(
64+
loader.get_tokenizer_name("gemini-3.1-flash-lite"), "gemma4"
65+
)
66+
self.assertEqual(
67+
loader.get_tokenizer_name("gemini-3.1-pro-preview"), "gemma4"
68+
)
69+
self.assertEqual(
70+
loader.get_tokenizer_name("gemini-4-flash-preview"), "gemma4"
71+
)
72+
6173
def test_get_tokenizer_name_unsupported(self):
6274
with self.assertRaisesRegex(
6375
ValueError, "Model unsupported-model is not supported"
@@ -233,3 +245,24 @@ def test_get_sentencepiece_caching(
233245

234246
# Should only be loaded once due to lru_cache
235247
mock_get.assert_called_once()
248+
249+
250+
class TestGetHuggingFaceTokenizer(unittest.TestCase):
251+
252+
@patch("genai._local_tokenizer_loader.AutoProcessor")
253+
def test_get_huggingface_tokenizer_success(self, mock_auto_processor):
254+
mock_processor = MagicMock()
255+
mock_tokenizer = MagicMock()
256+
mock_processor.tokenizer = mock_tokenizer
257+
mock_auto_processor.from_pretrained.return_value = mock_processor
258+
259+
tokenizer = loader.get_huggingface_tokenizer("gemma4")
260+
261+
self.assertEqual(tokenizer, mock_tokenizer)
262+
mock_auto_processor.from_pretrained.assert_called_once_with(
263+
"google/gemma-4-E4B-it"
264+
)
265+
266+
def test_get_huggingface_tokenizer_unsupported(self):
267+
with self.assertRaises(KeyError):
268+
loader.get_huggingface_tokenizer("unsupported")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ dependencies = [
3939

4040
[project.optional-dependencies]
4141
aiohttp = ["aiohttp>=3.10.11, <4.0.0"]
42-
local-tokenizer = ["sentencepiece>=0.2.0", "protobuf"]
42+
local-tokenizer = ["sentencepiece>=0.2.0", "protobuf", "transformers"]
4343
pyopenssl = ["pyopenssl"]
4444

4545
[project.urls]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ websockets==16.0
3232
mcp>=1.14.0; python_version > '3.9'
3333
sentencepiece>=0.2.0
3434
protobuf
35+
transformers>=5.10.1

0 commit comments

Comments
 (0)