1717import sys
1818from typing import AsyncIterator , Awaitable , Optional , Union , get_args
1919
20+
21+ from . import _extra_utils
2022from . import _transformers as t
2123from . import types
2224from .models import AsyncModels , Models
@@ -113,7 +115,7 @@ def __init__(
113115 history : list [ContentOrDict ],
114116 ):
115117 self ._model = model
116- self ._config = config
118+ self ._config = _extra_utils . get_usage_header ( config , usage = "chat" )
117119 content_models = []
118120 for content in history :
119121 if not isinstance (content , Content ):
@@ -249,10 +251,14 @@ def send_message(
249251 f" { types .PartUnionDict } , got { type (message )} "
250252 )
251253 input_content = t .t_content (message )
254+ method_config = config if config else self ._config
255+ method_config = _extra_utils .get_usage_header (
256+ method_config , usage = "chat"
257+ )
252258 response = self ._modules .generate_content (
253259 model = self ._model ,
254260 contents = self ._curated_history + [input_content ], # type: ignore[arg-type]
255- config = config if config else self . _config ,
261+ config = method_config ,
256262 )
257263 model_output = (
258264 [response .candidates [0 ].content ]
@@ -306,11 +312,15 @@ def send_message_stream(
306312 finish_reason = None
307313 is_valid = True
308314 chunk = None
315+ method_config = config if config else self ._config
316+ method_config = _extra_utils .get_usage_header (
317+ method_config , usage = "chat"
318+ )
309319 if isinstance (self ._modules , Models ):
310320 for chunk in self ._modules .generate_content_stream (
311321 model = self ._model ,
312322 contents = self ._curated_history + [input_content ], # type: ignore[arg-type]
313- config = config if config else self . _config ,
323+ config = method_config ,
314324 ):
315325 if not _validate_response (chunk ):
316326 is_valid = False
@@ -411,10 +421,14 @@ async def send_message(
411421 f" { types .PartUnionDict } , got { type (message )} "
412422 )
413423 input_content = t .t_content (message )
424+ method_config = config if config else self ._config
425+ method_config = _extra_utils .get_usage_header (
426+ method_config , usage = "chat"
427+ )
414428 response = await self ._modules .generate_content (
415429 model = self ._model ,
416430 contents = self ._curated_history + [input_content ], # type: ignore[arg-type]
417- config = config if config else self . _config ,
431+ config = method_config ,
418432 )
419433 model_output = (
420434 [response .candidates [0 ].content ]
@@ -465,6 +479,11 @@ async def send_message_stream(
465479 )
466480 input_content = t .t_content (message )
467481
482+ method_config = config if config else self ._config
483+ method_config = _extra_utils .get_usage_header (
484+ method_config , usage = "chat"
485+ )
486+
468487 async def async_generator (): # type: ignore[no-untyped-def]
469488 output_contents = []
470489 finish_reason = None
@@ -473,7 +492,7 @@ async def async_generator(): # type: ignore[no-untyped-def]
473492 async for chunk in await self ._modules .generate_content_stream ( # type: ignore[attr-defined]
474493 model = self ._model ,
475494 contents = self ._curated_history + [input_content ], # type: ignore[arg-type]
476- config = config if config else self . _config ,
495+ config = method_config ,
477496 ):
478497 if not _validate_response (chunk ):
479498 is_valid = False
0 commit comments