From 98d252357931c44c7278bd37426047f557941666 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Thu, 4 Jun 2026 15:14:05 +0200 Subject: [PATCH 01/47] Parametrize chunk and message join separators --- src/mistral_common/protocol/instruct/normalize.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index c96a1940..4ef3f740 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -32,13 +32,14 @@ from mistral_common.tokens.tokenizers.base import InstructRequestType, TokenizerVersion from mistral_common.tokens.tokenizers.model_settings_builder import ModelSettingsBuilder -_DEFAULT_JOIN_STR = "\n\n" +_MSG_JOIN_STR = "\n\n" +_CHUNK_JOIN_STR = "" def _aggregate_content_chunks_impl( contents: list[list[ContentChunk] | str | None], - msg_join_str: str, - chunk_join_str: str, + msg_join_str: str = _MSG_JOIN_STR, + chunk_join_str: str = _CHUNK_JOIN_STR, ) -> list[ContentChunk] | str: r"""Coalesce TextChunks within the same message and across different messages. @@ -128,8 +129,8 @@ class InstructRequestNormalizer( _system_prompt_in_begin: bool = False _allow_tool_call_and_content: bool = False - _chunk_join_str: str = _DEFAULT_JOIN_STR - _msg_join_str: str = _DEFAULT_JOIN_STR + _chunk_join_str: str = _CHUNK_JOIN_STR + _msg_join_str: str = _MSG_JOIN_STR def __init__( self, From 7aa8189da5469f2b5f752f5a49f9f5c4dda48842 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Thu, 4 Jun 2026 15:17:47 +0200 Subject: [PATCH 02/47] Use TypeGuard for type-safe content narrowing --- src/mistral_common/protocol/instruct/normalize.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index 4ef3f740..ea48db1b 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -297,8 +297,10 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag ) weight = message.weight - if isinstance(content, str) or _is_assistant_content(content): + if isinstance(content, str): narrowed_content: str | list[TextChunk | ThinkChunk] = content + elif _is_assistant_content(content): + narrowed_content = content else: raise InvalidRequestException( f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" @@ -317,7 +319,9 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType: """Coalesce neighboring blocks of ContentChunks in user messages.""" content = self._aggregate_content_chunks(messages) - if isinstance(content, str) or _is_user_content(content): + if isinstance(content, str): + return self._user_message_class(content=content) + elif _is_user_content(content): return self._user_message_class(content=content) else: raise InvalidRequestException( From 0cb2e81cdb63a81f84d90a4c524820f64d4a05bd Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Thu, 4 Jun 2026 17:04:51 +0200 Subject: [PATCH 03/47] Fix chat template intra-message chunk joining to match normalizer --- .../chat_templates/template_generator.py | 16 ++++++++++++++-- tests/data/chat_templates/v1.jinja | 16 ++++++++++++++-- tests/data/chat_templates/v11.jinja | 10 +++++++++- tests/data/chat_templates/v11_audio.jinja | 10 +++++++++- tests/data/chat_templates/v11_image.jinja | 10 +++++++++- tests/data/chat_templates/v11_image_think.jinja | 10 +++++++++- tests/data/chat_templates/v11_think.jinja | 10 +++++++++- tests/data/chat_templates/v13.jinja | 10 +++++++++- tests/data/chat_templates/v13_audio.jinja | 10 +++++++++- tests/data/chat_templates/v13_image.jinja | 10 +++++++++- tests/data/chat_templates/v13_image_think.jinja | 10 +++++++++- tests/data/chat_templates/v13_think.jinja | 10 +++++++++- tests/data/chat_templates/v1_spm.jinja | 16 ++++++++++++++-- tests/data/chat_templates/v2.jinja | 16 ++++++++++++++-- tests/data/chat_templates/v2_spm.jinja | 16 ++++++++++++++-- tests/data/chat_templates/v3.jinja | 16 ++++++++++++++-- tests/data/chat_templates/v3_image.jinja | 16 ++++++++++++++-- tests/data/chat_templates/v3_image_spm.jinja | 16 ++++++++++++++-- tests/data/chat_templates/v3_spm.jinja | 16 ++++++++++++++-- tests/data/chat_templates/v7.jinja | 10 +++++++++- tests/data/chat_templates/v7_audio.jinja | 10 +++++++++- tests/data/chat_templates/v7_image.jinja | 10 +++++++++- tests/data/chat_templates/v7_image_spm.jinja | 10 +++++++++- tests/data/chat_templates/v7_spm.jinja | 10 +++++++++- 24 files changed, 261 insertions(+), 33 deletions(-) diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py index 807ed528..9916cb83 100644 --- a/src/mistral_common/integrations/chat_templates/template_generator.py +++ b/src/mistral_common/integrations/chat_templates/template_generator.py @@ -287,13 +287,17 @@ def _generate_system_prompt_handling_pre_v7(config: TemplateConfig) -> list[str] " {%- if message['content'] is string %}", " {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}", " {%- else %}", + " {%- set ns_sys_msg = namespace(parts=[]) %}", " {%- for block in message['content'] %}", " {%- if block['type'] == 'text' %}", - " {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}", + " {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}", " {%- else %}", " {{- raise_exception('Only text chunks are supported in system message content.') }}", " {%- endif %}", " {%- endfor %}", + " {%- if ns_sys_msg.parts | length > 0 %}", + " {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}", + " {%- endif %}", " {%- endif %}", " {%- else %}", " {%- set ns_sys.filtered = ns_sys.filtered + [message] %}", @@ -796,10 +800,15 @@ def _generate_flush_logic(config: TemplateConfig) -> list[str]: else: list_content_lines = [ " {%- elif msg['content'] is not none %}", + " {%- set ns_msg = namespace(msg_text_parts=[]) %}", " {%- for block in msg['content'] %}", " {%- if block['type'] == 'text' %}", - " {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}", + " {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}", " {%- else %}", + " {%- if ns_msg.msg_text_parts | length > 0 %}", + " {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}", # noqa: E501 + " {%- set ns_msg.msg_text_parts = [] %}", + " {%- endif %}", " {%- if ns_c.text_parts | length > 0 %}", " {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\\n\\n')}] %}", # noqa: E501 " {%- set ns_c.text_parts = [] %}", @@ -808,6 +817,9 @@ def _generate_flush_logic(config: TemplateConfig) -> list[str]: " {%- set ns_c.has_non_text = true %}", " {%- endif %}", " {%- endfor %}", + " {%- if ns_msg.msg_text_parts | length > 0 %}", + " {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}", + " {%- endif %}", " {%- endif %}", ] diff --git a/tests/data/chat_templates/v1.jinja b/tests/data/chat_templates/v1.jinja index 2b6162c7..340c8963 100644 --- a/tests/data/chat_templates/v1.jinja +++ b/tests/data/chat_templates/v1.jinja @@ -13,13 +13,17 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} + {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} + {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} + {%- if ns_sys_msg.parts | length > 0 %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} + {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -45,10 +49,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -57,6 +66,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v11.jinja b/tests/data/chat_templates/v11.jinja index 79943d0d..2a1d7562 100644 --- a/tests/data/chat_templates/v11.jinja +++ b/tests/data/chat_templates/v11.jinja @@ -50,10 +50,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -62,6 +67,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v11_audio.jinja b/tests/data/chat_templates/v11_audio.jinja index e96cdbf9..c95b1dc7 100644 --- a/tests/data/chat_templates/v11_audio.jinja +++ b/tests/data/chat_templates/v11_audio.jinja @@ -60,10 +60,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -72,6 +77,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v11_image.jinja b/tests/data/chat_templates/v11_image.jinja index 164a6763..1944dca4 100644 --- a/tests/data/chat_templates/v11_image.jinja +++ b/tests/data/chat_templates/v11_image.jinja @@ -52,10 +52,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -64,6 +69,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v11_image_think.jinja b/tests/data/chat_templates/v11_image_think.jinja index 950746ef..39683fce 100644 --- a/tests/data/chat_templates/v11_image_think.jinja +++ b/tests/data/chat_templates/v11_image_think.jinja @@ -79,10 +79,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -91,6 +96,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v11_think.jinja b/tests/data/chat_templates/v11_think.jinja index ca99772c..a4820c07 100644 --- a/tests/data/chat_templates/v11_think.jinja +++ b/tests/data/chat_templates/v11_think.jinja @@ -77,10 +77,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -89,6 +94,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v13.jinja b/tests/data/chat_templates/v13.jinja index 38c253d9..14e0ce58 100644 --- a/tests/data/chat_templates/v13.jinja +++ b/tests/data/chat_templates/v13.jinja @@ -50,10 +50,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -62,6 +67,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v13_audio.jinja b/tests/data/chat_templates/v13_audio.jinja index c0a2fdec..5fbb6ccb 100644 --- a/tests/data/chat_templates/v13_audio.jinja +++ b/tests/data/chat_templates/v13_audio.jinja @@ -60,10 +60,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -72,6 +77,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v13_image.jinja b/tests/data/chat_templates/v13_image.jinja index 1b79382a..38d608bf 100644 --- a/tests/data/chat_templates/v13_image.jinja +++ b/tests/data/chat_templates/v13_image.jinja @@ -52,10 +52,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -64,6 +69,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v13_image_think.jinja b/tests/data/chat_templates/v13_image_think.jinja index cb39ae85..cc5079f3 100644 --- a/tests/data/chat_templates/v13_image_think.jinja +++ b/tests/data/chat_templates/v13_image_think.jinja @@ -79,10 +79,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -91,6 +96,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v13_think.jinja b/tests/data/chat_templates/v13_think.jinja index 5313eabe..8606f7c3 100644 --- a/tests/data/chat_templates/v13_think.jinja +++ b/tests/data/chat_templates/v13_think.jinja @@ -77,10 +77,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -89,6 +94,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v1_spm.jinja b/tests/data/chat_templates/v1_spm.jinja index 5fb2e7d0..7fae6926 100644 --- a/tests/data/chat_templates/v1_spm.jinja +++ b/tests/data/chat_templates/v1_spm.jinja @@ -13,13 +13,17 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} + {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} + {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} + {%- if ns_sys_msg.parts | length > 0 %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} + {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -45,10 +49,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -57,6 +66,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v2.jinja b/tests/data/chat_templates/v2.jinja index 0ba2a722..cd009d81 100644 --- a/tests/data/chat_templates/v2.jinja +++ b/tests/data/chat_templates/v2.jinja @@ -13,13 +13,17 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} + {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} + {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} + {%- if ns_sys_msg.parts | length > 0 %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} + {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -71,10 +75,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -83,6 +92,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v2_spm.jinja b/tests/data/chat_templates/v2_spm.jinja index 647ad3ac..8b9d18ca 100644 --- a/tests/data/chat_templates/v2_spm.jinja +++ b/tests/data/chat_templates/v2_spm.jinja @@ -13,13 +13,17 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} + {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} + {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} + {%- if ns_sys_msg.parts | length > 0 %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} + {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -71,10 +75,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -83,6 +92,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v3.jinja b/tests/data/chat_templates/v3.jinja index 8a0bfa7d..6678f208 100644 --- a/tests/data/chat_templates/v3.jinja +++ b/tests/data/chat_templates/v3.jinja @@ -13,13 +13,17 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} + {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} + {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} + {%- if ns_sys_msg.parts | length > 0 %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} + {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -71,10 +75,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -83,6 +92,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v3_image.jinja b/tests/data/chat_templates/v3_image.jinja index d00b3529..f62b7b28 100644 --- a/tests/data/chat_templates/v3_image.jinja +++ b/tests/data/chat_templates/v3_image.jinja @@ -13,13 +13,17 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} + {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} + {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} + {%- if ns_sys_msg.parts | length > 0 %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} + {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -73,10 +77,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -85,6 +94,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v3_image_spm.jinja b/tests/data/chat_templates/v3_image_spm.jinja index 8cb31b30..eb1cf1ee 100644 --- a/tests/data/chat_templates/v3_image_spm.jinja +++ b/tests/data/chat_templates/v3_image_spm.jinja @@ -13,13 +13,17 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} + {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} + {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} + {%- if ns_sys_msg.parts | length > 0 %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} + {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -79,10 +83,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -91,6 +100,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v3_spm.jinja b/tests/data/chat_templates/v3_spm.jinja index 304644e4..f1aef3c2 100644 --- a/tests/data/chat_templates/v3_spm.jinja +++ b/tests/data/chat_templates/v3_spm.jinja @@ -13,13 +13,17 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} + {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} + {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} + {%- if ns_sys_msg.parts | length > 0 %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} + {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -71,10 +75,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -83,6 +92,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v7.jinja b/tests/data/chat_templates/v7.jinja index 519991ec..516ea1e0 100644 --- a/tests/data/chat_templates/v7.jinja +++ b/tests/data/chat_templates/v7.jinja @@ -50,10 +50,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -62,6 +67,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v7_audio.jinja b/tests/data/chat_templates/v7_audio.jinja index f4e98518..c3d6cbf8 100644 --- a/tests/data/chat_templates/v7_audio.jinja +++ b/tests/data/chat_templates/v7_audio.jinja @@ -60,10 +60,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -72,6 +77,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v7_image.jinja b/tests/data/chat_templates/v7_image.jinja index e3c372df..0c0a1223 100644 --- a/tests/data/chat_templates/v7_image.jinja +++ b/tests/data/chat_templates/v7_image.jinja @@ -52,10 +52,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -64,6 +69,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v7_image_spm.jinja b/tests/data/chat_templates/v7_image_spm.jinja index 69088620..dda5bef3 100644 --- a/tests/data/chat_templates/v7_image_spm.jinja +++ b/tests/data/chat_templates/v7_image_spm.jinja @@ -58,10 +58,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -70,6 +75,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v7_spm.jinja b/tests/data/chat_templates/v7_spm.jinja index 02875ee6..84d76105 100644 --- a/tests/data/chat_templates/v7_spm.jinja +++ b/tests/data/chat_templates/v7_spm.jinja @@ -50,10 +50,15 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} + {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} + {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} {%- else %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- set ns_msg.msg_text_parts = [] %} + {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -62,6 +67,9 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} + {%- if ns_msg.msg_text_parts | length > 0 %} + {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} + {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} From d9d2e634b211052d121f9abe39b06e0bf0ac8513 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Thu, 4 Jun 2026 17:48:17 +0200 Subject: [PATCH 04/47] Restrict intra-message chunk join change to V15 only --- .../chat_templates/template_generator.py | 17 +++-------------- .../protocol/instruct/normalize.py | 11 +++++------ tests/data/chat_templates/v1.jinja | 16 ++-------------- tests/data/chat_templates/v11.jinja | 10 +--------- tests/data/chat_templates/v11_audio.jinja | 10 +--------- tests/data/chat_templates/v11_image.jinja | 10 +--------- tests/data/chat_templates/v11_image_think.jinja | 10 +--------- tests/data/chat_templates/v11_think.jinja | 10 +--------- tests/data/chat_templates/v13.jinja | 10 +--------- tests/data/chat_templates/v13_audio.jinja | 10 +--------- tests/data/chat_templates/v13_image.jinja | 10 +--------- tests/data/chat_templates/v13_image_think.jinja | 10 +--------- tests/data/chat_templates/v13_think.jinja | 10 +--------- tests/data/chat_templates/v1_spm.jinja | 16 ++-------------- tests/data/chat_templates/v2.jinja | 16 ++-------------- tests/data/chat_templates/v2_spm.jinja | 16 ++-------------- tests/data/chat_templates/v3.jinja | 16 ++-------------- tests/data/chat_templates/v3_image.jinja | 16 ++-------------- tests/data/chat_templates/v3_image_spm.jinja | 16 ++-------------- tests/data/chat_templates/v3_spm.jinja | 16 ++-------------- tests/data/chat_templates/v7.jinja | 10 +--------- tests/data/chat_templates/v7_audio.jinja | 10 +--------- tests/data/chat_templates/v7_image.jinja | 10 +--------- tests/data/chat_templates/v7_image_spm.jinja | 10 +--------- tests/data/chat_templates/v7_spm.jinja | 10 +--------- 25 files changed, 39 insertions(+), 267 deletions(-) diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py index 9916cb83..d580acb8 100644 --- a/src/mistral_common/integrations/chat_templates/template_generator.py +++ b/src/mistral_common/integrations/chat_templates/template_generator.py @@ -287,17 +287,13 @@ def _generate_system_prompt_handling_pre_v7(config: TemplateConfig) -> list[str] " {%- if message['content'] is string %}", " {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}", " {%- else %}", - " {%- set ns_sys_msg = namespace(parts=[]) %}", " {%- for block in message['content'] %}", " {%- if block['type'] == 'text' %}", - " {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}", + " {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}", " {%- else %}", " {{- raise_exception('Only text chunks are supported in system message content.') }}", " {%- endif %}", " {%- endfor %}", - " {%- if ns_sys_msg.parts | length > 0 %}", - " {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}", - " {%- endif %}", " {%- endif %}", " {%- else %}", " {%- set ns_sys.filtered = ns_sys.filtered + [message] %}", @@ -800,15 +796,10 @@ def _generate_flush_logic(config: TemplateConfig) -> list[str]: else: list_content_lines = [ " {%- elif msg['content'] is not none %}", - " {%- set ns_msg = namespace(msg_text_parts=[]) %}", " {%- for block in msg['content'] %}", " {%- if block['type'] == 'text' %}", - " {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}", + " {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}", " {%- else %}", - " {%- if ns_msg.msg_text_parts | length > 0 %}", - " {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}", # noqa: E501 - " {%- set ns_msg.msg_text_parts = [] %}", - " {%- endif %}", " {%- if ns_c.text_parts | length > 0 %}", " {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\\n\\n')}] %}", # noqa: E501 " {%- set ns_c.text_parts = [] %}", @@ -817,9 +808,7 @@ def _generate_flush_logic(config: TemplateConfig) -> list[str]: " {%- set ns_c.has_non_text = true %}", " {%- endif %}", " {%- endfor %}", - " {%- if ns_msg.msg_text_parts | length > 0 %}", - " {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}", - " {%- endif %}", + " {%- endif %}", ] diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index ea48db1b..56325e37 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -32,14 +32,13 @@ from mistral_common.tokens.tokenizers.base import InstructRequestType, TokenizerVersion from mistral_common.tokens.tokenizers.model_settings_builder import ModelSettingsBuilder -_MSG_JOIN_STR = "\n\n" -_CHUNK_JOIN_STR = "" +_DEFAULT_JOIN_STR = "\n\n" def _aggregate_content_chunks_impl( contents: list[list[ContentChunk] | str | None], - msg_join_str: str = _MSG_JOIN_STR, - chunk_join_str: str = _CHUNK_JOIN_STR, + msg_join_str: str, + chunk_join_str: str, ) -> list[ContentChunk] | str: r"""Coalesce TextChunks within the same message and across different messages. @@ -129,8 +128,8 @@ class InstructRequestNormalizer( _system_prompt_in_begin: bool = False _allow_tool_call_and_content: bool = False - _chunk_join_str: str = _CHUNK_JOIN_STR - _msg_join_str: str = _MSG_JOIN_STR + _chunk_join_str: str = _DEFAULT_JOIN_STR + _msg_join_str: str = _DEFAULT_JOIN_STR def __init__( self, diff --git a/tests/data/chat_templates/v1.jinja b/tests/data/chat_templates/v1.jinja index 340c8963..2b6162c7 100644 --- a/tests/data/chat_templates/v1.jinja +++ b/tests/data/chat_templates/v1.jinja @@ -13,17 +13,13 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} - {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} - {%- if ns_sys_msg.parts | length > 0 %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} - {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -49,15 +45,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -66,9 +57,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v11.jinja b/tests/data/chat_templates/v11.jinja index 2a1d7562..79943d0d 100644 --- a/tests/data/chat_templates/v11.jinja +++ b/tests/data/chat_templates/v11.jinja @@ -50,15 +50,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -67,9 +62,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v11_audio.jinja b/tests/data/chat_templates/v11_audio.jinja index c95b1dc7..e96cdbf9 100644 --- a/tests/data/chat_templates/v11_audio.jinja +++ b/tests/data/chat_templates/v11_audio.jinja @@ -60,15 +60,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -77,9 +72,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v11_image.jinja b/tests/data/chat_templates/v11_image.jinja index 1944dca4..164a6763 100644 --- a/tests/data/chat_templates/v11_image.jinja +++ b/tests/data/chat_templates/v11_image.jinja @@ -52,15 +52,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -69,9 +64,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v11_image_think.jinja b/tests/data/chat_templates/v11_image_think.jinja index 39683fce..950746ef 100644 --- a/tests/data/chat_templates/v11_image_think.jinja +++ b/tests/data/chat_templates/v11_image_think.jinja @@ -79,15 +79,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -96,9 +91,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v11_think.jinja b/tests/data/chat_templates/v11_think.jinja index a4820c07..ca99772c 100644 --- a/tests/data/chat_templates/v11_think.jinja +++ b/tests/data/chat_templates/v11_think.jinja @@ -77,15 +77,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -94,9 +89,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v13.jinja b/tests/data/chat_templates/v13.jinja index 14e0ce58..38c253d9 100644 --- a/tests/data/chat_templates/v13.jinja +++ b/tests/data/chat_templates/v13.jinja @@ -50,15 +50,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -67,9 +62,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v13_audio.jinja b/tests/data/chat_templates/v13_audio.jinja index 5fbb6ccb..c0a2fdec 100644 --- a/tests/data/chat_templates/v13_audio.jinja +++ b/tests/data/chat_templates/v13_audio.jinja @@ -60,15 +60,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -77,9 +72,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v13_image.jinja b/tests/data/chat_templates/v13_image.jinja index 38d608bf..1b79382a 100644 --- a/tests/data/chat_templates/v13_image.jinja +++ b/tests/data/chat_templates/v13_image.jinja @@ -52,15 +52,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -69,9 +64,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v13_image_think.jinja b/tests/data/chat_templates/v13_image_think.jinja index cc5079f3..cb39ae85 100644 --- a/tests/data/chat_templates/v13_image_think.jinja +++ b/tests/data/chat_templates/v13_image_think.jinja @@ -79,15 +79,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -96,9 +91,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v13_think.jinja b/tests/data/chat_templates/v13_think.jinja index 8606f7c3..5313eabe 100644 --- a/tests/data/chat_templates/v13_think.jinja +++ b/tests/data/chat_templates/v13_think.jinja @@ -77,15 +77,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -94,9 +89,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v1_spm.jinja b/tests/data/chat_templates/v1_spm.jinja index 7fae6926..5fb2e7d0 100644 --- a/tests/data/chat_templates/v1_spm.jinja +++ b/tests/data/chat_templates/v1_spm.jinja @@ -13,17 +13,13 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} - {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} - {%- if ns_sys_msg.parts | length > 0 %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} - {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -49,15 +45,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -66,9 +57,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v2.jinja b/tests/data/chat_templates/v2.jinja index cd009d81..0ba2a722 100644 --- a/tests/data/chat_templates/v2.jinja +++ b/tests/data/chat_templates/v2.jinja @@ -13,17 +13,13 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} - {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} - {%- if ns_sys_msg.parts | length > 0 %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} - {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -75,15 +71,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -92,9 +83,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v2_spm.jinja b/tests/data/chat_templates/v2_spm.jinja index 8b9d18ca..647ad3ac 100644 --- a/tests/data/chat_templates/v2_spm.jinja +++ b/tests/data/chat_templates/v2_spm.jinja @@ -13,17 +13,13 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} - {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} - {%- if ns_sys_msg.parts | length > 0 %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} - {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -75,15 +71,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -92,9 +83,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v3.jinja b/tests/data/chat_templates/v3.jinja index 6678f208..8a0bfa7d 100644 --- a/tests/data/chat_templates/v3.jinja +++ b/tests/data/chat_templates/v3.jinja @@ -13,17 +13,13 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} - {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} - {%- if ns_sys_msg.parts | length > 0 %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} - {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -75,15 +71,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -92,9 +83,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v3_image.jinja b/tests/data/chat_templates/v3_image.jinja index f62b7b28..d00b3529 100644 --- a/tests/data/chat_templates/v3_image.jinja +++ b/tests/data/chat_templates/v3_image.jinja @@ -13,17 +13,13 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} - {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} - {%- if ns_sys_msg.parts | length > 0 %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} - {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -77,15 +73,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -94,9 +85,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v3_image_spm.jinja b/tests/data/chat_templates/v3_image_spm.jinja index eb1cf1ee..8cb31b30 100644 --- a/tests/data/chat_templates/v3_image_spm.jinja +++ b/tests/data/chat_templates/v3_image_spm.jinja @@ -13,17 +13,13 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} - {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} - {%- if ns_sys_msg.parts | length > 0 %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} - {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -83,15 +79,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -100,9 +91,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v3_spm.jinja b/tests/data/chat_templates/v3_spm.jinja index f1aef3c2..304644e4 100644 --- a/tests/data/chat_templates/v3_spm.jinja +++ b/tests/data/chat_templates/v3_spm.jinja @@ -13,17 +13,13 @@ {%- if message['content'] is string %} {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %} {%- else %} - {%- set ns_sys_msg = namespace(parts=[]) %} {%- for block in message['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %} + {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %} {%- else %} {{- raise_exception('Only text chunks are supported in system message content.') }} {%- endif %} {%- endfor %} - {%- if ns_sys_msg.parts | length > 0 %} - {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %} - {%- endif %} {%- endif %} {%- else %} {%- set ns_sys.filtered = ns_sys.filtered + [message] %} @@ -75,15 +71,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -92,9 +83,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v7.jinja b/tests/data/chat_templates/v7.jinja index 516ea1e0..519991ec 100644 --- a/tests/data/chat_templates/v7.jinja +++ b/tests/data/chat_templates/v7.jinja @@ -50,15 +50,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -67,9 +62,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v7_audio.jinja b/tests/data/chat_templates/v7_audio.jinja index c3d6cbf8..f4e98518 100644 --- a/tests/data/chat_templates/v7_audio.jinja +++ b/tests/data/chat_templates/v7_audio.jinja @@ -60,15 +60,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -77,9 +72,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v7_image.jinja b/tests/data/chat_templates/v7_image.jinja index 0c0a1223..e3c372df 100644 --- a/tests/data/chat_templates/v7_image.jinja +++ b/tests/data/chat_templates/v7_image.jinja @@ -52,15 +52,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -69,9 +64,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v7_image_spm.jinja b/tests/data/chat_templates/v7_image_spm.jinja index dda5bef3..69088620 100644 --- a/tests/data/chat_templates/v7_image_spm.jinja +++ b/tests/data/chat_templates/v7_image_spm.jinja @@ -58,15 +58,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -75,9 +70,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} diff --git a/tests/data/chat_templates/v7_spm.jinja b/tests/data/chat_templates/v7_spm.jinja index 84d76105..02875ee6 100644 --- a/tests/data/chat_templates/v7_spm.jinja +++ b/tests/data/chat_templates/v7_spm.jinja @@ -50,15 +50,10 @@ {%- if msg['content'] is string %} {%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %} {%- elif msg['content'] is not none %} - {%- set ns_msg = namespace(msg_text_parts=[]) %} {%- for block in msg['content'] %} {%- if block['type'] == 'text' %} - {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %} + {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %} {%- else %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- set ns_msg.msg_text_parts = [] %} - {%- endif %} {%- if ns_c.text_parts | length > 0 %} {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %} {%- set ns_c.text_parts = [] %} @@ -67,9 +62,6 @@ {%- set ns_c.has_non_text = true %} {%- endif %} {%- endfor %} - {%- if ns_msg.msg_text_parts | length > 0 %} - {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %} - {%- endif %} {%- endif %} {%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %} {%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %} From 1842eb44f079e705cc2f82a08d55154964b9a454 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 5 Jun 2026 10:28:26 +0200 Subject: [PATCH 05/47] Address review nits: remove cur_text_len, combine branches with or --- src/mistral_common/protocol/instruct/normalize.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index 56325e37..c96a1940 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -296,10 +296,8 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag ) weight = message.weight - if isinstance(content, str): + if isinstance(content, str) or _is_assistant_content(content): narrowed_content: str | list[TextChunk | ThinkChunk] = content - elif _is_assistant_content(content): - narrowed_content = content else: raise InvalidRequestException( f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" @@ -318,9 +316,7 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType: """Coalesce neighboring blocks of ContentChunks in user messages.""" content = self._aggregate_content_chunks(messages) - if isinstance(content, str): - return self._user_message_class(content=content) - elif _is_user_content(content): + if isinstance(content, str) or _is_user_content(content): return self._user_message_class(content=content) else: raise InvalidRequestException( From 16ecf704ff28fbe81942f2617911be20659de4e5 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 5 Jun 2026 16:48:44 +0200 Subject: [PATCH 06/47] Consolidate multimodal ContentChunk support for all message roles Add per-role content type aliases: AssistantContentChunk, SystemContentChunk, ToolContentChunk. Widen all message content fields to use per-role aliases. Add shared _content_to_openai/_content_from_openai helpers on BaseMessage. Refactor all to_openai/from_openai to use shared helpers. Add normalizer _narrow_*_content methods with version-aware validation: - _narrow_assistant_content: pre-V15 allows TextChunk/ThinkChunk, V15 allows all - _narrow_tool_content: pre-V15 allows text only, V15 allows all - _narrow_system_content: V7+ allows text/think/audio, V15 rejects ThinkChunk Widen encode_tool/assistant/system_message return types to tuple[list[int], list[np.ndarray], list[Audio]]. Add TemplateConfig properties: tool_supports_multimodal, assistant_supports_multimodal, system_supports_audio, system_supports_thinking. Update template generation with dynamic supported_types_desc. Closes #236, closes #237, closes #238 --- .../experimental/app/routers.py | 4 +- .../chat_templates/template_generator.py | 70 ++++- src/mistral_common/protocol/instruct/chunk.py | 8 +- .../protocol/instruct/messages.py | 115 ++++--- .../protocol/instruct/normalize.py | 174 +++++++++-- .../tokens/tokenizers/instruct.py | 117 +++++--- tests/data/chat_templates/v15.jinja | 4 +- tests/data/chat_templates/v15_audio.jinja | 8 +- tests/data/chat_templates/v15_image.jinja | 6 +- .../data/chat_templates/v15_image_think.jinja | 6 +- tests/data/chat_templates/v15_think.jinja | 4 +- tests/experimental/test_app.py | 6 +- tests/guidance/test_guidance.py | 5 +- .../chat_templates/fixtures_data.py | 66 ++++ .../chat_templates/test_parity.py | 125 ++++++++ .../transformers/test_core_parity.py | 20 +- .../chat_templates/unit/test_v15.py | 281 ++++++++++++++++++ tests/test_converters.py | 19 +- tests/test_normalization.py | 190 ++++++++++++ tests/test_tokenizer_v11.py | 16 +- tests/test_tokenizer_v13.py | 18 +- tests/test_tokenizer_v15.py | 206 ++++++++++++- 22 files changed, 1312 insertions(+), 156 deletions(-) diff --git a/src/mistral_common/experimental/app/routers.py b/src/mistral_common/experimental/app/routers.py index 73c9c6e3..1c8aabc4 100644 --- a/src/mistral_common/experimental/app/routers.py +++ b/src/mistral_common/experimental/app/routers.py @@ -14,7 +14,7 @@ ) from mistral_common.experimental.think import _split_content_and_think_chunks from mistral_common.experimental.tools import _decode_tool_calls, _split_content_and_tool_calls -from mistral_common.protocol.instruct.chunk import TextChunk, ThinkChunk +from mistral_common.protocol.instruct.chunk import AssistantContentChunk, TextChunk, ThinkChunk from mistral_common.protocol.instruct.messages import AssistantMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, Tokenized, TokenizerVersion @@ -100,7 +100,7 @@ async def detokenize_to_assistant_message( else: content_tokens, tool_calls_tokens = tokens, () - content: str | list[TextChunk | ThinkChunk] | None = None + content: str | list[AssistantContentChunk] | None = None if settings.tokenizer.instruct_tokenizer.tokenizer.version >= TokenizerVersion.v13: assert isinstance(settings.tokenizer.instruct_tokenizer, InstructTokenizerV13) diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py index d580acb8..9d8a4373 100644 --- a/src/mistral_common/integrations/chat_templates/template_generator.py +++ b/src/mistral_common/integrations/chat_templates/template_generator.py @@ -216,6 +216,35 @@ def validates_assistant_non_empty(self) -> bool: r"""Whether to validate that assistant messages have non-empty content or tool calls.""" return self.version >= TokenizerVersion.v7 or (self.version >= TokenizerVersion.v3 and not self.spm) + @property + def tool_supports_multimodal(self) -> bool: + r"""Whether tool messages can contain non-text content chunks. V15+.""" + return self.version >= TokenizerVersion.v15 + + @property + def assistant_supports_multimodal(self) -> bool: + r"""Whether assistant messages can contain non-text content chunks. V15+.""" + return self.version >= TokenizerVersion.v15 + + @property + def system_supports_audio(self) -> bool: + r"""Whether system messages can contain audio. V15+ with audio support.""" + return self.audio_support and self.version >= TokenizerVersion.v15 + + +def _join_types_desc(parts: list[str]) -> str: + r"""Join type names into a human-readable description string. + + Args: + parts: List of type names (e.g. ["text", "thinking", "image"]). + + Returns: + Formatted string like "text", "text and thinking", or "text, thinking and image". + """ + if len(parts) == 1: + return parts[0] + return ", ".join(parts[:-1]) + " and " + parts[-1] + def _generate_header(config: TemplateConfig) -> str: r"""Generate template header with default system message. @@ -873,6 +902,8 @@ def _generate_system_message_handling(config: TemplateConfig) -> str: if has_extra_types: if config.system_supports_thinking: rc_args += ", supported_types_desc='text and thinking'" + elif config.system_supports_audio: + rc_args += ", supported_types_desc='text and audio'" else: rc_args += ", supported_types_desc='text'" if config.any_thinking_support: @@ -883,7 +914,7 @@ def _generate_system_message_handling(config: TemplateConfig) -> str: if config.image_support: rc_args += ", support_images=false" if config.audio_support: - rc_args += ", support_audio=false" + rc_args += f", support_audio={'true' if config.system_supports_audio else 'false'}" lines.append(" {{- render_content(" + rc_args + ") -}}") lines.append(" {{- '" + _END_SYSTEM + "' -}}") @@ -1236,16 +1267,20 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str: has_extra_types = config.any_thinking_support or config.image_support or config.audio_support rc_call_args = "message['content'], 'assistant message contents'" if has_extra_types: + desc_parts = ["text"] if config.any_thinking_support: - rc_call_args += ", supported_types_desc='text and thinking'" - else: - rc_call_args += ", supported_types_desc='text'" + desc_parts.append("thinking") + if config.assistant_supports_multimodal and config.image_support: + desc_parts.append("image") + if config.assistant_supports_multimodal and config.audio_support: + desc_parts.append("audio") + rc_call_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'" if config.any_thinking_support: rc_call_args += ", support_thinking=true" if config.image_support: - rc_call_args += ", support_images=false" + rc_call_args += f", support_images={'true' if config.assistant_supports_multimodal else 'false'}" if config.audio_support: - rc_call_args += ", support_audio=false" + rc_call_args += f", support_audio={'true' if config.assistant_supports_multimodal else 'false'}" lines.append(" {%- if message['content'] %}") @@ -1485,9 +1520,26 @@ def _generate_tool_message_handling(config: TemplateConfig) -> str: + "' }}" # noqa: E501 ) elif config.uses_simple_tool_results: - lines.append( - " {{- '" + _BEGIN_TOOL_RESULTS + "' + message['content']|string + '" + _END_TOOL_RESULTS + "' }}" - ) # noqa: E501 + if config.tool_supports_multimodal: + tool_rc_args = "message['content'], 'tool message contents'" + if config.image_support or config.audio_support: + desc_parts = ["text"] + if config.image_support: + desc_parts.append("image") + if config.audio_support: + desc_parts.append("audio") + tool_rc_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'" + if config.image_support: + tool_rc_args += ", support_images=true" + if config.audio_support: + tool_rc_args += ", support_audio=true" + lines.append(" {{- '" + _BEGIN_TOOL_RESULTS + "' -}}") + lines.append(" {{- render_content(" + tool_rc_args + ") -}}") + lines.append(" {{- '" + _END_TOOL_RESULTS + "' }}") + else: + lines.append( + " {{- '" + _BEGIN_TOOL_RESULTS + "' + message['content']|string + '" + _END_TOOL_RESULTS + "' }}" + ) # noqa: E501 else: # v3 non-spm style lines.extend(_emit_int_float_parsing(" ")) diff --git a/src/mistral_common/protocol/instruct/chunk.py b/src/mistral_common/protocol/instruct/chunk.py index 2ed61528..430c8c3d 100644 --- a/src/mistral_common/protocol/instruct/chunk.py +++ b/src/mistral_common/protocol/instruct/chunk.py @@ -7,7 +7,7 @@ from urllib.parse import urlparse from pydantic import ConfigDict, Field, ValidationError, field_validator, model_validator -from typing_extensions import Annotated +from typing_extensions import Annotated, TypeAlias from mistral_common.base import MistralBase from mistral_common.deprecation import warn_once @@ -455,6 +455,12 @@ def from_openai(cls, openai_chunk: dict[str, Any]) -> "ThinkChunk": TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk, Field(discriminator="type") ] +AssistantContentChunk: TypeAlias = ContentChunk + +SystemContentChunk = Annotated[TextChunk | AudioChunk | ThinkChunk, Field(discriminator="type")] + +ToolContentChunk: TypeAlias = ContentChunk + def _convert_openai_content_chunks(openai_content_chunks: dict[str, Any]) -> ContentChunk: content_type_str = openai_content_chunks.get("type") diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py index e82379cb..140a1f3a 100644 --- a/src/mistral_common/protocol/instruct/messages.py +++ b/src/mistral_common/protocol/instruct/messages.py @@ -1,16 +1,20 @@ import warnings +from collections.abc import Sequence from enum import Enum -from typing import Any, Literal, TypeGuard, TypeVar +from typing import Any, Literal, TypeVar from pydantic import Field -from typing_extensions import Annotated, TypeAlias +from typing_extensions import Annotated, TypeAlias, TypeGuard from mistral_common.base import MistralBase from mistral_common.exceptions import InvalidAssistantMessageException from mistral_common.protocol.instruct.chunk import ( + AssistantContentChunk, ContentChunk, + SystemContentChunk, TextChunk, ThinkChunk, + ToolContentChunk, UserContentChunk, _convert_openai_content_chunks, ) @@ -23,12 +27,9 @@ ) -def _are_think_chunks(think_chunks: list[ThinkChunk | TextChunk]) -> TypeGuard[list[ThinkChunk]]: - return all(isinstance(c, ThinkChunk) for c in think_chunks) - - -def _are_text_chunks(think_chunks: list[ThinkChunk | TextChunk]) -> TypeGuard[list[TextChunk]]: - return all(isinstance(c, TextChunk) for c in think_chunks) +def _are_think_chunks(chunks: list[ContentChunk]) -> TypeGuard[list[ThinkChunk]]: + r"""Narrow a ContentChunk list to ThinkChunk list.""" + return all(isinstance(c, ThinkChunk) for c in chunks) class ReasoningFieldFormat(str, Enum): @@ -73,6 +74,44 @@ class BaseMessage(MistralBase): role: Literal[Roles.system, Roles.user, Roles.assistant, Roles.tool] + @staticmethod + def _content_to_openai( + content: str | Sequence[ContentChunk] | None, + ) -> str | list[dict[str, Any]] | None: + r"""Serialize message content to OpenAI format. + + Args: + content: String, list of content chunks, or None. + + Returns: + String content as-is, list of chunks serialized via each chunk's + to_openai(), or None. + """ + if content is None or isinstance(content, str): + return content + return [chunk.to_openai() for chunk in content] + + @staticmethod + def _content_from_openai( + raw: str | list[dict[str, Any]] | None, + ) -> str | list[ContentChunk] | None: + r"""Deserialize content from OpenAI format. + + Args: + raw: Raw content from OpenAI message dict. + + Returns: + String content as-is, list of deserialized content chunks, or None. + + Raises: + ValueError: If content type is unrecognized. + """ + if raw is None or isinstance(raw, str): + return raw + if isinstance(raw, list): + return [_convert_openai_content_chunks(chunk) for chunk in raw] + raise ValueError(f"Unknown content type: {type(raw)}") + def to_openai(self) -> dict[str, Any]: r"""Converts the message to the OpenAI format. @@ -104,20 +143,13 @@ class UserMessage(BaseMessage): def to_openai(self) -> dict[str, Any]: r"""Converts the message to the OpenAI format.""" - if isinstance(self.content, str): - return {"role": self.role, "content": self.content} - return {"role": self.role, "content": [chunk.to_openai() for chunk in self.content]} + return {"role": self.role, "content": self._content_to_openai(self.content)} @classmethod def from_openai(cls, openai_message: dict[str, Any]) -> "UserMessage": r"""Converts the OpenAI message to the Mistral format.""" - if isinstance(openai_message["content"], str): - return cls.model_validate_ignore_extra(openai_message) return cls.model_validate( - { - "role": openai_message["role"], - "content": [_convert_openai_content_chunks(chunk) for chunk in openai_message["content"]], - }, + {"role": openai_message["role"], "content": cls._content_from_openai(openai_message.get("content"))} ) @@ -132,16 +164,18 @@ class SystemMessage(BaseMessage): """ role: Literal[Roles.system] = Roles.system - content: str | list[TextChunk | ThinkChunk] + content: str | list[SystemContentChunk] def to_openai(self) -> dict[str, Any]: r"""Converts the message to the OpenAI format.""" - return self.model_dump() + return {"role": self.role, "content": self._content_to_openai(self.content)} @classmethod def from_openai(cls, openai_message: dict[str, Any]) -> "SystemMessage": r"""Converts the OpenAI message to the Mistral format.""" - return cls.model_validate_ignore_extra(openai_message) + return cls.model_validate( + {"role": openai_message["role"], "content": cls._content_from_openai(openai_message.get("content"))} + ) class AssistantMessage(BaseMessage): @@ -158,7 +192,7 @@ class AssistantMessage(BaseMessage): """ role: Literal[Roles.assistant] = Roles.assistant - content: str | list[TextChunk | ThinkChunk] | None = None + content: str | list[AssistantContentChunk] | None = None tool_calls: list[ToolCall] | None = None prefix: bool = False @@ -205,18 +239,18 @@ def to_openai( match reasoning_field_format: case None | ReasoningFieldFormat.thinking_chunks: - out_dict["content"] = [chunk.to_openai() for chunk in self.content] + out_dict["content"] = self._content_to_openai(self.content) case ReasoningFieldFormat.reasoning | ReasoningFieldFormat.reasoning_content: think_chunks, content_chunks = self.content[: last_think_idx + 1], self.content[last_think_idx + 1 :] - if not _are_think_chunks(think_chunks) or not _are_text_chunks(content_chunks): - raise RuntimeError("Impossible, only think or content chunks should have been present.") + if not _are_think_chunks(think_chunks): + raise RuntimeError("Expected only ThinkChunks in the leading portion.") if len(think_chunks) > 0: out_dict[reasoning_field_format.value] = "\n".join(tc.thinking for tc in think_chunks) - if len(content_chunks) == 1: + if len(content_chunks) == 1 and isinstance(content_chunks[0], TextChunk): out_dict["content"] = content_chunks[0].text elif content_chunks: - out_dict["content"] = [chunk.to_openai() for chunk in content_chunks] + out_dict["content"] = self._content_to_openai(content_chunks) case _: raise ValueError(f"{reasoning_field_format=} is not supported.") @@ -234,14 +268,7 @@ def from_openai(cls, openai_message: dict[str, Any]) -> "AssistantMessage": tools_calls.append(ToolCall.from_openai(openai_tool_call)) else: raise ValueError(f"tool_calls must be a list, got {type(openai_tool_calls)}") - openai_content = openai_message.get("content", None) - content: str | list[ContentChunk] | None = None - if openai_content is None or isinstance(openai_content, str): - content = openai_content - elif isinstance(openai_content, list): - content = [_convert_openai_content_chunks(chunk) for chunk in openai_content] - else: - raise ValueError(f"Unknown content type: {type(openai_content)}") + content = cls._content_from_openai(openai_message.get("content")) reasoning_content: str | None = openai_message.get("reasoning_content") reasoning: str | None = openai_message.get("reasoning") @@ -308,7 +335,7 @@ class ToolMessage(BaseMessage): >>> message = ToolMessage(content="Hello, how can I help you?", tool_call_id="123") """ - content: str | list[TextChunk] + content: str | list[ToolContentChunk] role: Literal[Roles.tool] = Roles.tool tool_call_id: str | None = None @@ -318,12 +345,24 @@ class ToolMessage(BaseMessage): def to_openai(self) -> dict[str, Any]: r"""Converts the message to the OpenAI format.""" assert self.tool_call_id is not None, "tool_call_id must be provided for tool messages." - return self.model_dump(exclude={"name"}) + return { + "role": self.role, + "tool_call_id": self.tool_call_id, + "content": self._content_to_openai(self.content), + } @classmethod - def from_openai(cls, messages: dict[str, str | list[dict[str, str | dict[str, Any]]]]) -> "ToolMessage": + def from_openai(cls, openai_message: dict[str, Any]) -> "ToolMessage": r"""Converts the OpenAI message to the Mistral format.""" - tool_message = cls.model_validate_ignore_extra(messages) + content = cls._content_from_openai(openai_message.get("content")) + tool_message = cls.model_validate( + { + "role": openai_message["role"], + "tool_call_id": openai_message.get("tool_call_id"), + "content": content, + "name": openai_message.get("name"), + } + ) assert tool_message.tool_call_id is not None, "tool_call_id must be provided for tool messages." return tool_message diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index c96a1940..b053b381 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -1,8 +1,8 @@ import json import warnings -from typing import Generic, Sequence, TypeGuard +from typing import Generic, Sequence -from typing_extensions import assert_never +from typing_extensions import TypeGuard, assert_never from mistral_common.exceptions import InvalidRequestException from mistral_common.protocol.instruct.chunk import ( @@ -35,6 +35,13 @@ _DEFAULT_JOIN_STR = "\n\n" +def _is_user_content( + chunks: list[ContentChunk], +) -> TypeGuard[list[TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk]]: + r"""Narrow ContentChunk list to user-compatible types (no ThinkChunk).""" + return all(not isinstance(c, ThinkChunk) for c in chunks) + + def _aggregate_content_chunks_impl( contents: list[list[ContentChunk] | str | None], msg_join_str: str, @@ -98,18 +105,6 @@ def _flush_text() -> None: return all_content -def _is_assistant_content(chunks: list[ContentChunk]) -> TypeGuard[list[TextChunk | ThinkChunk]]: - """Narrow ContentChunk list to assistant-compatible types.""" - return all(isinstance(c, (TextChunk, ThinkChunk)) for c in chunks) - - -def _is_user_content( - chunks: list[ContentChunk], -) -> TypeGuard[list[TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk]]: - """Narrow ContentChunk list to user-compatible types.""" - return all(not isinstance(c, ThinkChunk) for c in chunks) - - class InstructRequestNormalizer( Generic[UserMessageType, AssistantMessageType, ToolMessageType, SystemMessageType, InstructRequestType] ): @@ -244,13 +239,17 @@ def _aggregate_system_prompts(self, messages: list[UATS]) -> str | None: def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: """Normalize tool messages without aggregation across messages. - Each tool message's content chunks are aggregated individually and JSON is normalized. + Each tool message's content is validated and JSON-normalized. """ tool_messages: list[ToolMessageType] = [] for message in messages: assert isinstance(message, self._tool_message_class), "Expected tool message" - content = self._aggregate_content_chunks_to_str_same_message(message) - normalized_content = self._normalize_json_content(content) + content = self._aggregate_content_chunks([message]) + validated = self._narrow_tool_content(content) + if isinstance(validated, str): + normalized_content: str | list[ContentChunk] = self._normalize_json_content(validated) + else: + normalized_content = validated tool_messages.append( self._tool_message_class( content=normalized_content, tool_call_id=message.tool_call_id, name=message.name @@ -266,6 +265,51 @@ def _normalize_tool_call(self, tool_call: ToolCall) -> ToolCall: id=tool_call.id, ) + def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: + r"""Validate and narrow content chunks for assistant messages. + + Pre-V15 normalizers only allow TextChunk and ThinkChunk. + + Args: + content: The aggregated content chunks. + + Returns: + The validated and narrowed content. + + Raises: + InvalidRequestException: If unsupported chunk types are found. + """ + if isinstance(content, str) or all(isinstance(c, (TextChunk, ThinkChunk)) for c in content): + return content + raise InvalidRequestException( + f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" + ) + + def _narrow_tool_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: + r"""Validate and narrow content for tool messages. + + Pre-V15 normalizers only allow text content. + + Args: + content: The raw or aggregated content. + + Returns: + The content as a string. + + Raises: + InvalidRequestException: If non-text content chunks are found. + """ + if isinstance(content, str): + return content + text_parts: list[str] = [] + for c in content: + if not isinstance(c, TextChunk): + raise InvalidRequestException( + f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}" + ) + text_parts.append(c.text) + return "".join(text_parts) + def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]: return [] @@ -296,15 +340,10 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag ) weight = message.weight - if isinstance(content, str) or _is_assistant_content(content): - narrowed_content: str | list[TextChunk | ThinkChunk] = content - else: - raise InvalidRequestException( - f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" - ) + validated_content = self._narrow_assistant_content(content) aggregated_message = self._assistant_message_class( - content=narrowed_content, + content=validated_content, tool_calls=tool_calls or None, prefix=prefix, ) @@ -318,10 +357,9 @@ def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType: content = self._aggregate_content_chunks(messages) if isinstance(content, str) or _is_user_content(content): return self._user_message_class(content=content) - else: - raise InvalidRequestException( - f"Unexpected content chunk types in user message: {[type(c).__name__ for c in content]}" - ) + raise InvalidRequestException( + f"Unexpected content chunk types in user message: {[type(c).__name__ for c in content]}" + ) def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]: if role == Roles.tool: @@ -444,12 +482,28 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest[UATS, Tool], None ) + def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: + r"""Validate content chunks for system messages. + + V7+ accepts all SystemContentChunk types (Pydantic validates at construction). + V15 overrides to reject ThinkChunk. + + Args: + content: The aggregated content chunks. + + Returns: + The validated content. + """ + return content + def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]: - return [ - self._system_message_class(content=self._aggregate_content_chunks([message])) - for message in messages - if isinstance(message, self._system_message_class) - ] + aggregated: list[SystemMessageType] = [] + for message in messages: + if isinstance(message, self._system_message_class): + content = self._aggregate_content_chunks([message]) + validated = self._narrow_system_content(content) + aggregated.append(self._system_message_class(content=validated)) + return aggregated def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]: if role == Roles.tool: @@ -555,6 +609,60 @@ class InstructRequestNormalizerV15(InstructRequestNormalizerV13): _chunk_join_str: str = "" + def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: + r"""V15 accepts all ContentChunk types in assistant messages.""" + return content + + def _narrow_tool_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: + r"""V15 accepts all ContentChunk types in tool messages.""" + if isinstance(content, str): + return content + text_parts: list[str] = [] + for c in content: + if not isinstance(c, TextChunk): + return content + text_parts.append(c.text) + return "".join(text_parts) + + def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: + r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk.""" + if isinstance(content, str): + return content + if any(isinstance(c, ThinkChunk) for c in content): + raise InvalidRequestException("ThinkChunk in system message is not supported for V15") + return content + + def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: + r"""V15 tool messages preserve non-text content chunks. + + Text-only content is aggregated and JSON-normalized. Mixed content is preserved as-is. + """ + tool_messages: list[ToolMessageType] = [] + for message in messages: + assert isinstance(message, self._tool_message_class), "Expected tool message" + content = self._aggregate_content_chunks([message]) + validated = self._narrow_tool_content(content) + if isinstance(validated, str): + normalized_content: str | list[ContentChunk] = self._normalize_json_content(validated) + else: + normalized_content = validated + tool_messages.append( + self._tool_message_class( + content=normalized_content, tool_call_id=message.tool_call_id, name=message.name + ) + ) + + # Reorder by tool call order + id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)} + id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)} + tool_messages.sort( + key=lambda msg: ( + id_to_tool_call_idx.get(msg.tool_call_id or "null", float("inf")), + id_to_tool_result_idx[msg.tool_call_id], + ), + ) + return tool_messages + @staticmethod def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15": r"""Returns a normalizer for the V15 instruct request. diff --git a/src/mistral_common/tokens/tokenizers/instruct.py b/src/mistral_common/tokens/tokenizers/instruct.py index 7992df1c..08de0c95 100644 --- a/src/mistral_common/tokens/tokenizers/instruct.py +++ b/src/mistral_common/tokens/tokenizers/instruct.py @@ -104,7 +104,9 @@ def find_first_last_user(request: InstructRequest) -> tuple[int, int]: return first_user_idx, last_user_idx @abstractmethod - def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]: + def encode_tool_message( + self, message: ToolMessage, is_before_last_user_message: bool + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a tool message. Raises: @@ -115,7 +117,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: @abstractmethod def encode_assistant_message( self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool - ) -> list[int]: + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode an assistant message. Raises: @@ -192,17 +194,23 @@ def encode_instruct( images.extend(new_images) audios.extend(new_audios) elif isinstance(msg, ToolMessage): - new_tokens = self.encode_tool_message(msg, msg_idx < last_user_idx) + new_tokens, new_images, new_audios = self.encode_tool_message(msg, msg_idx < last_user_idx) + images.extend(new_images) + audios.extend(new_audios) elif isinstance(msg, AssistantMessage): continue_message = request.continue_final_message and (msg_idx == len(request.messages) - 1) - new_tokens = self.encode_assistant_message( + new_tokens, new_images, new_audios = self.encode_assistant_message( msg, msg_idx < last_user_idx, continue_message=continue_message ) + images.extend(new_images) + audios.extend(new_audios) if msg_idx == len(request.messages) - 1: prefix_ids = new_tokens elif isinstance(msg, SystemMessage): - new_tokens = self.encode_system_message(msg) + new_tokens, new_images, new_audios = self.encode_system_message(msg) + images.extend(new_images) + audios.extend(new_audios) else: raise TokenizerException(f"Unknown message type {type(msg)}") @@ -289,7 +297,7 @@ def encode_user_message( curr_tokens, image, audio = self.encode_user_content(content=message_txt, is_last=False, system_prompt=None) return curr_tokens, image, audio - def encode_system_message(self, message: SystemMessage) -> list[int]: + def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[np.ndarray], list[Audio]]: raise NotImplementedError(f"System message encoding not implemented for {self.__class__.__name__}") def encode_user_content( @@ -318,7 +326,9 @@ def encode_user_content( tokens = self.tokenizer.encode(content, bos=False, eos=False) return tokens, [], [] - def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]: + def encode_tool_message( + self, message: ToolMessage, is_before_last_user_message: bool + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a tool message. Raises: @@ -328,7 +338,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: def encode_assistant_message( self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool - ) -> list[int]: + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode an assistant message. Args: @@ -338,7 +348,7 @@ def encode_assistant_message( Only use this if the assistant message is the last message. Returns: - The encoded tokens. + The encoded tokens, images, and audios. """ assert isinstance(message, AssistantMessage), message if message.tool_calls is not None and len(message.tool_calls) > 0: @@ -354,7 +364,7 @@ def encode_assistant_message( raise TokenizerException(f"{message.content} // {message.tool_calls}") if not message.prefix and not continue_message: curr_tokens.append(self.tokenizer.eos_id) - return curr_tokens + return curr_tokens, [], [] def encode_think(self, chunk: ThinkChunk) -> list[int]: r"""Encode a think chunk. @@ -480,9 +490,9 @@ def _parse_json_content(self, content: str) -> Any: except json.JSONDecodeError: return content - def _parse_tool_content(self, content: str | list[TextChunk]) -> Any: + def _parse_tool_content(self, content: str | list[ContentChunk]) -> Any: if isinstance(content, list): - content = "".join(chunk.text for chunk in content) + content = "".join(chunk.text for chunk in content if isinstance(chunk, TextChunk)) return self._parse_json_content(content) def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]: @@ -492,7 +502,9 @@ def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]: "content": self._parse_tool_content(tool_message.content), } - def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]: + def encode_tool_message( + self, message: ToolMessage, is_before_last_user_message: bool + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a tool message. Args: @@ -501,11 +513,11 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: not encoded. Returns: - The encoded tokens. + The encoded tokens, images, and audios. """ if is_before_last_user_message: # don't tokenize last tool response before last user msg - return [] + return [], [], [] # Currently only supports single tool results tool_result_str = json.dumps([self._prepare_tool_result(message)], ensure_ascii=False) @@ -514,7 +526,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: *self.tokenizer.encode(tool_result_str, bos=False, eos=False), self.END_TOOL_RESULTS, ] - return curr_tokens + return curr_tokens, [], [] def _prepare_function_call(self, tool_call: ToolCall) -> dict[str, Any]: r"""Bit of a hack due to the way function calls are tokenized.""" @@ -550,7 +562,7 @@ def _encode_settings( def encode_assistant_message( self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool - ) -> list[int]: + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode an assistant message. Args: @@ -561,7 +573,7 @@ def encode_assistant_message( Only use this if the assistant message is the last message. Returns: - The encoded tokens. + The encoded tokens, images, and audios. """ if message.tool_calls and message.content: raise ValueError(f"Cannot have tool calls and content defined in the same assistant message {message}") @@ -573,7 +585,7 @@ def encode_assistant_message( if message.tool_calls: if is_before_last_user_message: # don't tokenize tool call before last user message - return [] + return [], [], [] curr_tokens = self._encode_tool_calls_in_assistant_message(message) elif message.content: assert isinstance(message.content, str), "Message content must be a string for tokenizer < V7" @@ -582,7 +594,7 @@ def encode_assistant_message( raise TokenizerException(f"Invalid assistant message: {message.content}") if not message.prefix and not continue_message: curr_tokens.append(self.tokenizer.eos_id) - return curr_tokens + return curr_tokens, [], [] def _encode_infilling(self, text: str) -> list[int]: r"""Remove prefix space in the case of SentencePieceTokenizers.""" @@ -652,7 +664,9 @@ def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]: "call_id": tool_message.tool_call_id, } - def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]: + def encode_tool_message( + self, message: ToolMessage, is_before_last_user_message: bool + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a tool message. Note: @@ -665,7 +679,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: not encoded. Returns: - The encoded tokens. + The encoded tokens, images, and audios. """ tool_result_str = json.dumps(self._prepare_tool_result(message), ensure_ascii=False) curr_tokens = [ @@ -673,11 +687,11 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: *self.tokenizer.encode(tool_result_str, bos=False, eos=False), self.END_TOOL_RESULTS, ] - return curr_tokens + return curr_tokens, [], [] def encode_assistant_message( self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool - ) -> list[int]: + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode an assistant message. Note: @@ -691,7 +705,7 @@ def encode_assistant_message( is_before_last_user_message: Not used. Returns: - The encoded tokens. + The encoded tokens, images, and audios. """ return super().encode_assistant_message(message, False, continue_message) @@ -867,22 +881,22 @@ def drop(idx: int) -> None: if to_drop > 0: raise TokenizerException("Input couldn't fit in truncate_at_max_token") - def encode_system_message(self, message: SystemMessage) -> list[int]: + def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a system message. Args: message: The message to encode. Returns: - The encoded tokens. + The encoded tokens, images, and audios. """ - tokens = [self.BEGIN_SYSTEM] if isinstance(content := message.content, str): content = [TextChunk(text=content)] - tokens += self._encode_content_chunks(content)[0] + content_tokens, images, audios = self._encode_content_chunks(content) + tokens += content_tokens tokens.append(self.END_SYSTEM) - return tokens + return tokens, images, audios def encode_user_content( self, @@ -1081,7 +1095,9 @@ def _has_audio(messages: list[UATS]) -> bool: for message in messages ) - def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]: + def encode_tool_message( + self, message: ToolMessage, is_before_last_user_message: bool + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a tool message. Note: @@ -1093,7 +1109,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: is_before_last_user_message: Not used. Returns: - The encoded tokens. + The encoded tokens, images, and audios. """ assert message.tool_call_id is not None assert isinstance(message.content, str), "Message content must be normalized" @@ -1110,11 +1126,11 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: *tokens, self.END_TOOL_RESULTS, ] - return curr_tokens + return curr_tokens, [], [] def encode_assistant_message( self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool - ) -> list[int]: + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode an assistant message. Args: @@ -1124,7 +1140,7 @@ def encode_assistant_message( Only use this if the assistant message is the last message. Returns: - The encoded tokens. + The encoded tokens, images, and audios. """ if not message.content and not message.tool_calls: raise TokenizerException(f"Invalid assistant message: {message}") @@ -1134,17 +1150,22 @@ def encode_assistant_message( ) curr_tokens: list = [] + images: list[np.ndarray] = [] + audios: list[Audio] = [] if message.content: if isinstance(message.content, str): curr_tokens = self._encode_normal_content_assistant_message(message) elif isinstance(message.content, list): - curr_tokens += self._encode_content_chunks(message.content)[0] + content_tokens, new_images, new_audios = self._encode_content_chunks(message.content) + curr_tokens += content_tokens + images.extend(new_images) + audios.extend(new_audios) if message.tool_calls: curr_tokens += self._encode_tool_calls_in_assistant_message(message) if not message.prefix and not continue_message: curr_tokens.append(self.tokenizer.eos_id) - return curr_tokens + return curr_tokens, images, audios def _encode_audio_for_speech_request(self, ref_audio: str | bytes | None, voice: str | None) -> Tokenized: r"""Encode reference audio or voice preset into a Tokenized object. @@ -1282,28 +1303,32 @@ def _encode_tool_calls_in_assistant_message(self, message: AssistantMessageType) ] return curr_tokens - def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]: + def encode_tool_message( + self, message: ToolMessage, is_before_last_user_message: bool + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a tool message. Args: message: The message to encode. is_before_last_user_message: Not used. Returns: - The encoded tokens. + The encoded tokens, images, and audios. """ assert message.tool_call_id is not None, "Tool call id must be provided for tokenizer >= v13" - content = message.content - if not isinstance(content, str): - content = "".join(chunk.text for chunk in content) + if isinstance(message.content, str): + content_tokens = self.tokenizer.encode(message.content, bos=False, eos=False) + images: list[np.ndarray] = [] + audios: list[Audio] = [] + else: + content_tokens, images, audios = self._encode_content_chunks(message.content) - tokens = self.tokenizer.encode(content, bos=False, eos=False) curr_tokens = [ self.BEGIN_TOOL_RESULTS, - *tokens, + *content_tokens, self.END_TOOL_RESULTS, ] - return curr_tokens + return curr_tokens, images, audios def encode_think(self, chunk: ThinkChunk) -> list[int]: r"""Encode a thinking chunk. @@ -1368,7 +1393,7 @@ def _encode_settings( ] return settings_tokens - def encode_system_message(self, message: SystemMessage) -> list[int]: + def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a system message, rejecting ThinkChunk content.""" if isinstance(message.content, list): if any(isinstance(chunk, ThinkChunk) for chunk in message.content): diff --git a/tests/data/chat_templates/v15.jinja b/tests/data/chat_templates/v15.jinja index 26b4f9d2..d1f29add 100644 --- a/tests/data/chat_templates/v15.jinja +++ b/tests/data/chat_templates/v15.jinja @@ -182,7 +182,9 @@ {#- Tool messages only supports text content. #} {%- elif message['role'] == 'tool' %} - {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }} + {{- '[TOOL_RESULTS]' -}} + {{- render_content(message['content'], 'tool message contents') -}} + {{- '[/TOOL_RESULTS]' }} {#- System messages. #} {%- elif message['role'] == 'system' %} diff --git a/tests/data/chat_templates/v15_audio.jinja b/tests/data/chat_templates/v15_audio.jinja index 68ce5346..381c72cb 100644 --- a/tests/data/chat_templates/v15_audio.jinja +++ b/tests/data/chat_templates/v15_audio.jinja @@ -158,7 +158,7 @@ {%- endif %} {%- if message['content'] %} - {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_audio=false) -}} + {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and audio', support_audio=true) -}} {%- endif %} {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %} @@ -179,12 +179,14 @@ {#- Tool messages only supports text content. #} {%- elif message['role'] == 'tool' %} - {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }} + {{- '[TOOL_RESULTS]' -}} + {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and audio', support_audio=true) -}} + {{- '[/TOOL_RESULTS]' }} {#- System messages. #} {%- elif message['role'] == 'system' %} {{- '[SYSTEM_PROMPT]' -}} - {{- render_content(message['content'], 'system message contents', supported_types_desc='text', support_audio=false) -}} + {{- render_content(message['content'], 'system message contents', supported_types_desc='text and audio', support_audio=true) -}} {{- '[/SYSTEM_PROMPT]' -}} {#- Raise exception for unsupported roles. #} diff --git a/tests/data/chat_templates/v15_image.jinja b/tests/data/chat_templates/v15_image.jinja index bb338233..58146ed7 100644 --- a/tests/data/chat_templates/v15_image.jinja +++ b/tests/data/chat_templates/v15_image.jinja @@ -169,7 +169,7 @@ {%- endif %} {%- if message['content'] %} - {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_images=false) -}} + {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and image', support_images=true) -}} {%- endif %} {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %} @@ -190,7 +190,9 @@ {#- Tool messages only supports text content. #} {%- elif message['role'] == 'tool' %} - {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }} + {{- '[TOOL_RESULTS]' -}} + {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and image', support_images=true) -}} + {{- '[/TOOL_RESULTS]' }} {#- System messages. #} {%- elif message['role'] == 'system' %} diff --git a/tests/data/chat_templates/v15_image_think.jinja b/tests/data/chat_templates/v15_image_think.jinja index 63c61f6b..1460ec9b 100644 --- a/tests/data/chat_templates/v15_image_think.jinja +++ b/tests/data/chat_templates/v15_image_think.jinja @@ -196,7 +196,7 @@ {%- endif %} {%- if message['content'] %} - {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and thinking', support_thinking=true, support_images=false) -}} + {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text, thinking and image', support_thinking=true, support_images=true) -}} {%- endif %} {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %} @@ -217,7 +217,9 @@ {#- Tool messages only supports text content. #} {%- elif message['role'] == 'tool' %} - {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }} + {{- '[TOOL_RESULTS]' -}} + {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and image', support_images=true) -}} + {{- '[/TOOL_RESULTS]' }} {#- System messages. #} {%- elif message['role'] == 'system' %} diff --git a/tests/data/chat_templates/v15_think.jinja b/tests/data/chat_templates/v15_think.jinja index 7df9072a..d91b3909 100644 --- a/tests/data/chat_templates/v15_think.jinja +++ b/tests/data/chat_templates/v15_think.jinja @@ -209,7 +209,9 @@ {#- Tool messages only supports text content. #} {%- elif message['role'] == 'tool' %} - {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }} + {{- '[TOOL_RESULTS]' -}} + {{- render_content(message['content'], 'tool message contents') -}} + {{- '[/TOOL_RESULTS]' }} {#- System messages. #} {%- elif message['role'] == 'system' %} diff --git a/tests/experimental/test_app.py b/tests/experimental/test_app.py index 9851cb06..27bb9dd8 100644 --- a/tests/experimental/test_app.py +++ b/tests/experimental/test_app.py @@ -394,7 +394,9 @@ def test_detokenize_assistant_message( def test_detokenize_assistant_message_think_chunks( assistant_message: AssistantMessage, mistral_tokenizer_v13: MistralTokenizer, tekken_v13_client: TestClient ) -> None: - encoded_tokens = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message(assistant_message, False, False) # type: ignore[attr-defined] + encoded_tokens, _, _ = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined] + assistant_message, False, False + ) response = tekken_v13_client.post("/v1/detokenize/", json=encoded_tokens) assert response.status_code == 200 @@ -462,7 +464,7 @@ def test_generate( engine_request: dict | ChatCompletionRequest | OpenAIChatCompletionRequest, output_assistant_message: AssistantMessage, ) -> None: - output_tokens = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined] + output_tokens, _, _ = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined] output_assistant_message, False, False ) if output_assistant_message.tool_calls: diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py index 9afe6f33..cac90e4c 100644 --- a/tests/guidance/test_guidance.py +++ b/tests/guidance/test_guidance.py @@ -228,16 +228,17 @@ def _encode_content( tokenizer = instruct_tokenizer.tokenizer if isinstance(content, str): - return instruct_tokenizer.encode_assistant_message( + result, _, _ = instruct_tokenizer.encode_assistant_message( AssistantMessage(content=content), is_before_last_user_message=False, continue_message=False ) + return result tool_calls = [x for x in content if isinstance(x, ToolCall)] content_chunks = [x for x in content if not isinstance(x, ToolCall)] tokens: list[int] = [] if content_chunks: - tokens = instruct_tokenizer.encode_assistant_message( + tokens, _, _ = instruct_tokenizer.encode_assistant_message( AssistantMessage(content=content_chunks), is_before_last_user_message=False, continue_message=False, diff --git a/tests/integrations/chat_templates/fixtures_data.py b/tests/integrations/chat_templates/fixtures_data.py index 6f0e6e55..c45f1b79 100644 --- a/tests/integrations/chat_templates/fixtures_data.py +++ b/tests/integrations/chat_templates/fixtures_data.py @@ -952,6 +952,60 @@ ] ) +# -- Multimodal content in non-user messages (v15+) -- + +REQUEST_TOOL_IMAGE_TRAIN = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + UserMessage(content="What is in this image?"), + AssistantMessage( + content=None, + tool_calls=[ + ToolCall( + id="tl1mg2345", + function=FunctionCall( + name="tool1", + arguments={"location": "San Francisco, CA"}, # type: ignore[arg-type] + ), + ), + ], + ), + ToolMessage( + content=[ + TextChunk(text="Here is the result."), + ImageURLChunk(image_url=_IMAGE_URL), + ], + tool_call_id="tl1mg2345", + ), + AssistantMessage(content="The tool returned an image of a red square."), + ], + tools=_TOOLS, +) + +REQUEST_ASSISTANT_IMAGE_TRAIN = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + UserMessage(content="Generate an image for me."), + AssistantMessage( + content=[ + TextChunk(text="Here is the generated image."), + ImageURLChunk(image_url=_IMAGE_URL), + ], + ), + ], +) + +REQUEST_SYSTEM_AUDIO_TRAIN = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + SystemMessage( + content=[ + TextChunk(text="You are an audio assistant."), + AudioChunk(input_audio=_AUDIO), + ], + ), + UserMessage(content="What was that sound?"), + AssistantMessage(content="That was a sine wave tone."), + ], +) + def _get_conversations( tokenizer_version: TokenizerVersion, @@ -1102,6 +1156,18 @@ def _get_conversations( ] ) + # v15+ only: multimodal content in non-user messages (finetuning only) + if tokenizer_version >= TokenizerVersion.v15 and validation_mode == ValidationMode.finetuning: + if image: + conversations.extend( + [ + REQUEST_TOOL_IMAGE_TRAIN, + REQUEST_ASSISTANT_IMAGE_TRAIN, + ] + ) + if audio: + conversations.append(REQUEST_SYSTEM_AUDIO_TRAIN) + conversations = [c.model_copy(deep=True) for c in conversations] if think and tokenizer_version >= TokenizerVersion.v15: diff --git a/tests/integrations/chat_templates/test_parity.py b/tests/integrations/chat_templates/test_parity.py index 5394d5cb..0cce8183 100644 --- a/tests/integrations/chat_templates/test_parity.py +++ b/tests/integrations/chat_templates/test_parity.py @@ -426,12 +426,133 @@ def test_dynamic_template_comprehensive(config: TestConfig) -> None: ] ) + # V15+ multimodal: tool messages with image/audio content + if config.version >= TokenizerVersion.v15 and config.image: + test_cases.append( + { + "name": "v15_tool_with_image", + "messages": [ + {"role": "user", "content": "Use tool"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "abc123def", "function": {"name": "fn", "arguments": "{}"}}, + ], + }, + { + "role": "tool", + "content": [ + {"type": "text", "text": "result"}, + {"type": "image_url", "image_url": "http://example.com/img.png"}, + ], + "tool_call_id": "abc123def", + }, + {"role": "assistant", "content": "Done"}, + ], + "tools": [ + { + "type": "function", + "function": {"name": "fn", "description": "test", "parameters": {}}, + } + ], + } + ) + + if config.version >= TokenizerVersion.v15 and config.audio: + test_cases.append( + { + "name": "v15_tool_with_audio", + "messages": [ + {"role": "user", "content": "Use tool"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "abc123def", "function": {"name": "fn", "arguments": "{}"}}, + ], + }, + { + "role": "tool", + "content": [ + {"type": "text", "text": "result"}, + {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, + ], + "tool_call_id": "abc123def", + }, + {"role": "assistant", "content": "Done"}, + ], + "tools": [ + { + "type": "function", + "function": {"name": "fn", "description": "test", "parameters": {}}, + } + ], + } + ) + + # V15+ multimodal: assistant messages with image/audio content + if config.version >= TokenizerVersion.v15 and config.image: + test_cases.append( + { + "name": "v15_assistant_with_image", + "messages": [ + {"role": "user", "content": "Show me"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Here is the image"}, + {"type": "image_url", "image_url": "http://example.com/img.png"}, + ], + }, + ], + } + ) + + if config.version >= TokenizerVersion.v15 and config.audio: + test_cases.append( + { + "name": "v15_assistant_with_audio", + "messages": [ + {"role": "user", "content": "Listen"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Here is audio"}, + {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, + ], + }, + ], + } + ) + + # V15+ multimodal: system messages with audio content + if config.version >= TokenizerVersion.v15 and config.audio: + test_cases.append( + { + "name": "v15_system_with_audio", + "messages": [ + { + "role": "system", + "content": [ + {"type": "text", "text": "Listen to context"}, + {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, + ], + }, + {"role": "user", "content": "Summarize"}, + {"role": "assistant", "content": "Done"}, + ], + } + ) + skip_names_image = {"with_image", "consecutive_users_with_image", "consecutive_users_multi_image"} skip_names_audio = {"with_audio", "consecutive_users_multi_audio"} skip_names_think = {"with_thinking", "consecutive_assistants_think", "consecutive_systems_think"} # ThinkChunks in system messages are only supported in v13 (not v15+) skip_names_think_system = {"consecutive_systems_think"} skip_names_tools = {"with_tools_definition", "with_tool_call_and_result"} + skip_names_v15_image = {"v15_tool_with_image", "v15_assistant_with_image"} + skip_names_v15_audio = {"v15_tool_with_audio", "v15_assistant_with_audio", "v15_system_with_audio"} # Not using parametrize here because test_cases are built dynamically based on the # config (version/image/audio/think) from the outer parametrize. Each sub-case is @@ -451,6 +572,10 @@ def test_dynamic_template_comprehensive(config: TestConfig) -> None: continue if test_name in skip_names_tools and config.version <= TokenizerVersion.v1: continue + if test_name in skip_names_v15_image and (config.version < TokenizerVersion.v15 or not config.image): + continue + if test_name in skip_names_v15_audio and (config.version < TokenizerVersion.v15 or not config.audio): + continue static_output = render_template(static_template, messages, tools=tools) # type: ignore dynamic_output = render_template(dynamic_template, messages, tools=tools) # type: ignore diff --git a/tests/integrations/chat_templates/transformers/test_core_parity.py b/tests/integrations/chat_templates/transformers/test_core_parity.py index bf5eb23d..ffb47608 100644 --- a/tests/integrations/chat_templates/transformers/test_core_parity.py +++ b/tests/integrations/chat_templates/transformers/test_core_parity.py @@ -205,10 +205,16 @@ def test_invalid_chunks( # Not using parametrize here because invalid_convs depends on the version/image/audio/think # parameters from the outer parametrize. Each sub-case is identifiable via the TemplateError # match string which includes the role and allowed chunks. + is_v15_plus = config.version >= TokenizerVersion.v15 for conv in invalid_convs: msg_template = "Only {chunks} chunks are supported in {role} message content." if conv in sp_invalids: - chunks = "text and thinking" if config.think and config.version < TokenizerVersion.v15 else "text" + if config.think and not is_v15_plus: + chunks = "text and thinking" + elif config.audio and is_v15_plus: + chunks = "text and audio" + else: + chunks = "text" role = "system" elif conv in user_invalids: chunks = "text" @@ -218,7 +224,17 @@ def test_invalid_chunks( chunks += ", input_audio and audio_url" role = "user" elif conv in assistant_invalids: - chunks = "text and thinking" if config.think else "text" + desc_parts = ["text"] + if config.think: + desc_parts.append("thinking") + if is_v15_plus and config.image: + desc_parts.append("image") + if is_v15_plus and config.audio: + desc_parts.append("audio") + if len(desc_parts) == 1: + chunks = "text" + else: + chunks = ", ".join(desc_parts[:-1]) + " and " + desc_parts[-1] role = "assistant" err_msg = msg_template.format(chunks=chunks, role=role) diff --git a/tests/integrations/chat_templates/unit/test_v15.py b/tests/integrations/chat_templates/unit/test_v15.py index 1364e5f1..3401a8cf 100644 --- a/tests/integrations/chat_templates/unit/test_v15.py +++ b/tests/integrations/chat_templates/unit/test_v15.py @@ -246,3 +246,284 @@ def test_v15_think_template_rejects_think_in_system(self, image: bool) -> None: with pytest.raises(ValueError, match="Only text chunks are supported in system message contents"): render_template(template, messages) + + +class TestV15MultimodalContent: + def test_v15_tool_message_with_image_content(self) -> None: + r"""V15 image template renders tool message with image content using render_content.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v15, + image_support=True, + audio_support=False, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + {"role": "user", "content": "Use tool"}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"id": "test12345", "function": {"name": "fn", "arguments": "{}"}}], + }, + { + "role": "tool", + "content": [ + {"type": "text", "text": "result"}, + {"type": "image_url", "image_url": "http://example.com/img.png"}, + ], + "tool_call_id": "test12345", + }, + {"role": "assistant", "content": "Done"}, + ] + + tools = [{"type": "function", "function": {"name": "fn", "description": "test", "parameters": {}}}] + output = render_template(template, messages, tools=tools, reasoning_effort="none") + assert "[TOOL_RESULTS]" in output + assert "[IMG]" in output + assert "result" in output + + def test_v15_tool_message_with_audio_content(self) -> None: + r"""V15 audio template renders tool message with audio content.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v15, + image_support=False, + audio_support=True, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + {"role": "user", "content": "Use tool"}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"id": "test12345", "function": {"name": "fn", "arguments": "{}"}}], + }, + { + "role": "tool", + "content": [ + {"type": "text", "text": "result"}, + {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, + ], + "tool_call_id": "test12345", + }, + {"role": "assistant", "content": "Done"}, + ] + + tools = [{"type": "function", "function": {"name": "fn", "description": "test", "parameters": {}}}] + output = render_template(template, messages, tools=tools, reasoning_effort="none") + assert "[TOOL_RESULTS]" in output + assert "[AUDIO]" in output + assert "result" in output + + def test_v15_assistant_message_with_image_content(self) -> None: + r"""V15 image template renders assistant message with image content.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v15, + image_support=True, + audio_support=False, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + {"role": "user", "content": "Show me"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Here is the image"}, + {"type": "image_url", "image_url": "http://example.com/img.png"}, + ], + }, + ] + + output = render_template(template, messages, reasoning_effort="none") + assert "Here is the image" in output + assert "[IMG]" in output + + def test_v15_assistant_message_with_audio_content(self) -> None: + r"""V15 audio template renders assistant message with audio content.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v15, + image_support=False, + audio_support=True, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + {"role": "user", "content": "Listen"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Here is audio"}, + {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, + ], + }, + ] + + output = render_template(template, messages, reasoning_effort="none") + assert "Here is audio" in output + assert "[AUDIO]" in output + + def test_v15_system_message_with_audio_content(self) -> None: + r"""V15 audio template renders system message with audio content.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v15, + image_support=False, + audio_support=True, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "Listen to context"}, + {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, + ], + }, + {"role": "user", "content": "Summarize"}, + {"role": "assistant", "content": "Done"}, + ] + + output = render_template(template, messages, reasoning_effort="none") + assert "[SYSTEM_PROMPT]" in output + assert "Listen to context" in output + assert "[AUDIO]" in output + + def test_pre_v15_image_template_rejects_image_in_assistant(self) -> None: + r"""Pre-V15 image template rejects image chunks in assistant message.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v13, + image_support=True, + audio_support=False, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + {"role": "user", "content": "Show me"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Here"}, + {"type": "image_url", "image_url": "http://example.com/img.png"}, + ], + }, + ] + + with pytest.raises(ValueError, match="Only text chunks are supported in assistant message contents"): + render_template(template, messages) + + def test_pre_v15_audio_template_rejects_audio_in_assistant(self) -> None: + r"""Pre-V15 audio template rejects audio chunks in assistant message.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v13, + image_support=False, + audio_support=True, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + {"role": "user", "content": "Listen"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Here"}, + {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, + ], + }, + ] + + with pytest.raises(ValueError, match="Only text chunks are supported in assistant message contents"): + render_template(template, messages) + + def test_pre_v15_audio_template_rejects_audio_in_system(self) -> None: + r"""Pre-V15 audio template rejects audio chunks in system message.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v13, + image_support=False, + audio_support=True, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "Context"}, + {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, + ], + }, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ] + + with pytest.raises(ValueError, match="Only text chunks are supported in system message contents"): + render_template(template, messages) + + def test_pre_v15_image_template_rejects_image_in_tool(self) -> None: + r"""Pre-V15 image template rejects image chunks in tool message content.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v13, + image_support=True, + audio_support=False, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + {"role": "user", "content": "Use tool"}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"id": "test12345", "function": {"name": "fn", "arguments": "{}"}}], + }, + { + "role": "tool", + "content": [ + {"type": "text", "text": "result"}, + {"type": "image_url", "image_url": "http://example.com/img.png"}, + ], + "tool_call_id": "test12345", + }, + {"role": "assistant", "content": "Done"}, + ] + + tools = [{"type": "function", "function": {"name": "fn", "description": "test", "parameters": {}}}] + # Pre-V15 uses message['content']|string which coerces list to string representation + # rather than rendering through render_content — no [IMG] token produced + output = render_template(template, messages, tools=tools) + assert "[IMG]" not in output diff --git a/tests/test_converters.py b/tests/test_converters.py index d093d945..e09ad399 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -588,7 +588,7 @@ def test_non_leading_think_chunks_construction_ok() -> None: ) def test_non_leading_think_chunks_to_openai_raises(content: list[TextChunk | ThinkChunk]) -> None: """to_openai raises when ThinkChunks are not leading.""" - msg = AssistantMessage(content=content) + msg = AssistantMessage(content=content) # type: ignore[arg-type] with pytest.raises(InvalidAssistantMessageException, match="ThinkChunks must be leading"): msg.to_openai() @@ -1426,3 +1426,20 @@ def test_from_openai_drops_extra_fields(from_openai_call: Any, expected: Any) -> def test_direct_construction_still_strict(constructor: Any) -> None: with pytest.raises(Exception): constructor() + + +def test_assistant_message_to_openai_reasoning_with_multimodal() -> None: + r"""Reasoning format with multimodal content serializes non-think portion as list.""" + msg = AssistantMessage( + content=[ + ThinkChunk(thinking="Let me think", closed=True), + TextChunk(text="Here is the result"), + ImageURLChunk(image_url="data:image/png;base64,iVBORw0"), + ] + ) + result = msg.to_openai(reasoning_field_format=ReasoningFieldFormat.reasoning) + assert result["reasoning"] == "Let me think" + assert isinstance(result["content"], list) + assert len(result["content"]) == 2 + assert result["content"][0] == {"type": "text", "text": "Here is the result"} + assert result["content"][1]["type"] == "image_url" diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 1ff09546..b8a55921 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1,8 +1,11 @@ import json import pytest +from pydantic import ValidationError +from mistral_common.exceptions import InvalidRequestException from mistral_common.protocol.instruct.chunk import ( + AudioChunk, ChunkTypes, ContentChunk, ImageURLChunk, @@ -1121,3 +1124,190 @@ def test_get_normalizer_version_mapping( normalizer = get_normalizer(version, model_settings_builder) assert isinstance(normalizer, expected_class) assert normalizer._model_settings_builder == model_settings_builder + + +class TestToolMessageContentChunk: + @pytest.fixture() + def normalizer_v15(self) -> InstructRequestNormalizerV15: + return InstructRequestNormalizerV15( + UserMessage, + AssistantMessage, + ToolMessage, + SystemMessage, + InstructRequest, + ModelSettingsBuilder( + reasoning_effort=EnumBuilder[ReasoningEffort]( + values=list(ReasoningEffort), accepts_none=False, default=None + ) + ), + ) + + @pytest.fixture() + def normalizer_v13(self) -> InstructRequestNormalizerV13: + return InstructRequestNormalizerV13( + UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None + ) + + def test_v15_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + image_chunk = ImageURLChunk(image_url="https://example.com/image.png") + request = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + UserMessage(content="query"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), + ToolMessage(content=[image_chunk], tool_call_id="c1"), + ], + reasoning_effort=ReasoningEffort.high, + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) + tool_msg = parsed.messages[2] + assert isinstance(tool_msg, ToolMessage) + assert isinstance(tool_msg.content, list) + assert tool_msg.content == [image_chunk] + + def test_pre_v15_rejects_non_text_tool_content(self) -> None: + r"""Pre-V15 normalizer raises InvalidRequestException for non-text tool content.""" + normalizer = get_normalizer(TokenizerVersion.v13) + request = mock_chat_completion( + messages=[ + UserMessage(content="hi"), + AssistantMessage( + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")], + ), + ToolMessage( + content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], + tool_call_id="test12345", + ), + ] + ) + with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"): + normalizer.from_chat_completion_request(request) + + def test_pre_v15_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), + ToolMessage(content=[TextChunk(text="hello"), TextChunk(text="world")], tool_call_id="c1"), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) + tool_msg = parsed.messages[2] + assert isinstance(tool_msg, ToolMessage) + assert isinstance(tool_msg.content, str) + assert tool_msg.content == "hello\n\nworld" + + def test_pre_v15_rejects_audio_in_tool_content(self) -> None: + r"""Pre-V15 normalizer raises InvalidRequestException for audio tool content.""" + normalizer = get_normalizer(TokenizerVersion.v13) + request = mock_chat_completion( + messages=[ + UserMessage(content="hi"), + AssistantMessage( + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")], + ), + ToolMessage( + content=[AudioChunk(input_audio=b"fake_audio_data")], + tool_call_id="test12345", + ), + ] + ) + with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"): + normalizer.from_chat_completion_request(request) + + +class TestAssistantMessageContentChunk: + @pytest.fixture() + def normalizer_v15(self) -> InstructRequestNormalizerV15: + return InstructRequestNormalizerV15( + UserMessage, + AssistantMessage, + ToolMessage, + SystemMessage, + InstructRequest, + ModelSettingsBuilder( + reasoning_effort=EnumBuilder[ReasoningEffort]( + values=list(ReasoningEffort), accepts_none=False, default=None + ) + ), + ) + + @pytest.fixture() + def normalizer_v13(self) -> InstructRequestNormalizerV13: + return InstructRequestNormalizerV13( + UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None + ) + + def test_v15_preserves_non_text_assistant_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + image_chunk = ImageURLChunk(image_url="https://example.com/image.png") + text_chunk = TextChunk(text="description") + request = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[image_chunk, text_chunk]), + ], + reasoning_effort=ReasoningEffort.high, + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) + assistant_msg = parsed.messages[1] + assert isinstance(assistant_msg, AssistantMessage) + assert isinstance(assistant_msg.content, list) + assert assistant_msg.content == [image_chunk, TextChunk(text="description")] + + def test_pre_v15_rejects_non_text_assistant_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: + image_chunk = ImageURLChunk(image_url="https://example.com/image.png") + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[image_chunk]), + ], + ) + with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in assistant message"): + normalizer_v13.from_chat_completion_request(request) + + +class TestSystemMessageContentChunk: + def test_system_message_accepts_audio_chunk(self) -> None: + msg = SystemMessage(content=[AudioChunk(input_audio="dGVzdA==")]) + assert isinstance(msg.content, list) + assert len(msg.content) == 1 + assert isinstance(msg.content[0], AudioChunk) + + def test_system_message_rejects_image_chunk(self) -> None: + with pytest.raises(ValidationError): + SystemMessage(content=[ImageURLChunk(image_url="https://example.com/image.png")]) # type: ignore[list-item] + + def test_v15_rejects_think_in_system_message(self) -> None: + r"""V15 normalizer rejects ThinkChunk in system messages.""" + normalizer = get_normalizer( + TokenizerVersion.v15, + model_settings_builder=ModelSettingsBuilder( + reasoning_effort=EnumBuilder[ReasoningEffort]( + values=list(ReasoningEffort), accepts_none=False, default=None + ) + ), + ) + request = mock_chat_completion( + messages=[ + SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]), + UserMessage(content="test"), + ] + ) + with pytest.raises(InvalidRequestException, match="ThinkChunk"): + normalizer.from_chat_completion_request(request) + + def test_v7_normalization_preserves_audio_in_system_message(self) -> None: + r"""V7 normalizer preserves AudioChunk in system messages.""" + normalizer = InstructRequestNormalizerV7.normalizer() + request = mock_chat_completion( + messages=[ + SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]), + UserMessage(content="test"), + ] + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) + system_msg = parsed.messages[0] + assert isinstance(system_msg, SystemMessage) + assert isinstance(system_msg.content, list) + assert len(system_msg.content) == 2 + assert isinstance(system_msg.content[0], TextChunk) + assert isinstance(system_msg.content[1], AudioChunk) diff --git a/tests/test_tokenizer_v11.py b/tests/test_tokenizer_v11.py index 8c3c7330..59bcb466 100644 --- a/tests/test_tokenizer_v11.py +++ b/tests/test_tokenizer_v11.py @@ -32,13 +32,15 @@ def test_special_tokens(tekkenizer: InstructTokenizerV11) -> None: def test_tokenize_assistant_message(tekkenizer: InstructTokenizerV11) -> None: - tokens = tekkenizer.encode_assistant_message( + tokens, images, audios = tekkenizer.encode_assistant_message( AssistantMessage( tool_calls=[ToolCall(function=FunctionCall(name="a_a_a", arguments="blabla"))], ), is_before_last_user_message=False, continue_message=False, ) + assert images == [] + assert audios == [] assert tokens == [ tekkenizer.TOOL_CALLS, 197, @@ -61,13 +63,15 @@ def test_tokenize_assistant_message(tekkenizer: InstructTokenizerV11) -> None: def test_tokenize_assistant_message_continue_message(tekkenizer: InstructTokenizerV11) -> None: - tokens = tekkenizer.encode_assistant_message( + tokens, images, audios = tekkenizer.encode_assistant_message( AssistantMessage( content='"blabla"', ), is_before_last_user_message=False, continue_message=True, ) + assert images == [] + assert audios == [] assert tokens == [ 134, 198, @@ -95,7 +99,7 @@ def test_tokenize_assistant_message_continue_message(tekkenizer: InstructTokeniz def test_tokenize_assistant_messages(tekkenizer: InstructTokenizerV11) -> None: - tokens = tekkenizer.encode_assistant_message( + tokens, images, audios = tekkenizer.encode_assistant_message( AssistantMessage( tool_calls=[ ToolCall(function=FunctionCall(name="a_a_a", arguments="blabla")), @@ -105,6 +109,8 @@ def test_tokenize_assistant_messages(tekkenizer: InstructTokenizerV11) -> None: is_before_last_user_message=False, continue_message=False, ) + assert images == [] + assert audios == [] assert tokens == [ tekkenizer.TOOL_CALLS, 197, @@ -135,13 +141,15 @@ def test_tokenize_assistant_messages(tekkenizer: InstructTokenizerV11) -> None: def test_tokenize_assistant_message_train(tekkenizer: InstructTokenizerV11) -> None: - tokens = tekkenizer.encode_assistant_message( + tokens, images, audios = tekkenizer.encode_assistant_message( AssistantMessage( tool_calls=[ToolCall(function=FunctionCall(name="a_a_a", arguments="blabla"), id="ABC")], ), is_before_last_user_message=True, continue_message=False, ) + assert images == [] + assert audios == [] assert tokens == [ tekkenizer.TOOL_CALLS, 197, diff --git a/tests/test_tokenizer_v13.py b/tests/test_tokenizer_v13.py index 15bb6f3a..cec43105 100644 --- a/tests/test_tokenizer_v13.py +++ b/tests/test_tokenizer_v13.py @@ -237,13 +237,17 @@ def test_end_to_end_v13_wrong_order( def test_encode_tool_message(v13_tekkenizer: InstructTokenizerV13) -> None: tool_message = ToolMessage(content="R1", tool_call_id="123456789") assert isinstance(v13_tekkenizer, InstructTokenizerV13) - encoded = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False) + encoded, images, audios = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False) assert encoded == [7, 182, 149, 8] + assert images == [] + assert audios == [] tool_message = ToolMessage(content=[TextChunk(text="R1"), TextChunk(text="R2")], tool_call_id="123456789") assert isinstance(v13_tekkenizer, InstructTokenizerV13) - encoded = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False) + encoded, images, audios = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False) assert encoded == [7, 182, 149, 182, 150, 8] + assert images == [] + assert audios == [] def test_encode_think_chunk(v13_tekkenizer_think: InstructTokenizerV13) -> None: @@ -295,7 +299,7 @@ def test_tokenize_assistant_message( v13_tekkenizer_think: InstructTokenizerV13, message: AssistantMessage, expected: str, continue_final_message: bool ) -> None: if not continue_final_message: - tokens = v13_tekkenizer_think.encode_assistant_message( + tokens, images, audios = v13_tekkenizer_think.encode_assistant_message( message, is_before_last_user_message=False, continue_message=continue_final_message ) if not message.prefix: @@ -310,10 +314,12 @@ def test_tokenize_assistant_message( message, is_before_last_user_message=False, continue_message=continue_final_message ) return - tokens = v13_tekkenizer_think.encode_assistant_message( + tokens, images, audios = v13_tekkenizer_think.encode_assistant_message( message, is_before_last_user_message=False, continue_message=continue_final_message ) assert v13_tekkenizer_think.decode(tokens, special_token_policy=SpecialTokenPolicy.KEEP) == expected + assert images == [] + assert audios == [] def test_tokenize_assistant_message_error(v13_tekkenizer: InstructTokenizerV13) -> None: @@ -371,8 +377,10 @@ def test_tokenize_assistant_message_error(v13_tekkenizer: InstructTokenizerV13) def test_encode_system_message( v13_tekkenizer_think: InstructTokenizerV13, message: SystemMessage, expected: str ) -> None: - encoded = v13_tekkenizer_think.encode_system_message(message) + encoded, images, audios = v13_tekkenizer_think.encode_system_message(message) assert v13_tekkenizer_think.decode(encoded, special_token_policy=SpecialTokenPolicy.KEEP) == expected + assert images == [] + assert audios == [] @pytest.mark.parametrize("audio_fixture", ["audio_chunk", "audio_url_chunk"]) diff --git a/tests/test_tokenizer_v15.py b/tests/test_tokenizer_v15.py index fdd84c33..8b45e69a 100644 --- a/tests/test_tokenizer_v15.py +++ b/tests/test_tokenizer_v15.py @@ -1,7 +1,18 @@ +import base64 +from collections.abc import Callable +from io import BytesIO + import pytest +from PIL import Image from mistral_common.exceptions import InvalidRequestException, TokenizerException -from mistral_common.protocol.instruct.chunk import ThinkChunk +from mistral_common.protocol.instruct.chunk import ( + AudioChunk, + AudioURLChunk, + ImageURLChunk, + TextChunk, + ThinkChunk, +) from mistral_common.protocol.instruct.messages import ( AssistantMessage, ChatMessage, @@ -18,11 +29,14 @@ ) from mistral_common.protocol.instruct.tool_calls import Function, FunctionCall, Tool, ToolCall from mistral_common.protocol.instruct.validator import ValidationMode, get_validator -from mistral_common.tokens.tokenizers.base import TokenizerVersion +from mistral_common.tokens.tokenizers.audio import AudioConfig, AudioEncoder, AudioSpectrogramConfig, SpecialAudioIDs +from mistral_common.tokens.tokenizers.base import SpecialTokens, TokenizerVersion +from mistral_common.tokens.tokenizers.image import ImageConfig, ImageEncoder, SpecialImageIDs from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV15 from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.model_settings_builder import EnumBuilder, ModelSettingsBuilder from mistral_common.tokens.tokenizers.tekken import Tekkenizer +from tests.fixtures.audio import get_dummy_audio_chunk, get_dummy_audio_url_chunk from tests.test_tekken import get_special_tokens, quick_vocab EXPECTED_TEXT_V15: str = ( @@ -294,3 +308,191 @@ def test_encode_chat_completion_continue_final_message() -> None: eos_id = tokenizer_v15.instruct_tokenizer.tokenizer.eos_id assert encoded.tokens[-1] != eos_id + + +@pytest.fixture(scope="session") +def audio_chunk() -> AudioChunk: + return get_dummy_audio_chunk() + + +def get_v15_mistral_tokenizer_with_audio() -> MistralTokenizer: + r"""Build a V15 MistralTokenizer with audio encoder.""" + builder = _build_model_settings_builder(tuple(ReasoningEffort)) + tekkenizer = Tekkenizer( + quick_vocab([b"a", b"b", b"c", b"f", b"de"]), + special_tokens=get_special_tokens(TokenizerVersion.v15, add_audio=True), + pattern=r".+", + vocab_size=256 + 100, + num_special_tokens=100, + version=TokenizerVersion.v15, + model_settings_builder=builder, + ) + audio_config = AudioConfig( + sampling_rate=24_000, + frame_rate=12.5, + encoding_config=AudioSpectrogramConfig( + num_mel_bins=128, + hop_length=160, + window_size=400, + ), + ) + special_audio_ids = SpecialAudioIDs( + audio=tekkenizer.get_special_token(SpecialTokens.audio.value), + begin_audio=tekkenizer.get_special_token(SpecialTokens.begin_audio.value), + streaming_pad=None, + text_to_audio=None, + audio_to_text=None, + ) + audio_encoder = AudioEncoder(audio_config, special_audio_ids) + instruct_tokenizer = InstructTokenizerV15(tekkenizer, audio_encoder=audio_encoder) + request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder) + validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test) + return MistralTokenizer( + instruct_tokenizer=instruct_tokenizer, + validator=validator, + request_normalizer=request_normalizer, + ) + + +def _get_dummy_image_url_chunk() -> ImageURLChunk: + r"""Build a small base64-encoded image URL chunk for testing.""" + img = Image.new("RGB", (4, 4), "red") + buf = BytesIO() + img.save(buf, "PNG") + data_url = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" + return ImageURLChunk(image_url=data_url) + + +def get_v15_mistral_tokenizer_with_image() -> MistralTokenizer: + r"""Build a V15 MistralTokenizer with image encoder.""" + builder = _build_model_settings_builder(tuple(ReasoningEffort)) + tekkenizer = Tekkenizer( + quick_vocab([b"a", b"b", b"c", b"f", b"de"]), + special_tokens=get_special_tokens(TokenizerVersion.v15, add_think=True), + pattern=r".+", + vocab_size=256 + 100, + num_special_tokens=100, + version=TokenizerVersion.v15, + model_settings_builder=builder, + ) + image_config = ImageConfig(image_patch_size=16, max_image_size=1024) + special_image_ids = SpecialImageIDs( + img=tekkenizer.get_special_token(SpecialTokens.img.value), + img_break=tekkenizer.get_special_token(SpecialTokens.img_break.value), + img_end=tekkenizer.get_special_token(SpecialTokens.img_end.value), + ) + image_encoder = ImageEncoder(image_config, special_image_ids) + instruct_tokenizer = InstructTokenizerV15(tekkenizer, image_encoder=image_encoder) + request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder) + validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test) + return MistralTokenizer( + instruct_tokenizer=instruct_tokenizer, + validator=validator, + request_normalizer=request_normalizer, + ) + + +# Multimodal content chunks and their corresponding tokenizer factories for parametrized tests. +_AUDIO_CHUNK_PARAMS = pytest.param(get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, id="audio") +_AUDIO_URL_CHUNK_PARAMS = pytest.param( + get_dummy_audio_url_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, id="audio_url" +) +_IMAGE_URL_CHUNK_PARAMS = pytest.param( + _get_dummy_image_url_chunk(), 0, 1, get_v15_mistral_tokenizer_with_image, id="image_url" +) + +_ALL_MULTIMODAL_PARAMS = [_AUDIO_CHUNK_PARAMS, _AUDIO_URL_CHUNK_PARAMS, _IMAGE_URL_CHUNK_PARAMS] +_AUDIO_ONLY_PARAMS = [_AUDIO_CHUNK_PARAMS] + + +@pytest.mark.parametrize( + ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"), + _ALL_MULTIMODAL_PARAMS, +) +def test_encode_chat_completion_with_multimodal_tool( + content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk, + expected_audios: int, + expected_images: int, + tokenizer_factory: Callable[[], MistralTokenizer], +) -> None: + mistral_tokenizer = tokenizer_factory() + chat_request = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + UserMessage(content="Use the tool"), + AssistantMessage(tool_calls=[ToolCall(id="test12345", function=FunctionCall(name="fn", arguments="{}"))]), + ToolMessage( + content=[TextChunk(text="result"), content_chunk], + tool_call_id="test12345", + ), + ], + tools=[Tool(function=Function(name="fn", description="test", parameters={}))], + ) + encoded = mistral_tokenizer.encode_chat_completion(chat_request) + assert len(encoded.audios) == expected_audios + assert len(encoded.images) == expected_images + + +@pytest.mark.parametrize( + ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"), + _ALL_MULTIMODAL_PARAMS, +) +def test_encode_chat_completion_with_multimodal_assistant( + content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk, + expected_audios: int, + expected_images: int, + tokenizer_factory: Callable[[], MistralTokenizer], +) -> None: + mistral_tokenizer = tokenizer_factory() + chat_request = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + UserMessage(content="Hello"), + AssistantMessage(content=[TextChunk(text="Here is content"), content_chunk]), + UserMessage(content="Thanks"), + ], + ) + encoded = mistral_tokenizer.encode_chat_completion(chat_request) + assert len(encoded.audios) == expected_audios + assert len(encoded.images) == expected_images + + +@pytest.mark.parametrize( + ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"), + _AUDIO_ONLY_PARAMS, +) +def test_encode_chat_completion_with_multimodal_system( + content_chunk: AudioChunk, + expected_audios: int, + expected_images: int, + tokenizer_factory: Callable[[], MistralTokenizer], +) -> None: + mistral_tokenizer = tokenizer_factory() + chat_request = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + SystemMessage(content=[TextChunk(text="System with content"), content_chunk]), + UserMessage(content="Hello"), + ], + ) + encoded = mistral_tokenizer.encode_chat_completion(chat_request) + assert len(encoded.audios) == expected_audios + assert len(encoded.images) == expected_images + + +@pytest.mark.parametrize( + ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"), + _ALL_MULTIMODAL_PARAMS, +) +def test_encode_chat_completion_with_multimodal_user( + content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk, + expected_audios: int, + expected_images: int, + tokenizer_factory: Callable[[], MistralTokenizer], +) -> None: + mistral_tokenizer = tokenizer_factory() + chat_request = ChatCompletionRequest( + messages=[ + UserMessage(content=[TextChunk(text="Here is content"), content_chunk]), + ], + ) + encoded = mistral_tokenizer.encode_chat_completion(chat_request) + assert len(encoded.audios) == expected_audios + assert len(encoded.images) == expected_images From 4f0735861d48a4871f89cfb7532497241952e509 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 10:24:17 +0200 Subject: [PATCH 07/47] Address PR review: remove assistant multimodal, simplify normalizer, use positive type guards --- .../chat_templates/template_generator.py | 14 +---- src/mistral_common/protocol/instruct/chunk.py | 2 +- .../protocol/instruct/messages.py | 7 +-- .../protocol/instruct/normalize.py | 56 +++---------------- tests/data/chat_templates/v15_audio.jinja | 2 +- tests/data/chat_templates/v15_image.jinja | 2 +- .../data/chat_templates/v15_image_think.jinja | 2 +- .../chat_templates/fixtures_data.py | 13 ----- .../chat_templates/test_parity.py | 39 +------------ .../chat_templates/unit/test_v15.py | 56 ------------------- tests/test_converters.py | 17 ------ tests/test_normalization.py | 50 ----------------- tests/test_tokenizer_v15.py | 23 -------- 13 files changed, 19 insertions(+), 264 deletions(-) diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py index 9d8a4373..868cf1c2 100644 --- a/src/mistral_common/integrations/chat_templates/template_generator.py +++ b/src/mistral_common/integrations/chat_templates/template_generator.py @@ -221,11 +221,6 @@ def tool_supports_multimodal(self) -> bool: r"""Whether tool messages can contain non-text content chunks. V15+.""" return self.version >= TokenizerVersion.v15 - @property - def assistant_supports_multimodal(self) -> bool: - r"""Whether assistant messages can contain non-text content chunks. V15+.""" - return self.version >= TokenizerVersion.v15 - @property def system_supports_audio(self) -> bool: r"""Whether system messages can contain audio. V15+ with audio support.""" @@ -837,7 +832,6 @@ def _generate_flush_logic(config: TemplateConfig) -> list[str]: " {%- set ns_c.has_non_text = true %}", " {%- endif %}", " {%- endfor %}", - " {%- endif %}", ] @@ -1270,17 +1264,13 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str: desc_parts = ["text"] if config.any_thinking_support: desc_parts.append("thinking") - if config.assistant_supports_multimodal and config.image_support: - desc_parts.append("image") - if config.assistant_supports_multimodal and config.audio_support: - desc_parts.append("audio") rc_call_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'" if config.any_thinking_support: rc_call_args += ", support_thinking=true" if config.image_support: - rc_call_args += f", support_images={'true' if config.assistant_supports_multimodal else 'false'}" + rc_call_args += ", support_images=false" if config.audio_support: - rc_call_args += f", support_audio={'true' if config.assistant_supports_multimodal else 'false'}" + rc_call_args += ", support_audio=false" lines.append(" {%- if message['content'] %}") diff --git a/src/mistral_common/protocol/instruct/chunk.py b/src/mistral_common/protocol/instruct/chunk.py index 430c8c3d..3ba496b9 100644 --- a/src/mistral_common/protocol/instruct/chunk.py +++ b/src/mistral_common/protocol/instruct/chunk.py @@ -455,7 +455,7 @@ def from_openai(cls, openai_chunk: dict[str, Any]) -> "ThinkChunk": TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk, Field(discriminator="type") ] -AssistantContentChunk: TypeAlias = ContentChunk +AssistantContentChunk = Annotated[TextChunk | ThinkChunk, Field(discriminator="type")] SystemContentChunk = Annotated[TextChunk | AudioChunk | ThinkChunk, Field(discriminator="type")] diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py index 140a1f3a..96aa5fd7 100644 --- a/src/mistral_common/protocol/instruct/messages.py +++ b/src/mistral_common/protocol/instruct/messages.py @@ -27,8 +27,8 @@ ) -def _are_think_chunks(chunks: list[ContentChunk]) -> TypeGuard[list[ThinkChunk]]: - r"""Narrow a ContentChunk list to ThinkChunk list.""" +def _are_think_chunks(chunks: Sequence[TextChunk | ThinkChunk]) -> TypeGuard[list[ThinkChunk]]: + r"""Narrow a chunk list to ThinkChunk list.""" return all(isinstance(c, ThinkChunk) for c in chunks) @@ -358,12 +358,11 @@ def from_openai(cls, openai_message: dict[str, Any]) -> "ToolMessage": tool_message = cls.model_validate( { "role": openai_message["role"], - "tool_call_id": openai_message.get("tool_call_id"), + "tool_call_id": openai_message["tool_call_id"], "content": content, "name": openai_message.get("name"), } ) - assert tool_message.tool_call_id is not None, "tool_call_id must be provided for tool messages." return tool_message diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index b053b381..0dd77243 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -1,6 +1,6 @@ import json import warnings -from typing import Generic, Sequence +from typing import Generic, Sequence, cast from typing_extensions import TypeGuard, assert_never @@ -38,8 +38,8 @@ def _is_user_content( chunks: list[ContentChunk], ) -> TypeGuard[list[TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk]]: - r"""Narrow ContentChunk list to user-compatible types (no ThinkChunk).""" - return all(not isinstance(c, ThinkChunk) for c in chunks) + r"""Narrow ContentChunk list to user-compatible types.""" + return all(isinstance(c, (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk)) for c in chunks) def _aggregate_content_chunks_impl( @@ -265,7 +265,7 @@ def _normalize_tool_call(self, tool_call: ToolCall) -> ToolCall: id=tool_call.id, ) - def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: + def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[TextChunk | ThinkChunk]: r"""Validate and narrow content chunks for assistant messages. Pre-V15 normalizers only allow TextChunk and ThinkChunk. @@ -279,8 +279,10 @@ def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | Raises: InvalidRequestException: If unsupported chunk types are found. """ - if isinstance(content, str) or all(isinstance(c, (TextChunk, ThinkChunk)) for c in content): + if isinstance(content, str): return content + if all(isinstance(c, (TextChunk, ThinkChunk)) for c in content): + return cast(list[TextChunk | ThinkChunk], content) raise InvalidRequestException( f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" ) @@ -609,20 +611,9 @@ class InstructRequestNormalizerV15(InstructRequestNormalizerV13): _chunk_join_str: str = "" - def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: - r"""V15 accepts all ContentChunk types in assistant messages.""" - return content - def _narrow_tool_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: r"""V15 accepts all ContentChunk types in tool messages.""" - if isinstance(content, str): - return content - text_parts: list[str] = [] - for c in content: - if not isinstance(c, TextChunk): - return content - text_parts.append(c.text) - return "".join(text_parts) + return content def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk.""" @@ -632,37 +623,6 @@ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | lis raise InvalidRequestException("ThinkChunk in system message is not supported for V15") return content - def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: - r"""V15 tool messages preserve non-text content chunks. - - Text-only content is aggregated and JSON-normalized. Mixed content is preserved as-is. - """ - tool_messages: list[ToolMessageType] = [] - for message in messages: - assert isinstance(message, self._tool_message_class), "Expected tool message" - content = self._aggregate_content_chunks([message]) - validated = self._narrow_tool_content(content) - if isinstance(validated, str): - normalized_content: str | list[ContentChunk] = self._normalize_json_content(validated) - else: - normalized_content = validated - tool_messages.append( - self._tool_message_class( - content=normalized_content, tool_call_id=message.tool_call_id, name=message.name - ) - ) - - # Reorder by tool call order - id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)} - id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)} - tool_messages.sort( - key=lambda msg: ( - id_to_tool_call_idx.get(msg.tool_call_id or "null", float("inf")), - id_to_tool_result_idx[msg.tool_call_id], - ), - ) - return tool_messages - @staticmethod def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15": r"""Returns a normalizer for the V15 instruct request. diff --git a/tests/data/chat_templates/v15_audio.jinja b/tests/data/chat_templates/v15_audio.jinja index 381c72cb..bbd04595 100644 --- a/tests/data/chat_templates/v15_audio.jinja +++ b/tests/data/chat_templates/v15_audio.jinja @@ -158,7 +158,7 @@ {%- endif %} {%- if message['content'] %} - {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and audio', support_audio=true) -}} + {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_audio=false) -}} {%- endif %} {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %} diff --git a/tests/data/chat_templates/v15_image.jinja b/tests/data/chat_templates/v15_image.jinja index 58146ed7..35b872fe 100644 --- a/tests/data/chat_templates/v15_image.jinja +++ b/tests/data/chat_templates/v15_image.jinja @@ -169,7 +169,7 @@ {%- endif %} {%- if message['content'] %} - {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and image', support_images=true) -}} + {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_images=false) -}} {%- endif %} {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %} diff --git a/tests/data/chat_templates/v15_image_think.jinja b/tests/data/chat_templates/v15_image_think.jinja index 1460ec9b..5fab24c0 100644 --- a/tests/data/chat_templates/v15_image_think.jinja +++ b/tests/data/chat_templates/v15_image_think.jinja @@ -196,7 +196,7 @@ {%- endif %} {%- if message['content'] %} - {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text, thinking and image', support_thinking=true, support_images=true) -}} + {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and thinking', support_thinking=true, support_images=false) -}} {%- endif %} {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %} diff --git a/tests/integrations/chat_templates/fixtures_data.py b/tests/integrations/chat_templates/fixtures_data.py index c45f1b79..8fad3fcf 100644 --- a/tests/integrations/chat_templates/fixtures_data.py +++ b/tests/integrations/chat_templates/fixtures_data.py @@ -981,18 +981,6 @@ tools=_TOOLS, ) -REQUEST_ASSISTANT_IMAGE_TRAIN = ChatCompletionRequest( # type: ignore[type-var] - messages=[ - UserMessage(content="Generate an image for me."), - AssistantMessage( - content=[ - TextChunk(text="Here is the generated image."), - ImageURLChunk(image_url=_IMAGE_URL), - ], - ), - ], -) - REQUEST_SYSTEM_AUDIO_TRAIN = ChatCompletionRequest( # type: ignore[type-var] messages=[ SystemMessage( @@ -1162,7 +1150,6 @@ def _get_conversations( conversations.extend( [ REQUEST_TOOL_IMAGE_TRAIN, - REQUEST_ASSISTANT_IMAGE_TRAIN, ] ) if audio: diff --git a/tests/integrations/chat_templates/test_parity.py b/tests/integrations/chat_templates/test_parity.py index 0cce8183..8c90c97d 100644 --- a/tests/integrations/chat_templates/test_parity.py +++ b/tests/integrations/chat_templates/test_parity.py @@ -491,41 +491,6 @@ def test_dynamic_template_comprehensive(config: TestConfig) -> None: } ) - # V15+ multimodal: assistant messages with image/audio content - if config.version >= TokenizerVersion.v15 and config.image: - test_cases.append( - { - "name": "v15_assistant_with_image", - "messages": [ - {"role": "user", "content": "Show me"}, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Here is the image"}, - {"type": "image_url", "image_url": "http://example.com/img.png"}, - ], - }, - ], - } - ) - - if config.version >= TokenizerVersion.v15 and config.audio: - test_cases.append( - { - "name": "v15_assistant_with_audio", - "messages": [ - {"role": "user", "content": "Listen"}, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Here is audio"}, - {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, - ], - }, - ], - } - ) - # V15+ multimodal: system messages with audio content if config.version >= TokenizerVersion.v15 and config.audio: test_cases.append( @@ -551,8 +516,8 @@ def test_dynamic_template_comprehensive(config: TestConfig) -> None: # ThinkChunks in system messages are only supported in v13 (not v15+) skip_names_think_system = {"consecutive_systems_think"} skip_names_tools = {"with_tools_definition", "with_tool_call_and_result"} - skip_names_v15_image = {"v15_tool_with_image", "v15_assistant_with_image"} - skip_names_v15_audio = {"v15_tool_with_audio", "v15_assistant_with_audio", "v15_system_with_audio"} + skip_names_v15_image = {"v15_tool_with_image"} + skip_names_v15_audio = {"v15_tool_with_audio", "v15_system_with_audio"} # Not using parametrize here because test_cases are built dynamically based on the # config (version/image/audio/think) from the outer parametrize. Each sub-case is diff --git a/tests/integrations/chat_templates/unit/test_v15.py b/tests/integrations/chat_templates/unit/test_v15.py index 3401a8cf..39c4705a 100644 --- a/tests/integrations/chat_templates/unit/test_v15.py +++ b/tests/integrations/chat_templates/unit/test_v15.py @@ -323,62 +323,6 @@ def test_v15_tool_message_with_audio_content(self) -> None: assert "[AUDIO]" in output assert "result" in output - def test_v15_assistant_message_with_image_content(self) -> None: - r"""V15 image template renders assistant message with image content.""" - template = generate_chat_template( - spm=False, - tokenizer_version=TokenizerVersion.v15, - image_support=True, - audio_support=False, - thinking_support=False, - default_system_prompt=None, - plain_thinking_support=False, - use_special_token_variables=True, - ) - - messages: list[dict[str, Any]] = [ - {"role": "user", "content": "Show me"}, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Here is the image"}, - {"type": "image_url", "image_url": "http://example.com/img.png"}, - ], - }, - ] - - output = render_template(template, messages, reasoning_effort="none") - assert "Here is the image" in output - assert "[IMG]" in output - - def test_v15_assistant_message_with_audio_content(self) -> None: - r"""V15 audio template renders assistant message with audio content.""" - template = generate_chat_template( - spm=False, - tokenizer_version=TokenizerVersion.v15, - image_support=False, - audio_support=True, - thinking_support=False, - default_system_prompt=None, - plain_thinking_support=False, - use_special_token_variables=True, - ) - - messages: list[dict[str, Any]] = [ - {"role": "user", "content": "Listen"}, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Here is audio"}, - {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, - ], - }, - ] - - output = render_template(template, messages, reasoning_effort="none") - assert "Here is audio" in output - assert "[AUDIO]" in output - def test_v15_system_message_with_audio_content(self) -> None: r"""V15 audio template renders system message with audio content.""" template = generate_chat_template( diff --git a/tests/test_converters.py b/tests/test_converters.py index e09ad399..3cd546a8 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -1426,20 +1426,3 @@ def test_from_openai_drops_extra_fields(from_openai_call: Any, expected: Any) -> def test_direct_construction_still_strict(constructor: Any) -> None: with pytest.raises(Exception): constructor() - - -def test_assistant_message_to_openai_reasoning_with_multimodal() -> None: - r"""Reasoning format with multimodal content serializes non-think portion as list.""" - msg = AssistantMessage( - content=[ - ThinkChunk(thinking="Let me think", closed=True), - TextChunk(text="Here is the result"), - ImageURLChunk(image_url="data:image/png;base64,iVBORw0"), - ] - ) - result = msg.to_openai(reasoning_field_format=ReasoningFieldFormat.reasoning) - assert result["reasoning"] == "Let me think" - assert isinstance(result["content"], list) - assert len(result["content"]) == 2 - assert result["content"][0] == {"type": "text", "text": "Here is the result"} - assert result["content"][1]["type"] == "image_url" diff --git a/tests/test_normalization.py b/tests/test_normalization.py index b8a55921..b1d88d80 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1215,56 +1215,6 @@ def test_pre_v15_rejects_audio_in_tool_content(self) -> None: normalizer.from_chat_completion_request(request) -class TestAssistantMessageContentChunk: - @pytest.fixture() - def normalizer_v15(self) -> InstructRequestNormalizerV15: - return InstructRequestNormalizerV15( - UserMessage, - AssistantMessage, - ToolMessage, - SystemMessage, - InstructRequest, - ModelSettingsBuilder( - reasoning_effort=EnumBuilder[ReasoningEffort]( - values=list(ReasoningEffort), accepts_none=False, default=None - ) - ), - ) - - @pytest.fixture() - def normalizer_v13(self) -> InstructRequestNormalizerV13: - return InstructRequestNormalizerV13( - UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None - ) - - def test_v15_preserves_non_text_assistant_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None: - image_chunk = ImageURLChunk(image_url="https://example.com/image.png") - text_chunk = TextChunk(text="description") - request = ChatCompletionRequest( # type: ignore[type-var] - messages=[ - UserMessage(content="query"), - AssistantMessage(content=[image_chunk, text_chunk]), - ], - reasoning_effort=ReasoningEffort.high, - ) - parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assistant_msg = parsed.messages[1] - assert isinstance(assistant_msg, AssistantMessage) - assert isinstance(assistant_msg.content, list) - assert assistant_msg.content == [image_chunk, TextChunk(text="description")] - - def test_pre_v15_rejects_non_text_assistant_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: - image_chunk = ImageURLChunk(image_url="https://example.com/image.png") - request = mock_chat_completion( - messages=[ - UserMessage(content="query"), - AssistantMessage(content=[image_chunk]), - ], - ) - with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in assistant message"): - normalizer_v13.from_chat_completion_request(request) - - class TestSystemMessageContentChunk: def test_system_message_accepts_audio_chunk(self) -> None: msg = SystemMessage(content=[AudioChunk(input_audio="dGVzdA==")]) diff --git a/tests/test_tokenizer_v15.py b/tests/test_tokenizer_v15.py index 8b45e69a..739ec75a 100644 --- a/tests/test_tokenizer_v15.py +++ b/tests/test_tokenizer_v15.py @@ -432,29 +432,6 @@ def test_encode_chat_completion_with_multimodal_tool( assert len(encoded.images) == expected_images -@pytest.mark.parametrize( - ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"), - _ALL_MULTIMODAL_PARAMS, -) -def test_encode_chat_completion_with_multimodal_assistant( - content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk, - expected_audios: int, - expected_images: int, - tokenizer_factory: Callable[[], MistralTokenizer], -) -> None: - mistral_tokenizer = tokenizer_factory() - chat_request = ChatCompletionRequest( # type: ignore[type-var] - messages=[ - UserMessage(content="Hello"), - AssistantMessage(content=[TextChunk(text="Here is content"), content_chunk]), - UserMessage(content="Thanks"), - ], - ) - encoded = mistral_tokenizer.encode_chat_completion(chat_request) - assert len(encoded.audios) == expected_audios - assert len(encoded.images) == expected_images - - @pytest.mark.parametrize( ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"), _AUDIO_ONLY_PARAMS, From 823ee67728232eeb57045c06f413841f608ae021 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 10:27:51 +0200 Subject: [PATCH 08/47] Require content in UserMessage and SystemMessage from_openai --- src/mistral_common/protocol/instruct/messages.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py index 96aa5fd7..f8f0c366 100644 --- a/src/mistral_common/protocol/instruct/messages.py +++ b/src/mistral_common/protocol/instruct/messages.py @@ -149,7 +149,7 @@ def to_openai(self) -> dict[str, Any]: def from_openai(cls, openai_message: dict[str, Any]) -> "UserMessage": r"""Converts the OpenAI message to the Mistral format.""" return cls.model_validate( - {"role": openai_message["role"], "content": cls._content_from_openai(openai_message.get("content"))} + {"role": openai_message["role"], "content": cls._content_from_openai(openai_message["content"])} ) @@ -174,7 +174,7 @@ def to_openai(self) -> dict[str, Any]: def from_openai(cls, openai_message: dict[str, Any]) -> "SystemMessage": r"""Converts the OpenAI message to the Mistral format.""" return cls.model_validate( - {"role": openai_message["role"], "content": cls._content_from_openai(openai_message.get("content"))} + {"role": openai_message["role"], "content": cls._content_from_openai(openai_message["content"])} ) From 655180c805f6f7b7ab45892873beb033840f9b4a Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 10:42:02 +0200 Subject: [PATCH 09/47] Skip JSON normalization of tool content for V7+ normalizers --- src/mistral_common/protocol/instruct/normalize.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index 0dd77243..40610744 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -498,6 +498,21 @@ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | lis """ return content + def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: + """Normalize tool messages without JSON normalization. + + V7+ normalizers skip JSON content normalization for tool messages. + """ + tool_messages: list[ToolMessageType] = [] + for message in messages: + assert isinstance(message, self._tool_message_class), "Expected tool message" + content = self._aggregate_content_chunks([message]) + validated = self._narrow_tool_content(content) + tool_messages.append( + self._tool_message_class(content=validated, tool_call_id=message.tool_call_id, name=message.name) + ) + return tool_messages + def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]: aggregated: list[SystemMessageType] = [] for message in messages: From ebe26d9e142e05bba26113e1c7a56b249c37b2f3 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 11:03:24 +0200 Subject: [PATCH 10/47] Use TypeGuard for assistant content, simplify tool content narrowing, add tests --- .../protocol/instruct/normalize.py | 36 +++++---- tests/test_normalization.py | 75 +++++++++++++++++++ 2 files changed, 95 insertions(+), 16 deletions(-) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index 40610744..d199f96b 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -1,6 +1,6 @@ import json import warnings -from typing import Generic, Sequence, cast +from typing import Generic, Sequence from typing_extensions import TypeGuard, assert_never @@ -42,6 +42,13 @@ def _is_user_content( return all(isinstance(c, (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk)) for c in chunks) +def _is_assistant_content( + chunks: list[ContentChunk], +) -> TypeGuard[list[TextChunk | ThinkChunk]]: + r"""Narrow ContentChunk list to assistant-compatible types.""" + return all(isinstance(c, (TextChunk, ThinkChunk)) for c in chunks) + + def _aggregate_content_chunks_impl( contents: list[list[ContentChunk] | str | None], msg_join_str: str, @@ -268,7 +275,7 @@ def _normalize_tool_call(self, tool_call: ToolCall) -> ToolCall: def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[TextChunk | ThinkChunk]: r"""Validate and narrow content chunks for assistant messages. - Pre-V15 normalizers only allow TextChunk and ThinkChunk. + Only TextChunk and ThinkChunk are allowed. Args: content: The aggregated content chunks. @@ -281,36 +288,33 @@ def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | """ if isinstance(content, str): return content - if all(isinstance(c, (TextChunk, ThinkChunk)) for c in content): - return cast(list[TextChunk | ThinkChunk], content) + if _is_assistant_content(content): + return content raise InvalidRequestException( f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" ) def _narrow_tool_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: - r"""Validate and narrow content for tool messages. + r"""Validate that tool content is text-only. - Pre-V15 normalizers only allow text content. + Pre-V15 normalizers only allow text content. Since ``_aggregate_content_chunks`` + already collapses text-only content to ``str``, receiving a list here means + non-text chunks are present. Args: - content: The raw or aggregated content. + content: The aggregated content. Returns: - The content as a string. + The validated content as a string. Raises: InvalidRequestException: If non-text content chunks are found. """ if isinstance(content, str): return content - text_parts: list[str] = [] - for c in content: - if not isinstance(c, TextChunk): - raise InvalidRequestException( - f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}" - ) - text_parts.append(c.text) - return "".join(text_parts) + raise InvalidRequestException( + f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}" + ) def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]: return [] diff --git a/tests/test_normalization.py b/tests/test_normalization.py index b1d88d80..eb60f180 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1126,6 +1126,45 @@ def test_get_normalizer_version_mapping( assert normalizer._model_settings_builder == model_settings_builder +class TestAssistantContentNarrowing: + def test_accepts_text_and_think_chunks(self) -> None: + r"""Normalizer accepts TextChunk and ThinkChunk in assistant messages.""" + normalizer = get_normalizer(TokenizerVersion.v13) + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) + assistant_msg = parsed.messages[1] + assert isinstance(assistant_msg, AssistantMessage) + assert isinstance(assistant_msg.content, list) + assert len(assistant_msg.content) == 2 + + def test_accepts_string_content(self) -> None: + r"""Normalizer accepts string content in assistant messages.""" + normalizer = get_normalizer( + TokenizerVersion.v15, + model_settings_builder=ModelSettingsBuilder( + reasoning_effort=EnumBuilder[ReasoningEffort]( + values=list(ReasoningEffort), accepts_none=False, default=None + ) + ), + ) + request = ChatCompletionRequest( + messages=[ + UserMessage(content="query"), + AssistantMessage(content="plain text"), + ], + reasoning_effort=ReasoningEffort.high, + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) + assistant_msg = parsed.messages[1] + assert isinstance(assistant_msg, AssistantMessage) + assert assistant_msg.content == "plain text" + + class TestToolMessageContentChunk: @pytest.fixture() def normalizer_v15(self) -> InstructRequestNormalizerV15: @@ -1214,6 +1253,42 @@ def test_pre_v15_rejects_audio_in_tool_content(self) -> None: with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"): normalizer.from_chat_completion_request(request) + def test_base_normalizer_json_normalizes_tool_content(self) -> None: + r"""Base normalizer (v1-v3) JSON-normalizes tool message content.""" + normalizer = InstructRequestNormalizer( + UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None + ) + messy_json = '{"key" : "value" , "num": 1}' + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), + ToolMessage(content=messy_json, tool_call_id="c1"), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) + tool_msg = parsed.messages[2] + assert isinstance(tool_msg, ToolMessage) + assert tool_msg.content == '{"key": "value", "num": 1}' + + def test_v7_skips_json_normalization_on_tool_content(self) -> None: + r"""V7+ normalizers do not JSON-normalize tool message content.""" + normalizer = InstructRequestNormalizerV7( + UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None + ) + messy_json = '{"key" : "value" , "num": 1}' + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), + ToolMessage(content=messy_json, tool_call_id="c1"), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) + tool_msg = parsed.messages[2] + assert isinstance(tool_msg, ToolMessage) + assert tool_msg.content == messy_json + class TestSystemMessageContentChunk: def test_system_message_accepts_audio_chunk(self) -> None: From fee10213ad45ce381532169e302e9a13cff767b6 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 12:29:00 +0200 Subject: [PATCH 11/47] Remove _narrow_tool_content, inline logic directly --- .../protocol/instruct/messages.py | 2 +- .../protocol/instruct/normalize.py | 40 ++++--------------- 2 files changed, 8 insertions(+), 34 deletions(-) diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py index f8f0c366..c0b922d3 100644 --- a/src/mistral_common/protocol/instruct/messages.py +++ b/src/mistral_common/protocol/instruct/messages.py @@ -354,7 +354,7 @@ def to_openai(self) -> dict[str, Any]: @classmethod def from_openai(cls, openai_message: dict[str, Any]) -> "ToolMessage": r"""Converts the OpenAI message to the Mistral format.""" - content = cls._content_from_openai(openai_message.get("content")) + content = cls._content_from_openai(openai_message["content"]) tool_message = cls.model_validate( { "role": openai_message["role"], diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index d199f96b..c69fe912 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -252,11 +252,12 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s for message in messages: assert isinstance(message, self._tool_message_class), "Expected tool message" content = self._aggregate_content_chunks([message]) - validated = self._narrow_tool_content(content) - if isinstance(validated, str): - normalized_content: str | list[ContentChunk] = self._normalize_json_content(validated) - else: - normalized_content = validated + if not isinstance(content, str): + raise InvalidRequestException( + f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}" + ) + normalized_content = self._normalize_json_content(content) + tool_messages.append( self._tool_message_class( content=normalized_content, tool_call_id=message.tool_call_id, name=message.name @@ -294,28 +295,6 @@ def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" ) - def _narrow_tool_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: - r"""Validate that tool content is text-only. - - Pre-V15 normalizers only allow text content. Since ``_aggregate_content_chunks`` - already collapses text-only content to ``str``, receiving a list here means - non-text chunks are present. - - Args: - content: The aggregated content. - - Returns: - The validated content as a string. - - Raises: - InvalidRequestException: If non-text content chunks are found. - """ - if isinstance(content, str): - return content - raise InvalidRequestException( - f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}" - ) - def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]: return [] @@ -511,9 +490,8 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s for message in messages: assert isinstance(message, self._tool_message_class), "Expected tool message" content = self._aggregate_content_chunks([message]) - validated = self._narrow_tool_content(content) tool_messages.append( - self._tool_message_class(content=validated, tool_call_id=message.tool_call_id, name=message.name) + self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name) ) return tool_messages @@ -630,10 +608,6 @@ class InstructRequestNormalizerV15(InstructRequestNormalizerV13): _chunk_join_str: str = "" - def _narrow_tool_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: - r"""V15 accepts all ContentChunk types in tool messages.""" - return content - def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk.""" if isinstance(content, str): From dc09e05fab29954f0782f6ea6331fb52ff246e11 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 14:31:16 +0200 Subject: [PATCH 12/47] Fix mypy errors and enable V15 assistant multimodal in chat templates --- .../chat_templates/template_generator.py | 21 ++++++++++++++----- tests/test_converters.py | 2 +- tests/test_normalization.py | 2 +- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py index 868cf1c2..eaf66a15 100644 --- a/src/mistral_common/integrations/chat_templates/template_generator.py +++ b/src/mistral_common/integrations/chat_templates/template_generator.py @@ -1230,10 +1230,17 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str: """ lines = [] + assistant_supports_image = config.image_support and config.version >= TokenizerVersion.v15 + assistant_supports_audio = config.audio_support and config.version >= TokenizerVersion.v15 + + comment_parts = ["text"] if config.any_thinking_support: - chunk_types = "text and thinking" - else: - chunk_types = "text" + comment_parts.append("thinking") + if assistant_supports_image: + comment_parts.append("image") + if assistant_supports_audio: + comment_parts.append("audio") + chunk_types = _join_types_desc(comment_parts) comment = f"{{#- Assistant messages supports {chunk_types} content. #}}" lines.append("") @@ -1264,13 +1271,17 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str: desc_parts = ["text"] if config.any_thinking_support: desc_parts.append("thinking") + if assistant_supports_image: + desc_parts.append("image") + if assistant_supports_audio: + desc_parts.append("audio") rc_call_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'" if config.any_thinking_support: rc_call_args += ", support_thinking=true" if config.image_support: - rc_call_args += ", support_images=false" + rc_call_args += f", support_images={'true' if assistant_supports_image else 'false'}" if config.audio_support: - rc_call_args += ", support_audio=false" + rc_call_args += f", support_audio={'true' if assistant_supports_audio else 'false'}" lines.append(" {%- if message['content'] %}") diff --git a/tests/test_converters.py b/tests/test_converters.py index 3cd546a8..d093d945 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -588,7 +588,7 @@ def test_non_leading_think_chunks_construction_ok() -> None: ) def test_non_leading_think_chunks_to_openai_raises(content: list[TextChunk | ThinkChunk]) -> None: """to_openai raises when ThinkChunks are not leading.""" - msg = AssistantMessage(content=content) # type: ignore[arg-type] + msg = AssistantMessage(content=content) with pytest.raises(InvalidAssistantMessageException, match="ThinkChunks must be leading"): msg.to_openai() diff --git a/tests/test_normalization.py b/tests/test_normalization.py index eb60f180..001fd34a 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1152,7 +1152,7 @@ def test_accepts_string_content(self) -> None: ) ), ) - request = ChatCompletionRequest( + request = ChatCompletionRequest[ChatMessage]( messages=[ UserMessage(content="query"), AssistantMessage(content="plain text"), From db684eb1846bcaefa23b652e4f5d0395b3a562d4 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 14:34:57 +0200 Subject: [PATCH 13/47] Restore V7 tool content validation, add V15 override accepting all chunks --- .../protocol/instruct/normalize.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index c69fe912..ff742266 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -484,12 +484,17 @@ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | lis def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: """Normalize tool messages without JSON normalization. - V7+ normalizers skip JSON content normalization for tool messages. + V7+ normalizers skip JSON content normalization for tool messages but still + reject non-text content chunks. """ tool_messages: list[ToolMessageType] = [] for message in messages: assert isinstance(message, self._tool_message_class), "Expected tool message" content = self._aggregate_content_chunks([message]) + if not isinstance(content, str): + raise InvalidRequestException( + f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}" + ) tool_messages.append( self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name) ) @@ -608,6 +613,26 @@ class InstructRequestNormalizerV15(InstructRequestNormalizerV13): _chunk_join_str: str = "" + def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: + r"""V15 accepts all ContentChunk types in tool messages.""" + tool_messages: list[ToolMessageType] = [] + for message in messages: + assert isinstance(message, self._tool_message_class), "Expected tool message" + content = self._aggregate_content_chunks([message]) + tool_messages.append( + self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name) + ) + # Reorder tool messages based on the tool call order (from V13). + id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)} + id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)} + tool_messages.sort( + key=lambda msg: ( + id_to_tool_call_idx.get(msg.tool_call_id or "null", float("inf")), + id_to_tool_result_idx[msg.tool_call_id], + ), + ) + return tool_messages + def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk.""" if isinstance(content, str): From affebf53ad64332e2f9848e77c450942b796723c Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 14:39:09 +0200 Subject: [PATCH 14/47] Regenerate V15 golden chat templates with assistant multimodal support --- tests/data/chat_templates/v15_audio.jinja | 4 ++-- tests/data/chat_templates/v15_image.jinja | 4 ++-- tests/data/chat_templates/v15_image_think.jinja | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/data/chat_templates/v15_audio.jinja b/tests/data/chat_templates/v15_audio.jinja index bbd04595..0d843e90 100644 --- a/tests/data/chat_templates/v15_audio.jinja +++ b/tests/data/chat_templates/v15_audio.jinja @@ -151,14 +151,14 @@ {{- render_content(message['content'], 'user message content', supported_types_desc='text, input_audio and audio_url', support_audio=true) -}} {{- '[/INST]' }} - {#- Assistant messages supports text content. #} + {#- Assistant messages supports text and audio content. #} {%- elif message['role'] == 'assistant' %} {%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %} {{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }} {%- endif %} {%- if message['content'] %} - {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_audio=false) -}} + {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and audio', support_audio=true) -}} {%- endif %} {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %} diff --git a/tests/data/chat_templates/v15_image.jinja b/tests/data/chat_templates/v15_image.jinja index 35b872fe..bd3eefa3 100644 --- a/tests/data/chat_templates/v15_image.jinja +++ b/tests/data/chat_templates/v15_image.jinja @@ -162,14 +162,14 @@ {{- render_content(user_content, 'user message content', supported_types_desc='text, image and image_url', support_images=true) -}} {{- '[/INST]' }} - {#- Assistant messages supports text content. #} + {#- Assistant messages supports text and image content. #} {%- elif message['role'] == 'assistant' %} {%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %} {{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }} {%- endif %} {%- if message['content'] %} - {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_images=false) -}} + {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and image', support_images=true) -}} {%- endif %} {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %} diff --git a/tests/data/chat_templates/v15_image_think.jinja b/tests/data/chat_templates/v15_image_think.jinja index 5fab24c0..570cb7b8 100644 --- a/tests/data/chat_templates/v15_image_think.jinja +++ b/tests/data/chat_templates/v15_image_think.jinja @@ -189,14 +189,14 @@ {{- render_content(user_content, 'user message content', supported_types_desc='text, image and image_url', support_thinking=false, support_images=true) -}} {{- '[/INST]' }} - {#- Assistant messages supports text and thinking content. #} + {#- Assistant messages supports text, thinking and image content. #} {%- elif message['role'] == 'assistant' %} {%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %} {{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }} {%- endif %} {%- if message['content'] %} - {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and thinking', support_thinking=true, support_images=false) -}} + {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text, thinking and image', support_thinking=true, support_images=true) -}} {%- endif %} {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %} From a8b65ccdb1834940362910cbf3cd93115d112745 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 14:50:04 +0200 Subject: [PATCH 15/47] Restore _are_text_chunks guard, remove assistant image/audio from templates, add V15 tool sort test --- .../chat_templates/template_generator.py | 15 ++-------- .../protocol/instruct/messages.py | 9 ++++-- .../protocol/instruct/normalize.py | 20 +++++-------- tests/data/chat_templates/v15_audio.jinja | 4 +-- tests/data/chat_templates/v15_image.jinja | 4 +-- .../data/chat_templates/v15_image_think.jinja | 4 +-- .../transformers/test_core_parity.py | 4 --- tests/test_normalization.py | 29 +++++++++++++++++++ 8 files changed, 52 insertions(+), 37 deletions(-) diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py index eaf66a15..10cf02fe 100644 --- a/src/mistral_common/integrations/chat_templates/template_generator.py +++ b/src/mistral_common/integrations/chat_templates/template_generator.py @@ -1230,16 +1230,9 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str: """ lines = [] - assistant_supports_image = config.image_support and config.version >= TokenizerVersion.v15 - assistant_supports_audio = config.audio_support and config.version >= TokenizerVersion.v15 - comment_parts = ["text"] if config.any_thinking_support: comment_parts.append("thinking") - if assistant_supports_image: - comment_parts.append("image") - if assistant_supports_audio: - comment_parts.append("audio") chunk_types = _join_types_desc(comment_parts) comment = f"{{#- Assistant messages supports {chunk_types} content. #}}" @@ -1271,17 +1264,13 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str: desc_parts = ["text"] if config.any_thinking_support: desc_parts.append("thinking") - if assistant_supports_image: - desc_parts.append("image") - if assistant_supports_audio: - desc_parts.append("audio") rc_call_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'" if config.any_thinking_support: rc_call_args += ", support_thinking=true" if config.image_support: - rc_call_args += f", support_images={'true' if assistant_supports_image else 'false'}" + rc_call_args += ", support_images=false" if config.audio_support: - rc_call_args += f", support_audio={'true' if assistant_supports_audio else 'false'}" + rc_call_args += ", support_audio=false" lines.append(" {%- if message['content'] %}") diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py index c0b922d3..3b9c8c02 100644 --- a/src/mistral_common/protocol/instruct/messages.py +++ b/src/mistral_common/protocol/instruct/messages.py @@ -32,6 +32,11 @@ def _are_think_chunks(chunks: Sequence[TextChunk | ThinkChunk]) -> TypeGuard[lis return all(isinstance(c, ThinkChunk) for c in chunks) +def _are_text_chunks(chunks: Sequence[TextChunk | ThinkChunk]) -> TypeGuard[list[TextChunk]]: + r"""Narrow a chunk list to TextChunk list.""" + return all(isinstance(c, TextChunk) for c in chunks) + + class ReasoningFieldFormat(str, Enum): r"""How to serialize leading `ThinkChunk` in `AssistantMessage.to_openai()`. @@ -242,8 +247,8 @@ def to_openai( out_dict["content"] = self._content_to_openai(self.content) case ReasoningFieldFormat.reasoning | ReasoningFieldFormat.reasoning_content: think_chunks, content_chunks = self.content[: last_think_idx + 1], self.content[last_think_idx + 1 :] - if not _are_think_chunks(think_chunks): - raise RuntimeError("Expected only ThinkChunks in the leading portion.") + if not _are_think_chunks(think_chunks) or not _are_text_chunks(content_chunks): + raise RuntimeError("Impossible, only think or content chunks should have been present.") if len(think_chunks) > 0: out_dict[reasoning_field_format.value] = "\n".join(tc.thinking for tc in think_chunks) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index ff742266..810e1a3b 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -588,17 +588,21 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest[UATS, Tool], None ) - def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: - tool_messages: list[ToolMessageType] = super()._aggregate_tool_messages(messages, latest_call_ids) + @staticmethod + def _inplace_sort_tool_messages(tool_messages: list[ToolMessageType], latest_call_ids: list[str]) -> None: id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)} id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)} # First order by tool call idx and then by tool result idx - tool_messages.sort( + return tool_messages.sort( key=lambda msg: ( id_to_tool_call_idx.get(msg.tool_call_id or "null", float("inf")), id_to_tool_result_idx[msg.tool_call_id], ), ) + + def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: + tool_messages: list[ToolMessageType] = super()._aggregate_tool_messages(messages, latest_call_ids) + self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids) return tool_messages @@ -622,15 +626,7 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s tool_messages.append( self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name) ) - # Reorder tool messages based on the tool call order (from V13). - id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)} - id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)} - tool_messages.sort( - key=lambda msg: ( - id_to_tool_call_idx.get(msg.tool_call_id or "null", float("inf")), - id_to_tool_result_idx[msg.tool_call_id], - ), - ) + self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids) return tool_messages def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: diff --git a/tests/data/chat_templates/v15_audio.jinja b/tests/data/chat_templates/v15_audio.jinja index 0d843e90..bbd04595 100644 --- a/tests/data/chat_templates/v15_audio.jinja +++ b/tests/data/chat_templates/v15_audio.jinja @@ -151,14 +151,14 @@ {{- render_content(message['content'], 'user message content', supported_types_desc='text, input_audio and audio_url', support_audio=true) -}} {{- '[/INST]' }} - {#- Assistant messages supports text and audio content. #} + {#- Assistant messages supports text content. #} {%- elif message['role'] == 'assistant' %} {%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %} {{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }} {%- endif %} {%- if message['content'] %} - {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and audio', support_audio=true) -}} + {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_audio=false) -}} {%- endif %} {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %} diff --git a/tests/data/chat_templates/v15_image.jinja b/tests/data/chat_templates/v15_image.jinja index bd3eefa3..35b872fe 100644 --- a/tests/data/chat_templates/v15_image.jinja +++ b/tests/data/chat_templates/v15_image.jinja @@ -162,14 +162,14 @@ {{- render_content(user_content, 'user message content', supported_types_desc='text, image and image_url', support_images=true) -}} {{- '[/INST]' }} - {#- Assistant messages supports text and image content. #} + {#- Assistant messages supports text content. #} {%- elif message['role'] == 'assistant' %} {%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %} {{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }} {%- endif %} {%- if message['content'] %} - {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and image', support_images=true) -}} + {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_images=false) -}} {%- endif %} {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %} diff --git a/tests/data/chat_templates/v15_image_think.jinja b/tests/data/chat_templates/v15_image_think.jinja index 570cb7b8..5fab24c0 100644 --- a/tests/data/chat_templates/v15_image_think.jinja +++ b/tests/data/chat_templates/v15_image_think.jinja @@ -189,14 +189,14 @@ {{- render_content(user_content, 'user message content', supported_types_desc='text, image and image_url', support_thinking=false, support_images=true) -}} {{- '[/INST]' }} - {#- Assistant messages supports text, thinking and image content. #} + {#- Assistant messages supports text and thinking content. #} {%- elif message['role'] == 'assistant' %} {%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %} {{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }} {%- endif %} {%- if message['content'] %} - {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text, thinking and image', support_thinking=true, support_images=true) -}} + {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and thinking', support_thinking=true, support_images=false) -}} {%- endif %} {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %} diff --git a/tests/integrations/chat_templates/transformers/test_core_parity.py b/tests/integrations/chat_templates/transformers/test_core_parity.py index ffb47608..c579e53c 100644 --- a/tests/integrations/chat_templates/transformers/test_core_parity.py +++ b/tests/integrations/chat_templates/transformers/test_core_parity.py @@ -227,10 +227,6 @@ def test_invalid_chunks( desc_parts = ["text"] if config.think: desc_parts.append("thinking") - if is_v15_plus and config.image: - desc_parts.append("image") - if is_v15_plus and config.audio: - desc_parts.append("audio") if len(desc_parts) == 1: chunks = "text" else: diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 001fd34a..6b396e06 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1203,6 +1203,35 @@ def test_v15_preserves_non_text_tool_content(self, normalizer_v15: InstructReque assert isinstance(tool_msg.content, list) assert tool_msg.content == [image_chunk] + def test_v15_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + image_chunk_1 = ImageURLChunk(image_url="https://example.com/img1.png") + image_chunk_2 = ImageURLChunk(image_url="https://example.com/img2.png") + request = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + UserMessage(content="query"), + AssistantMessage( + tool_calls=[ + ToolCall(function=FunctionCall(name="fn1", arguments="{}"), id="c1"), + ToolCall(function=FunctionCall(name="fn2", arguments="{}"), id="c2"), + ] + ), + ToolMessage(content=[image_chunk_2], tool_call_id="c2"), + ToolMessage(content=[image_chunk_1], tool_call_id="c1"), + ], + reasoning_effort=ReasoningEffort.high, + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) + + tool_msg_1 = parsed.messages[2] + assert isinstance(tool_msg_1, ToolMessage) + assert tool_msg_1.tool_call_id == "c1" + assert tool_msg_1.content == [image_chunk_1] + + tool_msg_2 = parsed.messages[3] + assert isinstance(tool_msg_2, ToolMessage) + assert tool_msg_2.tool_call_id == "c2" + assert tool_msg_2.content == [image_chunk_2] + def test_pre_v15_rejects_non_text_tool_content(self) -> None: r"""Pre-V15 normalizer raises InvalidRequestException for non-text tool content.""" normalizer = get_normalizer(TokenizerVersion.v13) From d16b7ef1edcb1a0beab99010072b9b7920ff2c8e Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 15:42:22 +0200 Subject: [PATCH 16/47] Apply suggestions from code review Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> --- src/mistral_common/protocol/instruct/messages.py | 2 +- src/mistral_common/protocol/instruct/normalize.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py index 3b9c8c02..2b922dfb 100644 --- a/src/mistral_common/protocol/instruct/messages.py +++ b/src/mistral_common/protocol/instruct/messages.py @@ -252,7 +252,7 @@ def to_openai( if len(think_chunks) > 0: out_dict[reasoning_field_format.value] = "\n".join(tc.thinking for tc in think_chunks) - if len(content_chunks) == 1 and isinstance(content_chunks[0], TextChunk): + if len(content_chunks) == 1: out_dict["content"] = content_chunks[0].text elif content_chunks: out_dict["content"] = self._content_to_openai(content_chunks) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index 810e1a3b..45969660 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -287,9 +287,7 @@ def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | Raises: InvalidRequestException: If unsupported chunk types are found. """ - if isinstance(content, str): - return content - if _is_assistant_content(content): + if isinstance(content, str) or _is_assistant_content(content): return content raise InvalidRequestException( f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" From a701ae337797ac2266ee2fe5793202382ec7f346 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 15:45:55 +0200 Subject: [PATCH 17/47] Use content chunk type aliases in TypeGuard and narrowing function signatures --- .../protocol/instruct/messages.py | 4 +- .../protocol/instruct/normalize.py | 38 ++++++++++++++----- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py index 2b922dfb..6ca46b2f 100644 --- a/src/mistral_common/protocol/instruct/messages.py +++ b/src/mistral_common/protocol/instruct/messages.py @@ -27,12 +27,12 @@ ) -def _are_think_chunks(chunks: Sequence[TextChunk | ThinkChunk]) -> TypeGuard[list[ThinkChunk]]: +def _are_think_chunks(chunks: Sequence[AssistantContentChunk]) -> TypeGuard[list[ThinkChunk]]: r"""Narrow a chunk list to ThinkChunk list.""" return all(isinstance(c, ThinkChunk) for c in chunks) -def _are_text_chunks(chunks: Sequence[TextChunk | ThinkChunk]) -> TypeGuard[list[TextChunk]]: +def _are_text_chunks(chunks: Sequence[AssistantContentChunk]) -> TypeGuard[list[TextChunk]]: r"""Narrow a chunk list to TextChunk list.""" return all(isinstance(c, TextChunk) for c in chunks) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index 45969660..b774b102 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -6,13 +6,16 @@ from mistral_common.exceptions import InvalidRequestException from mistral_common.protocol.instruct.chunk import ( + AssistantContentChunk, AudioChunk, AudioURLChunk, ContentChunk, ImageChunk, ImageURLChunk, + SystemContentChunk, TextChunk, ThinkChunk, + UserContentChunk, ) from mistral_common.protocol.instruct.messages import ( UATS, @@ -37,18 +40,25 @@ def _is_user_content( chunks: list[ContentChunk], -) -> TypeGuard[list[TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk]]: +) -> TypeGuard[list[UserContentChunk]]: r"""Narrow ContentChunk list to user-compatible types.""" return all(isinstance(c, (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk)) for c in chunks) def _is_assistant_content( chunks: list[ContentChunk], -) -> TypeGuard[list[TextChunk | ThinkChunk]]: +) -> TypeGuard[list[AssistantContentChunk]]: r"""Narrow ContentChunk list to assistant-compatible types.""" return all(isinstance(c, (TextChunk, ThinkChunk)) for c in chunks) +def _is_system_content( + chunks: list[ContentChunk], +) -> TypeGuard[list[SystemContentChunk]]: + r"""Narrow ContentChunk list to system-compatible types.""" + return all(isinstance(c, (TextChunk, AudioChunk, ThinkChunk)) for c in chunks) + + def _aggregate_content_chunks_impl( contents: list[list[ContentChunk] | str | None], msg_join_str: str, @@ -273,7 +283,7 @@ def _normalize_tool_call(self, tool_call: ToolCall) -> ToolCall: id=tool_call.id, ) - def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[TextChunk | ThinkChunk]: + def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[AssistantContentChunk]: r"""Validate and narrow content chunks for assistant messages. Only TextChunk and ThinkChunk are allowed. @@ -465,7 +475,7 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest[UATS, Tool], None ) - def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: + def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]: r"""Validate content chunks for system messages. V7+ accepts all SystemContentChunk types (Pydantic validates at construction). @@ -476,8 +486,15 @@ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | lis Returns: The validated content. + + Raises: + InvalidRequestException: If unsupported chunk types are found. """ - return content + if isinstance(content, str) or _is_system_content(content): + return content + raise InvalidRequestException( + f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}" + ) def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: """Normalize tool messages without JSON normalization. @@ -627,13 +644,14 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids) return tool_messages - def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]: + def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]: r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk.""" - if isinstance(content, str): - return content - if any(isinstance(c, ThinkChunk) for c in content): + validated = super()._narrow_system_content(content) + if isinstance(validated, str): + return validated + if any(isinstance(c, ThinkChunk) for c in validated): raise InvalidRequestException("ThinkChunk in system message is not supported for V15") - return content + return validated @staticmethod def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15": From a974af19add4131686bb2b5a2dc86e0931915742 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 16:07:50 +0200 Subject: [PATCH 18/47] Revert encode_assistant_message to list[int], reject audio in V7-V13 system messages --- .../protocol/instruct/normalize.py | 19 ++++++---- .../tokens/tokenizers/instruct.py | 37 ++++++++----------- tests/experimental/test_app.py | 4 +- tests/guidance/test_guidance.py | 4 +- tests/test_normalization.py | 26 ++++++++++++- tests/test_tokenizer_v11.py | 16 ++------ tests/test_tokenizer_v13.py | 6 +-- 7 files changed, 61 insertions(+), 51 deletions(-) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index b774b102..b84b4f5b 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -478,8 +478,8 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]: r"""Validate content chunks for system messages. - V7+ accepts all SystemContentChunk types (Pydantic validates at construction). - V15 overrides to reject ThinkChunk. + V7-V13 accepts text and thinking chunks. V15 overrides to allow audio + and reject thinking. Args: content: The aggregated content chunks. @@ -491,6 +491,8 @@ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | lis InvalidRequestException: If unsupported chunk types are found. """ if isinstance(content, str) or _is_system_content(content): + if isinstance(content, list) and any(isinstance(c, (AudioChunk, AudioURLChunk)) for c in content): + raise InvalidRequestException("Audio chunks in system messages are only supported from V15") return content raise InvalidRequestException( f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}" @@ -646,12 +648,15 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]: r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk.""" - validated = super()._narrow_system_content(content) - if isinstance(validated, str): - return validated - if any(isinstance(c, ThinkChunk) for c in validated): + if isinstance(content, str): + return content + if not _is_system_content(content): + raise InvalidRequestException( + f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}" + ) + if any(isinstance(c, ThinkChunk) for c in content): raise InvalidRequestException("ThinkChunk in system message is not supported for V15") - return validated + return content @staticmethod def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15": diff --git a/src/mistral_common/tokens/tokenizers/instruct.py b/src/mistral_common/tokens/tokenizers/instruct.py index 08de0c95..3ec6730e 100644 --- a/src/mistral_common/tokens/tokenizers/instruct.py +++ b/src/mistral_common/tokens/tokenizers/instruct.py @@ -117,7 +117,7 @@ def encode_tool_message( @abstractmethod def encode_assistant_message( self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool - ) -> tuple[list[int], list[np.ndarray], list[Audio]]: + ) -> list[int]: r"""Encode an assistant message. Raises: @@ -200,11 +200,9 @@ def encode_instruct( elif isinstance(msg, AssistantMessage): continue_message = request.continue_final_message and (msg_idx == len(request.messages) - 1) - new_tokens, new_images, new_audios = self.encode_assistant_message( + new_tokens = self.encode_assistant_message( msg, msg_idx < last_user_idx, continue_message=continue_message ) - images.extend(new_images) - audios.extend(new_audios) if msg_idx == len(request.messages) - 1: prefix_ids = new_tokens elif isinstance(msg, SystemMessage): @@ -338,7 +336,7 @@ def encode_tool_message( def encode_assistant_message( self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool - ) -> tuple[list[int], list[np.ndarray], list[Audio]]: + ) -> list[int]: r"""Encode an assistant message. Args: @@ -348,7 +346,7 @@ def encode_assistant_message( Only use this if the assistant message is the last message. Returns: - The encoded tokens, images, and audios. + The encoded tokens. """ assert isinstance(message, AssistantMessage), message if message.tool_calls is not None and len(message.tool_calls) > 0: @@ -364,7 +362,7 @@ def encode_assistant_message( raise TokenizerException(f"{message.content} // {message.tool_calls}") if not message.prefix and not continue_message: curr_tokens.append(self.tokenizer.eos_id) - return curr_tokens, [], [] + return curr_tokens def encode_think(self, chunk: ThinkChunk) -> list[int]: r"""Encode a think chunk. @@ -562,7 +560,7 @@ def _encode_settings( def encode_assistant_message( self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool - ) -> tuple[list[int], list[np.ndarray], list[Audio]]: + ) -> list[int]: r"""Encode an assistant message. Args: @@ -573,7 +571,7 @@ def encode_assistant_message( Only use this if the assistant message is the last message. Returns: - The encoded tokens, images, and audios. + The encoded tokens. """ if message.tool_calls and message.content: raise ValueError(f"Cannot have tool calls and content defined in the same assistant message {message}") @@ -585,7 +583,7 @@ def encode_assistant_message( if message.tool_calls: if is_before_last_user_message: # don't tokenize tool call before last user message - return [], [], [] + return [] curr_tokens = self._encode_tool_calls_in_assistant_message(message) elif message.content: assert isinstance(message.content, str), "Message content must be a string for tokenizer < V7" @@ -594,7 +592,7 @@ def encode_assistant_message( raise TokenizerException(f"Invalid assistant message: {message.content}") if not message.prefix and not continue_message: curr_tokens.append(self.tokenizer.eos_id) - return curr_tokens, [], [] + return curr_tokens def _encode_infilling(self, text: str) -> list[int]: r"""Remove prefix space in the case of SentencePieceTokenizers.""" @@ -691,7 +689,7 @@ def encode_tool_message( def encode_assistant_message( self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool - ) -> tuple[list[int], list[np.ndarray], list[Audio]]: + ) -> list[int]: r"""Encode an assistant message. Note: @@ -705,7 +703,7 @@ def encode_assistant_message( is_before_last_user_message: Not used. Returns: - The encoded tokens, images, and audios. + The encoded tokens. """ return super().encode_assistant_message(message, False, continue_message) @@ -1130,7 +1128,7 @@ def encode_tool_message( def encode_assistant_message( self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool - ) -> tuple[list[int], list[np.ndarray], list[Audio]]: + ) -> list[int]: r"""Encode an assistant message. Args: @@ -1140,7 +1138,7 @@ def encode_assistant_message( Only use this if the assistant message is the last message. Returns: - The encoded tokens, images, and audios. + The encoded tokens. """ if not message.content and not message.tool_calls: raise TokenizerException(f"Invalid assistant message: {message}") @@ -1150,22 +1148,17 @@ def encode_assistant_message( ) curr_tokens: list = [] - images: list[np.ndarray] = [] - audios: list[Audio] = [] if message.content: if isinstance(message.content, str): curr_tokens = self._encode_normal_content_assistant_message(message) elif isinstance(message.content, list): - content_tokens, new_images, new_audios = self._encode_content_chunks(message.content) - curr_tokens += content_tokens - images.extend(new_images) - audios.extend(new_audios) + curr_tokens += self._encode_content_chunks(message.content)[0] if message.tool_calls: curr_tokens += self._encode_tool_calls_in_assistant_message(message) if not message.prefix and not continue_message: curr_tokens.append(self.tokenizer.eos_id) - return curr_tokens, images, audios + return curr_tokens def _encode_audio_for_speech_request(self, ref_audio: str | bytes | None, voice: str | None) -> Tokenized: r"""Encode reference audio or voice preset into a Tokenized object. diff --git a/tests/experimental/test_app.py b/tests/experimental/test_app.py index 27bb9dd8..5b38d2c6 100644 --- a/tests/experimental/test_app.py +++ b/tests/experimental/test_app.py @@ -394,7 +394,7 @@ def test_detokenize_assistant_message( def test_detokenize_assistant_message_think_chunks( assistant_message: AssistantMessage, mistral_tokenizer_v13: MistralTokenizer, tekken_v13_client: TestClient ) -> None: - encoded_tokens, _, _ = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined] + encoded_tokens = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined] assistant_message, False, False ) @@ -464,7 +464,7 @@ def test_generate( engine_request: dict | ChatCompletionRequest | OpenAIChatCompletionRequest, output_assistant_message: AssistantMessage, ) -> None: - output_tokens, _, _ = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined] + output_tokens = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined] output_assistant_message, False, False ) if output_assistant_message.tool_calls: diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py index cac90e4c..a5c7a8a4 100644 --- a/tests/guidance/test_guidance.py +++ b/tests/guidance/test_guidance.py @@ -228,7 +228,7 @@ def _encode_content( tokenizer = instruct_tokenizer.tokenizer if isinstance(content, str): - result, _, _ = instruct_tokenizer.encode_assistant_message( + result = instruct_tokenizer.encode_assistant_message( AssistantMessage(content=content), is_before_last_user_message=False, continue_message=False ) return result @@ -238,7 +238,7 @@ def _encode_content( tokens: list[int] = [] if content_chunks: - tokens, _, _ = instruct_tokenizer.encode_assistant_message( + tokens = instruct_tokenizer.encode_assistant_message( AssistantMessage(content=content_chunks), is_before_last_user_message=False, continue_message=False, diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 6b396e06..eb6de338 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1349,8 +1349,8 @@ def test_v15_rejects_think_in_system_message(self) -> None: with pytest.raises(InvalidRequestException, match="ThinkChunk"): normalizer.from_chat_completion_request(request) - def test_v7_normalization_preserves_audio_in_system_message(self) -> None: - r"""V7 normalizer preserves AudioChunk in system messages.""" + def test_v7_rejects_audio_in_system_content(self) -> None: + r"""V7 normalizer rejects AudioChunk in system messages.""" normalizer = InstructRequestNormalizerV7.normalizer() request = mock_chat_completion( messages=[ @@ -1358,6 +1358,28 @@ def test_v7_normalization_preserves_audio_in_system_message(self) -> None: UserMessage(content="test"), ] ) + with pytest.raises( + InvalidRequestException, match="Audio chunks in system messages are only supported from V15" + ): + normalizer.from_chat_completion_request(request) + + def test_v15_preserves_audio_in_system_message(self) -> None: + r"""V15 normalizer preserves AudioChunk in system messages.""" + normalizer = get_normalizer( + TokenizerVersion.v15, + model_settings_builder=ModelSettingsBuilder( + reasoning_effort=EnumBuilder[ReasoningEffort]( + values=list(ReasoningEffort), accepts_none=False, default=None + ) + ), + ) + request = ChatCompletionRequest[ChatMessage]( + messages=[ + SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]), + UserMessage(content="test"), + ], + reasoning_effort=ReasoningEffort.high, + ) parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) system_msg = parsed.messages[0] assert isinstance(system_msg, SystemMessage) diff --git a/tests/test_tokenizer_v11.py b/tests/test_tokenizer_v11.py index 59bcb466..8c3c7330 100644 --- a/tests/test_tokenizer_v11.py +++ b/tests/test_tokenizer_v11.py @@ -32,15 +32,13 @@ def test_special_tokens(tekkenizer: InstructTokenizerV11) -> None: def test_tokenize_assistant_message(tekkenizer: InstructTokenizerV11) -> None: - tokens, images, audios = tekkenizer.encode_assistant_message( + tokens = tekkenizer.encode_assistant_message( AssistantMessage( tool_calls=[ToolCall(function=FunctionCall(name="a_a_a", arguments="blabla"))], ), is_before_last_user_message=False, continue_message=False, ) - assert images == [] - assert audios == [] assert tokens == [ tekkenizer.TOOL_CALLS, 197, @@ -63,15 +61,13 @@ def test_tokenize_assistant_message(tekkenizer: InstructTokenizerV11) -> None: def test_tokenize_assistant_message_continue_message(tekkenizer: InstructTokenizerV11) -> None: - tokens, images, audios = tekkenizer.encode_assistant_message( + tokens = tekkenizer.encode_assistant_message( AssistantMessage( content='"blabla"', ), is_before_last_user_message=False, continue_message=True, ) - assert images == [] - assert audios == [] assert tokens == [ 134, 198, @@ -99,7 +95,7 @@ def test_tokenize_assistant_message_continue_message(tekkenizer: InstructTokeniz def test_tokenize_assistant_messages(tekkenizer: InstructTokenizerV11) -> None: - tokens, images, audios = tekkenizer.encode_assistant_message( + tokens = tekkenizer.encode_assistant_message( AssistantMessage( tool_calls=[ ToolCall(function=FunctionCall(name="a_a_a", arguments="blabla")), @@ -109,8 +105,6 @@ def test_tokenize_assistant_messages(tekkenizer: InstructTokenizerV11) -> None: is_before_last_user_message=False, continue_message=False, ) - assert images == [] - assert audios == [] assert tokens == [ tekkenizer.TOOL_CALLS, 197, @@ -141,15 +135,13 @@ def test_tokenize_assistant_messages(tekkenizer: InstructTokenizerV11) -> None: def test_tokenize_assistant_message_train(tekkenizer: InstructTokenizerV11) -> None: - tokens, images, audios = tekkenizer.encode_assistant_message( + tokens = tekkenizer.encode_assistant_message( AssistantMessage( tool_calls=[ToolCall(function=FunctionCall(name="a_a_a", arguments="blabla"), id="ABC")], ), is_before_last_user_message=True, continue_message=False, ) - assert images == [] - assert audios == [] assert tokens == [ tekkenizer.TOOL_CALLS, 197, diff --git a/tests/test_tokenizer_v13.py b/tests/test_tokenizer_v13.py index cec43105..1d6412ad 100644 --- a/tests/test_tokenizer_v13.py +++ b/tests/test_tokenizer_v13.py @@ -299,7 +299,7 @@ def test_tokenize_assistant_message( v13_tekkenizer_think: InstructTokenizerV13, message: AssistantMessage, expected: str, continue_final_message: bool ) -> None: if not continue_final_message: - tokens, images, audios = v13_tekkenizer_think.encode_assistant_message( + tokens = v13_tekkenizer_think.encode_assistant_message( message, is_before_last_user_message=False, continue_message=continue_final_message ) if not message.prefix: @@ -314,12 +314,10 @@ def test_tokenize_assistant_message( message, is_before_last_user_message=False, continue_message=continue_final_message ) return - tokens, images, audios = v13_tekkenizer_think.encode_assistant_message( + tokens = v13_tekkenizer_think.encode_assistant_message( message, is_before_last_user_message=False, continue_message=continue_final_message ) assert v13_tekkenizer_think.decode(tokens, special_token_policy=SpecialTokenPolicy.KEEP) == expected - assert images == [] - assert audios == [] def test_tokenize_assistant_message_error(v13_tekkenizer: InstructTokenizerV13) -> None: From 2cf685e17969e192a5a6af672701368082a9cf14 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 16:14:23 +0200 Subject: [PATCH 19/47] Revert V7 audio rejection: audio supported in system messages from V7+ --- .../protocol/instruct/normalize.py | 19 +++++++------------ tests/test_normalization.py | 15 +++++++++------ 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index b84b4f5b..38cb4d89 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -478,8 +478,8 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]: r"""Validate content chunks for system messages. - V7-V13 accepts text and thinking chunks. V15 overrides to allow audio - and reject thinking. + V7+ accepts all SystemContentChunk types (text, audio, thinking). + V15 overrides to reject ThinkChunk. Args: content: The aggregated content chunks. @@ -491,8 +491,6 @@ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | lis InvalidRequestException: If unsupported chunk types are found. """ if isinstance(content, str) or _is_system_content(content): - if isinstance(content, list) and any(isinstance(c, (AudioChunk, AudioURLChunk)) for c in content): - raise InvalidRequestException("Audio chunks in system messages are only supported from V15") return content raise InvalidRequestException( f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}" @@ -648,15 +646,12 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]: r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk.""" - if isinstance(content, str): - return content - if not _is_system_content(content): - raise InvalidRequestException( - f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}" - ) - if any(isinstance(c, ThinkChunk) for c in content): + validated = super()._narrow_system_content(content) + if isinstance(validated, str): + return validated + if any(isinstance(c, ThinkChunk) for c in validated): raise InvalidRequestException("ThinkChunk in system message is not supported for V15") - return content + return validated @staticmethod def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15": diff --git a/tests/test_normalization.py b/tests/test_normalization.py index eb6de338..b68ea020 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1349,8 +1349,8 @@ def test_v15_rejects_think_in_system_message(self) -> None: with pytest.raises(InvalidRequestException, match="ThinkChunk"): normalizer.from_chat_completion_request(request) - def test_v7_rejects_audio_in_system_content(self) -> None: - r"""V7 normalizer rejects AudioChunk in system messages.""" + def test_v7_preserves_audio_in_system_message(self) -> None: + r"""V7 normalizer preserves AudioChunk in system messages.""" normalizer = InstructRequestNormalizerV7.normalizer() request = mock_chat_completion( messages=[ @@ -1358,10 +1358,13 @@ def test_v7_rejects_audio_in_system_content(self) -> None: UserMessage(content="test"), ] ) - with pytest.raises( - InvalidRequestException, match="Audio chunks in system messages are only supported from V15" - ): - normalizer.from_chat_completion_request(request) + parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) + system_msg = parsed.messages[0] + assert isinstance(system_msg, SystemMessage) + assert isinstance(system_msg.content, list) + assert len(system_msg.content) == 2 + assert isinstance(system_msg.content[0], TextChunk) + assert isinstance(system_msg.content[1], AudioChunk) def test_v15_preserves_audio_in_system_message(self) -> None: r"""V15 normalizer preserves AudioChunk in system messages.""" From 953873e6a79b9df49067deaafab345e2426ef161 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 17:23:11 +0200 Subject: [PATCH 20/47] Add pre-V7 system message rejection tests for audio and think chunks --- tests/test_normalization.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index b68ea020..2c446d83 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1366,6 +1366,36 @@ def test_v7_preserves_audio_in_system_message(self) -> None: assert isinstance(system_msg.content[0], TextChunk) assert isinstance(system_msg.content[1], AudioChunk) + def test_pre_v7_rejects_audio_in_system_message(self) -> None: + r"""Pre-V7 normalizer rejects AudioChunk in system messages.""" + normalizer = InstructRequestNormalizer( + UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None + ) + request = mock_chat_completion( + messages=[ + SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]), + UserMessage(content="query"), + AssistantMessage(content="answer"), + ] + ) + with pytest.raises(AssertionError): + normalizer.from_chat_completion_request(request) + + def test_pre_v7_rejects_think_in_system_message(self) -> None: + r"""Pre-V7 normalizer rejects ThinkChunk in system messages.""" + normalizer = InstructRequestNormalizer( + UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None + ) + request = mock_chat_completion( + messages=[ + SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]), + UserMessage(content="query"), + AssistantMessage(content="answer"), + ] + ) + with pytest.raises(AssertionError): + normalizer.from_chat_completion_request(request) + def test_v15_preserves_audio_in_system_message(self) -> None: r"""V15 normalizer preserves AudioChunk in system messages.""" normalizer = get_normalizer( From e73d56c21e6142df5803c8ffcd7610db7348043e Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 20:01:50 +0200 Subject: [PATCH 21/47] Fix review nits: stale comment, misleading return, test parametrize --- .../chat_templates/template_generator.py | 5 +++- src/mistral_common/protocol/instruct/chunk.py | 2 +- .../protocol/instruct/normalize.py | 2 +- tests/data/chat_templates/v15.jinja | 2 +- tests/data/chat_templates/v15_audio.jinja | 2 +- tests/data/chat_templates/v15_image.jinja | 2 +- .../data/chat_templates/v15_image_think.jinja | 2 +- tests/data/chat_templates/v15_think.jinja | 2 +- tests/test_normalization.py | 28 +++++++------------ 9 files changed, 21 insertions(+), 26 deletions(-) diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py index 10cf02fe..b9dc662a 100644 --- a/src/mistral_common/integrations/chat_templates/template_generator.py +++ b/src/mistral_common/integrations/chat_templates/template_generator.py @@ -1449,7 +1449,10 @@ def _generate_tool_message_handling(config: TemplateConfig) -> str: lines.append(" {#- Tool messages supports int, float or text content. #}") lines.append(" {%- elif message['role'] == 'tool' and ns.index > ns.max_idx_user %}") else: - lines.append(" {#- Tool messages only supports text content. #}") + if config.tool_supports_multimodal: + lines.append(" {#- Tool messages (multimodal). #}") + else: + lines.append(" {#- Tool messages only supports text content. #}") lines.append(" {%- elif message['role'] == 'tool' %}") if config.uses_spm_space_tracking: diff --git a/src/mistral_common/protocol/instruct/chunk.py b/src/mistral_common/protocol/instruct/chunk.py index 3ba496b9..786c38e6 100644 --- a/src/mistral_common/protocol/instruct/chunk.py +++ b/src/mistral_common/protocol/instruct/chunk.py @@ -459,7 +459,7 @@ def from_openai(cls, openai_chunk: dict[str, Any]) -> "ThinkChunk": SystemContentChunk = Annotated[TextChunk | AudioChunk | ThinkChunk, Field(discriminator="type")] -ToolContentChunk: TypeAlias = ContentChunk +ToolContentChunk: TypeAlias = ContentChunk # Accepts all content chunk types (no restriction on tool messages). def _convert_openai_content_chunks(openai_content_chunks: dict[str, Any]) -> ContentChunk: diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index 38cb4d89..4863d7bb 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -608,7 +608,7 @@ def _inplace_sort_tool_messages(tool_messages: list[ToolMessageType], latest_cal id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)} id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)} # First order by tool call idx and then by tool result idx - return tool_messages.sort( + tool_messages.sort( key=lambda msg: ( id_to_tool_call_idx.get(msg.tool_call_id or "null", float("inf")), id_to_tool_result_idx[msg.tool_call_id], diff --git a/tests/data/chat_templates/v15.jinja b/tests/data/chat_templates/v15.jinja index d1f29add..e4df96ef 100644 --- a/tests/data/chat_templates/v15.jinja +++ b/tests/data/chat_templates/v15.jinja @@ -180,7 +180,7 @@ {{- eos_token }} - {#- Tool messages only supports text content. #} + {#- Tool messages (multimodal). #} {%- elif message['role'] == 'tool' %} {{- '[TOOL_RESULTS]' -}} {{- render_content(message['content'], 'tool message contents') -}} diff --git a/tests/data/chat_templates/v15_audio.jinja b/tests/data/chat_templates/v15_audio.jinja index bbd04595..f3c90f1c 100644 --- a/tests/data/chat_templates/v15_audio.jinja +++ b/tests/data/chat_templates/v15_audio.jinja @@ -177,7 +177,7 @@ {{- eos_token }} - {#- Tool messages only supports text content. #} + {#- Tool messages (multimodal). #} {%- elif message['role'] == 'tool' %} {{- '[TOOL_RESULTS]' -}} {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and audio', support_audio=true) -}} diff --git a/tests/data/chat_templates/v15_image.jinja b/tests/data/chat_templates/v15_image.jinja index 35b872fe..bfbfdd17 100644 --- a/tests/data/chat_templates/v15_image.jinja +++ b/tests/data/chat_templates/v15_image.jinja @@ -188,7 +188,7 @@ {{- eos_token }} - {#- Tool messages only supports text content. #} + {#- Tool messages (multimodal). #} {%- elif message['role'] == 'tool' %} {{- '[TOOL_RESULTS]' -}} {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and image', support_images=true) -}} diff --git a/tests/data/chat_templates/v15_image_think.jinja b/tests/data/chat_templates/v15_image_think.jinja index 5fab24c0..e1f26e8f 100644 --- a/tests/data/chat_templates/v15_image_think.jinja +++ b/tests/data/chat_templates/v15_image_think.jinja @@ -215,7 +215,7 @@ {{- eos_token }} - {#- Tool messages only supports text content. #} + {#- Tool messages (multimodal). #} {%- elif message['role'] == 'tool' %} {{- '[TOOL_RESULTS]' -}} {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and image', support_images=true) -}} diff --git a/tests/data/chat_templates/v15_think.jinja b/tests/data/chat_templates/v15_think.jinja index d91b3909..40b8a7fc 100644 --- a/tests/data/chat_templates/v15_think.jinja +++ b/tests/data/chat_templates/v15_think.jinja @@ -207,7 +207,7 @@ {{- eos_token }} - {#- Tool messages only supports text content. #} + {#- Tool messages (multimodal). #} {%- elif message['role'] == 'tool' %} {{- '[TOOL_RESULTS]' -}} {{- render_content(message['content'], 'tool message contents') -}} diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 2c446d83..cef30eb9 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1366,29 +1366,21 @@ def test_v7_preserves_audio_in_system_message(self) -> None: assert isinstance(system_msg.content[0], TextChunk) assert isinstance(system_msg.content[1], AudioChunk) - def test_pre_v7_rejects_audio_in_system_message(self) -> None: - r"""Pre-V7 normalizer rejects AudioChunk in system messages.""" - normalizer = InstructRequestNormalizer( - UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None - ) - request = mock_chat_completion( - messages=[ - SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]), - UserMessage(content="query"), - AssistantMessage(content="answer"), - ] - ) - with pytest.raises(AssertionError): - normalizer.from_chat_completion_request(request) - - def test_pre_v7_rejects_think_in_system_message(self) -> None: - r"""Pre-V7 normalizer rejects ThinkChunk in system messages.""" + @pytest.mark.parametrize( + "chunk", + [ + pytest.param(AudioChunk(input_audio=b"fake_audio_data"), id="audio"), + pytest.param(ThinkChunk(thinking="thinking", closed=True), id="think"), + ], + ) + def test_pre_v7_rejects_non_text_in_system_message(self, chunk: AudioChunk | ThinkChunk) -> None: + r"""Pre-V7 normalizer rejects non-text chunks in system messages.""" normalizer = InstructRequestNormalizer( UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None ) request = mock_chat_completion( messages=[ - SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]), + SystemMessage(content=[TextChunk(text="hello"), chunk]), UserMessage(content="query"), AssistantMessage(content="answer"), ] From 354e244e4d381be0c7a7adb49873f14b32a3bc10 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 21:20:30 +0200 Subject: [PATCH 22/47] Narrow encode_system_message return type to exclude images --- src/mistral_common/tokens/tokenizers/instruct.py | 15 +++++++-------- tests/experimental/test_app.py | 4 +--- tests/guidance/test_guidance.py | 3 +-- tests/test_tokenizer_v13.py | 3 +-- 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/mistral_common/tokens/tokenizers/instruct.py b/src/mistral_common/tokens/tokenizers/instruct.py index 3ec6730e..a939ffb7 100644 --- a/src/mistral_common/tokens/tokenizers/instruct.py +++ b/src/mistral_common/tokens/tokenizers/instruct.py @@ -206,8 +206,7 @@ def encode_instruct( if msg_idx == len(request.messages) - 1: prefix_ids = new_tokens elif isinstance(msg, SystemMessage): - new_tokens, new_images, new_audios = self.encode_system_message(msg) - images.extend(new_images) + new_tokens, new_audios = self.encode_system_message(msg) audios.extend(new_audios) else: raise TokenizerException(f"Unknown message type {type(msg)}") @@ -295,7 +294,7 @@ def encode_user_message( curr_tokens, image, audio = self.encode_user_content(content=message_txt, is_last=False, system_prompt=None) return curr_tokens, image, audio - def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[np.ndarray], list[Audio]]: + def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[Audio]]: raise NotImplementedError(f"System message encoding not implemented for {self.__class__.__name__}") def encode_user_content( @@ -879,22 +878,22 @@ def drop(idx: int) -> None: if to_drop > 0: raise TokenizerException("Input couldn't fit in truncate_at_max_token") - def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[np.ndarray], list[Audio]]: + def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[Audio]]: r"""Encode a system message. Args: message: The message to encode. Returns: - The encoded tokens, images, and audios. + The encoded tokens and audios. """ tokens = [self.BEGIN_SYSTEM] if isinstance(content := message.content, str): content = [TextChunk(text=content)] - content_tokens, images, audios = self._encode_content_chunks(content) + content_tokens, _images, audios = self._encode_content_chunks(content) tokens += content_tokens tokens.append(self.END_SYSTEM) - return tokens, images, audios + return tokens, audios def encode_user_content( self, @@ -1386,7 +1385,7 @@ def _encode_settings( ] return settings_tokens - def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[np.ndarray], list[Audio]]: + def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[Audio]]: r"""Encode a system message, rejecting ThinkChunk content.""" if isinstance(message.content, list): if any(isinstance(chunk, ThinkChunk) for chunk in message.content): diff --git a/tests/experimental/test_app.py b/tests/experimental/test_app.py index 5b38d2c6..9851cb06 100644 --- a/tests/experimental/test_app.py +++ b/tests/experimental/test_app.py @@ -394,9 +394,7 @@ def test_detokenize_assistant_message( def test_detokenize_assistant_message_think_chunks( assistant_message: AssistantMessage, mistral_tokenizer_v13: MistralTokenizer, tekken_v13_client: TestClient ) -> None: - encoded_tokens = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined] - assistant_message, False, False - ) + encoded_tokens = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message(assistant_message, False, False) # type: ignore[attr-defined] response = tekken_v13_client.post("/v1/detokenize/", json=encoded_tokens) assert response.status_code == 200 diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py index a5c7a8a4..9afe6f33 100644 --- a/tests/guidance/test_guidance.py +++ b/tests/guidance/test_guidance.py @@ -228,10 +228,9 @@ def _encode_content( tokenizer = instruct_tokenizer.tokenizer if isinstance(content, str): - result = instruct_tokenizer.encode_assistant_message( + return instruct_tokenizer.encode_assistant_message( AssistantMessage(content=content), is_before_last_user_message=False, continue_message=False ) - return result tool_calls = [x for x in content if isinstance(x, ToolCall)] content_chunks = [x for x in content if not isinstance(x, ToolCall)] diff --git a/tests/test_tokenizer_v13.py b/tests/test_tokenizer_v13.py index 1d6412ad..64380ac3 100644 --- a/tests/test_tokenizer_v13.py +++ b/tests/test_tokenizer_v13.py @@ -375,9 +375,8 @@ def test_tokenize_assistant_message_error(v13_tekkenizer: InstructTokenizerV13) def test_encode_system_message( v13_tekkenizer_think: InstructTokenizerV13, message: SystemMessage, expected: str ) -> None: - encoded, images, audios = v13_tekkenizer_think.encode_system_message(message) + encoded, audios = v13_tekkenizer_think.encode_system_message(message) assert v13_tekkenizer_think.decode(encoded, special_token_policy=SpecialTokenPolicy.KEEP) == expected - assert images == [] assert audios == [] From 113acb0f02ab238c8c7d7dadb9943e27aee162d4 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 21:30:24 +0200 Subject: [PATCH 23/47] Add REQUEST_TOOL_AUDIO_TRAIN fixture for tool message with audio --- .../chat_templates/fixtures_data.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/integrations/chat_templates/fixtures_data.py b/tests/integrations/chat_templates/fixtures_data.py index 8fad3fcf..b396f457 100644 --- a/tests/integrations/chat_templates/fixtures_data.py +++ b/tests/integrations/chat_templates/fixtures_data.py @@ -981,6 +981,33 @@ tools=_TOOLS, ) +REQUEST_TOOL_AUDIO_TRAIN = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + UserMessage(content="What does this sound like?"), + AssistantMessage( + content=None, + tool_calls=[ + ToolCall( + id="tl1mg2345", + function=FunctionCall( + name="tool1", + arguments={"location": "San Francisco, CA"}, # type: ignore[arg-type] + ), + ), + ], + ), + ToolMessage( + content=[ + TextChunk(text="Here is the audio result."), + AudioChunk(input_audio=_AUDIO), + ], + tool_call_id="tl1mg2345", + ), + AssistantMessage(content="The tool returned audio of a sine wave."), + ], + tools=_TOOLS, +) + REQUEST_SYSTEM_AUDIO_TRAIN = ChatCompletionRequest( # type: ignore[type-var] messages=[ SystemMessage( @@ -1154,6 +1181,7 @@ def _get_conversations( ) if audio: conversations.append(REQUEST_SYSTEM_AUDIO_TRAIN) + conversations.append(REQUEST_TOOL_AUDIO_TRAIN) conversations = [c.model_copy(deep=True) for c in conversations] From baf29e6b6315a4c93b902be01db9bbd5c28325ac Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 21:30:26 +0200 Subject: [PATCH 24/47] Expand TestAssistantContentNarrowing to cover V7, V13, V15 with rejection tests --- tests/test_normalization.py | 56 +++++++++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index cef30eb9..4d066b86 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1127,10 +1127,34 @@ def test_get_normalizer_version_mapping( class TestAssistantContentNarrowing: - def test_accepts_text_and_think_chunks(self) -> None: + @pytest.fixture(params=[TokenizerVersion.v7, TokenizerVersion.v13, TokenizerVersion.v15]) + def normalizer(self, request: pytest.FixtureRequest) -> InstructRequestNormalizer: + r"""Normalizer fixture parametrized across V7, V13, V15.""" + version = request.param + if version == TokenizerVersion.v15: + return get_normalizer( + version, + model_settings_builder=ModelSettingsBuilder( + reasoning_effort=EnumBuilder[ReasoningEffort]( + values=list(ReasoningEffort), accepts_none=False, default=None + ) + ), + ) + return get_normalizer(version) + + @staticmethod + def _make_request( + normalizer: InstructRequestNormalizer, messages: list[ChatMessage] + ) -> ChatCompletionRequest[ChatMessage]: + r"""Build a ChatCompletionRequest, adding reasoning_effort for V15.""" + if isinstance(normalizer, InstructRequestNormalizerV15): + return ChatCompletionRequest(messages=messages, reasoning_effort=ReasoningEffort.high) + return mock_chat_completion(messages=messages) + + def test_accepts_text_and_think_chunks(self, normalizer: InstructRequestNormalizer) -> None: r"""Normalizer accepts TextChunk and ThinkChunk in assistant messages.""" - normalizer = get_normalizer(TokenizerVersion.v13) - request = mock_chat_completion( + request = self._make_request( + normalizer, messages=[ UserMessage(content="query"), AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), @@ -1142,28 +1166,32 @@ def test_accepts_text_and_think_chunks(self) -> None: assert isinstance(assistant_msg.content, list) assert len(assistant_msg.content) == 2 - def test_accepts_string_content(self) -> None: + def test_accepts_string_content(self, normalizer: InstructRequestNormalizer) -> None: r"""Normalizer accepts string content in assistant messages.""" - normalizer = get_normalizer( - TokenizerVersion.v15, - model_settings_builder=ModelSettingsBuilder( - reasoning_effort=EnumBuilder[ReasoningEffort]( - values=list(ReasoningEffort), accepts_none=False, default=None - ) - ), - ) - request = ChatCompletionRequest[ChatMessage]( + request = self._make_request( + normalizer, messages=[ UserMessage(content="query"), AssistantMessage(content="plain text"), ], - reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) assistant_msg = parsed.messages[1] assert isinstance(assistant_msg, AssistantMessage) assert assistant_msg.content == "plain text" + def test_pydantic_rejects_image_in_assistant(self) -> None: + with pytest.raises(ValidationError): + AssistantMessage( + content=[TextChunk(text="answer"), ImageURLChunk(image_url="https://example.com/img.png")] # type: ignore[list-item] + ) + + def test_pydantic_rejects_audio_in_assistant(self) -> None: + with pytest.raises(ValidationError): + AssistantMessage( + content=[TextChunk(text="answer"), AudioChunk(input_audio=b"fake_audio_data")] # type: ignore[list-item] + ) + class TestToolMessageContentChunk: @pytest.fixture() From 3c1228d321079ad36f827cd681e860e526d70be1 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 23:09:02 +0200 Subject: [PATCH 25/47] Refactor test_parity to use _get_conversations with to_openai conversion --- .../chat_templates/test_parity.py | 537 ++---------------- 1 file changed, 51 insertions(+), 486 deletions(-) diff --git a/tests/integrations/chat_templates/test_parity.py b/tests/integrations/chat_templates/test_parity.py index 8c90c97d..1bd6feed 100644 --- a/tests/integrations/chat_templates/test_parity.py +++ b/tests/integrations/chat_templates/test_parity.py @@ -3,8 +3,11 @@ import pytest from mistral_common.integrations.chat_templates.template_generator import build_chat_template +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.protocol.instruct.validator import ValidationMode from mistral_common.tokens.tokenizers.base import TokenizerVersion from tests.integrations.chat_templates.conftest import ALL_CONFIGS, _config_id +from tests.integrations.chat_templates.fixtures_data import _get_conversations from tests.integrations.chat_templates.helpers import ( TestConfig, _load_golden_template, @@ -13,6 +16,43 @@ ) +def _request_to_render_args(request: ChatCompletionRequest) -> dict[str, Any]: + r"""Convert a ChatCompletionRequest to render_template kwargs. + + Enriches tool messages with the ``name`` field resolved from the preceding + assistant's ``tool_calls``, which the v2 Jinja template requires but + ``ToolMessage.to_openai()`` omits. + """ + openai = request.to_openai() + messages = openai["messages"] + + # Build a mapping from tool_call_id -> function name so tool messages + # can carry the ``name`` field expected by v2 templates. + tool_call_names: dict[str, str] = {} + for msg in messages: + for tc in msg.get("tool_calls", []): + tc_id = tc.get("id") + fn_name = tc.get("function", {}).get("name") + if tc_id and fn_name: + tool_call_names[tc_id] = fn_name + + for msg in messages: + if msg.get("role") == "tool" and "name" not in msg: + tc_id = msg.get("tool_call_id") + if tc_id and tc_id in tool_call_names: + msg["name"] = tool_call_names[tc_id] + + kwargs: dict[str, Any] = { + "messages": messages, + } + if "tools" in openai and openai["tools"]: + kwargs["tools"] = openai["tools"] + reasoning = openai.get("reasoning_effort") + if reasoning is not None: + kwargs["reasoning_effort"] = reasoning + return kwargs + + @pytest.mark.parametrize( "config", [c for c in ALL_CONFIGS if not c.plain_think], @@ -64,489 +104,14 @@ def test_dynamic_template_comprehensive(config: TestConfig) -> None: static_template = _load_golden_template(template_config) dynamic_template = build_chat_template(template_config) - # Test cases - test_cases = [ - # Simple one-turn - { - "name": "one_turn", - "messages": [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - ], - }, - # With system message - { - "name": "with_system", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - ], - }, - # Multi-turn - { - "name": "multi_turn", - "messages": [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - {"role": "user", "content": "How are you?"}, - {"role": "assistant", "content": "I'm doing well!"}, - ], - }, - # Content as list of chunks - { - "name": "content_chunks", - "messages": [ - {"role": "user", "content": [{"type": "text", "text": "Hello"}]}, - {"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]}, - ], - }, - ] - - # Tool call scenarios (v2+ only) - if config.version > TokenizerVersion.v1: - test_cases.extend( - [ - { - "name": "with_tools_definition", - "messages": [ - {"role": "user", "content": "What's the weather?"}, - {"role": "assistant", "content": "It's sunny."}, - ], - "tools": [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get weather", - "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, - }, - } - ], - }, - { - "name": "with_tool_call_and_result", - "messages": [ - {"role": "user", "content": "What's the weather in Paris?"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "abc123def", - "type": "function", - "function": {"name": "get_weather", "arguments": '{"city": "Paris"}'}, - } - ], - }, - {"role": "tool", "content": '{"temp": 20}', "tool_call_id": "abc123def", "name": "get_weather"}, - {"role": "assistant", "content": "It's 20 degrees in Paris."}, - ], - "tools": [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get weather", - "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, - }, - } - ], - }, - ] - ) - - # Add image test case if image support - if config.image: - test_cases.append( - { - "name": "with_image", - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What is this?"}, - {"type": "image_url", "image_url": "http://example.com/image.png"}, - ], - }, - {"role": "assistant", "content": "It's an image."}, - ], - } - ) - - # Add audio test case if audio support - if config.audio: - test_cases.append( - { - "name": "with_audio", - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What is this?"}, - {"type": "audio_url", "audio_url": "http://example.com/audio.mp3"}, - ], - }, - {"role": "assistant", "content": "It's an audio file."}, - ], - } - ) - - # Add thinking test case if thinking support - if config.think: - test_cases.append( - { - "name": "with_thinking", - "messages": [ - {"role": "user", "content": "Solve this problem"}, - { - "role": "assistant", - "content": [ - {"type": "thinking", "thinking": "Let me think..."}, - {"type": "text", "text": "The answer is 42."}, - ], - }, - ], - } - ) - - # Add message aggregation test cases - test_cases.extend( - [ - { - "name": "consecutive_users", - "messages": [ - {"role": "user", "content": "Hello"}, - {"role": "user", "content": "World"}, - {"role": "assistant", "content": "Hi there"}, - ], - }, - { - "name": "consecutive_users_with_system", - "messages": [ - {"role": "system", "content": "You are helpful."}, - {"role": "user", "content": "Hello"}, - {"role": "user", "content": "World"}, - {"role": "assistant", "content": "Hi there"}, - ], - }, - { - "name": "consecutive_assistants", - "messages": [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi"}, - {"role": "assistant", "content": "How can I help?"}, - {"role": "user", "content": "Thanks"}, - {"role": "assistant", "content": "Welcome"}, - ], - }, - { - "name": "multiple_systems", - "messages": [ - {"role": "system", "content": "System 1."}, - {"role": "system", "content": "System 2."}, - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi"}, - ], - }, - { - "name": "mid_conv_system", - "messages": [ - {"role": "user", "content": "Hello"}, - {"role": "system", "content": "New instruction."}, - {"role": "assistant", "content": "Got it"}, - ], - }, - { - "name": "mid_conv_system_with_consecutive_users", - "messages": [ - {"role": "system", "content": "Be helpful."}, - {"role": "user", "content": "Hello"}, - {"role": "user", "content": "World"}, - {"role": "system", "content": "Now be concise."}, - {"role": "assistant", "content": "Got it"}, - ], - }, - ] - ) - - # Multi-chunk aggregation test cases - test_cases.extend( - [ - { - "name": "consecutive_users_text_chunks", - "messages": [ - {"role": "user", "content": "First as string"}, - {"role": "user", "content": [{"type": "text", "text": "Second as chunk"}]}, - { - "role": "user", - "content": [ - {"type": "text", "text": "Third part A"}, - {"type": "text", "text": "Third part B"}, - ], - }, - {"role": "assistant", "content": "Response"}, - ], - }, - { - "name": "system_text_chunks", - "messages": [ - { - "role": "system", - "content": [ - {"type": "text", "text": "You are helpful."}, - {"type": "text", "text": "Be concise."}, - ], - }, - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi"}, - ], - }, - ] - ) - - if config.image: - test_cases.extend( - [ - { - "name": "consecutive_users_with_image", - "messages": [ - {"role": "user", "content": "What is this?"}, - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": "http://example.com/image.png"}, - {"type": "text", "text": "Describe it"}, - ], - }, - {"role": "assistant", "content": "It's an image."}, - ], - }, - { - "name": "consecutive_users_multi_image", - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Describe this"}, - {"type": "image_url", "image_url": "http://example.com/a.png"}, - {"type": "text", "text": "What color?"}, - ], - }, - { - "role": "user", - "content": [ - {"type": "text", "text": "Also this"}, - {"type": "image_url", "image_url": "http://example.com/b.png"}, - {"type": "text", "text": "What shape?"}, - ], - }, - {"role": "assistant", "content": "Both are red squares."}, - ], - }, - ] - ) - - if config.audio: - test_cases.append( - { - "name": "consecutive_users_multi_audio", - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Listen"}, - {"type": "audio_url", "audio_url": "http://example.com/a.wav"}, - {"type": "text", "text": "What language?"}, - ], - }, - { - "role": "user", - "content": [ - {"type": "text", "text": "And this"}, - {"type": "audio_url", "audio_url": "http://example.com/b.wav"}, - {"type": "text", "text": "Transcribe it"}, - ], - }, - {"role": "assistant", "content": "Both are in English."}, - ], - } - ) - - if config.think: - test_cases.extend( - [ - { - "name": "consecutive_assistants_think", - "messages": [ - {"role": "user", "content": "Solve this"}, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Hmm."}, - {"type": "thinking", "thinking": "Let me think..."}, - {"type": "text", "text": "I need more context."}, - ], - }, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "OK."}, - {"type": "thinking", "thinking": "Now I understand."}, - {"type": "text", "text": "The answer is 42."}, - ], - }, - {"role": "user", "content": "Thanks"}, - {"role": "assistant", "content": "You're welcome"}, - ], - }, - { - "name": "consecutive_systems_think", - "messages": [ - { - "role": "system", - "content": [ - {"type": "text", "text": "Rule A"}, - {"type": "text", "text": "Rule B"}, - {"type": "thinking", "thinking": "Think 1"}, - ], - }, - { - "role": "system", - "content": [ - {"type": "thinking", "thinking": "Think 2"}, - {"type": "text", "text": "Rule C"}, - {"type": "text", "text": "Rule D"}, - ], - }, - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi"}, - ], - }, - ] - ) - - # V15+ multimodal: tool messages with image/audio content - if config.version >= TokenizerVersion.v15 and config.image: - test_cases.append( - { - "name": "v15_tool_with_image", - "messages": [ - {"role": "user", "content": "Use tool"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - {"id": "abc123def", "function": {"name": "fn", "arguments": "{}"}}, - ], - }, - { - "role": "tool", - "content": [ - {"type": "text", "text": "result"}, - {"type": "image_url", "image_url": "http://example.com/img.png"}, - ], - "tool_call_id": "abc123def", - }, - {"role": "assistant", "content": "Done"}, - ], - "tools": [ - { - "type": "function", - "function": {"name": "fn", "description": "test", "parameters": {}}, - } - ], - } - ) - - if config.version >= TokenizerVersion.v15 and config.audio: - test_cases.append( - { - "name": "v15_tool_with_audio", - "messages": [ - {"role": "user", "content": "Use tool"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - {"id": "abc123def", "function": {"name": "fn", "arguments": "{}"}}, - ], - }, - { - "role": "tool", - "content": [ - {"type": "text", "text": "result"}, - {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, - ], - "tool_call_id": "abc123def", - }, - {"role": "assistant", "content": "Done"}, - ], - "tools": [ - { - "type": "function", - "function": {"name": "fn", "description": "test", "parameters": {}}, - } - ], - } - ) - - # V15+ multimodal: system messages with audio content - if config.version >= TokenizerVersion.v15 and config.audio: - test_cases.append( - { - "name": "v15_system_with_audio", - "messages": [ - { - "role": "system", - "content": [ - {"type": "text", "text": "Listen to context"}, - {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, - ], - }, - {"role": "user", "content": "Summarize"}, - {"role": "assistant", "content": "Done"}, - ], - } - ) - - skip_names_image = {"with_image", "consecutive_users_with_image", "consecutive_users_multi_image"} - skip_names_audio = {"with_audio", "consecutive_users_multi_audio"} - skip_names_think = {"with_thinking", "consecutive_assistants_think", "consecutive_systems_think"} - # ThinkChunks in system messages are only supported in v13 (not v15+) - skip_names_think_system = {"consecutive_systems_think"} - skip_names_tools = {"with_tools_definition", "with_tool_call_and_result"} - skip_names_v15_image = {"v15_tool_with_image"} - skip_names_v15_audio = {"v15_tool_with_audio", "v15_system_with_audio"} - - # Not using parametrize here because test_cases are built dynamically based on the - # config (version/image/audio/think) from the outer parametrize. Each sub-case is - # identifiable via test_name in the assertion failure message. - for test_case in test_cases: - test_name = test_case["name"] - messages = test_case["messages"] - tools = test_case.get("tools") - - if test_name in skip_names_image and not config.image: - continue - if test_name in skip_names_audio and not config.audio: - continue - if test_name in skip_names_think and not config.think: - continue - if test_name in skip_names_think_system and config.version >= TokenizerVersion.v15: - continue - if test_name in skip_names_tools and config.version <= TokenizerVersion.v1: - continue - if test_name in skip_names_v15_image and (config.version < TokenizerVersion.v15 or not config.image): - continue - if test_name in skip_names_v15_audio and (config.version < TokenizerVersion.v15 or not config.audio): - continue - - static_output = render_template(static_template, messages, tools=tools) # type: ignore - dynamic_output = render_template(dynamic_template, messages, tools=tools) # type: ignore - - assert static_output == dynamic_output, ( - f"Output mismatch for {config}, case={test_name}\n\n" - f"Static output: {static_output}\n" - f"Dynamic output: {dynamic_output}" - ) + for mode in (ValidationMode.finetuning, ValidationMode.test): + conversations = _get_conversations(config.version, mode, config.image, config.audio, config.think) + for idx, request in enumerate(conversations): + render_args = _request_to_render_args(request) + static_output = render_template(static_template, **render_args) + dynamic_output = render_template(dynamic_template, **render_args) + assert static_output == dynamic_output, ( + f"Output mismatch for {config}, mode={mode}, conversation={idx}\n\n" + f"Static output: {static_output}\n" + f"Dynamic output: {dynamic_output}" + ) From 4ac82b7629eb4094c3e5e41f387d602b6d7f87dc Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Tue, 9 Jun 2026 23:09:03 +0200 Subject: [PATCH 26/47] Reorganize V15 tokenizer test fixtures, add text assertions --- tests/test_tokenizer_v15.py | 331 ++++++++++++++++++++++-------------- 1 file changed, 206 insertions(+), 125 deletions(-) diff --git a/tests/test_tokenizer_v15.py b/tests/test_tokenizer_v15.py index 739ec75a..968128f9 100644 --- a/tests/test_tokenizer_v15.py +++ b/tests/test_tokenizer_v15.py @@ -63,6 +63,44 @@ r"[INST]U2[/INST]" ) +EXPECTED_TEXT_TOOL_AUDIO: str = ( + r"" + r'[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "fn",' + r' "description": "test", "parameters": {}}}]' + r'[/AVAILABLE_TOOLS][MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]' + r"[INST]Use the tool[/INST]" + r"[TOOL_CALLS]fn[ARGS]{}" + r"[TOOL_RESULTS]result[BEGIN_AUDIO][AUDIO][AUDIO][/TOOL_RESULTS]" +) + +EXPECTED_TEXT_TOOL_IMAGE: str = ( + r"" + r'[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "fn",' + r' "description": "test", "parameters": {}}}]' + r'[/AVAILABLE_TOOLS][MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]' + r"[INST]Use the tool[/INST]" + r"[TOOL_CALLS]fn[ARGS]{}" + r"[TOOL_RESULTS]result[IMG][IMG_END][/TOOL_RESULTS]" +) + +EXPECTED_TEXT_SYSTEM_AUDIO: str = ( + r"[SYSTEM_PROMPT]System with content[BEGIN_AUDIO][AUDIO][AUDIO][/SYSTEM_PROMPT]" + r'[MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]' + r"[INST]Hello[/INST]" +) + +EXPECTED_TEXT_USER_AUDIO: str = ( + r"" + r'[MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]' + r"[INST]Here is content[BEGIN_AUDIO][AUDIO][AUDIO][/INST]" +) + +EXPECTED_TEXT_USER_IMAGE: str = ( + r"" + r'[MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]' + r"[INST][IMG][IMG_END]Here is content[/INST]" +) + def _build_v15_tekkenizer(model_settings_builder: ModelSettingsBuilder | None) -> Tekkenizer: r"""Build a v15 Tekkenizer with the given model settings builder. @@ -103,6 +141,107 @@ def get_v15_mistral_tokenizer( ) +def _build_model_settings_builder( + allowed_reasoning_effort: tuple[str, ...] | None, +) -> ModelSettingsBuilder: + """Build a ModelSettingsBuilder from allowed reasoning effort values. + + When `allowed_reasoning_effort` is `None`, returns `ModelSettingsBuilder.none()` + (all fields ignored). This matches the behavior of `Tekkenizer.from_file` when no + `model_settings_builder` key is present in the JSON. + """ + if allowed_reasoning_effort is None: + return ModelSettingsBuilder.none() + if not allowed_reasoning_effort: + return ModelSettingsBuilder( + reasoning_effort=EnumBuilder[ReasoningEffort](values=[], accepts_none=True, default=None) + ) + return ModelSettingsBuilder( + reasoning_effort=EnumBuilder[ReasoningEffort]( + values=[ReasoningEffort(v) for v in allowed_reasoning_effort], + accepts_none=True, + default=ReasoningEffort(allowed_reasoning_effort[0]) if allowed_reasoning_effort else None, + ) + ) + + +def _get_dummy_image_url_chunk() -> ImageURLChunk: + r"""Build a small base64-encoded image URL chunk for testing.""" + img = Image.new("RGB", (4, 4), "red") + buf = BytesIO() + img.save(buf, "PNG") + data_url = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" + return ImageURLChunk(image_url=data_url) + + +def get_v15_mistral_tokenizer_with_audio() -> MistralTokenizer: + r"""Build a V15 MistralTokenizer with audio encoder.""" + builder = _build_model_settings_builder(tuple(ReasoningEffort)) + tekkenizer = Tekkenizer( + quick_vocab([b"a", b"b", b"c", b"f", b"de"]), + special_tokens=get_special_tokens(TokenizerVersion.v15, add_audio=True), + pattern=r".+", + vocab_size=256 + 100, + num_special_tokens=100, + version=TokenizerVersion.v15, + model_settings_builder=builder, + ) + audio_config = AudioConfig( + sampling_rate=24_000, + frame_rate=12.5, + encoding_config=AudioSpectrogramConfig( + num_mel_bins=128, + hop_length=160, + window_size=400, + ), + ) + special_audio_ids = SpecialAudioIDs( + audio=tekkenizer.get_special_token(SpecialTokens.audio.value), + begin_audio=tekkenizer.get_special_token(SpecialTokens.begin_audio.value), + streaming_pad=None, + text_to_audio=None, + audio_to_text=None, + ) + audio_encoder = AudioEncoder(audio_config, special_audio_ids) + instruct_tokenizer = InstructTokenizerV15(tekkenizer, audio_encoder=audio_encoder) + request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder) + validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test) + return MistralTokenizer( + instruct_tokenizer=instruct_tokenizer, + validator=validator, + request_normalizer=request_normalizer, + ) + + +def get_v15_mistral_tokenizer_with_image() -> MistralTokenizer: + r"""Build a V15 MistralTokenizer with image encoder.""" + builder = _build_model_settings_builder(tuple(ReasoningEffort)) + tekkenizer = Tekkenizer( + quick_vocab([b"a", b"b", b"c", b"f", b"de"]), + special_tokens=get_special_tokens(TokenizerVersion.v15, add_think=True), + pattern=r".+", + vocab_size=256 + 100, + num_special_tokens=100, + version=TokenizerVersion.v15, + model_settings_builder=builder, + ) + image_config = ImageConfig(image_patch_size=16, max_image_size=1024) + special_image_ids = SpecialImageIDs( + img=tekkenizer.get_special_token(SpecialTokens.img.value), + img_break=tekkenizer.get_special_token(SpecialTokens.img_break.value), + img_end=tekkenizer.get_special_token(SpecialTokens.img_end.value), + ) + image_encoder = ImageEncoder(image_config, special_image_ids) + instruct_tokenizer = InstructTokenizerV15(tekkenizer, image_encoder=image_encoder) + request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder) + validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test) + return MistralTokenizer( + instruct_tokenizer=instruct_tokenizer, + validator=validator, + request_normalizer=request_normalizer, + ) + + @pytest.fixture(scope="session") def v15_tekkenizer() -> InstructTokenizerV15: return get_v15_tekkenizer(_build_model_settings_builder(tuple(ReasoningEffort))) @@ -153,6 +292,61 @@ def messages() -> list[ChatMessage]: ] +@pytest.fixture(scope="session") +def audio_chunk() -> AudioChunk: + return get_dummy_audio_chunk() + + +# Multimodal content chunks and their corresponding tokenizer factories for parametrized tests. +_TOOL_MULTIMODAL_PARAMS = [ + pytest.param( + get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, EXPECTED_TEXT_TOOL_AUDIO, id="audio" + ), + pytest.param( + get_dummy_audio_url_chunk(), + 1, + 0, + get_v15_mistral_tokenizer_with_audio, + EXPECTED_TEXT_TOOL_AUDIO, + id="audio_url", + ), + pytest.param( + _get_dummy_image_url_chunk(), + 0, + 1, + get_v15_mistral_tokenizer_with_image, + EXPECTED_TEXT_TOOL_IMAGE, + id="image_url", + ), +] +_SYSTEM_MULTIMODAL_PARAMS = [ + pytest.param( + get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, EXPECTED_TEXT_SYSTEM_AUDIO, id="audio" + ), +] +_USER_MULTIMODAL_PARAMS = [ + pytest.param( + get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, EXPECTED_TEXT_USER_AUDIO, id="audio" + ), + pytest.param( + get_dummy_audio_url_chunk(), + 1, + 0, + get_v15_mistral_tokenizer_with_audio, + EXPECTED_TEXT_USER_AUDIO, + id="audio_url", + ), + pytest.param( + _get_dummy_image_url_chunk(), + 0, + 1, + get_v15_mistral_tokenizer_with_image, + EXPECTED_TEXT_USER_IMAGE, + id="image_url", + ), +] + + def test_tools_and_reasoning_effort( v15_tekkenizer: InstructTokenizerV15, available_tools: list[Tool], messages: list[ChatMessage] ) -> None: @@ -191,30 +385,6 @@ def test_system_think_chunk_raises_v15(v15_tekkenizer: InstructTokenizerV15) -> v15_tekkenizer.encode_instruct(request) -def _build_model_settings_builder( - allowed_reasoning_effort: tuple[str, ...] | None, -) -> ModelSettingsBuilder: - """Build a ModelSettingsBuilder from allowed reasoning effort values. - - When `allowed_reasoning_effort` is `None`, returns `ModelSettingsBuilder.none()` - (all fields ignored). This matches the behavior of `Tekkenizer.from_file` when no - `model_settings_builder` key is present in the JSON. - """ - if allowed_reasoning_effort is None: - return ModelSettingsBuilder.none() - if not allowed_reasoning_effort: - return ModelSettingsBuilder( - reasoning_effort=EnumBuilder[ReasoningEffort](values=[], accepts_none=True, default=None) - ) - return ModelSettingsBuilder( - reasoning_effort=EnumBuilder[ReasoningEffort]( - values=[ReasoningEffort(v) for v in allowed_reasoning_effort], - accepts_none=True, - default=ReasoningEffort(allowed_reasoning_effort[0]) if allowed_reasoning_effort else None, - ) - ) - - @pytest.mark.parametrize( ("reasoning_effort", "allowed_reasoning_effort", "raises", "match"), [ @@ -310,110 +480,16 @@ def test_encode_chat_completion_continue_final_message() -> None: assert encoded.tokens[-1] != eos_id -@pytest.fixture(scope="session") -def audio_chunk() -> AudioChunk: - return get_dummy_audio_chunk() - - -def get_v15_mistral_tokenizer_with_audio() -> MistralTokenizer: - r"""Build a V15 MistralTokenizer with audio encoder.""" - builder = _build_model_settings_builder(tuple(ReasoningEffort)) - tekkenizer = Tekkenizer( - quick_vocab([b"a", b"b", b"c", b"f", b"de"]), - special_tokens=get_special_tokens(TokenizerVersion.v15, add_audio=True), - pattern=r".+", - vocab_size=256 + 100, - num_special_tokens=100, - version=TokenizerVersion.v15, - model_settings_builder=builder, - ) - audio_config = AudioConfig( - sampling_rate=24_000, - frame_rate=12.5, - encoding_config=AudioSpectrogramConfig( - num_mel_bins=128, - hop_length=160, - window_size=400, - ), - ) - special_audio_ids = SpecialAudioIDs( - audio=tekkenizer.get_special_token(SpecialTokens.audio.value), - begin_audio=tekkenizer.get_special_token(SpecialTokens.begin_audio.value), - streaming_pad=None, - text_to_audio=None, - audio_to_text=None, - ) - audio_encoder = AudioEncoder(audio_config, special_audio_ids) - instruct_tokenizer = InstructTokenizerV15(tekkenizer, audio_encoder=audio_encoder) - request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder) - validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test) - return MistralTokenizer( - instruct_tokenizer=instruct_tokenizer, - validator=validator, - request_normalizer=request_normalizer, - ) - - -def _get_dummy_image_url_chunk() -> ImageURLChunk: - r"""Build a small base64-encoded image URL chunk for testing.""" - img = Image.new("RGB", (4, 4), "red") - buf = BytesIO() - img.save(buf, "PNG") - data_url = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" - return ImageURLChunk(image_url=data_url) - - -def get_v15_mistral_tokenizer_with_image() -> MistralTokenizer: - r"""Build a V15 MistralTokenizer with image encoder.""" - builder = _build_model_settings_builder(tuple(ReasoningEffort)) - tekkenizer = Tekkenizer( - quick_vocab([b"a", b"b", b"c", b"f", b"de"]), - special_tokens=get_special_tokens(TokenizerVersion.v15, add_think=True), - pattern=r".+", - vocab_size=256 + 100, - num_special_tokens=100, - version=TokenizerVersion.v15, - model_settings_builder=builder, - ) - image_config = ImageConfig(image_patch_size=16, max_image_size=1024) - special_image_ids = SpecialImageIDs( - img=tekkenizer.get_special_token(SpecialTokens.img.value), - img_break=tekkenizer.get_special_token(SpecialTokens.img_break.value), - img_end=tekkenizer.get_special_token(SpecialTokens.img_end.value), - ) - image_encoder = ImageEncoder(image_config, special_image_ids) - instruct_tokenizer = InstructTokenizerV15(tekkenizer, image_encoder=image_encoder) - request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder) - validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test) - return MistralTokenizer( - instruct_tokenizer=instruct_tokenizer, - validator=validator, - request_normalizer=request_normalizer, - ) - - -# Multimodal content chunks and their corresponding tokenizer factories for parametrized tests. -_AUDIO_CHUNK_PARAMS = pytest.param(get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, id="audio") -_AUDIO_URL_CHUNK_PARAMS = pytest.param( - get_dummy_audio_url_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, id="audio_url" -) -_IMAGE_URL_CHUNK_PARAMS = pytest.param( - _get_dummy_image_url_chunk(), 0, 1, get_v15_mistral_tokenizer_with_image, id="image_url" -) - -_ALL_MULTIMODAL_PARAMS = [_AUDIO_CHUNK_PARAMS, _AUDIO_URL_CHUNK_PARAMS, _IMAGE_URL_CHUNK_PARAMS] -_AUDIO_ONLY_PARAMS = [_AUDIO_CHUNK_PARAMS] - - @pytest.mark.parametrize( - ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"), - _ALL_MULTIMODAL_PARAMS, + ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory", "expected_text"), + _TOOL_MULTIMODAL_PARAMS, ) def test_encode_chat_completion_with_multimodal_tool( content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk, expected_audios: int, expected_images: int, tokenizer_factory: Callable[[], MistralTokenizer], + expected_text: str, ) -> None: mistral_tokenizer = tokenizer_factory() chat_request = ChatCompletionRequest( # type: ignore[type-var] @@ -428,19 +504,21 @@ def test_encode_chat_completion_with_multimodal_tool( tools=[Tool(function=Function(name="fn", description="test", parameters={}))], ) encoded = mistral_tokenizer.encode_chat_completion(chat_request) + assert encoded.text == expected_text, encoded.text assert len(encoded.audios) == expected_audios assert len(encoded.images) == expected_images @pytest.mark.parametrize( - ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"), - _AUDIO_ONLY_PARAMS, + ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory", "expected_text"), + _SYSTEM_MULTIMODAL_PARAMS, ) def test_encode_chat_completion_with_multimodal_system( content_chunk: AudioChunk, expected_audios: int, expected_images: int, tokenizer_factory: Callable[[], MistralTokenizer], + expected_text: str, ) -> None: mistral_tokenizer = tokenizer_factory() chat_request = ChatCompletionRequest( # type: ignore[type-var] @@ -450,19 +528,21 @@ def test_encode_chat_completion_with_multimodal_system( ], ) encoded = mistral_tokenizer.encode_chat_completion(chat_request) + assert encoded.text == expected_text, encoded.text assert len(encoded.audios) == expected_audios assert len(encoded.images) == expected_images @pytest.mark.parametrize( - ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"), - _ALL_MULTIMODAL_PARAMS, + ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory", "expected_text"), + _USER_MULTIMODAL_PARAMS, ) def test_encode_chat_completion_with_multimodal_user( content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk, expected_audios: int, expected_images: int, tokenizer_factory: Callable[[], MistralTokenizer], + expected_text: str, ) -> None: mistral_tokenizer = tokenizer_factory() chat_request = ChatCompletionRequest( @@ -471,5 +551,6 @@ def test_encode_chat_completion_with_multimodal_user( ], ) encoded = mistral_tokenizer.encode_chat_completion(chat_request) + assert encoded.text == expected_text, encoded.text assert len(encoded.audios) == expected_audios assert len(encoded.images) == expected_images From 1035be10f136d9b08ba01a7ae5c4df58f25781ad Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Wed, 10 Jun 2026 00:01:46 +0200 Subject: [PATCH 27/47] Use ToolContentChunk type alias in _parse_tool_content signature --- src/mistral_common/tokens/tokenizers/instruct.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mistral_common/tokens/tokenizers/instruct.py b/src/mistral_common/tokens/tokenizers/instruct.py index a939ffb7..9b044f6c 100644 --- a/src/mistral_common/tokens/tokenizers/instruct.py +++ b/src/mistral_common/tokens/tokenizers/instruct.py @@ -20,6 +20,7 @@ ImageURLChunk, TextChunk, ThinkChunk, + ToolContentChunk, UserContentChunk, ) from mistral_common.protocol.instruct.messages import ( @@ -487,7 +488,7 @@ def _parse_json_content(self, content: str) -> Any: except json.JSONDecodeError: return content - def _parse_tool_content(self, content: str | list[ContentChunk]) -> Any: + def _parse_tool_content(self, content: str | list[ToolContentChunk]) -> Any: if isinstance(content, list): content = "".join(chunk.text for chunk in content if isinstance(chunk, TextChunk)) return self._parse_json_content(content) From 729050d24ae03693a27eace9d4a7fa24f32a1faa Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Wed, 10 Jun 2026 00:03:03 +0200 Subject: [PATCH 28/47] Use walrus operator for reasoning_effort, fix double backticks --- tests/integrations/chat_templates/hf_utils.py | 4 ++-- tests/integrations/chat_templates/test_parity.py | 11 +++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/integrations/chat_templates/hf_utils.py b/tests/integrations/chat_templates/hf_utils.py index 9e8133d2..6d071b3d 100644 --- a/tests/integrations/chat_templates/hf_utils.py +++ b/tests/integrations/chat_templates/hf_utils.py @@ -1,7 +1,7 @@ r"""HuggingFace transformers/tokenizers utilities for chat template tests. -This module contains helpers that depend on ``transformers`` and -``tokenizers``. General-purpose helpers live in ``helpers.py``. +This module contains helpers that depend on `transformers` and +`tokenizers`. General-purpose helpers live in `helpers.py`. """ import json diff --git a/tests/integrations/chat_templates/test_parity.py b/tests/integrations/chat_templates/test_parity.py index 1bd6feed..56d79c5e 100644 --- a/tests/integrations/chat_templates/test_parity.py +++ b/tests/integrations/chat_templates/test_parity.py @@ -19,15 +19,15 @@ def _request_to_render_args(request: ChatCompletionRequest) -> dict[str, Any]: r"""Convert a ChatCompletionRequest to render_template kwargs. - Enriches tool messages with the ``name`` field resolved from the preceding - assistant's ``tool_calls``, which the v2 Jinja template requires but - ``ToolMessage.to_openai()`` omits. + Enriches tool messages with the `name` field resolved from the preceding + assistant's `tool_calls`, which the v2 Jinja template requires but + `ToolMessage.to_openai()` omits. """ openai = request.to_openai() messages = openai["messages"] # Build a mapping from tool_call_id -> function name so tool messages - # can carry the ``name`` field expected by v2 templates. + # can carry the `name` field expected by v2 templates. tool_call_names: dict[str, str] = {} for msg in messages: for tc in msg.get("tool_calls", []): @@ -47,8 +47,7 @@ def _request_to_render_args(request: ChatCompletionRequest) -> dict[str, Any]: } if "tools" in openai and openai["tools"]: kwargs["tools"] = openai["tools"] - reasoning = openai.get("reasoning_effort") - if reasoning is not None: + if (reasoning := openai.get("reasoning_effort")) is not None: kwargs["reasoning_effort"] = reasoning return kwargs From 69ab7294ed00d37215fae51bc6435090b754bc28 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Wed, 10 Jun 2026 00:07:02 +0200 Subject: [PATCH 29/47] Refactor normalizer tests into version-specific classes, add error match strings --- tests/test_normalization.py | 497 ++++++++++++++++++------------------ 1 file changed, 242 insertions(+), 255 deletions(-) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 4d066b86..bd4f5fe8 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -426,6 +426,71 @@ def test_continue_final_message_forwarded(self, normalizer: InstructRequestNorma result: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) assert result.continue_final_message is True + def test_pydantic_rejects_image_in_assistant(self) -> None: + r"""Pydantic rejects ImageURLChunk in assistant message content.""" + with pytest.raises(ValidationError, match="union_tag_invalid"): + AssistantMessage( + content=[TextChunk(text="answer"), ImageURLChunk(image_url="https://example.com/img.png")] # type: ignore[list-item] + ) + + def test_pydantic_rejects_audio_in_assistant(self) -> None: + r"""Pydantic rejects AudioChunk in assistant message content.""" + with pytest.raises(ValidationError, match="union_tag_invalid"): + AssistantMessage( + content=[TextChunk(text="answer"), AudioChunk(input_audio=b"fake_audio_data")] # type: ignore[list-item] + ) + + def test_system_message_accepts_audio_chunk(self) -> None: + r"""SystemMessage Pydantic model accepts AudioChunk in content.""" + msg = SystemMessage(content=[AudioChunk(input_audio="dGVzdA==")]) + assert isinstance(msg.content, list) + assert len(msg.content) == 1 + assert isinstance(msg.content[0], AudioChunk) + + def test_system_message_rejects_image_chunk(self) -> None: + r"""SystemMessage Pydantic model rejects ImageURLChunk in content.""" + with pytest.raises(ValidationError, match="union_tag_invalid"): + SystemMessage(content=[ImageURLChunk(image_url="https://example.com/image.png")]) # type: ignore[list-item] + + def test_rejects_audio_in_system_message(self, normalizer: InstructRequestNormalizer) -> None: + r"""Pre-V7 normalizer rejects AudioChunk in system messages.""" + request = mock_chat_completion( + messages=[ + SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]), + UserMessage(content="query"), + AssistantMessage(content="answer"), + ] + ) + with pytest.raises(AssertionError, match="AudioChunk"): + normalizer.from_chat_completion_request(request) + + def test_rejects_think_in_system_message(self, normalizer: InstructRequestNormalizer) -> None: + r"""Pre-V7 normalizer rejects ThinkChunk in system messages.""" + request = mock_chat_completion( + messages=[ + SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]), + UserMessage(content="query"), + AssistantMessage(content="answer"), + ] + ) + with pytest.raises(AssertionError, match="ThinkChunk"): + normalizer.from_chat_completion_request(request) + + def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalizer) -> None: + r"""Base normalizer (v1-v3) JSON-normalizes tool message content.""" + messy_json = '{"key" : "value" , "num": 1}' + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), + ToolMessage(content=messy_json, tool_call_id="c1"), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) + tool_msg = parsed.messages[2] + assert isinstance(tool_msg, ToolMessage) + assert tool_msg.content == '{"key": "value", "num": 1}' + class TestChatCompletionRequestNormalizationV7: @pytest.fixture(autouse=True) @@ -634,6 +699,64 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma TextChunk(text="C\n\nD"), ] + def test_accepts_text_and_think_chunks(self, normalizer_v7: InstructRequestNormalizerV7) -> None: + r"""V7 normalizer accepts TextChunk and ThinkChunk in assistant messages.""" + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) + assistant_msg = parsed.messages[1] + assert isinstance(assistant_msg, AssistantMessage) + assert isinstance(assistant_msg.content, list) + assert len(assistant_msg.content) == 2 + + def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None: + r"""V7 normalizer accepts string content in assistant messages.""" + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(content="plain text"), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) + assistant_msg = parsed.messages[1] + assert isinstance(assistant_msg, AssistantMessage) + assert assistant_msg.content == "plain text" + + def test_skips_json_normalization_on_tool_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None: + r"""V7+ normalizers do not JSON-normalize tool message content.""" + messy_json = '{"key" : "value" , "num": 1}' + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), + ToolMessage(content=messy_json, tool_call_id="c1"), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) + tool_msg = parsed.messages[2] + assert isinstance(tool_msg, ToolMessage) + assert tool_msg.content == messy_json + + def test_preserves_audio_in_system_message(self, normalizer_v7: InstructRequestNormalizerV7) -> None: + r"""V7 normalizer preserves AudioChunk in system messages.""" + request = mock_chat_completion( + messages=[ + SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]), + UserMessage(content="test"), + ] + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) + system_msg = parsed.messages[0] + assert isinstance(system_msg, SystemMessage) + assert isinstance(system_msg.content, list) + assert len(system_msg.content) == 2 + assert isinstance(system_msg.content[0], TextChunk) + assert isinstance(system_msg.content[1], AudioChunk) + class TestFineTuningNormalizer: @pytest.fixture(autouse=True) @@ -978,6 +1101,82 @@ def test_continue_final_message_forwarded(self, normalizer_v13: InstructRequestN result: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) assert result.continue_final_message is True + def test_accepts_text_and_think_chunks(self, normalizer_v13: InstructRequestNormalizerV13) -> None: + r"""V13 normalizer accepts TextChunk and ThinkChunk in assistant messages.""" + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) + assistant_msg = parsed.messages[1] + assert isinstance(assistant_msg, AssistantMessage) + assert isinstance(assistant_msg.content, list) + assert len(assistant_msg.content) == 2 + + def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: + r"""V13 normalizer accepts string content in assistant messages.""" + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(content="plain text"), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) + assistant_msg = parsed.messages[1] + assert isinstance(assistant_msg, AssistantMessage) + assert assistant_msg.content == "plain text" + + def test_rejects_non_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: + r"""V13 normalizer raises InvalidRequestException for non-text tool content.""" + request = mock_chat_completion( + messages=[ + UserMessage(content="hi"), + AssistantMessage( + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")], + ), + ToolMessage( + content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], + tool_call_id="test12345", + ), + ] + ) + with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"): + normalizer_v13.from_chat_completion_request(request) + + def test_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: + r"""V13 normalizer aggregates TextChunks in tool messages to a string.""" + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), + ToolMessage(content=[TextChunk(text="hello"), TextChunk(text="world")], tool_call_id="c1"), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) + tool_msg = parsed.messages[2] + assert isinstance(tool_msg, ToolMessage) + assert isinstance(tool_msg.content, str) + assert tool_msg.content == "hello\n\nworld" + + def test_rejects_audio_in_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: + r"""V13 normalizer raises InvalidRequestException for audio tool content.""" + request = mock_chat_completion( + messages=[ + UserMessage(content="hi"), + AssistantMessage( + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")], + ), + ToolMessage( + content=[AudioChunk(input_audio=b"fake_audio_data")], + tool_call_id="test12345", + ), + ] + ) + with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"): + normalizer_v13.from_chat_completion_request(request) + class TestChatCompletionRequestNormalizationV15: @pytest.fixture(autouse=True) @@ -1097,125 +1296,37 @@ def test_v15_tool_message_text_chunks_joined_without_separator( assert isinstance(tool_msg, ToolMessage) assert tool_msg.content == "XY" - -@pytest.mark.parametrize( - "version,expected_class,model_settings_builder", - [ - (TokenizerVersion.v1, InstructRequestNormalizer, None), - (TokenizerVersion.v2, InstructRequestNormalizer, None), - (TokenizerVersion.v3, InstructRequestNormalizer, None), - (TokenizerVersion.v7, InstructRequestNormalizerV7, None), - (TokenizerVersion.v11, InstructRequestNormalizerV7, None), - (TokenizerVersion.v13, InstructRequestNormalizerV13, None), - ( - TokenizerVersion.v15, - InstructRequestNormalizerV15, - ModelSettingsBuilder( - reasoning_effort=EnumBuilder[ReasoningEffort]( - values=list(ReasoningEffort), accepts_none=False, default=None - ) - ), - ), - ], -) -def test_get_normalizer_version_mapping( - version: TokenizerVersion, expected_class: type, model_settings_builder: ModelSettingsBuilder -) -> None: - normalizer = get_normalizer(version, model_settings_builder) - assert isinstance(normalizer, expected_class) - assert normalizer._model_settings_builder == model_settings_builder - - -class TestAssistantContentNarrowing: - @pytest.fixture(params=[TokenizerVersion.v7, TokenizerVersion.v13, TokenizerVersion.v15]) - def normalizer(self, request: pytest.FixtureRequest) -> InstructRequestNormalizer: - r"""Normalizer fixture parametrized across V7, V13, V15.""" - version = request.param - if version == TokenizerVersion.v15: - return get_normalizer( - version, - model_settings_builder=ModelSettingsBuilder( - reasoning_effort=EnumBuilder[ReasoningEffort]( - values=list(ReasoningEffort), accepts_none=False, default=None - ) - ), - ) - return get_normalizer(version) - - @staticmethod - def _make_request( - normalizer: InstructRequestNormalizer, messages: list[ChatMessage] - ) -> ChatCompletionRequest[ChatMessage]: - r"""Build a ChatCompletionRequest, adding reasoning_effort for V15.""" - if isinstance(normalizer, InstructRequestNormalizerV15): - return ChatCompletionRequest(messages=messages, reasoning_effort=ReasoningEffort.high) - return mock_chat_completion(messages=messages) - - def test_accepts_text_and_think_chunks(self, normalizer: InstructRequestNormalizer) -> None: - r"""Normalizer accepts TextChunk and ThinkChunk in assistant messages.""" - request = self._make_request( - normalizer, + def test_accepts_text_and_think_chunks(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + r"""V15 normalizer accepts TextChunk and ThinkChunk in assistant messages.""" + request = ChatCompletionRequest[ChatMessage]( messages=[ UserMessage(content="query"), AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), ], + reasoning_effort=ReasoningEffort.high, ) - parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) assistant_msg = parsed.messages[1] assert isinstance(assistant_msg, AssistantMessage) assert isinstance(assistant_msg.content, list) assert len(assistant_msg.content) == 2 - def test_accepts_string_content(self, normalizer: InstructRequestNormalizer) -> None: - r"""Normalizer accepts string content in assistant messages.""" - request = self._make_request( - normalizer, + def test_accepts_string_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + r"""V15 normalizer accepts string content in assistant messages.""" + request = ChatCompletionRequest[ChatMessage]( messages=[ UserMessage(content="query"), AssistantMessage(content="plain text"), ], + reasoning_effort=ReasoningEffort.high, ) - parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) assistant_msg = parsed.messages[1] assert isinstance(assistant_msg, AssistantMessage) assert assistant_msg.content == "plain text" - def test_pydantic_rejects_image_in_assistant(self) -> None: - with pytest.raises(ValidationError): - AssistantMessage( - content=[TextChunk(text="answer"), ImageURLChunk(image_url="https://example.com/img.png")] # type: ignore[list-item] - ) - - def test_pydantic_rejects_audio_in_assistant(self) -> None: - with pytest.raises(ValidationError): - AssistantMessage( - content=[TextChunk(text="answer"), AudioChunk(input_audio=b"fake_audio_data")] # type: ignore[list-item] - ) - - -class TestToolMessageContentChunk: - @pytest.fixture() - def normalizer_v15(self) -> InstructRequestNormalizerV15: - return InstructRequestNormalizerV15( - UserMessage, - AssistantMessage, - ToolMessage, - SystemMessage, - InstructRequest, - ModelSettingsBuilder( - reasoning_effort=EnumBuilder[ReasoningEffort]( - values=list(ReasoningEffort), accepts_none=False, default=None - ) - ), - ) - - @pytest.fixture() - def normalizer_v13(self) -> InstructRequestNormalizerV13: - return InstructRequestNormalizerV13( - UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None - ) - - def test_v15_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + def test_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + r"""V15 normalizer preserves non-text chunks in tool messages.""" image_chunk = ImageURLChunk(image_url="https://example.com/image.png") request = ChatCompletionRequest( # type: ignore[type-var] messages=[ @@ -1231,7 +1342,8 @@ def test_v15_preserves_non_text_tool_content(self, normalizer_v15: InstructReque assert isinstance(tool_msg.content, list) assert tool_msg.content == [image_chunk] - def test_v15_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + r"""V15 normalizer sorts multimodal tool messages by tool call order.""" image_chunk_1 = ImageURLChunk(image_url="https://example.com/img1.png") image_chunk_2 = ImageURLChunk(image_url="https://example.com/img2.png") request = ChatCompletionRequest( # type: ignore[type-var] @@ -1260,114 +1372,8 @@ def test_v15_sorts_multimodal_tool_messages(self, normalizer_v15: InstructReques assert tool_msg_2.tool_call_id == "c2" assert tool_msg_2.content == [image_chunk_2] - def test_pre_v15_rejects_non_text_tool_content(self) -> None: - r"""Pre-V15 normalizer raises InvalidRequestException for non-text tool content.""" - normalizer = get_normalizer(TokenizerVersion.v13) - request = mock_chat_completion( - messages=[ - UserMessage(content="hi"), - AssistantMessage( - tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")], - ), - ToolMessage( - content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], - tool_call_id="test12345", - ), - ] - ) - with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"): - normalizer.from_chat_completion_request(request) - - def test_pre_v15_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: - request = mock_chat_completion( - messages=[ - UserMessage(content="query"), - AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), - ToolMessage(content=[TextChunk(text="hello"), TextChunk(text="world")], tool_call_id="c1"), - ], - ) - parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) - tool_msg = parsed.messages[2] - assert isinstance(tool_msg, ToolMessage) - assert isinstance(tool_msg.content, str) - assert tool_msg.content == "hello\n\nworld" - - def test_pre_v15_rejects_audio_in_tool_content(self) -> None: - r"""Pre-V15 normalizer raises InvalidRequestException for audio tool content.""" - normalizer = get_normalizer(TokenizerVersion.v13) - request = mock_chat_completion( - messages=[ - UserMessage(content="hi"), - AssistantMessage( - tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")], - ), - ToolMessage( - content=[AudioChunk(input_audio=b"fake_audio_data")], - tool_call_id="test12345", - ), - ] - ) - with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"): - normalizer.from_chat_completion_request(request) - - def test_base_normalizer_json_normalizes_tool_content(self) -> None: - r"""Base normalizer (v1-v3) JSON-normalizes tool message content.""" - normalizer = InstructRequestNormalizer( - UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None - ) - messy_json = '{"key" : "value" , "num": 1}' - request = mock_chat_completion( - messages=[ - UserMessage(content="query"), - AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), - ToolMessage(content=messy_json, tool_call_id="c1"), - ], - ) - parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) - tool_msg = parsed.messages[2] - assert isinstance(tool_msg, ToolMessage) - assert tool_msg.content == '{"key": "value", "num": 1}' - - def test_v7_skips_json_normalization_on_tool_content(self) -> None: - r"""V7+ normalizers do not JSON-normalize tool message content.""" - normalizer = InstructRequestNormalizerV7( - UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None - ) - messy_json = '{"key" : "value" , "num": 1}' - request = mock_chat_completion( - messages=[ - UserMessage(content="query"), - AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), - ToolMessage(content=messy_json, tool_call_id="c1"), - ], - ) - parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) - tool_msg = parsed.messages[2] - assert isinstance(tool_msg, ToolMessage) - assert tool_msg.content == messy_json - - -class TestSystemMessageContentChunk: - def test_system_message_accepts_audio_chunk(self) -> None: - msg = SystemMessage(content=[AudioChunk(input_audio="dGVzdA==")]) - assert isinstance(msg.content, list) - assert len(msg.content) == 1 - assert isinstance(msg.content[0], AudioChunk) - - def test_system_message_rejects_image_chunk(self) -> None: - with pytest.raises(ValidationError): - SystemMessage(content=[ImageURLChunk(image_url="https://example.com/image.png")]) # type: ignore[list-item] - - def test_v15_rejects_think_in_system_message(self) -> None: + def test_rejects_think_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 normalizer rejects ThinkChunk in system messages.""" - normalizer = get_normalizer( - TokenizerVersion.v15, - model_settings_builder=ModelSettingsBuilder( - reasoning_effort=EnumBuilder[ReasoningEffort]( - values=list(ReasoningEffort), accepts_none=False, default=None - ) - ), - ) request = mock_chat_completion( messages=[ SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]), @@ -1375,18 +1381,18 @@ def test_v15_rejects_think_in_system_message(self) -> None: ] ) with pytest.raises(InvalidRequestException, match="ThinkChunk"): - normalizer.from_chat_completion_request(request) + normalizer_v15.from_chat_completion_request(request) - def test_v7_preserves_audio_in_system_message(self) -> None: - r"""V7 normalizer preserves AudioChunk in system messages.""" - normalizer = InstructRequestNormalizerV7.normalizer() - request = mock_chat_completion( + def test_preserves_audio_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + r"""V15 normalizer preserves AudioChunk in system messages.""" + request = ChatCompletionRequest[ChatMessage]( messages=[ SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]), UserMessage(content="test"), - ] + ], + reasoning_effort=ReasoningEffort.high, ) - parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) system_msg = parsed.messages[0] assert isinstance(system_msg, SystemMessage) assert isinstance(system_msg.content, list) @@ -1394,49 +1400,30 @@ def test_v7_preserves_audio_in_system_message(self) -> None: assert isinstance(system_msg.content[0], TextChunk) assert isinstance(system_msg.content[1], AudioChunk) - @pytest.mark.parametrize( - "chunk", - [ - pytest.param(AudioChunk(input_audio=b"fake_audio_data"), id="audio"), - pytest.param(ThinkChunk(thinking="thinking", closed=True), id="think"), - ], - ) - def test_pre_v7_rejects_non_text_in_system_message(self, chunk: AudioChunk | ThinkChunk) -> None: - r"""Pre-V7 normalizer rejects non-text chunks in system messages.""" - normalizer = InstructRequestNormalizer( - UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None - ) - request = mock_chat_completion( - messages=[ - SystemMessage(content=[TextChunk(text="hello"), chunk]), - UserMessage(content="query"), - AssistantMessage(content="answer"), - ] - ) - with pytest.raises(AssertionError): - normalizer.from_chat_completion_request(request) - def test_v15_preserves_audio_in_system_message(self) -> None: - r"""V15 normalizer preserves AudioChunk in system messages.""" - normalizer = get_normalizer( +@pytest.mark.parametrize( + "version,expected_class,model_settings_builder", + [ + (TokenizerVersion.v1, InstructRequestNormalizer, None), + (TokenizerVersion.v2, InstructRequestNormalizer, None), + (TokenizerVersion.v3, InstructRequestNormalizer, None), + (TokenizerVersion.v7, InstructRequestNormalizerV7, None), + (TokenizerVersion.v11, InstructRequestNormalizerV7, None), + (TokenizerVersion.v13, InstructRequestNormalizerV13, None), + ( TokenizerVersion.v15, - model_settings_builder=ModelSettingsBuilder( + InstructRequestNormalizerV15, + ModelSettingsBuilder( reasoning_effort=EnumBuilder[ReasoningEffort]( values=list(ReasoningEffort), accepts_none=False, default=None ) ), - ) - request = ChatCompletionRequest[ChatMessage]( - messages=[ - SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]), - UserMessage(content="test"), - ], - reasoning_effort=ReasoningEffort.high, - ) - parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) - system_msg = parsed.messages[0] - assert isinstance(system_msg, SystemMessage) - assert isinstance(system_msg.content, list) - assert len(system_msg.content) == 2 - assert isinstance(system_msg.content[0], TextChunk) - assert isinstance(system_msg.content[1], AudioChunk) + ), + ], +) +def test_get_normalizer_version_mapping( + version: TokenizerVersion, expected_class: type, model_settings_builder: ModelSettingsBuilder +) -> None: + normalizer = get_normalizer(version, model_settings_builder) + assert isinstance(normalizer, expected_class) + assert normalizer._model_settings_builder == model_settings_builder From 21860dce198598924d78674a96343d83f6dd4b8d Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Wed, 10 Jun 2026 10:10:09 +0200 Subject: [PATCH 30/47] Reject ThinkChunk in assistant messages for pre-v11 normalizers, remove Pydantic tests --- .../protocol/instruct/normalize.py | 17 ++- tests/test_normalization.py | 132 ++++++------------ 2 files changed, 55 insertions(+), 94 deletions(-) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index 4863d7bb..8cd5f904 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -284,20 +284,21 @@ def _normalize_tool_call(self, tool_call: ToolCall) -> ToolCall: ) def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[AssistantContentChunk]: - r"""Validate and narrow content chunks for assistant messages. + r"""Validate content chunks for assistant messages. - Only TextChunk and ThinkChunk are allowed. + Pre-v11 only allows text content in assistant messages. + V11+ overrides to also accept ThinkChunk. Args: content: The aggregated content chunks. Returns: - The validated and narrowed content. + The validated content. Raises: InvalidRequestException: If unsupported chunk types are found. """ - if isinstance(content, str) or _is_assistant_content(content): + if isinstance(content, str): return content raise InvalidRequestException( f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" @@ -615,6 +616,14 @@ def _inplace_sort_tool_messages(tool_messages: list[ToolMessageType], latest_cal ), ) + def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[AssistantContentChunk]: + r"""V11+ accepts TextChunk and ThinkChunk in assistant messages.""" + if isinstance(content, str) or _is_assistant_content(content): + return content + raise InvalidRequestException( + f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" + ) + def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: tool_messages: list[ToolMessageType] = super()._aggregate_tool_messages(messages, latest_call_ids) self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index bd4f5fe8..2e25b931 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1,7 +1,6 @@ import json import pytest -from pydantic import ValidationError from mistral_common.exceptions import InvalidRequestException from mistral_common.protocol.instruct.chunk import ( @@ -165,41 +164,6 @@ def check_merge( assert isinstance(x, (UserMessage, AssistantMessage)) assert x.content == expected - def check_merge_chunks( - self, - roles: list[str], - expected_roles: list[str], - expected_content: list[list[ContentChunk] | str], - normalizer: InstructRequestNormalizer, - ) -> None: - letter_to_cls: dict[str, ChatMessage] = { - "s": SystemMessage(content="s"), - "u": UserMessage(content="u"), - "a": AssistantMessage(content="a"), - "a2": AssistantMessage( - content=[ - ThinkChunk(thinking="t1"), - ThinkChunk(thinking="t2"), - TextChunk(text="a1"), - TextChunk(text="a2"), - TextChunk(text="a3"), - ] - ), - } - - chat_completion_request = mock_chat_completion( - messages=[letter_to_cls[r] for r in roles], - ) - parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - assert len(parsed_request.messages) == len(expected_roles) - assert [message.role for message in parsed_request.messages] == [ - letter_to_cls[role].role for role in expected_roles - ] - assert len(expected_content) == len(parsed_request.messages) - for x, expected in zip(parsed_request.messages, expected_content): - assert isinstance(x, (UserMessage, AssistantMessage)) - assert x.content == expected - def test_message_aggregation(self, normalizer: InstructRequestNormalizer) -> None: self.check_merge(["s", "s", "s", "u"], ["u"], ["u"], normalizer) self.check_merge(["s", "s", "s", "u", "u"], ["u"], ["u\n\nu"], normalizer) @@ -219,21 +183,6 @@ def test_message_aggregation(self, normalizer: InstructRequestNormalizer) -> Non normalizer, ) - self.check_merge_chunks( - ["u", "a2", "u"], - ["u", "a", "u"], - [ - "u", - [ - ThinkChunk(thinking="t1"), - ThinkChunk(thinking="t2"), - TextChunk(text="a1\n\na2\n\na3"), - ], - "u", - ], - normalizer, - ) - def test_tool_chunk_aggregation(self, normalizer: InstructRequestNormalizer) -> None: messages = [ ToolMessage(content="C", tool_call_id="1"), @@ -426,32 +375,6 @@ def test_continue_final_message_forwarded(self, normalizer: InstructRequestNorma result: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) assert result.continue_final_message is True - def test_pydantic_rejects_image_in_assistant(self) -> None: - r"""Pydantic rejects ImageURLChunk in assistant message content.""" - with pytest.raises(ValidationError, match="union_tag_invalid"): - AssistantMessage( - content=[TextChunk(text="answer"), ImageURLChunk(image_url="https://example.com/img.png")] # type: ignore[list-item] - ) - - def test_pydantic_rejects_audio_in_assistant(self) -> None: - r"""Pydantic rejects AudioChunk in assistant message content.""" - with pytest.raises(ValidationError, match="union_tag_invalid"): - AssistantMessage( - content=[TextChunk(text="answer"), AudioChunk(input_audio=b"fake_audio_data")] # type: ignore[list-item] - ) - - def test_system_message_accepts_audio_chunk(self) -> None: - r"""SystemMessage Pydantic model accepts AudioChunk in content.""" - msg = SystemMessage(content=[AudioChunk(input_audio="dGVzdA==")]) - assert isinstance(msg.content, list) - assert len(msg.content) == 1 - assert isinstance(msg.content[0], AudioChunk) - - def test_system_message_rejects_image_chunk(self) -> None: - r"""SystemMessage Pydantic model rejects ImageURLChunk in content.""" - with pytest.raises(ValidationError, match="union_tag_invalid"): - SystemMessage(content=[ImageURLChunk(image_url="https://example.com/image.png")]) # type: ignore[list-item] - def test_rejects_audio_in_system_message(self, normalizer: InstructRequestNormalizer) -> None: r"""Pre-V7 normalizer rejects AudioChunk in system messages.""" request = mock_chat_completion( @@ -491,6 +414,17 @@ def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalize assert isinstance(tool_msg, ToolMessage) assert tool_msg.content == '{"key": "value", "num": 1}' + def test_rejects_think_in_assistant(self, normalizer: InstructRequestNormalizer) -> None: + r"""Pre-v11 normalizer rejects ThinkChunk in assistant messages.""" + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), + ] + ) + with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in assistant message"): + normalizer.from_chat_completion_request(request) + class TestChatCompletionRequestNormalizationV7: @pytest.fixture(autouse=True) @@ -681,7 +615,6 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma AssistantMessage(content=[TextChunk(text="")]), AssistantMessage( content=[ - ThinkChunk(thinking="T"), TextChunk(text="C"), TextChunk(text="D"), ] @@ -693,25 +626,18 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma ) first_message = parsed_request.messages[0] assert isinstance(first_message, AssistantMessage) - assert first_message.content == [ - TextChunk(text="A\n\nB"), - ThinkChunk(thinking="T"), - TextChunk(text="C\n\nD"), - ] + assert first_message.content == "A\n\nB\n\nC\n\nD" - def test_accepts_text_and_think_chunks(self, normalizer_v7: InstructRequestNormalizerV7) -> None: - r"""V7 normalizer accepts TextChunk and ThinkChunk in assistant messages.""" + def test_rejects_think_in_assistant(self, normalizer_v7: InstructRequestNormalizerV7) -> None: + r"""V7 normalizer rejects ThinkChunk in assistant messages (pre-v11).""" request = mock_chat_completion( messages=[ UserMessage(content="query"), AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), ], ) - parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) - assistant_msg = parsed.messages[1] - assert isinstance(assistant_msg, AssistantMessage) - assert isinstance(assistant_msg.content, list) - assert len(assistant_msg.content) == 2 + with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in assistant message"): + normalizer_v7.from_chat_completion_request(request) def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None: r"""V7 normalizer accepts string content in assistant messages.""" @@ -1128,6 +1054,32 @@ def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV assert isinstance(assistant_msg, AssistantMessage) assert assistant_msg.content == "plain text" + def test_assistant_think_chunk_aggregation(self, normalizer_v13: InstructRequestNormalizerV13) -> None: + r"""V13 normalizer preserves ThinkChunks in assistant message aggregation.""" + chat_completion_request = mock_chat_completion( + messages=[ + AssistantMessage(content="A"), + AssistantMessage(content=[TextChunk(text="B")]), + AssistantMessage( + content=[ + ThinkChunk(thinking="T"), + TextChunk(text="C"), + TextChunk(text="D"), + ] + ), + ] + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( + chat_completion_request + ) + first_message = parsed.messages[0] + assert isinstance(first_message, AssistantMessage) + assert first_message.content == [ + TextChunk(text="A\n\nB"), + ThinkChunk(thinking="T"), + TextChunk(text="C\n\nD"), + ] + def test_rejects_non_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: r"""V13 normalizer raises InvalidRequestException for non-text tool content.""" request = mock_chat_completion( From 53057e257a73dc24779a8ee1e4bdb512e9b15005 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Wed, 10 Jun 2026 10:31:33 +0200 Subject: [PATCH 31/47] Refactor normalizer tests to use Pydantic model equality assertions --- tests/test_normalization.py | 198 ++++++++++++------------------------ 1 file changed, 64 insertions(+), 134 deletions(-) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 2e25b931..27a7bdf5 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -100,9 +100,7 @@ def test_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> N parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "" + assert parsed_request.messages[0] == UserMessage(content="") assert parsed_request.system_prompt == "S" def test_assistant_content_with_tool_calls(self, normalizer: InstructRequestNormalizer) -> None: @@ -132,9 +130,7 @@ def test_assistant_system_user_adds_user(self, normalizer: InstructRequestNormal assert len(parsed_request.messages) == 3 # 1 user message added, system message removed - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "" + assert parsed_request.messages[0] == UserMessage(content="") assert parsed_request.system_prompt == "S" def check_merge( @@ -219,9 +215,7 @@ def test_normalize_chunks(self, normalizer: InstructRequestNormalizer) -> None: parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "foo\n\nchunk\n\nfoo\n\nchunk" + assert parsed_request.messages[0] == UserMessage(content="foo\n\nchunk\n\nfoo\n\nchunk") def test_many_chunks_in_user_message(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -237,9 +231,7 @@ def test_many_chunks_in_user_message(self, normalizer: InstructRequestNormalizer ], ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "foo\n\nchunk1\n\nchunk2\n\nchunk3" + assert parsed_request.messages[0] == UserMessage(content="foo\n\nchunk1\n\nchunk2\n\nchunk3") def test_ignore_middle_empty_text_chunks(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -261,12 +253,8 @@ def test_ignore_middle_empty_text_chunks(self, normalizer: InstructRequestNormal ] ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "U\n\nV" - second_message = parsed_request.messages[1] - assert isinstance(second_message, AssistantMessage) - assert second_message.content == "A\n\nB" + assert parsed_request.messages[0] == UserMessage(content="U\n\nV") + assert parsed_request.messages[1] == AssistantMessage(content="A\n\nB") def test_safety_prompt_aggregated(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = ChatCompletionRequest[ChatMessage]( @@ -280,9 +268,7 @@ def test_safety_prompt_aggregated(self, normalizer: InstructRequestNormalizer) - ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "user" + assert parsed_request.messages[0] == UserMessage(content="user") assert parsed_request.system_prompt == "system" def test_normalize_tools(self, normalizer: InstructRequestNormalizer) -> None: @@ -410,9 +396,7 @@ def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalize ], ) parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) - tool_msg = parsed.messages[2] - assert isinstance(tool_msg, ToolMessage) - assert tool_msg.content == '{"key": "value", "num": 1}' + assert parsed.messages[2] == ToolMessage(content='{"key": "value", "num": 1}', tool_call_id="c1") def test_rejects_think_in_assistant(self, normalizer: InstructRequestNormalizer) -> None: r"""Pre-v11 normalizer rejects ThinkChunk in assistant messages.""" @@ -444,14 +428,8 @@ def test_system_assistant_user_v7(self, normalizer_v7: InstructRequestNormalizer parsed_request: InstructRequest = normalizer_v7.from_chat_completion_request(chat_completion_request) - first_message = parsed_request.messages[0] - assert isinstance(first_message, SystemMessage) - assert first_message.content == "S" - - second_message = parsed_request.messages[1] - assert isinstance(second_message, AssistantMessage) - assert second_message.content == "A" - + assert parsed_request.messages[0] == SystemMessage(content="S") + assert parsed_request.messages[1] == AssistantMessage(content="A") assert parsed_request.system_prompt is None def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNormalizer) -> None: @@ -468,13 +446,8 @@ def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNorma assert len(parsed_request.messages) == 2 - first_message = parsed_request.messages[0] - assert isinstance(first_message, AssistantMessage) - assert first_message.content == "A" - - second_message = parsed_request.messages[1] - assert isinstance(second_message, SystemMessage) - assert second_message.content == "S" + assert parsed_request.messages[0] == AssistantMessage(content="A") + assert parsed_request.messages[1] == SystemMessage(content="S") def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -487,10 +460,9 @@ def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestN ) normalized_chat_req = normalizer_v7.from_chat_completion_request(chat_completion_request) - assert normalized_chat_req.messages[0].content == "A", normalized_chat_req.messages[0].content - assert len(normalized_chat_req.messages[0].tool_calls) == 1, normalized_chat_req.messages[0].tool_calls - assert normalized_chat_req.messages[0].tool_calls[0].function.name == "tool1", ( - normalized_chat_req.messages[0].tool_calls[0].function.name + assert normalized_chat_req.messages[0] == AssistantMessage( + content="A", + tool_calls=[ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "11"}'))], ) def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructRequestNormalizer) -> None: @@ -563,13 +535,8 @@ def test_only_empty_text_chunks(self, normalizer_v7: InstructRequestNormalizerV7 parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( chat_completion_request ) - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "" - second_message = parsed_request.messages[1] - assert isinstance(second_message, AssistantMessage) - # Empty string content is passed through directly - assert second_message.content == "" + assert parsed_request.messages[0] == UserMessage(content="") + assert parsed_request.messages[1] == AssistantMessage(content="") def test_complex_user_aggregation(self, normalizer_v7: InstructRequestNormalizerV7) -> None: """Complex multi-user-message aggregation with mixed str, chunks, empty, and non-text chunks.""" @@ -596,13 +563,13 @@ def test_complex_user_aggregation(self, normalizer_v7: InstructRequestNormalizer parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( chat_completion_request ) - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == [ - TextChunk(text="A\n\nB\n\nC\n\nD"), - ImageURLChunk(image_url="E"), - TextChunk(text="G\n\nH"), - ] + assert parsed_request.messages[0] == UserMessage( + content=[ + TextChunk(text="A\n\nB\n\nC\n\nD"), + ImageURLChunk(image_url="E"), + TextChunk(text="G\n\nH"), + ] + ) def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNormalizerV7) -> None: """Complex multi-assistant-message aggregation with mixed str, chunks, and empty content.""" @@ -624,9 +591,7 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( chat_completion_request ) - first_message = parsed_request.messages[0] - assert isinstance(first_message, AssistantMessage) - assert first_message.content == "A\n\nB\n\nC\n\nD" + assert parsed_request.messages[0] == AssistantMessage(content="A\n\nB\n\nC\n\nD") def test_rejects_think_in_assistant(self, normalizer_v7: InstructRequestNormalizerV7) -> None: r"""V7 normalizer rejects ThinkChunk in assistant messages (pre-v11).""" @@ -648,9 +613,7 @@ def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7 ], ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) - assistant_msg = parsed.messages[1] - assert isinstance(assistant_msg, AssistantMessage) - assert assistant_msg.content == "plain text" + assert parsed.messages[1] == AssistantMessage(content="plain text") def test_skips_json_normalization_on_tool_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None: r"""V7+ normalizers do not JSON-normalize tool message content.""" @@ -663,9 +626,7 @@ def test_skips_json_normalization_on_tool_content(self, normalizer_v7: InstructR ], ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) - tool_msg = parsed.messages[2] - assert isinstance(tool_msg, ToolMessage) - assert tool_msg.content == messy_json + assert parsed.messages[2] == ToolMessage(content=messy_json, tool_call_id="c1") def test_preserves_audio_in_system_message(self, normalizer_v7: InstructRequestNormalizerV7) -> None: r"""V7 normalizer preserves AudioChunk in system messages.""" @@ -676,12 +637,12 @@ def test_preserves_audio_in_system_message(self, normalizer_v7: InstructRequestN ] ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) - system_msg = parsed.messages[0] - assert isinstance(system_msg, SystemMessage) - assert isinstance(system_msg.content, list) - assert len(system_msg.content) == 2 - assert isinstance(system_msg.content[0], TextChunk) - assert isinstance(system_msg.content[1], AudioChunk) + assert parsed.messages[0] == SystemMessage( + content=[ + TextChunk(text="hello"), + AudioChunk(input_audio=b"fake_audio_data"), + ] + ) class TestFineTuningNormalizer: @@ -1036,10 +997,9 @@ def test_accepts_text_and_think_chunks(self, normalizer_v13: InstructRequestNorm ], ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) - assistant_msg = parsed.messages[1] - assert isinstance(assistant_msg, AssistantMessage) - assert isinstance(assistant_msg.content, list) - assert len(assistant_msg.content) == 2 + assert parsed.messages[1] == AssistantMessage( + content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")] + ) def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: r"""V13 normalizer accepts string content in assistant messages.""" @@ -1050,9 +1010,7 @@ def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV ], ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) - assistant_msg = parsed.messages[1] - assert isinstance(assistant_msg, AssistantMessage) - assert assistant_msg.content == "plain text" + assert parsed.messages[1] == AssistantMessage(content="plain text") def test_assistant_think_chunk_aggregation(self, normalizer_v13: InstructRequestNormalizerV13) -> None: r"""V13 normalizer preserves ThinkChunks in assistant message aggregation.""" @@ -1072,13 +1030,13 @@ def test_assistant_think_chunk_aggregation(self, normalizer_v13: InstructRequest parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - first_message = parsed.messages[0] - assert isinstance(first_message, AssistantMessage) - assert first_message.content == [ - TextChunk(text="A\n\nB"), - ThinkChunk(thinking="T"), - TextChunk(text="C\n\nD"), - ] + assert parsed.messages[0] == AssistantMessage( + content=[ + TextChunk(text="A\n\nB"), + ThinkChunk(thinking="T"), + TextChunk(text="C\n\nD"), + ] + ) def test_rejects_non_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: r"""V13 normalizer raises InvalidRequestException for non-text tool content.""" @@ -1107,10 +1065,7 @@ def test_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNorma ], ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) - tool_msg = parsed.messages[2] - assert isinstance(tool_msg, ToolMessage) - assert isinstance(tool_msg.content, str) - assert tool_msg.content == "hello\n\nworld" + assert parsed.messages[2] == ToolMessage(content="hello\n\nworld", tool_call_id="c1") def test_rejects_audio_in_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: r"""V13 normalizer raises InvalidRequestException for audio tool content.""" @@ -1179,12 +1134,8 @@ def test_v15_intra_message_chunks_joined_without_separator( reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - user_msg = parsed.messages[0] - assert isinstance(user_msg, UserMessage) - assert user_msg.content == "AB" - assistant_msg = parsed.messages[1] - assert isinstance(assistant_msg, AssistantMessage) - assert assistant_msg.content == "CD" + assert parsed.messages[0] == UserMessage(content="AB") + assert parsed.messages[1] == AssistantMessage(content="CD") def test_v15_inter_message_join_still_uses_separator(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 still joins text across different messages with '\n\n'.""" @@ -1197,9 +1148,7 @@ def test_v15_inter_message_join_still_uses_separator(self, normalizer_v15: Instr reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - user_msg = parsed.messages[0] - assert isinstance(user_msg, UserMessage) - assert user_msg.content == "First\n\nSecond" + assert parsed.messages[0] == UserMessage(content="First\n\nSecond") def test_v15_mixed_intra_and_inter_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 combines intra-message ('') and inter-message ('\n\n') joining.""" @@ -1212,9 +1161,7 @@ def test_v15_mixed_intra_and_inter_message(self, normalizer_v15: InstructRequest reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - user_msg = parsed.messages[0] - assert isinstance(user_msg, UserMessage) - assert user_msg.content == "AB\n\nCD" + assert parsed.messages[0] == UserMessage(content="AB\n\nCD") def test_v15_mixed_intra_and_inter_assistant_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 combines intra-message ('') and inter-message ('\n\n') joining for assistant messages.""" @@ -1227,9 +1174,7 @@ def test_v15_mixed_intra_and_inter_assistant_messages(self, normalizer_v15: Inst reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assistant_msg = parsed.messages[1] - assert isinstance(assistant_msg, AssistantMessage) - assert assistant_msg.content == "AB\n\nCD" + assert parsed.messages[1] == AssistantMessage(content="AB\n\nCD") def test_v15_tool_message_text_chunks_joined_without_separator( self, normalizer_v15: InstructRequestNormalizerV15 @@ -1244,9 +1189,7 @@ def test_v15_tool_message_text_chunks_joined_without_separator( reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - tool_msg = parsed.messages[2] - assert isinstance(tool_msg, ToolMessage) - assert tool_msg.content == "XY" + assert parsed.messages[2] == ToolMessage(content="XY", tool_call_id="c1") def test_accepts_text_and_think_chunks(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 normalizer accepts TextChunk and ThinkChunk in assistant messages.""" @@ -1258,10 +1201,9 @@ def test_accepts_text_and_think_chunks(self, normalizer_v15: InstructRequestNorm reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assistant_msg = parsed.messages[1] - assert isinstance(assistant_msg, AssistantMessage) - assert isinstance(assistant_msg.content, list) - assert len(assistant_msg.content) == 2 + assert parsed.messages[1] == AssistantMessage( + content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")] + ) def test_accepts_string_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 normalizer accepts string content in assistant messages.""" @@ -1273,9 +1215,7 @@ def test_accepts_string_content(self, normalizer_v15: InstructRequestNormalizerV reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assistant_msg = parsed.messages[1] - assert isinstance(assistant_msg, AssistantMessage) - assert assistant_msg.content == "plain text" + assert parsed.messages[1] == AssistantMessage(content="plain text") def test_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 normalizer preserves non-text chunks in tool messages.""" @@ -1289,10 +1229,7 @@ def test_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNo reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - tool_msg = parsed.messages[2] - assert isinstance(tool_msg, ToolMessage) - assert isinstance(tool_msg.content, list) - assert tool_msg.content == [image_chunk] + assert parsed.messages[2] == ToolMessage(content=[image_chunk], tool_call_id="c1") def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 normalizer sorts multimodal tool messages by tool call order.""" @@ -1314,15 +1251,8 @@ def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNor ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - tool_msg_1 = parsed.messages[2] - assert isinstance(tool_msg_1, ToolMessage) - assert tool_msg_1.tool_call_id == "c1" - assert tool_msg_1.content == [image_chunk_1] - - tool_msg_2 = parsed.messages[3] - assert isinstance(tool_msg_2, ToolMessage) - assert tool_msg_2.tool_call_id == "c2" - assert tool_msg_2.content == [image_chunk_2] + assert parsed.messages[2] == ToolMessage(content=[image_chunk_1], tool_call_id="c1") + assert parsed.messages[3] == ToolMessage(content=[image_chunk_2], tool_call_id="c2") def test_rejects_think_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 normalizer rejects ThinkChunk in system messages.""" @@ -1345,12 +1275,12 @@ def test_preserves_audio_in_system_message(self, normalizer_v15: InstructRequest reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - system_msg = parsed.messages[0] - assert isinstance(system_msg, SystemMessage) - assert isinstance(system_msg.content, list) - assert len(system_msg.content) == 2 - assert isinstance(system_msg.content[0], TextChunk) - assert isinstance(system_msg.content[1], AudioChunk) + assert parsed.messages[0] == SystemMessage( + content=[ + TextChunk(text="hello"), + AudioChunk(input_audio=b"fake_audio_data"), + ] + ) @pytest.mark.parametrize( From 66af65b3be0f7e23082a84f11c6d7194ba807c0a Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Wed, 10 Jun 2026 11:21:43 +0200 Subject: [PATCH 32/47] Add back intra-message ThinkChunk aggregation test to V13 class --- tests/test_normalization.py | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 27a7bdf5..fe7041b7 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -1012,8 +1012,10 @@ def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) assert parsed.messages[1] == AssistantMessage(content="plain text") - def test_assistant_think_chunk_aggregation(self, normalizer_v13: InstructRequestNormalizerV13) -> None: - r"""V13 normalizer preserves ThinkChunks in assistant message aggregation.""" + def test_assistant_think_chunk_inter_message_aggregation( + self, normalizer_v13: InstructRequestNormalizerV13 + ) -> None: + r"""V13 normalizer preserves ThinkChunks across multiple assistant messages.""" chat_completion_request = mock_chat_completion( messages=[ AssistantMessage(content="A"), @@ -1038,6 +1040,36 @@ def test_assistant_think_chunk_aggregation(self, normalizer_v13: InstructRequest ] ) + def test_assistant_think_chunk_intra_message_aggregation( + self, normalizer_v13: InstructRequestNormalizerV13 + ) -> None: + r"""V13 normalizer coalesces TextChunks and preserves multiple ThinkChunks within a single message.""" + chat_completion_request = mock_chat_completion( + messages=[ + UserMessage(content="u"), + AssistantMessage( + content=[ + ThinkChunk(thinking="t1"), + ThinkChunk(thinking="t2"), + TextChunk(text="a1"), + TextChunk(text="a2"), + TextChunk(text="a3"), + ] + ), + UserMessage(content="u"), + ] + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( + chat_completion_request + ) + assert parsed.messages[1] == AssistantMessage( + content=[ + ThinkChunk(thinking="t1"), + ThinkChunk(thinking="t2"), + TextChunk(text="a1\n\na2\n\na3"), + ] + ) + def test_rejects_non_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: r"""V13 normalizer raises InvalidRequestException for non-text tool content.""" request = mock_chat_completion( From 864d127ea2828d58d5bd099f782684232309a359 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Wed, 10 Jun 2026 11:35:22 +0200 Subject: [PATCH 33/47] Assert full InstructRequest output in normalizer tests --- tests/test_normalization.py | 680 +++++++++++++++++++++++------------- 1 file changed, 432 insertions(+), 248 deletions(-) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index fe7041b7..26f7d8be 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -6,7 +6,6 @@ from mistral_common.protocol.instruct.chunk import ( AudioChunk, ChunkTypes, - ContentChunk, ImageURLChunk, TextChunk, ThinkChunk, @@ -64,7 +63,10 @@ def test_user_system_user(self, normalizer: InstructRequestNormalizer) -> None: ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - assert parsed_request.system_prompt == "S" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="U"), UserMessage(content="U")], + system_prompt="S", + ) def test_multiple_system(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -76,7 +78,10 @@ def test_multiple_system(self, normalizer: InstructRequestNormalizer) -> None: ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - assert parsed_request.system_prompt == "S\n\nS\n\nS" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="")], + system_prompt="S\n\nS\n\nS", + ) def test_single_system(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -86,8 +91,10 @@ def test_single_system(self, normalizer: InstructRequestNormalizer) -> None: ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - - assert parsed_request.system_prompt == "S" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="")], + system_prompt="S", + ) def test_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -99,9 +106,10 @@ def test_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> N ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - - assert parsed_request.messages[0] == UserMessage(content="") - assert parsed_request.system_prompt == "S" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content=""), AssistantMessage(content="A"), UserMessage(content="U")], + system_prompt="S", + ) def test_assistant_content_with_tool_calls(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -125,58 +133,89 @@ def test_assistant_system_user_adds_user(self, normalizer: InstructRequestNormal ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - - assert parsed_request.system_prompt == "S" - - assert len(parsed_request.messages) == 3 # 1 user message added, system message removed - - assert parsed_request.messages[0] == UserMessage(content="") - assert parsed_request.system_prompt == "S" - - def check_merge( - self, - roles: list[str], - expected_roles: list[str], - expected_content: list[list[ContentChunk] | str], - normalizer: InstructRequestNormalizer, - ) -> None: - letter_to_cls: dict[str, ChatMessage] = { - "s": SystemMessage(content="s"), - "u": UserMessage(content="u"), - "a": AssistantMessage(content="a"), - "u2": UserMessage(content="u2"), - } - - chat_completion_request = mock_chat_completion( - messages=[letter_to_cls[r] for r in roles], + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content=""), AssistantMessage(content="A"), UserMessage(content="U")], + system_prompt="S", + ) + + def test_message_aggregation_system_then_user(self, normalizer: InstructRequestNormalizer) -> None: + parsed = normalizer.from_chat_completion_request( + mock_chat_completion( + [ + SystemMessage(content="s"), + SystemMessage(content="s"), + SystemMessage(content="s"), + UserMessage(content="u"), + ] + ) + ) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="u")], system_prompt="s\n\ns\n\ns" ) - parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - assert len(parsed_request.messages) == len(expected_roles) - assert [message.role for message in parsed_request.messages] == [ - letter_to_cls[role].role for role in expected_roles - ] - assert len(expected_content) == len(parsed_request.messages) - for x, expected in zip(parsed_request.messages, expected_content): - assert isinstance(x, (UserMessage, AssistantMessage)) - assert x.content == expected - - def test_message_aggregation(self, normalizer: InstructRequestNormalizer) -> None: - self.check_merge(["s", "s", "s", "u"], ["u"], ["u"], normalizer) - self.check_merge(["s", "s", "s", "u", "u"], ["u"], ["u\n\nu"], normalizer) - self.check_merge(["s", "s", "s", "u", "u", "s", "a", "u"], ["u", "a", "u"], ["u\n\nu", "a", "u"], normalizer) - self.check_merge( - ["s", "s", "s", "u", "u", "a", "a", "u"], - ["u", "a", "u"], - ["u\n\nu", "a\n\na", "u"], - normalizer, + def test_message_aggregation_system_then_users(self, normalizer: InstructRequestNormalizer) -> None: + parsed = normalizer.from_chat_completion_request( + mock_chat_completion( + [ + SystemMessage(content="s"), + SystemMessage(content="s"), + SystemMessage(content="s"), + UserMessage(content="u"), + UserMessage(content="u"), + ] + ) + ) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="u\n\nu")], system_prompt="s\n\ns\n\ns" + ) + + def test_message_aggregation_mixed_with_middle_system(self, normalizer: InstructRequestNormalizer) -> None: + parsed = normalizer.from_chat_completion_request( + mock_chat_completion( + [ + SystemMessage(content="s"), + SystemMessage(content="s"), + SystemMessage(content="s"), + UserMessage(content="u"), + UserMessage(content="u"), + SystemMessage(content="s"), + AssistantMessage(content="a"), + UserMessage(content="u"), + ] + ) + ) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="u\n\nu"), AssistantMessage(content="a"), UserMessage(content="u")], + system_prompt="s\n\ns\n\ns\n\ns", + ) + + def test_message_aggregation_consecutive_assistants(self, normalizer: InstructRequestNormalizer) -> None: + parsed = normalizer.from_chat_completion_request( + mock_chat_completion( + [ + SystemMessage(content="s"), + SystemMessage(content="s"), + SystemMessage(content="s"), + UserMessage(content="u"), + UserMessage(content="u"), + AssistantMessage(content="a"), + AssistantMessage(content="a"), + UserMessage(content="u"), + ] + ) + ) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="u\n\nu"), AssistantMessage(content="a\n\na"), UserMessage(content="u")], + system_prompt="s\n\ns\n\ns", ) - self.check_merge( - ["s", "a", "u"], - ["u", "a", "u"], - ["", "a", "u"], - normalizer, + def test_message_aggregation_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> None: + parsed = normalizer.from_chat_completion_request( + mock_chat_completion([SystemMessage(content="s"), AssistantMessage(content="a"), UserMessage(content="u")]) + ) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content=""), AssistantMessage(content="a"), UserMessage(content="u")], + system_prompt="s", ) def test_tool_chunk_aggregation(self, normalizer: InstructRequestNormalizer) -> None: @@ -214,8 +253,9 @@ def test_normalize_chunks(self, normalizer: InstructRequestNormalizer) -> None: ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - - assert parsed_request.messages[0] == UserMessage(content="foo\n\nchunk\n\nfoo\n\nchunk") + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="foo\n\nchunk\n\nfoo\n\nchunk")] + ) def test_many_chunks_in_user_message(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -231,7 +271,9 @@ def test_many_chunks_in_user_message(self, normalizer: InstructRequestNormalizer ], ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - assert parsed_request.messages[0] == UserMessage(content="foo\n\nchunk1\n\nchunk2\n\nchunk3") + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="foo\n\nchunk1\n\nchunk2\n\nchunk3")] + ) def test_ignore_middle_empty_text_chunks(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -253,8 +295,9 @@ def test_ignore_middle_empty_text_chunks(self, normalizer: InstructRequestNormal ] ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - assert parsed_request.messages[0] == UserMessage(content="U\n\nV") - assert parsed_request.messages[1] == AssistantMessage(content="A\n\nB") + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="U\n\nV"), AssistantMessage(content="A\n\nB")], + ) def test_safety_prompt_aggregated(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = ChatCompletionRequest[ChatMessage]( @@ -268,8 +311,10 @@ def test_safety_prompt_aggregated(self, normalizer: InstructRequestNormalizer) - ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - assert parsed_request.messages[0] == UserMessage(content="user") - assert parsed_request.system_prompt == "system" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="user")], + system_prompt="system", + ) def test_normalize_tools(self, normalizer: InstructRequestNormalizer) -> None: """ @@ -348,18 +393,19 @@ def test_assert_parsed_settings( normalizer: InstructRequestNormalizer, ) -> None: chat_completion_request = ChatCompletionRequest(messages=[UserMessage(content="B")]) - parsed_request: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request( - chat_completion_request - ) - assert parsed_request.settings == ModelSettings.none() + parsed_request = normalizer.from_chat_completion_request(chat_completion_request) + assert parsed_request == InstructRequest[ChatMessage, Tool](messages=[UserMessage(content="B")]) def test_continue_final_message_forwarded(self, normalizer: InstructRequestNormalizer) -> None: request = ChatCompletionRequest[ChatMessage]( messages=[UserMessage(content="a"), AssistantMessage(content="b")], continue_final_message=True, ) - result: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) - assert result.continue_final_message is True + result = normalizer.from_chat_completion_request(request) + assert result == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="a"), AssistantMessage(content="b")], + continue_final_message=True, + ) def test_rejects_audio_in_system_message(self, normalizer: InstructRequestNormalizer) -> None: r"""Pre-V7 normalizer rejects AudioChunk in system messages.""" @@ -395,8 +441,17 @@ def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalize ToolMessage(content=messy_json, tool_call_id="c1"), ], ) - parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) - assert parsed.messages[2] == ToolMessage(content='{"key": "value", "num": 1}', tool_call_id="c1") + parsed = normalizer.from_chat_completion_request(request) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage( + content="", + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")], + ), + ToolMessage(content='{"key": "value", "num": 1}', tool_call_id="c1"), + ], + ) def test_rejects_think_in_assistant(self, normalizer: InstructRequestNormalizer) -> None: r"""Pre-v11 normalizer rejects ThinkChunk in assistant messages.""" @@ -426,13 +481,14 @@ def test_system_assistant_user_v7(self, normalizer_v7: InstructRequestNormalizer ] ) - parsed_request: InstructRequest = normalizer_v7.from_chat_completion_request(chat_completion_request) - - assert parsed_request.messages[0] == SystemMessage(content="S") - assert parsed_request.messages[1] == AssistantMessage(content="A") - assert parsed_request.system_prompt is None + parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( + chat_completion_request + ) + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[SystemMessage(content="S"), AssistantMessage(content="A"), UserMessage(content="U")], + ) - def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNormalizer) -> None: + def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNormalizerV7) -> None: chat_completion_request = mock_chat_completion( messages=[ AssistantMessage(content="A"), @@ -440,16 +496,14 @@ def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNorma ] ) - parsed_request = normalizer_v7.from_chat_completion_request(chat_completion_request) - - assert parsed_request.system_prompt is None - - assert len(parsed_request.messages) == 2 - - assert parsed_request.messages[0] == AssistantMessage(content="A") - assert parsed_request.messages[1] == SystemMessage(content="S") + parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( + chat_completion_request + ) + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[AssistantMessage(content="A"), SystemMessage(content="S")], + ) - def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestNormalizer) -> None: + def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestNormalizerV7) -> None: chat_completion_request = mock_chat_completion( messages=[ AssistantMessage( @@ -458,27 +512,28 @@ def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestN ) ] ) - normalized_chat_req = normalizer_v7.from_chat_completion_request(chat_completion_request) - - assert normalized_chat_req.messages[0] == AssistantMessage( - content="A", - tool_calls=[ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "11"}'))], + normalized_chat_req: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( + chat_completion_request + ) + assert normalized_chat_req == InstructRequest[ChatMessage, Tool]( + messages=[ + AssistantMessage( + content="A", + tool_calls=[ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "11"}'))], + ), + ], ) - def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructRequestNormalizer) -> None: + def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructRequestNormalizerV7) -> None: chat_completion_request = mock_chat_completion( messages=[ UserMessage(content="A1"), - AssistantMessage( - content="B1", - ), + AssistantMessage(content="B1"), AssistantMessage( content="B2", tool_calls=[ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "1"}'))], ), - AssistantMessage( - content="B3", - ), + AssistantMessage(content="B3"), AssistantMessage( content="B4", tool_calls=[ @@ -486,24 +541,27 @@ def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructReq ToolCall(function=FunctionCall(name="tool22", arguments='{"input": "22"}')), ], ), + AssistantMessage(content="B5"), + UserMessage(content="C1"), + ] + ) + normalized_chat_req: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( + chat_completion_request + ) + assert normalized_chat_req == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="A1"), AssistantMessage( - content="B5", + content="B1\n\nB2\n\nB3\n\nB4\n\nB5", + tool_calls=[ + ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "1"}')), + ToolCall(function=FunctionCall(name="tool21", arguments='{"input": "21"}')), + ToolCall(function=FunctionCall(name="tool22", arguments='{"input": "22"}')), + ], ), UserMessage(content="C1"), - ] + ], ) - normalized_chat_req = normalizer_v7.from_chat_completion_request(chat_completion_request) - - assert normalized_chat_req.messages[0].content == "A1" - assert normalized_chat_req.messages[1].content.split("\n\n") == [f"B{i}" for i in range(1, 6)] - - tool_calls = normalized_chat_req.messages[1].tool_calls - - assert len(tool_calls) == 3 - - tool_key = ["1", "21", "22"] - assert all([t.function.name == f"tool{tool_key[i]}" for i, t in enumerate(tool_calls)]) - assert all([json.loads(t.function.arguments)["input"] == tool_key[i] for i, t in enumerate(tool_calls)]) def test_assert_parsed_settings( self, @@ -513,7 +571,7 @@ def test_assert_parsed_settings( parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( chat_completion_request ) - assert parsed_request.settings == ModelSettings.none() + assert parsed_request == InstructRequest[ChatMessage, Tool](messages=[UserMessage(content="B")]) def test_continue_final_message_forwarded(self, normalizer_v7: InstructRequestNormalizerV7) -> None: request = ChatCompletionRequest[ChatMessage]( @@ -521,7 +579,10 @@ def test_continue_final_message_forwarded(self, normalizer_v7: InstructRequestNo continue_final_message=True, ) result: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) - assert result.continue_final_message is True + assert result == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="a"), AssistantMessage(content="b")], + continue_final_message=True, + ) @pytest.mark.parametrize("num_empty", [0, 1, 2]) def test_only_empty_text_chunks(self, normalizer_v7: InstructRequestNormalizerV7, num_empty: int) -> None: @@ -535,8 +596,9 @@ def test_only_empty_text_chunks(self, normalizer_v7: InstructRequestNormalizerV7 parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages[0] == UserMessage(content="") - assert parsed_request.messages[1] == AssistantMessage(content="") + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content=""), AssistantMessage(content="")], + ) def test_complex_user_aggregation(self, normalizer_v7: InstructRequestNormalizerV7) -> None: """Complex multi-user-message aggregation with mixed str, chunks, empty, and non-text chunks.""" @@ -563,12 +625,16 @@ def test_complex_user_aggregation(self, normalizer_v7: InstructRequestNormalizer parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages[0] == UserMessage( - content=[ - TextChunk(text="A\n\nB\n\nC\n\nD"), - ImageURLChunk(image_url="E"), - TextChunk(text="G\n\nH"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage( + content=[ + TextChunk(text="A\n\nB\n\nC\n\nD"), + ImageURLChunk(image_url="E"), + TextChunk(text="G\n\nH"), + ] + ), + ], ) def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNormalizerV7) -> None: @@ -591,7 +657,9 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages[0] == AssistantMessage(content="A\n\nB\n\nC\n\nD") + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[AssistantMessage(content="A\n\nB\n\nC\n\nD")], + ) def test_rejects_think_in_assistant(self, normalizer_v7: InstructRequestNormalizerV7) -> None: r"""V7 normalizer rejects ThinkChunk in assistant messages (pre-v11).""" @@ -613,7 +681,9 @@ def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7 ], ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) - assert parsed.messages[1] == AssistantMessage(content="plain text") + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="query"), AssistantMessage(content="plain text")], + ) def test_skips_json_normalization_on_tool_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None: r"""V7+ normalizers do not JSON-normalize tool message content.""" @@ -626,7 +696,16 @@ def test_skips_json_normalization_on_tool_content(self, normalizer_v7: InstructR ], ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) - assert parsed.messages[2] == ToolMessage(content=messy_json, tool_call_id="c1") + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage( + content="", + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")], + ), + ToolMessage(content=messy_json, tool_call_id="c1"), + ], + ) def test_preserves_audio_in_system_message(self, normalizer_v7: InstructRequestNormalizerV7) -> None: r"""V7 normalizer preserves AudioChunk in system messages.""" @@ -637,11 +716,13 @@ def test_preserves_audio_in_system_message(self, normalizer_v7: InstructRequestN ] ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) - assert parsed.messages[0] == SystemMessage( - content=[ - TextChunk(text="hello"), - AudioChunk(input_audio=b"fake_audio_data"), - ] + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + SystemMessage( + content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")], + ), + UserMessage(content="test"), + ], ) @@ -733,18 +814,20 @@ def test_no_reorder_tool_messages(self, normalizer_v13: InstructRequestNormalize parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - UserMessage(content="A"), - AssistantMessage( - content="B", - tool_calls=[ - ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), - ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), - ], - ), - ToolMessage(content="C", tool_call_id="1"), - ToolMessage(content="D", tool_call_id="2"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="A"), + AssistantMessage( + content="B", + tool_calls=[ + ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), + ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), + ], + ), + ToolMessage(content="C", tool_call_id="1"), + ToolMessage(content="D", tool_call_id="2"), + ], + ) def test_reorder_last_tool_messages(self, normalizer_v13: InstructRequestNormalizerV13) -> None: chat_completion_request: ChatCompletionRequest = self._mock_chat_completion( @@ -764,18 +847,20 @@ def test_reorder_last_tool_messages(self, normalizer_v13: InstructRequestNormali parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - UserMessage(content="A"), - AssistantMessage( - content="B", - tool_calls=[ - ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), - ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), - ], - ), - ToolMessage(content="C", tool_call_id="1"), - ToolMessage(content="D", tool_call_id="2"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="A"), + AssistantMessage( + content="B", + tool_calls=[ + ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), + ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), + ], + ), + ToolMessage(content="C", tool_call_id="1"), + ToolMessage(content="D", tool_call_id="2"), + ], + ) def test_reorder_internal_tool_messages(self, normalizer_v13: InstructRequestNormalizerV13) -> None: chat_completion_request: ChatCompletionRequest = self._mock_chat_completion( @@ -796,19 +881,21 @@ def test_reorder_internal_tool_messages(self, normalizer_v13: InstructRequestNor parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - UserMessage(content="A"), - AssistantMessage( - content="B", - tool_calls=[ - ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), - ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), - ], - ), - ToolMessage(content="C", tool_call_id="1"), - ToolMessage(content="D", tool_call_id="2"), - AssistantMessage(content="E"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="A"), + AssistantMessage( + content="B", + tool_calls=[ + ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), + ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), + ], + ), + ToolMessage(content="C", tool_call_id="1"), + ToolMessage(content="D", tool_call_id="2"), + AssistantMessage(content="E"), + ], + ) def test_reorder_extra_tool_messages(self, normalizer_v13: InstructRequestNormalizerV13) -> None: chat_completion_request: ChatCompletionRequest = self._mock_chat_completion( @@ -827,17 +914,19 @@ def test_reorder_extra_tool_messages(self, normalizer_v13: InstructRequestNormal parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - UserMessage(content="A"), - AssistantMessage( - content="B", - tool_calls=[ - ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), - ], - ), - ToolMessage(content="C", tool_call_id="1"), - ToolMessage(content="D", tool_call_id="2"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="A"), + AssistantMessage( + content="B", + tool_calls=[ + ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), + ], + ), + ToolMessage(content="C", tool_call_id="1"), + ToolMessage(content="D", tool_call_id="2"), + ], + ) def test_reorder_only_from_latest_assistant_message(self, normalizer_v13: InstructRequestNormalizerV13) -> None: chat_completion_request: ChatCompletionRequest = self._mock_chat_completion( @@ -866,27 +955,29 @@ def test_reorder_only_from_latest_assistant_message(self, normalizer_v13: Instru parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - UserMessage(content="A"), - AssistantMessage( - content="B", - tool_calls=[ - ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), - ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), - ], - ), - ToolMessage(content="C", tool_call_id="1"), - ToolMessage(content="D", tool_call_id="2"), - AssistantMessage( - content="E", - tool_calls=[ - ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), - ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), - ], - ), - ToolMessage(content="D", tool_call_id="2"), - ToolMessage(content="C", tool_call_id="1"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="A"), + AssistantMessage( + content="B", + tool_calls=[ + ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), + ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), + ], + ), + ToolMessage(content="C", tool_call_id="1"), + ToolMessage(content="D", tool_call_id="2"), + AssistantMessage( + content="E", + tool_calls=[ + ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), + ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), + ], + ), + ToolMessage(content="D", tool_call_id="2"), + ToolMessage(content="C", tool_call_id="1"), + ], + ) @pytest.mark.parametrize( ["system_message", "expected_system_message"], @@ -932,7 +1023,9 @@ def test_aggregate_system_prompt_content( parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [expected_system_message, UserMessage(content="B")] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[expected_system_message, UserMessage(content="B")] + ) def test_system_messages_no_aggregation(self, normalizer_v13: InstructRequestNormalizerV13) -> None: """Consecutive system messages are NOT aggregated into one in V7+.""" @@ -947,12 +1040,14 @@ def test_system_messages_no_aggregation(self, normalizer_v13: InstructRequestNor parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - SystemMessage(content="A"), - SystemMessage(content="B"), - UserMessage(content="C"), - SystemMessage(content="D"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + SystemMessage(content="A"), + SystemMessage(content="B"), + UserMessage(content="C"), + SystemMessage(content="D"), + ], + ) def test_system_messages_normalization(self, normalizer_v13: InstructRequestNormalizerV13) -> None: """System message chunks within the same message are aggregated with no separator.""" @@ -965,10 +1060,12 @@ def test_system_messages_normalization(self, normalizer_v13: InstructRequestNorm parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - SystemMessage(content="A\n\nB"), - SystemMessage(content="C"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + SystemMessage(content="A\n\nB"), + SystemMessage(content="C"), + ], + ) def test_assert_parsed_settings( self, @@ -978,7 +1075,7 @@ def test_assert_parsed_settings( parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.settings == ModelSettings.none() + assert parsed_request == InstructRequest[ChatMessage, Tool](messages=[UserMessage(content="B")]) def test_continue_final_message_forwarded(self, normalizer_v13: InstructRequestNormalizerV13) -> None: request = ChatCompletionRequest[ChatMessage]( @@ -986,7 +1083,10 @@ def test_continue_final_message_forwarded(self, normalizer_v13: InstructRequestN continue_final_message=True, ) result: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) - assert result.continue_final_message is True + assert result == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="a"), AssistantMessage(content="b")], + continue_final_message=True, + ) def test_accepts_text_and_think_chunks(self, normalizer_v13: InstructRequestNormalizerV13) -> None: r"""V13 normalizer accepts TextChunk and ThinkChunk in assistant messages.""" @@ -997,8 +1097,11 @@ def test_accepts_text_and_think_chunks(self, normalizer_v13: InstructRequestNorm ], ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) - assert parsed.messages[1] == AssistantMessage( - content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")] + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), + ], ) def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: @@ -1010,7 +1113,9 @@ def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV ], ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) - assert parsed.messages[1] == AssistantMessage(content="plain text") + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="query"), AssistantMessage(content="plain text")], + ) def test_assistant_think_chunk_inter_message_aggregation( self, normalizer_v13: InstructRequestNormalizerV13 @@ -1032,12 +1137,16 @@ def test_assistant_think_chunk_inter_message_aggregation( parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed.messages[0] == AssistantMessage( - content=[ - TextChunk(text="A\n\nB"), - ThinkChunk(thinking="T"), - TextChunk(text="C\n\nD"), - ] + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + AssistantMessage( + content=[ + TextChunk(text="A\n\nB"), + ThinkChunk(thinking="T"), + TextChunk(text="C\n\nD"), + ] + ), + ], ) def test_assistant_think_chunk_intra_message_aggregation( @@ -1062,12 +1171,18 @@ def test_assistant_think_chunk_intra_message_aggregation( parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed.messages[1] == AssistantMessage( - content=[ - ThinkChunk(thinking="t1"), - ThinkChunk(thinking="t2"), - TextChunk(text="a1\n\na2\n\na3"), - ] + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="u"), + AssistantMessage( + content=[ + ThinkChunk(thinking="t1"), + ThinkChunk(thinking="t2"), + TextChunk(text="a1\n\na2\n\na3"), + ] + ), + UserMessage(content="u"), + ], ) def test_rejects_non_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: @@ -1097,7 +1212,16 @@ def test_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNorma ], ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) - assert parsed.messages[2] == ToolMessage(content="hello\n\nworld", tool_call_id="c1") + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage( + content="", + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")], + ), + ToolMessage(content="hello\n\nworld", tool_call_id="c1"), + ], + ) def test_rejects_audio_in_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: r"""V13 normalizer raises InvalidRequestException for audio tool content.""" @@ -1143,7 +1267,10 @@ def test_assert_parsed_settings( parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request( chat_completion_request ) - assert parsed_request.settings == ModelSettings(reasoning_effort=reasoning_effort) + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="B")], + settings=ModelSettings(reasoning_effort=reasoning_effort), + ) def test_continue_final_message_forwarded(self, normalizer_v15: InstructRequestNormalizerV15) -> None: request = ChatCompletionRequest[ChatMessage]( @@ -1152,7 +1279,11 @@ def test_continue_final_message_forwarded(self, normalizer_v15: InstructRequestN reasoning_effort=ReasoningEffort.high, ) result: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assert result.continue_final_message is True + assert result == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="a"), AssistantMessage(content="b")], + continue_final_message=True, + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_v15_intra_message_chunks_joined_without_separator( self, normalizer_v15: InstructRequestNormalizerV15 @@ -1166,8 +1297,10 @@ def test_v15_intra_message_chunks_joined_without_separator( reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assert parsed.messages[0] == UserMessage(content="AB") - assert parsed.messages[1] == AssistantMessage(content="CD") + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="AB"), AssistantMessage(content="CD")], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_v15_inter_message_join_still_uses_separator(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 still joins text across different messages with '\n\n'.""" @@ -1180,7 +1313,10 @@ def test_v15_inter_message_join_still_uses_separator(self, normalizer_v15: Instr reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assert parsed.messages[0] == UserMessage(content="First\n\nSecond") + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="First\n\nSecond"), AssistantMessage(content="Reply")], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_v15_mixed_intra_and_inter_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 combines intra-message ('') and inter-message ('\n\n') joining.""" @@ -1193,7 +1329,10 @@ def test_v15_mixed_intra_and_inter_message(self, normalizer_v15: InstructRequest reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assert parsed.messages[0] == UserMessage(content="AB\n\nCD") + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="AB\n\nCD"), AssistantMessage(content="Reply")], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_v15_mixed_intra_and_inter_assistant_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 combines intra-message ('') and inter-message ('\n\n') joining for assistant messages.""" @@ -1206,7 +1345,10 @@ def test_v15_mixed_intra_and_inter_assistant_messages(self, normalizer_v15: Inst reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assert parsed.messages[1] == AssistantMessage(content="AB\n\nCD") + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="Hello"), AssistantMessage(content="AB\n\nCD")], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_v15_tool_message_text_chunks_joined_without_separator( self, normalizer_v15: InstructRequestNormalizerV15 @@ -1221,7 +1363,17 @@ def test_v15_tool_message_text_chunks_joined_without_separator( reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assert parsed.messages[2] == ToolMessage(content="XY", tool_call_id="c1") + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage( + content="", + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")], + ), + ToolMessage(content="XY", tool_call_id="c1"), + ], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_accepts_text_and_think_chunks(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 normalizer accepts TextChunk and ThinkChunk in assistant messages.""" @@ -1233,8 +1385,12 @@ def test_accepts_text_and_think_chunks(self, normalizer_v15: InstructRequestNorm reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assert parsed.messages[1] == AssistantMessage( - content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")] + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), + ], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), ) def test_accepts_string_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None: @@ -1247,7 +1403,10 @@ def test_accepts_string_content(self, normalizer_v15: InstructRequestNormalizerV reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assert parsed.messages[1] == AssistantMessage(content="plain text") + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="query"), AssistantMessage(content="plain text")], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 normalizer preserves non-text chunks in tool messages.""" @@ -1261,7 +1420,17 @@ def test_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNo reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assert parsed.messages[2] == ToolMessage(content=[image_chunk], tool_call_id="c1") + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage( + content="", + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")], + ), + ToolMessage(content=[image_chunk], tool_call_id="c1"), + ], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 normalizer sorts multimodal tool messages by tool call order.""" @@ -1282,9 +1451,21 @@ def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNor reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - - assert parsed.messages[2] == ToolMessage(content=[image_chunk_1], tool_call_id="c1") - assert parsed.messages[3] == ToolMessage(content=[image_chunk_2], tool_call_id="c2") + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage( + content="", + tool_calls=[ + ToolCall(function=FunctionCall(name="fn1", arguments="{}"), id="c1"), + ToolCall(function=FunctionCall(name="fn2", arguments="{}"), id="c2"), + ], + ), + ToolMessage(content=[image_chunk_1], tool_call_id="c1"), + ToolMessage(content=[image_chunk_2], tool_call_id="c2"), + ], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_rejects_think_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 normalizer rejects ThinkChunk in system messages.""" @@ -1307,11 +1488,14 @@ def test_preserves_audio_in_system_message(self, normalizer_v15: InstructRequest reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assert parsed.messages[0] == SystemMessage( - content=[ - TextChunk(text="hello"), - AudioChunk(input_audio=b"fake_audio_data"), - ] + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + SystemMessage( + content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")], + ), + UserMessage(content="test"), + ], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), ) From 75f78e95add60c2f6e8a53ecc132a26e064d6c4d Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 12 Jun 2026 18:06:29 +0200 Subject: [PATCH 34/47] Remove redundant empty reasoning-effort branch in test helper --- tests/test_tokenizer_v15.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_tokenizer_v15.py b/tests/test_tokenizer_v15.py index 968128f9..2c882cac 100644 --- a/tests/test_tokenizer_v15.py +++ b/tests/test_tokenizer_v15.py @@ -152,10 +152,6 @@ def _build_model_settings_builder( """ if allowed_reasoning_effort is None: return ModelSettingsBuilder.none() - if not allowed_reasoning_effort: - return ModelSettingsBuilder( - reasoning_effort=EnumBuilder[ReasoningEffort](values=[], accepts_none=True, default=None) - ) return ModelSettingsBuilder( reasoning_effort=EnumBuilder[ReasoningEffort]( values=[ReasoningEffort(v) for v in allowed_reasoning_effort], From 093f44eb530156747af26492b016b20ed544e89b Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 12 Jun 2026 18:18:28 +0200 Subject: [PATCH 35/47] Add content-chunk type validation to base request validator --- .../protocol/instruct/validator.py | 49 ++++++++++++- tests/validation/test_chat_validation.py | 70 ++++++++++++++++++- 2 files changed, 116 insertions(+), 3 deletions(-) diff --git a/src/mistral_common/protocol/instruct/validator.py b/src/mistral_common/protocol/instruct/validator.py index f121e49e..5902422c 100644 --- a/src/mistral_common/protocol/instruct/validator.py +++ b/src/mistral_common/protocol/instruct/validator.py @@ -1,4 +1,5 @@ import re +from collections.abc import Sequence from enum import Enum from typing import Generic @@ -14,8 +15,15 @@ InvalidToolException, InvalidToolMessageException, InvalidToolSchemaException, + InvalidUserMessageException, + MistralCommonException, +) +from mistral_common.protocol.instruct.chunk import ( + AudioChunk, + AudioURLChunk, + ContentChunk, + TextChunk, ) -from mistral_common.protocol.instruct.chunk import AudioChunk, AudioURLChunk from mistral_common.protocol.instruct.messages import ( UATS, AssistantMessage, @@ -38,6 +46,23 @@ from mistral_common.tokens.tokenizers.base import TokenizerVersion +def _validate_content_chunk_types( + content: "str | Sequence[ContentChunk] | None", + allowed: tuple[type, ...], + role: str, + exception_cls: type[MistralCommonException], +) -> None: + r"""Raise if any chunk in a list content is not an instance of `allowed`. + + String or None content is always accepted (covered elsewhere). + """ + if not isinstance(content, list): + return + invalid = sorted({type(chunk).__name__ for chunk in content if not isinstance(chunk, allowed)}) + if invalid: + raise exception_cls(f"Unexpected content chunk types in {role} message: {invalid}") + + class ValidationMode(str, Enum): r"""Enum for the validation mode. @@ -153,13 +178,30 @@ def _validate_tools(self, tools: list[Tool]) -> None: self._validate_function(tool.function) def _validate_user_message(self, message: UserMessageType) -> None: - pass + self._validate_user_content_chunks(message.content) + + def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + r"""v1/v2 user messages accept text content only (image >= v3, audio >= v7).""" + _validate_content_chunk_types(content, (TextChunk,), "user", InvalidUserMessageException) + + def _validate_assistant_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + r"""Pre-v11 assistant messages accept text content only.""" + _validate_content_chunk_types(content, (TextChunk,), "assistant", InvalidAssistantMessageException) + + def _validate_system_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + r"""v1-v3 system messages accept text content only.""" + _validate_content_chunk_types(content, (TextChunk,), "system", InvalidSystemPromptException) + + def _validate_tool_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + r"""Pre-v15 tool messages accept text content only.""" + _validate_content_chunk_types(content, (TextChunk,), "tool", InvalidToolMessageException) def _validate_tool_message(self, message: ToolMessageType) -> None: """ Checks: - The tool name is valid """ + self._validate_tool_content_chunks(message.content) if message.name is not None: if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", message.name): raise InvalidToolMessageException( @@ -174,6 +216,7 @@ def _validate_system_message(self, message: SystemMessageType) -> None: """ if message.content is None: raise InvalidSystemPromptException("System prompt must have content") + self._validate_system_content_chunks(message.content) def _validate_function_call(self, function_call: FunctionCall) -> None: """ @@ -201,6 +244,8 @@ def _validate_assistant_message(self, message: AssistantMessageType, is_last_mes - That the tool calls are valid """ + self._validate_assistant_content_chunks(message.content) + # Validate that the message has either text or tool_calls # but not both and not neither. if (not self._allow_tool_call_and_content) and (bool(message.content) == bool(message.tool_calls)): diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py index f4c8653c..d5161d8a 100644 --- a/tests/validation/test_chat_validation.py +++ b/tests/validation/test_chat_validation.py @@ -4,8 +4,16 @@ InvalidAssistantMessageException, InvalidMessageStructureException, InvalidRequestException, + InvalidSystemPromptException, + InvalidToolMessageException, + InvalidUserMessageException, +) +from mistral_common.protocol.instruct.chunk import ( + AudioChunk, + AudioURLChunk, + TextChunk, + ThinkChunk, ) -from mistral_common.protocol.instruct.chunk import AudioChunk, AudioURLChunk from mistral_common.protocol.instruct.messages import ( AssistantMessage, SystemMessage, @@ -621,3 +629,63 @@ def test_build_settings_v15_reasoning_effort( ) -> None: request = ChatCompletionRequest(messages=[UserMessage(content="Hello")], reasoning_effort=reasoning_effort) validator_v15._validate_model_settings(request) + + +class TestBaseValidatorContentChunks: + @pytest.fixture + def base_validator(self) -> MistralRequestValidator: + return MistralRequestValidator(ValidationMode.serving) + + def test_rejects_think_in_assistant(self, base_validator: MistralRequestValidator) -> None: + messages = [ + UserMessage(content="hi"), + AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]), + UserMessage(content="hi again"), + ] + with pytest.raises( + InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message" + ): + base_validator.validate_messages(messages, continue_final_message=False) + + def test_rejects_audio_in_system(self, base_validator: MistralRequestValidator) -> None: + messages = [ + SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]), + UserMessage(content="hi"), + ] + with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"): + base_validator.validate_messages(messages, continue_final_message=False) + + def test_rejects_image_in_tool(self, base_validator: MistralRequestValidator) -> None: + from mistral_common.protocol.instruct.chunk import ImageURLChunk + + messages = [ + UserMessage(content="hi"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), + ToolMessage( + content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], + tool_call_id="test12345", + ), + ] + with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"): + base_validator.validate_messages(messages, continue_final_message=False) + + def test_rejects_image_in_user(self, base_validator: MistralRequestValidator) -> None: + from mistral_common.protocol.instruct.chunk import ImageURLChunk + + messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])] + with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): + base_validator.validate_messages(messages, continue_final_message=False) + + def test_rejects_audio_in_user(self, base_validator: MistralRequestValidator) -> None: + messages = [UserMessage(content=[AudioChunk(input_audio=b"fake")])] + with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): + base_validator.validate_messages(messages, continue_final_message=False) + + def test_accepts_multiple_text_chunks(self, base_validator: MistralRequestValidator) -> None: + messages = [ + SystemMessage(content=[TextChunk(text="a"), TextChunk(text="b")]), + UserMessage(content="hi"), + AssistantMessage(content=[TextChunk(text="c"), TextChunk(text="d")]), + UserMessage(content="hi again"), + ] + base_validator.validate_messages(messages, continue_final_message=False) From 091e1f41bc8406d72f294d7fde0fab285ff7acea Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 12 Jun 2026 18:21:47 +0200 Subject: [PATCH 36/47] Add per-version content-chunk validation overrides --- .../protocol/instruct/validator.py | 37 ++++++++ tests/validation/test_chat_validation.py | 93 +++++++++++++++++++ 2 files changed, 130 insertions(+) diff --git a/src/mistral_common/protocol/instruct/validator.py b/src/mistral_common/protocol/instruct/validator.py index 5902422c..6a271aa7 100644 --- a/src/mistral_common/protocol/instruct/validator.py +++ b/src/mistral_common/protocol/instruct/validator.py @@ -22,7 +22,10 @@ AudioChunk, AudioURLChunk, ContentChunk, + ImageChunk, + ImageURLChunk, TextChunk, + ThinkChunk, ) from mistral_common.protocol.instruct.messages import ( UATS, @@ -404,12 +407,19 @@ class MistralRequestValidatorV3(MistralRequestValidator): >>> validator = MistralRequestValidatorV3() """ + def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + r"""v3 user messages accept text and image chunks (audio >= v7).""" + _validate_content_chunk_types( + content, (TextChunk, ImageChunk, ImageURLChunk), "user", InvalidUserMessageException + ) + def _validate_tool_message(self, message: ToolMessageType) -> None: """ Checks: - The tool name is valid - Tool call id is valid """ + self._validate_tool_content_chunks(message.content) if message.name is not None: if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", message.name): raise InvalidToolMessageException( @@ -468,6 +478,21 @@ class MistralRequestValidatorV5(MistralRequestValidatorV3): _allow_tool_call_and_content: bool = True + def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + r"""v7+ user messages accept text, image and audio chunks.""" + _validate_content_chunk_types( + content, + (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk), + "user", + InvalidUserMessageException, + ) + + def _validate_system_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + r"""v7+ system messages accept text, audio and thinking chunks.""" + _validate_content_chunk_types( + content, (TextChunk, AudioChunk, ThinkChunk), "system", InvalidSystemPromptException + ) + def _validate_system_prompt_and_audio(self, messages: list[UATS]) -> None: r"""Validates that system prompts and audio chunks are not used together in v5.""" @@ -552,12 +577,24 @@ def _validate_tool_calls_followed_by_tool_messages(self, messages: list[UATS]) - elif len(expected_tool_ids) < len(observed_tool_ids) and self._mode == ValidationMode.finetuning: raise InvalidMessageStructureException("More tool responses than tool calls") + def _validate_assistant_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + r"""v11+ assistant messages accept text and thinking chunks.""" + _validate_content_chunk_types(content, (TextChunk, ThinkChunk), "assistant", InvalidAssistantMessageException) + def _validate_system_prompt_and_audio(self, messages: list[UATS]) -> None: r"""Allows system prompts and audio chunks to coexist in v13.""" return class MistralRequestValidatorV15(MistralRequestValidatorV13): + def _validate_system_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + r"""v15 system messages accept text and audio but reject thinking chunks.""" + _validate_content_chunk_types(content, (TextChunk, AudioChunk), "system", InvalidSystemPromptException) + + def _validate_tool_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + r"""v15 tool messages accept all content chunk types.""" + return + def _validate_model_settings(self, request: ChatCompletionRequest) -> None: pass diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py index d5161d8a..17f88e5b 100644 --- a/tests/validation/test_chat_validation.py +++ b/tests/validation/test_chat_validation.py @@ -689,3 +689,96 @@ def test_accepts_multiple_text_chunks(self, base_validator: MistralRequestValida UserMessage(content="hi again"), ] base_validator.validate_messages(messages, continue_final_message=False) + + +class TestVersionedValidatorContentChunks: + def test_v3_allows_image_rejects_audio_in_user(self) -> None: + from mistral_common.protocol.instruct.chunk import ImageURLChunk + + validator = MistralRequestValidatorV3(ValidationMode.serving) + ok = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])] + validator.validate_messages(ok, continue_final_message=False) + bad = [UserMessage(content=[AudioChunk(input_audio=b"fake")])] + with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): + validator.validate_messages(bad, continue_final_message=False) + + def test_v5_allows_image_and_audio_in_user(self) -> None: + from mistral_common.protocol.instruct.chunk import ImageURLChunk + + validator = MistralRequestValidatorV5(ValidationMode.serving) + messages = [ + UserMessage( + content=[ + ImageURLChunk(image_url="data:image/png;base64,iVBORw0"), + AudioChunk(input_audio=b"fake"), + ] + ) + ] + validator.validate_messages(messages, continue_final_message=False) + + def test_v5_allows_audio_and_think_in_system(self) -> None: + validator = MistralRequestValidatorV5(ValidationMode.serving) + messages = [ + SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake"), ThinkChunk(thinking="t")]), + UserMessage(content="hi"), + ] + validator.validate_messages(messages, continue_final_message=False) + + def test_v5_still_rejects_think_in_assistant(self) -> None: + validator = MistralRequestValidatorV5(ValidationMode.serving) + messages = [ + UserMessage(content="hi"), + AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]), + UserMessage(content="next"), + ] + with pytest.raises(InvalidAssistantMessageException, match="Unexpected content chunk types"): + validator.validate_messages(messages, continue_final_message=False) + + def test_v13_allows_think_in_assistant(self) -> None: + validator = MistralRequestValidatorV13(ValidationMode.serving) + messages = [ + UserMessage(content="hi"), + AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]), + UserMessage(content="next"), + ] + validator.validate_messages(messages, continue_final_message=False) + + def test_v13_still_rejects_image_in_tool(self) -> None: + from mistral_common.protocol.instruct.chunk import ImageURLChunk + + validator = MistralRequestValidatorV13(ValidationMode.serving) + messages = [ + UserMessage(content="hi"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), + ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="test12345"), + ] + with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"): + validator.validate_messages(messages, continue_final_message=False) + + def test_v15_rejects_think_in_system(self) -> None: + validator = MistralRequestValidatorV15(ValidationMode.serving) + messages = [ + SystemMessage(content=[TextChunk(text="x"), ThinkChunk(thinking="t")]), + UserMessage(content="hi"), + ] + with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"): + validator.validate_messages(messages, continue_final_message=False) + + def test_v15_allows_audio_in_system(self) -> None: + validator = MistralRequestValidatorV15(ValidationMode.serving) + messages = [ + SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]), + UserMessage(content="hi"), + ] + validator.validate_messages(messages, continue_final_message=False) + + def test_v15_allows_image_in_tool(self) -> None: + from mistral_common.protocol.instruct.chunk import ImageURLChunk + + validator = MistralRequestValidatorV15(ValidationMode.serving) + messages = [ + UserMessage(content="hi"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), + ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="test12345"), + ] + validator.validate_messages(messages, continue_final_message=False) From b21094399a7ddc1809b3c4fdcdab0551c19f9bd9 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 12 Jun 2026 18:31:14 +0200 Subject: [PATCH 37/47] Clean up content-chunk validator hooks per review --- .../protocol/instruct/validator.py | 32 ++++++++----------- tests/validation/test_chat_validation.py | 32 +++++++++++-------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/src/mistral_common/protocol/instruct/validator.py b/src/mistral_common/protocol/instruct/validator.py index 6a271aa7..e15d0935 100644 --- a/src/mistral_common/protocol/instruct/validator.py +++ b/src/mistral_common/protocol/instruct/validator.py @@ -50,8 +50,8 @@ def _validate_content_chunk_types( - content: "str | Sequence[ContentChunk] | None", - allowed: tuple[type, ...], + content: str | Sequence[ContentChunk] | None, + allowed: tuple[type[ContentChunk], ...], role: str, exception_cls: type[MistralCommonException], ) -> None: @@ -183,19 +183,19 @@ def _validate_tools(self, tools: list[Tool]) -> None: def _validate_user_message(self, message: UserMessageType) -> None: self._validate_user_content_chunks(message.content) - def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + def _validate_user_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: r"""v1/v2 user messages accept text content only (image >= v3, audio >= v7).""" _validate_content_chunk_types(content, (TextChunk,), "user", InvalidUserMessageException) - def _validate_assistant_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + def _validate_assistant_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: r"""Pre-v11 assistant messages accept text content only.""" _validate_content_chunk_types(content, (TextChunk,), "assistant", InvalidAssistantMessageException) - def _validate_system_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + def _validate_system_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: r"""v1-v3 system messages accept text content only.""" _validate_content_chunk_types(content, (TextChunk,), "system", InvalidSystemPromptException) - def _validate_tool_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + def _validate_tool_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: r"""Pre-v15 tool messages accept text content only.""" _validate_content_chunk_types(content, (TextChunk,), "tool", InvalidToolMessageException) @@ -407,7 +407,7 @@ class MistralRequestValidatorV3(MistralRequestValidator): >>> validator = MistralRequestValidatorV3() """ - def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + def _validate_user_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: r"""v3 user messages accept text and image chunks (audio >= v7).""" _validate_content_chunk_types( content, (TextChunk, ImageChunk, ImageURLChunk), "user", InvalidUserMessageException @@ -419,13 +419,7 @@ def _validate_tool_message(self, message: ToolMessageType) -> None: - The tool name is valid - Tool call id is valid """ - self._validate_tool_content_chunks(message.content) - if message.name is not None: - if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", message.name): - raise InvalidToolMessageException( - f"Function name was {message.name} but must be a-z, A-Z, 0-9, " - "or contain underscores and dashes, with a maximum length of 64." - ) + super()._validate_tool_message(message) if message.tool_call_id is None: raise InvalidRequestException("Tool call id has to be defined.") @@ -478,7 +472,7 @@ class MistralRequestValidatorV5(MistralRequestValidatorV3): _allow_tool_call_and_content: bool = True - def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + def _validate_user_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: r"""v7+ user messages accept text, image and audio chunks.""" _validate_content_chunk_types( content, @@ -487,7 +481,7 @@ def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] | InvalidUserMessageException, ) - def _validate_system_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + def _validate_system_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: r"""v7+ system messages accept text, audio and thinking chunks.""" _validate_content_chunk_types( content, (TextChunk, AudioChunk, ThinkChunk), "system", InvalidSystemPromptException @@ -577,7 +571,7 @@ def _validate_tool_calls_followed_by_tool_messages(self, messages: list[UATS]) - elif len(expected_tool_ids) < len(observed_tool_ids) and self._mode == ValidationMode.finetuning: raise InvalidMessageStructureException("More tool responses than tool calls") - def _validate_assistant_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + def _validate_assistant_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: r"""v11+ assistant messages accept text and thinking chunks.""" _validate_content_chunk_types(content, (TextChunk, ThinkChunk), "assistant", InvalidAssistantMessageException) @@ -587,11 +581,11 @@ def _validate_system_prompt_and_audio(self, messages: list[UATS]) -> None: class MistralRequestValidatorV15(MistralRequestValidatorV13): - def _validate_system_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + def _validate_system_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: r"""v15 system messages accept text and audio but reject thinking chunks.""" _validate_content_chunk_types(content, (TextChunk, AudioChunk), "system", InvalidSystemPromptException) - def _validate_tool_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None: + def _validate_tool_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: r"""v15 tool messages accept all content chunk types.""" return diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py index 17f88e5b..74a5a01a 100644 --- a/tests/validation/test_chat_validation.py +++ b/tests/validation/test_chat_validation.py @@ -11,6 +11,7 @@ from mistral_common.protocol.instruct.chunk import ( AudioChunk, AudioURLChunk, + ImageURLChunk, TextChunk, ThinkChunk, ) @@ -656,8 +657,6 @@ def test_rejects_audio_in_system(self, base_validator: MistralRequestValidator) base_validator.validate_messages(messages, continue_final_message=False) def test_rejects_image_in_tool(self, base_validator: MistralRequestValidator) -> None: - from mistral_common.protocol.instruct.chunk import ImageURLChunk - messages = [ UserMessage(content="hi"), AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), @@ -670,8 +669,6 @@ def test_rejects_image_in_tool(self, base_validator: MistralRequestValidator) -> base_validator.validate_messages(messages, continue_final_message=False) def test_rejects_image_in_user(self, base_validator: MistralRequestValidator) -> None: - from mistral_common.protocol.instruct.chunk import ImageURLChunk - messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])] with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): base_validator.validate_messages(messages, continue_final_message=False) @@ -681,6 +678,21 @@ def test_rejects_audio_in_user(self, base_validator: MistralRequestValidator) -> with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): base_validator.validate_messages(messages, continue_final_message=False) + def test_reports_sorted_unique_invalid_types(self, base_validator: MistralRequestValidator) -> None: + messages = [ + UserMessage( + content=[ + AudioChunk(input_audio=b"fake"), + ImageURLChunk(image_url="data:image/png;base64,iVBORw0"), + ] + ) + ] + with pytest.raises( + InvalidUserMessageException, + match=r"Unexpected content chunk types in user message: \['AudioChunk', 'ImageURLChunk'\]", + ): + base_validator.validate_messages(messages, continue_final_message=False) + def test_accepts_multiple_text_chunks(self, base_validator: MistralRequestValidator) -> None: messages = [ SystemMessage(content=[TextChunk(text="a"), TextChunk(text="b")]), @@ -693,8 +705,6 @@ def test_accepts_multiple_text_chunks(self, base_validator: MistralRequestValida class TestVersionedValidatorContentChunks: def test_v3_allows_image_rejects_audio_in_user(self) -> None: - from mistral_common.protocol.instruct.chunk import ImageURLChunk - validator = MistralRequestValidatorV3(ValidationMode.serving) ok = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])] validator.validate_messages(ok, continue_final_message=False) @@ -703,8 +713,6 @@ def test_v3_allows_image_rejects_audio_in_user(self) -> None: validator.validate_messages(bad, continue_final_message=False) def test_v5_allows_image_and_audio_in_user(self) -> None: - from mistral_common.protocol.instruct.chunk import ImageURLChunk - validator = MistralRequestValidatorV5(ValidationMode.serving) messages = [ UserMessage( @@ -731,7 +739,9 @@ def test_v5_still_rejects_think_in_assistant(self) -> None: AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]), UserMessage(content="next"), ] - with pytest.raises(InvalidAssistantMessageException, match="Unexpected content chunk types"): + with pytest.raises( + InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message" + ): validator.validate_messages(messages, continue_final_message=False) def test_v13_allows_think_in_assistant(self) -> None: @@ -744,8 +754,6 @@ def test_v13_allows_think_in_assistant(self) -> None: validator.validate_messages(messages, continue_final_message=False) def test_v13_still_rejects_image_in_tool(self) -> None: - from mistral_common.protocol.instruct.chunk import ImageURLChunk - validator = MistralRequestValidatorV13(ValidationMode.serving) messages = [ UserMessage(content="hi"), @@ -773,8 +781,6 @@ def test_v15_allows_audio_in_system(self) -> None: validator.validate_messages(messages, continue_final_message=False) def test_v15_allows_image_in_tool(self) -> None: - from mistral_common.protocol.instruct.chunk import ImageURLChunk - validator = MistralRequestValidatorV15(ValidationMode.serving) messages = [ UserMessage(content="hi"), From 43b15e66ef22e2caac6a284282fdf15dbbbcc7eb Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 12 Jun 2026 18:33:54 +0200 Subject: [PATCH 38/47] Move content-chunk validation out of the normalizer --- .../protocol/instruct/normalize.py | 114 ++---------------- 1 file changed, 8 insertions(+), 106 deletions(-) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index 8cd5f904..d6785b34 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -1,20 +1,15 @@ import json import warnings -from typing import Generic, Sequence +from typing import Generic, Sequence, cast -from typing_extensions import TypeGuard, assert_never +from typing_extensions import assert_never from mistral_common.exceptions import InvalidRequestException from mistral_common.protocol.instruct.chunk import ( AssistantContentChunk, - AudioChunk, - AudioURLChunk, ContentChunk, - ImageChunk, - ImageURLChunk, SystemContentChunk, TextChunk, - ThinkChunk, UserContentChunk, ) from mistral_common.protocol.instruct.messages import ( @@ -38,27 +33,6 @@ _DEFAULT_JOIN_STR = "\n\n" -def _is_user_content( - chunks: list[ContentChunk], -) -> TypeGuard[list[UserContentChunk]]: - r"""Narrow ContentChunk list to user-compatible types.""" - return all(isinstance(c, (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk)) for c in chunks) - - -def _is_assistant_content( - chunks: list[ContentChunk], -) -> TypeGuard[list[AssistantContentChunk]]: - r"""Narrow ContentChunk list to assistant-compatible types.""" - return all(isinstance(c, (TextChunk, ThinkChunk)) for c in chunks) - - -def _is_system_content( - chunks: list[ContentChunk], -) -> TypeGuard[list[SystemContentChunk]]: - r"""Narrow ContentChunk list to system-compatible types.""" - return all(isinstance(c, (TextChunk, AudioChunk, ThinkChunk)) for c in chunks) - - def _aggregate_content_chunks_impl( contents: list[list[ContentChunk] | str | None], msg_join_str: str, @@ -262,10 +236,7 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s for message in messages: assert isinstance(message, self._tool_message_class), "Expected tool message" content = self._aggregate_content_chunks([message]) - if not isinstance(content, str): - raise InvalidRequestException( - f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}" - ) + assert isinstance(content, str), f"Unexpected content chunk types in tool message: {content}" normalized_content = self._normalize_json_content(content) tool_messages.append( @@ -283,27 +254,6 @@ def _normalize_tool_call(self, tool_call: ToolCall) -> ToolCall: id=tool_call.id, ) - def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[AssistantContentChunk]: - r"""Validate content chunks for assistant messages. - - Pre-v11 only allows text content in assistant messages. - V11+ overrides to also accept ThinkChunk. - - Args: - content: The aggregated content chunks. - - Returns: - The validated content. - - Raises: - InvalidRequestException: If unsupported chunk types are found. - """ - if isinstance(content, str): - return content - raise InvalidRequestException( - f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" - ) - def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]: return [] @@ -334,10 +284,8 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag ) weight = message.weight - validated_content = self._narrow_assistant_content(content) - aggregated_message = self._assistant_message_class( - content=validated_content, + content=cast("str | list[AssistantContentChunk]", content), tool_calls=tool_calls or None, prefix=prefix, ) @@ -349,11 +297,7 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType: """Coalesce neighboring blocks of ContentChunks in user messages.""" content = self._aggregate_content_chunks(messages) - if isinstance(content, str) or _is_user_content(content): - return self._user_message_class(content=content) - raise InvalidRequestException( - f"Unexpected content chunk types in user message: {[type(c).__name__ for c in content]}" - ) + return self._user_message_class(content=cast("str | list[UserContentChunk]", content)) def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]: if role == Roles.tool: @@ -476,27 +420,6 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest[UATS, Tool], None ) - def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]: - r"""Validate content chunks for system messages. - - V7+ accepts all SystemContentChunk types (text, audio, thinking). - V15 overrides to reject ThinkChunk. - - Args: - content: The aggregated content chunks. - - Returns: - The validated content. - - Raises: - InvalidRequestException: If unsupported chunk types are found. - """ - if isinstance(content, str) or _is_system_content(content): - return content - raise InvalidRequestException( - f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}" - ) - def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: """Normalize tool messages without JSON normalization. @@ -507,10 +430,7 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s for message in messages: assert isinstance(message, self._tool_message_class), "Expected tool message" content = self._aggregate_content_chunks([message]) - if not isinstance(content, str): - raise InvalidRequestException( - f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}" - ) + assert isinstance(content, str), f"Unexpected content chunk types in tool message: {content}" tool_messages.append( self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name) ) @@ -521,8 +441,7 @@ def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessage for message in messages: if isinstance(message, self._system_message_class): content = self._aggregate_content_chunks([message]) - validated = self._narrow_system_content(content) - aggregated.append(self._system_message_class(content=validated)) + aggregated.append(self._system_message_class(content=cast("str | list[SystemContentChunk]", content))) return aggregated def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]: @@ -616,14 +535,6 @@ def _inplace_sort_tool_messages(tool_messages: list[ToolMessageType], latest_cal ), ) - def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[AssistantContentChunk]: - r"""V11+ accepts TextChunk and ThinkChunk in assistant messages.""" - if isinstance(content, str) or _is_assistant_content(content): - return content - raise InvalidRequestException( - f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" - ) - def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: tool_messages: list[ToolMessageType] = super()._aggregate_tool_messages(messages, latest_call_ids) self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids) @@ -642,7 +553,7 @@ class InstructRequestNormalizerV15(InstructRequestNormalizerV13): _chunk_join_str: str = "" def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: - r"""V15 accepts all ContentChunk types in tool messages.""" + r"""V15 keeps all aggregated tool content (validation handled by the validator).""" tool_messages: list[ToolMessageType] = [] for message in messages: assert isinstance(message, self._tool_message_class), "Expected tool message" @@ -653,15 +564,6 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids) return tool_messages - def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]: - r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk.""" - validated = super()._narrow_system_content(content) - if isinstance(validated, str): - return validated - if any(isinstance(c, ThinkChunk) for c in validated): - raise InvalidRequestException("ThinkChunk in system message is not supported for V15") - return validated - @staticmethod def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15": r"""Returns a normalizer for the V15 instruct request. From f995ddd6771ce494fb709e0601bd3b3bc79f01af Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 12 Jun 2026 18:36:37 +0200 Subject: [PATCH 39/47] Reconcile chunk-type tests across Pydantic, normalizer and validator layers --- tests/test_messages.py | 41 ++++++++++++++++++++++ tests/test_normalization.py | 68 +++++++------------------------------ 2 files changed, 54 insertions(+), 55 deletions(-) create mode 100644 tests/test_messages.py diff --git a/tests/test_messages.py b/tests/test_messages.py new file mode 100644 index 00000000..ca6f6486 --- /dev/null +++ b/tests/test_messages.py @@ -0,0 +1,41 @@ +import pytest +from pydantic import ValidationError + +from mistral_common.protocol.instruct.chunk import ( + AudioChunk, + ImageURLChunk, + TextChunk, + ThinkChunk, +) +from mistral_common.protocol.instruct.messages import ( + AssistantMessage, + SystemMessage, + UserMessage, +) + + +class TestMessageContentChunkUnions: + def test_assistant_rejects_image(self) -> None: + with pytest.raises(ValidationError): + AssistantMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")]) + + def test_assistant_rejects_audio(self) -> None: + with pytest.raises(ValidationError): + AssistantMessage(content=[AudioChunk(input_audio=b"fake")]) + + def test_assistant_accepts_text_and_think(self) -> None: + AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]) + + def test_system_rejects_image(self) -> None: + with pytest.raises(ValidationError): + SystemMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")]) + + def test_system_accepts_audio(self) -> None: + SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]) + + def test_user_rejects_think(self) -> None: + with pytest.raises(ValidationError): + UserMessage(content=[ThinkChunk(thinking="r")]) + + def test_user_accepts_image(self) -> None: + UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")]) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 26f7d8be..6fa9d19a 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -2,7 +2,6 @@ import pytest -from mistral_common.exceptions import InvalidRequestException from mistral_common.protocol.instruct.chunk import ( AudioChunk, ChunkTypes, @@ -453,16 +452,18 @@ def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalize ], ) - def test_rejects_think_in_assistant(self, normalizer: InstructRequestNormalizer) -> None: - r"""Pre-v11 normalizer rejects ThinkChunk in assistant messages.""" + def test_passes_think_in_assistant_through(self, normalizer: InstructRequestNormalizer) -> None: + r"""Normalizer no longer gates version rules; the validator rejects pre-v11 think.""" request = mock_chat_completion( messages=[ UserMessage(content="query"), AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), ] ) - with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in assistant message"): - normalizer.from_chat_completion_request(request) + parsed = normalizer.from_chat_completion_request(request) + assert parsed.messages[-1] == AssistantMessage( + content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")] + ) class TestChatCompletionRequestNormalizationV7: @@ -661,16 +662,18 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma messages=[AssistantMessage(content="A\n\nB\n\nC\n\nD")], ) - def test_rejects_think_in_assistant(self, normalizer_v7: InstructRequestNormalizerV7) -> None: - r"""V7 normalizer rejects ThinkChunk in assistant messages (pre-v11).""" + def test_passes_think_in_assistant_through(self, normalizer_v7: InstructRequestNormalizerV7) -> None: + r"""Normalizer no longer gates version rules; the validator rejects pre-v11 think.""" request = mock_chat_completion( messages=[ UserMessage(content="query"), AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), - ], + ] + ) + parsed = normalizer_v7.from_chat_completion_request(request) + assert parsed.messages[-1] == AssistantMessage( + content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")] ) - with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in assistant message"): - normalizer_v7.from_chat_completion_request(request) def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None: r"""V7 normalizer accepts string content in assistant messages.""" @@ -1185,23 +1188,6 @@ def test_assistant_think_chunk_intra_message_aggregation( ], ) - def test_rejects_non_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: - r"""V13 normalizer raises InvalidRequestException for non-text tool content.""" - request = mock_chat_completion( - messages=[ - UserMessage(content="hi"), - AssistantMessage( - tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")], - ), - ToolMessage( - content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], - tool_call_id="test12345", - ), - ] - ) - with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"): - normalizer_v13.from_chat_completion_request(request) - def test_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: r"""V13 normalizer aggregates TextChunks in tool messages to a string.""" request = mock_chat_completion( @@ -1223,23 +1209,6 @@ def test_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNorma ], ) - def test_rejects_audio_in_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: - r"""V13 normalizer raises InvalidRequestException for audio tool content.""" - request = mock_chat_completion( - messages=[ - UserMessage(content="hi"), - AssistantMessage( - tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")], - ), - ToolMessage( - content=[AudioChunk(input_audio=b"fake_audio_data")], - tool_call_id="test12345", - ), - ] - ) - with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"): - normalizer_v13.from_chat_completion_request(request) - class TestChatCompletionRequestNormalizationV15: @pytest.fixture(autouse=True) @@ -1467,17 +1436,6 @@ def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNor settings=ModelSettings(reasoning_effort=ReasoningEffort.high), ) - def test_rejects_think_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None: - r"""V15 normalizer rejects ThinkChunk in system messages.""" - request = mock_chat_completion( - messages=[ - SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]), - UserMessage(content="test"), - ] - ) - with pytest.raises(InvalidRequestException, match="ThinkChunk"): - normalizer_v15.from_chat_completion_request(request) - def test_preserves_audio_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 normalizer preserves AudioChunk in system messages.""" request = ChatCompletionRequest[ChatMessage]( From a14f8551a48e960f4a105d1b51c22ab9e86b6886 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 12 Jun 2026 18:42:59 +0200 Subject: [PATCH 40/47] Polish normalizer docstrings and tighten chunk-type tests per review --- src/mistral_common/protocol/instruct/normalize.py | 14 +++++++++----- tests/test_messages.py | 4 ++++ tests/test_normalization.py | 14 ++++++++++---- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index d6785b34..988e25ec 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -230,13 +230,15 @@ def _aggregate_system_prompts(self, messages: list[UATS]) -> str | None: def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: """Normalize tool messages without aggregation across messages. - Each tool message's content is validated and JSON-normalized. + Each tool message's content is JSON-normalized; chunk types are guaranteed by the validator. """ tool_messages: list[ToolMessageType] = [] for message in messages: assert isinstance(message, self._tool_message_class), "Expected tool message" content = self._aggregate_content_chunks([message]) - assert isinstance(content, str), f"Unexpected content chunk types in tool message: {content}" + assert isinstance(content, str), ( + f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}" + ) normalized_content = self._normalize_json_content(content) tool_messages.append( @@ -423,14 +425,16 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: """Normalize tool messages without JSON normalization. - V7+ normalizers skip JSON content normalization for tool messages but still - reject non-text content chunks. + V7+ normalizers skip JSON content normalization for tool messages (chunk-type validation is + handled by the validator). """ tool_messages: list[ToolMessageType] = [] for message in messages: assert isinstance(message, self._tool_message_class), "Expected tool message" content = self._aggregate_content_chunks([message]) - assert isinstance(content, str), f"Unexpected content chunk types in tool message: {content}" + assert isinstance(content, str), ( + f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}" + ) tool_messages.append( self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name) ) diff --git a/tests/test_messages.py b/tests/test_messages.py index ca6f6486..4b8a3f6c 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -10,6 +10,7 @@ from mistral_common.protocol.instruct.messages import ( AssistantMessage, SystemMessage, + ToolMessage, UserMessage, ) @@ -39,3 +40,6 @@ def test_user_rejects_think(self) -> None: def test_user_accepts_image(self) -> None: UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")]) + + def test_tool_accepts_arbitrary_chunks(self) -> None: + ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="c1") diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 6fa9d19a..2edad20c 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -461,8 +461,11 @@ def test_passes_think_in_assistant_through(self, normalizer: InstructRequestNorm ] ) parsed = normalizer.from_chat_completion_request(request) - assert parsed.messages[-1] == AssistantMessage( - content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")] + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), + ], ) @@ -671,8 +674,11 @@ def test_passes_think_in_assistant_through(self, normalizer_v7: InstructRequestN ] ) parsed = normalizer_v7.from_chat_completion_request(request) - assert parsed.messages[-1] == AssistantMessage( - content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")] + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), + ], ) def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None: From 0385e802c33eec2b3b8339a3cf6e405cb3aaa586 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 12 Jun 2026 19:25:38 +0200 Subject: [PATCH 41/47] Replace normalizer casts with content-chunk TypeGuards in chunk.py --- src/mistral_common/protocol/instruct/chunk.py | 55 ++++++++++++++++++- .../protocol/instruct/normalize.py | 17 +++--- 2 files changed, 64 insertions(+), 8 deletions(-) diff --git a/src/mistral_common/protocol/instruct/chunk.py b/src/mistral_common/protocol/instruct/chunk.py index 786c38e6..fda7b995 100644 --- a/src/mistral_common/protocol/instruct/chunk.py +++ b/src/mistral_common/protocol/instruct/chunk.py @@ -7,7 +7,7 @@ from urllib.parse import urlparse from pydantic import ConfigDict, Field, ValidationError, field_validator, model_validator -from typing_extensions import Annotated, TypeAlias +from typing_extensions import Annotated, TypeAlias, TypeGuard from mistral_common.base import MistralBase from mistral_common.deprecation import warn_once @@ -462,6 +462,59 @@ def from_openai(cls, openai_chunk: dict[str, Any]) -> "ThinkChunk": ToolContentChunk: TypeAlias = ContentChunk # Accepts all content chunk types (no restriction on tool messages). +def is_user_content(content: str | list[ContentChunk]) -> TypeGuard[str | list[UserContentChunk]]: + r"""Narrow aggregated content to the types allowed in a user message. + + String content is always accepted. List content is accepted only when every chunk is a + `TextChunk`, `ImageChunk`, `ImageURLChunk`, `AudioChunk` or `AudioURLChunk`. + + Args: + content: The content to narrow. + + Returns: + Whether the content only contains user-compatible chunk types. + """ + if isinstance(content, str): + return True + return all( + isinstance(chunk, (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk)) for chunk in content + ) + + +def is_assistant_content(content: str | list[ContentChunk]) -> TypeGuard[str | list[AssistantContentChunk]]: + r"""Narrow aggregated content to the types allowed in an assistant message. + + String content is always accepted. List content is accepted only when every chunk is a + `TextChunk` or `ThinkChunk`. + + Args: + content: The content to narrow. + + Returns: + Whether the content only contains assistant-compatible chunk types. + """ + if isinstance(content, str): + return True + return all(isinstance(chunk, (TextChunk, ThinkChunk)) for chunk in content) + + +def is_system_content(content: str | list[ContentChunk]) -> TypeGuard[str | list[SystemContentChunk]]: + r"""Narrow aggregated content to the types allowed in a system message. + + String content is always accepted. List content is accepted only when every chunk is a + `TextChunk`, `AudioChunk` or `ThinkChunk`. + + Args: + content: The content to narrow. + + Returns: + Whether the content only contains system-compatible chunk types. + """ + if isinstance(content, str): + return True + return all(isinstance(chunk, (TextChunk, AudioChunk, ThinkChunk)) for chunk in content) + + def _convert_openai_content_chunks(openai_content_chunks: dict[str, Any]) -> ContentChunk: content_type_str = openai_content_chunks.get("type") diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index 988e25ec..6ed7b736 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -1,16 +1,16 @@ import json import warnings -from typing import Generic, Sequence, cast +from typing import Generic, Sequence from typing_extensions import assert_never from mistral_common.exceptions import InvalidRequestException from mistral_common.protocol.instruct.chunk import ( - AssistantContentChunk, ContentChunk, - SystemContentChunk, TextChunk, - UserContentChunk, + is_assistant_content, + is_system_content, + is_user_content, ) from mistral_common.protocol.instruct.messages import ( UATS, @@ -286,8 +286,9 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag ) weight = message.weight + assert is_assistant_content(content) aggregated_message = self._assistant_message_class( - content=cast("str | list[AssistantContentChunk]", content), + content=content, tool_calls=tool_calls or None, prefix=prefix, ) @@ -299,7 +300,8 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType: """Coalesce neighboring blocks of ContentChunks in user messages.""" content = self._aggregate_content_chunks(messages) - return self._user_message_class(content=cast("str | list[UserContentChunk]", content)) + assert is_user_content(content) + return self._user_message_class(content=content) def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]: if role == Roles.tool: @@ -445,7 +447,8 @@ def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessage for message in messages: if isinstance(message, self._system_message_class): content = self._aggregate_content_chunks([message]) - aggregated.append(self._system_message_class(content=cast("str | list[SystemContentChunk]", content))) + assert is_system_content(content) + aggregated.append(self._system_message_class(content=content)) return aggregated def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]: From c36ddbffe254d9434fb1ee117e5ff87ec6f26580 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 12 Jun 2026 19:25:45 +0200 Subject: [PATCH 42/47] Reorganize chunk-type tests into per-version classes; drop normalizer validator-concern tests --- tests/test_normalization.py | 56 ----- tests/validation/test_chat_validation.py | 271 +++++++++++------------ 2 files changed, 134 insertions(+), 193 deletions(-) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 2edad20c..70145f16 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -406,30 +406,6 @@ def test_continue_final_message_forwarded(self, normalizer: InstructRequestNorma continue_final_message=True, ) - def test_rejects_audio_in_system_message(self, normalizer: InstructRequestNormalizer) -> None: - r"""Pre-V7 normalizer rejects AudioChunk in system messages.""" - request = mock_chat_completion( - messages=[ - SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]), - UserMessage(content="query"), - AssistantMessage(content="answer"), - ] - ) - with pytest.raises(AssertionError, match="AudioChunk"): - normalizer.from_chat_completion_request(request) - - def test_rejects_think_in_system_message(self, normalizer: InstructRequestNormalizer) -> None: - r"""Pre-V7 normalizer rejects ThinkChunk in system messages.""" - request = mock_chat_completion( - messages=[ - SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]), - UserMessage(content="query"), - AssistantMessage(content="answer"), - ] - ) - with pytest.raises(AssertionError, match="ThinkChunk"): - normalizer.from_chat_completion_request(request) - def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalizer) -> None: r"""Base normalizer (v1-v3) JSON-normalizes tool message content.""" messy_json = '{"key" : "value" , "num": 1}' @@ -452,22 +428,6 @@ def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalize ], ) - def test_passes_think_in_assistant_through(self, normalizer: InstructRequestNormalizer) -> None: - r"""Normalizer no longer gates version rules; the validator rejects pre-v11 think.""" - request = mock_chat_completion( - messages=[ - UserMessage(content="query"), - AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), - ] - ) - parsed = normalizer.from_chat_completion_request(request) - assert parsed == InstructRequest[ChatMessage, Tool]( - messages=[ - UserMessage(content="query"), - AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), - ], - ) - class TestChatCompletionRequestNormalizationV7: @pytest.fixture(autouse=True) @@ -665,22 +625,6 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma messages=[AssistantMessage(content="A\n\nB\n\nC\n\nD")], ) - def test_passes_think_in_assistant_through(self, normalizer_v7: InstructRequestNormalizerV7) -> None: - r"""Normalizer no longer gates version rules; the validator rejects pre-v11 think.""" - request = mock_chat_completion( - messages=[ - UserMessage(content="query"), - AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), - ] - ) - parsed = normalizer_v7.from_chat_completion_request(request) - assert parsed == InstructRequest[ChatMessage, Tool]( - messages=[ - UserMessage(content="query"), - AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), - ], - ) - def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None: r"""V7 normalizer accepts string content in assistant messages.""" request = mock_chat_completion( diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py index 74a5a01a..2f1913ca 100644 --- a/tests/validation/test_chat_validation.py +++ b/tests/validation/test_chat_validation.py @@ -58,6 +58,16 @@ def validator(request: pytest.FixtureRequest) -> MistralRequestValidator: return request.param # type: ignore +@pytest.fixture +def validator_base() -> MistralRequestValidator: + return MistralRequestValidator(ValidationMode.serving) + + +@pytest.fixture +def validator_v3() -> MistralRequestValidatorV3: + return MistralRequestValidatorV3(ValidationMode.serving) + + @pytest.fixture def validator_v5() -> MistralRequestValidatorV5: return MistralRequestValidatorV5(ValidationMode.serving) @@ -362,6 +372,82 @@ def test_build_settings_raises_error(self, validator: MistralRequestValidator) - with pytest.raises(InvalidRequestException, match="reasoning_effort='none' is not supported for this model"): validator._validate_model_settings(request) + def test_rejects_think_in_assistant_content_chunks(self, validator_base: MistralRequestValidator) -> None: + messages = [ + UserMessage(content="hi"), + AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]), + UserMessage(content="hi again"), + ] + with pytest.raises( + InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message" + ): + validator_base.validate_messages(messages, continue_final_message=False) + + def test_rejects_audio_in_system_content_chunks(self, validator_base: MistralRequestValidator) -> None: + messages = [ + SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]), + UserMessage(content="hi"), + ] + with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"): + validator_base.validate_messages(messages, continue_final_message=False) + + def test_rejects_image_in_tool_content_chunks(self, validator_base: MistralRequestValidator) -> None: + messages = [ + UserMessage(content="hi"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), + ToolMessage( + content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], + tool_call_id="test12345", + ), + ] + with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"): + validator_base.validate_messages(messages, continue_final_message=False) + + def test_rejects_image_in_user_content_chunks(self, validator_base: MistralRequestValidator) -> None: + messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])] + with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): + validator_base.validate_messages(messages, continue_final_message=False) + + def test_rejects_audio_in_user_content_chunks(self, validator_base: MistralRequestValidator) -> None: + messages = [UserMessage(content=[AudioChunk(input_audio=b"fake")])] + with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): + validator_base.validate_messages(messages, continue_final_message=False) + + def test_reports_sorted_unique_invalid_chunk_types(self, validator_base: MistralRequestValidator) -> None: + messages = [ + UserMessage( + content=[ + AudioChunk(input_audio=b"fake"), + ImageURLChunk(image_url="data:image/png;base64,iVBORw0"), + ] + ) + ] + with pytest.raises( + InvalidUserMessageException, + match=r"Unexpected content chunk types in user message: \['AudioChunk', 'ImageURLChunk'\]", + ): + validator_base.validate_messages(messages, continue_final_message=False) + + def test_accepts_multiple_text_content_chunks(self, validator_base: MistralRequestValidator) -> None: + messages = [ + SystemMessage(content=[TextChunk(text="a"), TextChunk(text="b")]), + UserMessage(content="hi"), + AssistantMessage(content=[TextChunk(text="c"), TextChunk(text="d")]), + UserMessage(content="hi again"), + ] + validator_base.validate_messages(messages, continue_final_message=False) + + +class TestChatValidationV3: + def test_allows_image_in_user_content_chunks(self, validator_v3: MistralRequestValidatorV3) -> None: + messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])] + validator_v3.validate_messages(messages, continue_final_message=False) + + def test_rejects_audio_in_user_content_chunks(self, validator_v3: MistralRequestValidatorV3) -> None: + messages = [UserMessage(content=[AudioChunk(input_audio=b"fake")])] + with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): + validator_v3.validate_messages(messages, continue_final_message=False) + class TestChatValidationV5: @pytest.mark.parametrize("audio_fixture", ["audio_chunk", "audio_url_chunk"]) @@ -412,6 +498,35 @@ def test_build_settings_raises_error(self, validator: MistralRequestValidator) - with pytest.raises(InvalidRequestException, match="reasoning_effort='none' is not supported for this model"): validator._validate_model_settings(request) + def test_allows_image_and_audio_in_user_content_chunks(self, validator_v5: MistralRequestValidatorV5) -> None: + messages = [ + UserMessage( + content=[ + ImageURLChunk(image_url="data:image/png;base64,iVBORw0"), + AudioChunk(input_audio=b"fake"), + ] + ) + ] + validator_v5.validate_messages(messages, continue_final_message=False) + + def test_allows_audio_and_think_in_system_content_chunks(self, validator_v5: MistralRequestValidatorV5) -> None: + messages = [ + SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake"), ThinkChunk(thinking="t")]), + UserMessage(content="hi"), + ] + validator_v5.validate_messages(messages, continue_final_message=False) + + def test_rejects_think_in_assistant_content_chunks(self, validator_v5: MistralRequestValidatorV5) -> None: + messages = [ + UserMessage(content="hi"), + AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]), + UserMessage(content="next"), + ] + with pytest.raises( + InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message" + ): + validator_v5.validate_messages(messages, continue_final_message=False) + class TestChatValidationV13: def test_right_number_results_invalid_id(self, validator_v13: MistralRequestValidatorV13) -> None: @@ -622,169 +737,51 @@ def test_audio_with_system_prompt_raises_ok( continue_final_message=False, ) - -class TestChatValidationV15: - @pytest.mark.parametrize("reasoning_effort", [*list(ReasoningEffort), None]) - def test_build_settings_v15_reasoning_effort( - self, reasoning_effort: ReasoningEffort | None, validator_v15: MistralRequestValidatorV15 - ) -> None: - request = ChatCompletionRequest(messages=[UserMessage(content="Hello")], reasoning_effort=reasoning_effort) - validator_v15._validate_model_settings(request) - - -class TestBaseValidatorContentChunks: - @pytest.fixture - def base_validator(self) -> MistralRequestValidator: - return MistralRequestValidator(ValidationMode.serving) - - def test_rejects_think_in_assistant(self, base_validator: MistralRequestValidator) -> None: - messages = [ - UserMessage(content="hi"), - AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]), - UserMessage(content="hi again"), - ] - with pytest.raises( - InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message" - ): - base_validator.validate_messages(messages, continue_final_message=False) - - def test_rejects_audio_in_system(self, base_validator: MistralRequestValidator) -> None: - messages = [ - SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]), - UserMessage(content="hi"), - ] - with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"): - base_validator.validate_messages(messages, continue_final_message=False) - - def test_rejects_image_in_tool(self, base_validator: MistralRequestValidator) -> None: - messages = [ - UserMessage(content="hi"), - AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), - ToolMessage( - content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], - tool_call_id="test12345", - ), - ] - with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"): - base_validator.validate_messages(messages, continue_final_message=False) - - def test_rejects_image_in_user(self, base_validator: MistralRequestValidator) -> None: - messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])] - with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): - base_validator.validate_messages(messages, continue_final_message=False) - - def test_rejects_audio_in_user(self, base_validator: MistralRequestValidator) -> None: - messages = [UserMessage(content=[AudioChunk(input_audio=b"fake")])] - with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): - base_validator.validate_messages(messages, continue_final_message=False) - - def test_reports_sorted_unique_invalid_types(self, base_validator: MistralRequestValidator) -> None: - messages = [ - UserMessage( - content=[ - AudioChunk(input_audio=b"fake"), - ImageURLChunk(image_url="data:image/png;base64,iVBORw0"), - ] - ) - ] - with pytest.raises( - InvalidUserMessageException, - match=r"Unexpected content chunk types in user message: \['AudioChunk', 'ImageURLChunk'\]", - ): - base_validator.validate_messages(messages, continue_final_message=False) - - def test_accepts_multiple_text_chunks(self, base_validator: MistralRequestValidator) -> None: - messages = [ - SystemMessage(content=[TextChunk(text="a"), TextChunk(text="b")]), - UserMessage(content="hi"), - AssistantMessage(content=[TextChunk(text="c"), TextChunk(text="d")]), - UserMessage(content="hi again"), - ] - base_validator.validate_messages(messages, continue_final_message=False) - - -class TestVersionedValidatorContentChunks: - def test_v3_allows_image_rejects_audio_in_user(self) -> None: - validator = MistralRequestValidatorV3(ValidationMode.serving) - ok = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])] - validator.validate_messages(ok, continue_final_message=False) - bad = [UserMessage(content=[AudioChunk(input_audio=b"fake")])] - with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): - validator.validate_messages(bad, continue_final_message=False) - - def test_v5_allows_image_and_audio_in_user(self) -> None: - validator = MistralRequestValidatorV5(ValidationMode.serving) - messages = [ - UserMessage( - content=[ - ImageURLChunk(image_url="data:image/png;base64,iVBORw0"), - AudioChunk(input_audio=b"fake"), - ] - ) - ] - validator.validate_messages(messages, continue_final_message=False) - - def test_v5_allows_audio_and_think_in_system(self) -> None: - validator = MistralRequestValidatorV5(ValidationMode.serving) - messages = [ - SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake"), ThinkChunk(thinking="t")]), - UserMessage(content="hi"), - ] - validator.validate_messages(messages, continue_final_message=False) - - def test_v5_still_rejects_think_in_assistant(self) -> None: - validator = MistralRequestValidatorV5(ValidationMode.serving) - messages = [ - UserMessage(content="hi"), - AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]), - UserMessage(content="next"), - ] - with pytest.raises( - InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message" - ): - validator.validate_messages(messages, continue_final_message=False) - - def test_v13_allows_think_in_assistant(self) -> None: - validator = MistralRequestValidatorV13(ValidationMode.serving) + def test_allows_think_in_assistant_content_chunks(self, validator_v13: MistralRequestValidatorV13) -> None: messages = [ UserMessage(content="hi"), AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]), UserMessage(content="next"), ] - validator.validate_messages(messages, continue_final_message=False) + validator_v13.validate_messages(messages, continue_final_message=False) - def test_v13_still_rejects_image_in_tool(self) -> None: - validator = MistralRequestValidatorV13(ValidationMode.serving) + def test_rejects_image_in_tool_content_chunks(self, validator_v13: MistralRequestValidatorV13) -> None: messages = [ UserMessage(content="hi"), AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="test12345"), ] with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"): - validator.validate_messages(messages, continue_final_message=False) + validator_v13.validate_messages(messages, continue_final_message=False) + + +class TestChatValidationV15: + @pytest.mark.parametrize("reasoning_effort", [*list(ReasoningEffort), None]) + def test_build_settings_v15_reasoning_effort( + self, reasoning_effort: ReasoningEffort | None, validator_v15: MistralRequestValidatorV15 + ) -> None: + request = ChatCompletionRequest(messages=[UserMessage(content="Hello")], reasoning_effort=reasoning_effort) + validator_v15._validate_model_settings(request) - def test_v15_rejects_think_in_system(self) -> None: - validator = MistralRequestValidatorV15(ValidationMode.serving) + def test_rejects_think_in_system_content_chunks(self, validator_v15: MistralRequestValidatorV15) -> None: messages = [ SystemMessage(content=[TextChunk(text="x"), ThinkChunk(thinking="t")]), UserMessage(content="hi"), ] with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"): - validator.validate_messages(messages, continue_final_message=False) + validator_v15.validate_messages(messages, continue_final_message=False) - def test_v15_allows_audio_in_system(self) -> None: - validator = MistralRequestValidatorV15(ValidationMode.serving) + def test_allows_audio_in_system_content_chunks(self, validator_v15: MistralRequestValidatorV15) -> None: messages = [ SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]), UserMessage(content="hi"), ] - validator.validate_messages(messages, continue_final_message=False) + validator_v15.validate_messages(messages, continue_final_message=False) - def test_v15_allows_image_in_tool(self) -> None: - validator = MistralRequestValidatorV15(ValidationMode.serving) + def test_allows_image_in_tool_content_chunks(self, validator_v15: MistralRequestValidatorV15) -> None: messages = [ UserMessage(content="hi"), AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="test12345"), ] - validator.validate_messages(messages, continue_final_message=False) + validator_v15.validate_messages(messages, continue_final_message=False) From 216e85dacc518e8073e9ce695027578d10202e7d Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 12 Jun 2026 19:57:06 +0200 Subject: [PATCH 43/47] Add informative messages to normalizer content TypeGuard asserts --- src/mistral_common/protocol/instruct/normalize.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index 6ed7b736..96f1dddf 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -286,7 +286,9 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag ) weight = message.weight - assert is_assistant_content(content) + assert is_assistant_content(content), ( + f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" + ) aggregated_message = self._assistant_message_class( content=content, tool_calls=tool_calls or None, @@ -300,7 +302,9 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType: """Coalesce neighboring blocks of ContentChunks in user messages.""" content = self._aggregate_content_chunks(messages) - assert is_user_content(content) + assert is_user_content(content), ( + f"Unexpected content chunk types in user message: {[type(c).__name__ for c in content]}" + ) return self._user_message_class(content=content) def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]: @@ -447,7 +451,9 @@ def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessage for message in messages: if isinstance(message, self._system_message_class): content = self._aggregate_content_chunks([message]) - assert is_system_content(content) + assert is_system_content(content), ( + f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}" + ) aggregated.append(self._system_message_class(content=content)) return aggregated From 5bfde189cbe47bc9fad9c481f58b2ae02d67c027 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 12 Jun 2026 19:57:07 +0200 Subject: [PATCH 44/47] Restructure chunk-type tests into exhaustive allow/disallow methods --- tests/fixtures/chunks.py | 35 ++++ tests/test_messages.py | 59 +++--- tests/validation/test_chat_validation.py | 234 ++++++++++------------- 3 files changed, 165 insertions(+), 163 deletions(-) create mode 100644 tests/fixtures/chunks.py diff --git a/tests/fixtures/chunks.py b/tests/fixtures/chunks.py new file mode 100644 index 00000000..72b3f85d --- /dev/null +++ b/tests/fixtures/chunks.py @@ -0,0 +1,35 @@ +from PIL import Image + +from mistral_common.protocol.instruct.chunk import ( + ContentChunk, + ImageChunk, + ImageURLChunk, + TextChunk, + ThinkChunk, +) +from tests.fixtures.audio import get_dummy_audio_chunk, get_dummy_audio_url_chunk + + +def get_content_chunk(name: str) -> ContentChunk: + r"""Return a single instance of the requested content chunk type. + + Args: + name: One of "text", "image", "image_url", "audio", "audio_url" or "think". + + Returns: + A content chunk of the requested type. + """ + chunks: dict[str, ContentChunk] = { + "text": TextChunk(text="hello"), + "image": ImageChunk(image=Image.new("RGB", (4, 4), "red")), + "image_url": ImageURLChunk(image_url="data:image/png;base64,iVBORw0"), + "audio": get_dummy_audio_chunk(), + "audio_url": get_dummy_audio_url_chunk(), + "think": ThinkChunk(thinking="reasoning"), + } + return chunks[name] + + +def get_content_chunks(names: tuple[str, ...]) -> list[ContentChunk]: + r"""Return a list of content chunks for the requested type names.""" + return [get_content_chunk(name) for name in names] diff --git a/tests/test_messages.py b/tests/test_messages.py index 4b8a3f6c..44aba528 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -1,45 +1,44 @@ import pytest from pydantic import ValidationError -from mistral_common.protocol.instruct.chunk import ( - AudioChunk, - ImageURLChunk, - TextChunk, - ThinkChunk, -) from mistral_common.protocol.instruct.messages import ( AssistantMessage, SystemMessage, ToolMessage, UserMessage, ) +from tests.fixtures.chunks import get_content_chunks as _chunks class TestMessageContentChunkUnions: - def test_assistant_rejects_image(self) -> None: - with pytest.raises(ValidationError): - AssistantMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")]) - - def test_assistant_rejects_audio(self) -> None: - with pytest.raises(ValidationError): - AssistantMessage(content=[AudioChunk(input_audio=b"fake")]) - - def test_assistant_accepts_text_and_think(self) -> None: - AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]) + r"""Pydantic-level (version-independent) content-chunk unions for each message role.""" - def test_system_rejects_image(self) -> None: - with pytest.raises(ValidationError): - SystemMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")]) - - def test_system_accepts_audio(self) -> None: - SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]) + def test_user_allows_text_image_audio(self) -> None: + UserMessage(content=_chunks(("text", "image", "image_url", "audio", "audio_url"))) def test_user_rejects_think(self) -> None: - with pytest.raises(ValidationError): - UserMessage(content=[ThinkChunk(thinking="r")]) - - def test_user_accepts_image(self) -> None: - UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")]) - - def test_tool_accepts_arbitrary_chunks(self) -> None: - ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="c1") + for name in ("think",): + with pytest.raises(ValidationError): + UserMessage(content=_chunks((name,))) + + def test_assistant_allows_text_and_think(self) -> None: + AssistantMessage(content=_chunks(("text", "think"))) + + def test_assistant_rejects_image_and_audio(self) -> None: + for name in ("image", "image_url", "audio", "audio_url"): + with pytest.raises(ValidationError): + AssistantMessage(content=_chunks((name,))) + + def test_system_allows_text_audio_think(self) -> None: + SystemMessage(content=_chunks(("text", "audio", "think"))) + + def test_system_rejects_image_and_audio_url(self) -> None: + for name in ("image", "image_url", "audio_url"): + with pytest.raises(ValidationError): + SystemMessage(content=_chunks((name,))) + + def test_tool_allows_all_chunk_types(self) -> None: + ToolMessage( + content=_chunks(("text", "image", "image_url", "audio", "audio_url", "think")), + tool_call_id="c1", + ) diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py index 2f1913ca..90212bf8 100644 --- a/tests/validation/test_chat_validation.py +++ b/tests/validation/test_chat_validation.py @@ -11,9 +11,7 @@ from mistral_common.protocol.instruct.chunk import ( AudioChunk, AudioURLChunk, - ImageURLChunk, - TextChunk, - ThinkChunk, + ContentChunk, ) from mistral_common.protocol.instruct.messages import ( AssistantMessage, @@ -35,6 +33,29 @@ ValidationMode, ) from tests.fixtures.audio import get_dummy_audio_chunk, get_dummy_audio_url_chunk +from tests.fixtures.chunks import get_content_chunks + +_Messages = list[UserMessage | AssistantMessage | SystemMessage | ToolMessage] + + +def _user_convo(content: "str | list[ContentChunk]") -> _Messages: + return [UserMessage(content=content)] + + +def _assistant_convo(content: "str | list[ContentChunk]") -> _Messages: + return [UserMessage(content="hi"), AssistantMessage(content=content), UserMessage(content="next")] + + +def _system_convo(content: "str | list[ContentChunk]") -> _Messages: + return [SystemMessage(content=content), UserMessage(content="hi")] + + +def _tool_convo(content: "str | list[ContentChunk]") -> _Messages: + return [ + UserMessage(content="hi"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), + ToolMessage(content=content, tool_call_id="test12345"), + ] @pytest.fixture(scope="module") @@ -372,81 +393,56 @@ def test_build_settings_raises_error(self, validator: MistralRequestValidator) - with pytest.raises(InvalidRequestException, match="reasoning_effort='none' is not supported for this model"): validator._validate_model_settings(request) - def test_rejects_think_in_assistant_content_chunks(self, validator_base: MistralRequestValidator) -> None: - messages = [ - UserMessage(content="hi"), - AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]), - UserMessage(content="hi again"), - ] - with pytest.raises( - InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message" - ): - validator_base.validate_messages(messages, continue_final_message=False) - - def test_rejects_audio_in_system_content_chunks(self, validator_base: MistralRequestValidator) -> None: - messages = [ - SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]), - UserMessage(content="hi"), - ] - with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"): - validator_base.validate_messages(messages, continue_final_message=False) - - def test_rejects_image_in_tool_content_chunks(self, validator_base: MistralRequestValidator) -> None: - messages = [ - UserMessage(content="hi"), - AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), - ToolMessage( - content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], - tool_call_id="test12345", - ), - ] - with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"): - validator_base.validate_messages(messages, continue_final_message=False) - - def test_rejects_image_in_user_content_chunks(self, validator_base: MistralRequestValidator) -> None: - messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])] - with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): - validator_base.validate_messages(messages, continue_final_message=False) - - def test_rejects_audio_in_user_content_chunks(self, validator_base: MistralRequestValidator) -> None: - messages = [UserMessage(content=[AudioChunk(input_audio=b"fake")])] - with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): - validator_base.validate_messages(messages, continue_final_message=False) + def test_allows_text_content_chunks(self, validator_base: MistralRequestValidator) -> None: + validator_base.validate_messages(_user_convo(get_content_chunks(("text",))), continue_final_message=False) + validator_base.validate_messages(_assistant_convo(get_content_chunks(("text",))), continue_final_message=False) + validator_base.validate_messages(_system_convo(get_content_chunks(("text",))), continue_final_message=False) + validator_base.validate_messages(_tool_convo(get_content_chunks(("text",))), continue_final_message=False) + + def test_rejects_non_text_in_user(self, validator_base: MistralRequestValidator) -> None: + for name in ("image", "image_url", "audio", "audio_url"): + with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): + validator_base.validate_messages(_user_convo(get_content_chunks((name,))), continue_final_message=False) + + def test_rejects_non_text_in_assistant(self, validator_base: MistralRequestValidator) -> None: + for name in ("think",): + with pytest.raises( + InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message" + ): + validator_base.validate_messages( + _assistant_convo(get_content_chunks((name,))), continue_final_message=False + ) + + def test_rejects_non_text_in_system(self, validator_base: MistralRequestValidator) -> None: + for name in ("audio", "think"): + with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"): + validator_base.validate_messages( + _system_convo(get_content_chunks((name,))), continue_final_message=False + ) + + def test_rejects_non_text_in_tool(self, validator_base: MistralRequestValidator) -> None: + for name in ("image", "image_url", "audio", "audio_url", "think"): + with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"): + validator_base.validate_messages(_tool_convo(get_content_chunks((name,))), continue_final_message=False) def test_reports_sorted_unique_invalid_chunk_types(self, validator_base: MistralRequestValidator) -> None: - messages = [ - UserMessage( - content=[ - AudioChunk(input_audio=b"fake"), - ImageURLChunk(image_url="data:image/png;base64,iVBORw0"), - ] - ) - ] + content = get_content_chunks(("audio", "image_url")) with pytest.raises( InvalidUserMessageException, match=r"Unexpected content chunk types in user message: \['AudioChunk', 'ImageURLChunk'\]", ): - validator_base.validate_messages(messages, continue_final_message=False) - - def test_accepts_multiple_text_content_chunks(self, validator_base: MistralRequestValidator) -> None: - messages = [ - SystemMessage(content=[TextChunk(text="a"), TextChunk(text="b")]), - UserMessage(content="hi"), - AssistantMessage(content=[TextChunk(text="c"), TextChunk(text="d")]), - UserMessage(content="hi again"), - ] - validator_base.validate_messages(messages, continue_final_message=False) + validator_base.validate_messages(_user_convo(content), continue_final_message=False) class TestChatValidationV3: - def test_allows_image_in_user_content_chunks(self, validator_v3: MistralRequestValidatorV3) -> None: - messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])] - validator_v3.validate_messages(messages, continue_final_message=False) + def test_allows_text_and_image_in_user(self, validator_v3: MistralRequestValidatorV3) -> None: + content = get_content_chunks(("text", "image", "image_url")) + validator_v3.validate_messages(_user_convo(content), continue_final_message=False) - def test_rejects_audio_in_user_content_chunks(self, validator_v3: MistralRequestValidatorV3) -> None: - messages = [UserMessage(content=[AudioChunk(input_audio=b"fake")])] - with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): - validator_v3.validate_messages(messages, continue_final_message=False) + def test_rejects_audio_in_user(self, validator_v3: MistralRequestValidatorV3) -> None: + for name in ("audio", "audio_url"): + with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): + validator_v3.validate_messages(_user_convo(get_content_chunks((name,))), continue_final_message=False) class TestChatValidationV5: @@ -498,34 +494,22 @@ def test_build_settings_raises_error(self, validator: MistralRequestValidator) - with pytest.raises(InvalidRequestException, match="reasoning_effort='none' is not supported for this model"): validator._validate_model_settings(request) - def test_allows_image_and_audio_in_user_content_chunks(self, validator_v5: MistralRequestValidatorV5) -> None: - messages = [ - UserMessage( - content=[ - ImageURLChunk(image_url="data:image/png;base64,iVBORw0"), - AudioChunk(input_audio=b"fake"), - ] - ) - ] - validator_v5.validate_messages(messages, continue_final_message=False) - - def test_allows_audio_and_think_in_system_content_chunks(self, validator_v5: MistralRequestValidatorV5) -> None: - messages = [ - SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake"), ThinkChunk(thinking="t")]), - UserMessage(content="hi"), - ] - validator_v5.validate_messages(messages, continue_final_message=False) - - def test_rejects_think_in_assistant_content_chunks(self, validator_v5: MistralRequestValidatorV5) -> None: - messages = [ - UserMessage(content="hi"), - AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]), - UserMessage(content="next"), - ] - with pytest.raises( - InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message" - ): - validator_v5.validate_messages(messages, continue_final_message=False) + def test_allows_text_image_audio_in_user(self, validator_v5: MistralRequestValidatorV5) -> None: + content = get_content_chunks(("text", "image", "image_url", "audio", "audio_url")) + validator_v5.validate_messages(_user_convo(content), continue_final_message=False) + + def test_allows_text_audio_think_in_system(self, validator_v5: MistralRequestValidatorV5) -> None: + content = get_content_chunks(("text", "audio", "think")) + validator_v5.validate_messages(_system_convo(content), continue_final_message=False) + + def test_rejects_think_in_assistant(self, validator_v5: MistralRequestValidatorV5) -> None: + for name in ("think",): + with pytest.raises( + InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message" + ): + validator_v5.validate_messages( + _assistant_convo(get_content_chunks((name,))), continue_final_message=False + ) class TestChatValidationV13: @@ -737,22 +721,14 @@ def test_audio_with_system_prompt_raises_ok( continue_final_message=False, ) - def test_allows_think_in_assistant_content_chunks(self, validator_v13: MistralRequestValidatorV13) -> None: - messages = [ - UserMessage(content="hi"), - AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]), - UserMessage(content="next"), - ] - validator_v13.validate_messages(messages, continue_final_message=False) + def test_allows_text_and_think_in_assistant(self, validator_v13: MistralRequestValidatorV13) -> None: + content = get_content_chunks(("text", "think")) + validator_v13.validate_messages(_assistant_convo(content), continue_final_message=False) - def test_rejects_image_in_tool_content_chunks(self, validator_v13: MistralRequestValidatorV13) -> None: - messages = [ - UserMessage(content="hi"), - AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), - ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="test12345"), - ] - with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"): - validator_v13.validate_messages(messages, continue_final_message=False) + def test_rejects_non_text_in_tool(self, validator_v13: MistralRequestValidatorV13) -> None: + for name in ("image", "image_url", "audio", "audio_url", "think"): + with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"): + validator_v13.validate_messages(_tool_convo(get_content_chunks((name,))), continue_final_message=False) class TestChatValidationV15: @@ -763,25 +739,17 @@ def test_build_settings_v15_reasoning_effort( request = ChatCompletionRequest(messages=[UserMessage(content="Hello")], reasoning_effort=reasoning_effort) validator_v15._validate_model_settings(request) - def test_rejects_think_in_system_content_chunks(self, validator_v15: MistralRequestValidatorV15) -> None: - messages = [ - SystemMessage(content=[TextChunk(text="x"), ThinkChunk(thinking="t")]), - UserMessage(content="hi"), - ] - with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"): - validator_v15.validate_messages(messages, continue_final_message=False) - - def test_allows_audio_in_system_content_chunks(self, validator_v15: MistralRequestValidatorV15) -> None: - messages = [ - SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]), - UserMessage(content="hi"), - ] - validator_v15.validate_messages(messages, continue_final_message=False) - - def test_allows_image_in_tool_content_chunks(self, validator_v15: MistralRequestValidatorV15) -> None: - messages = [ - UserMessage(content="hi"), - AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), - ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="test12345"), - ] - validator_v15.validate_messages(messages, continue_final_message=False) + def test_allows_text_and_audio_in_system(self, validator_v15: MistralRequestValidatorV15) -> None: + content = get_content_chunks(("text", "audio")) + validator_v15.validate_messages(_system_convo(content), continue_final_message=False) + + def test_rejects_think_in_system(self, validator_v15: MistralRequestValidatorV15) -> None: + for name in ("think",): + with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"): + validator_v15.validate_messages( + _system_convo(get_content_chunks((name,))), continue_final_message=False + ) + + def test_allows_all_chunk_types_in_tool(self, validator_v15: MistralRequestValidatorV15) -> None: + content = get_content_chunks(("text", "image", "image_url", "audio", "audio_url", "think")) + validator_v15.validate_messages(_tool_convo(content), continue_final_message=False) From 34d57d9437d7f2ae612fca6129e2bad77a794bce Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Fri, 12 Jun 2026 20:24:02 +0200 Subject: [PATCH 45/47] Type test chunk helpers as Any to satisfy mypy on message construction --- tests/fixtures/chunks.py | 11 +++++++++-- tests/validation/test_chat_validation.py | 11 ++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/fixtures/chunks.py b/tests/fixtures/chunks.py index 72b3f85d..289765f5 100644 --- a/tests/fixtures/chunks.py +++ b/tests/fixtures/chunks.py @@ -1,3 +1,5 @@ +from typing import Any + from PIL import Image from mistral_common.protocol.instruct.chunk import ( @@ -30,6 +32,11 @@ def get_content_chunk(name: str) -> ContentChunk: return chunks[name] -def get_content_chunks(names: tuple[str, ...]) -> list[ContentChunk]: - r"""Return a list of content chunks for the requested type names.""" +def get_content_chunks(names: tuple[str, ...]) -> list[Any]: + r"""Return a list of content chunks for the requested type names. + + The element type is intentionally `Any` so the chunks can be fed to any message role + constructor (including deliberately invalid combinations) in tests, leaving the runtime + Pydantic validation to enforce the actual rules. + """ return [get_content_chunk(name) for name in names] diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py index 90212bf8..c036e2a8 100644 --- a/tests/validation/test_chat_validation.py +++ b/tests/validation/test_chat_validation.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from mistral_common.exceptions import ( @@ -11,7 +13,6 @@ from mistral_common.protocol.instruct.chunk import ( AudioChunk, AudioURLChunk, - ContentChunk, ) from mistral_common.protocol.instruct.messages import ( AssistantMessage, @@ -38,19 +39,19 @@ _Messages = list[UserMessage | AssistantMessage | SystemMessage | ToolMessage] -def _user_convo(content: "str | list[ContentChunk]") -> _Messages: +def _user_convo(content: "str | list[Any]") -> _Messages: return [UserMessage(content=content)] -def _assistant_convo(content: "str | list[ContentChunk]") -> _Messages: +def _assistant_convo(content: "str | list[Any]") -> _Messages: return [UserMessage(content="hi"), AssistantMessage(content=content), UserMessage(content="next")] -def _system_convo(content: "str | list[ContentChunk]") -> _Messages: +def _system_convo(content: "str | list[Any]") -> _Messages: return [SystemMessage(content=content), UserMessage(content="hi")] -def _tool_convo(content: "str | list[ContentChunk]") -> _Messages: +def _tool_convo(content: "str | list[Any]") -> _Messages: return [ UserMessage(content="hi"), AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), From 46bd99f139a71efd391fb816ad9a47a2dee4132a Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Mon, 15 Jun 2026 13:24:36 +0200 Subject: [PATCH 46/47] Centralize content-chunk allow-lists on message classes --- .../experimental/app/routers.py | 4 +- src/mistral_common/protocol/instruct/chunk.py | 64 +------------------ .../protocol/instruct/messages.py | 57 +++++++++++++---- .../protocol/instruct/normalize.py | 12 ---- src/mistral_common/tokens/tokenizers/base.py | 4 +- .../tokens/tokenizers/instruct.py | 42 ++++++++---- tests/test_tokenizer_v15.py | 4 +- tests/test_tokenizer_v7_audio.py | 6 +- 8 files changed, 84 insertions(+), 109 deletions(-) diff --git a/src/mistral_common/experimental/app/routers.py b/src/mistral_common/experimental/app/routers.py index 1c8aabc4..40103e70 100644 --- a/src/mistral_common/experimental/app/routers.py +++ b/src/mistral_common/experimental/app/routers.py @@ -14,7 +14,7 @@ ) from mistral_common.experimental.think import _split_content_and_think_chunks from mistral_common.experimental.tools import _decode_tool_calls, _split_content_and_tool_calls -from mistral_common.protocol.instruct.chunk import AssistantContentChunk, TextChunk, ThinkChunk +from mistral_common.protocol.instruct.chunk import ContentChunk, TextChunk, ThinkChunk from mistral_common.protocol.instruct.messages import AssistantMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, Tokenized, TokenizerVersion @@ -100,7 +100,7 @@ async def detokenize_to_assistant_message( else: content_tokens, tool_calls_tokens = tokens, () - content: str | list[AssistantContentChunk] | None = None + content: str | list[ContentChunk] | None = None if settings.tokenizer.instruct_tokenizer.tokenizer.version >= TokenizerVersion.v13: assert isinstance(settings.tokenizer.instruct_tokenizer, InstructTokenizerV13) diff --git a/src/mistral_common/protocol/instruct/chunk.py b/src/mistral_common/protocol/instruct/chunk.py index fda7b995..91fcb679 100644 --- a/src/mistral_common/protocol/instruct/chunk.py +++ b/src/mistral_common/protocol/instruct/chunk.py @@ -7,7 +7,7 @@ from urllib.parse import urlparse from pydantic import ConfigDict, Field, ValidationError, field_validator, model_validator -from typing_extensions import Annotated, TypeAlias, TypeGuard +from typing_extensions import Annotated from mistral_common.base import MistralBase from mistral_common.deprecation import warn_once @@ -451,68 +451,6 @@ def from_openai(cls, openai_chunk: dict[str, Any]) -> "ThinkChunk": ContentChunk = Annotated[ TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk | ThinkChunk, Field(discriminator="type") ] -UserContentChunk = Annotated[ - TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk, Field(discriminator="type") -] - -AssistantContentChunk = Annotated[TextChunk | ThinkChunk, Field(discriminator="type")] - -SystemContentChunk = Annotated[TextChunk | AudioChunk | ThinkChunk, Field(discriminator="type")] - -ToolContentChunk: TypeAlias = ContentChunk # Accepts all content chunk types (no restriction on tool messages). - - -def is_user_content(content: str | list[ContentChunk]) -> TypeGuard[str | list[UserContentChunk]]: - r"""Narrow aggregated content to the types allowed in a user message. - - String content is always accepted. List content is accepted only when every chunk is a - `TextChunk`, `ImageChunk`, `ImageURLChunk`, `AudioChunk` or `AudioURLChunk`. - - Args: - content: The content to narrow. - - Returns: - Whether the content only contains user-compatible chunk types. - """ - if isinstance(content, str): - return True - return all( - isinstance(chunk, (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk)) for chunk in content - ) - - -def is_assistant_content(content: str | list[ContentChunk]) -> TypeGuard[str | list[AssistantContentChunk]]: - r"""Narrow aggregated content to the types allowed in an assistant message. - - String content is always accepted. List content is accepted only when every chunk is a - `TextChunk` or `ThinkChunk`. - - Args: - content: The content to narrow. - - Returns: - Whether the content only contains assistant-compatible chunk types. - """ - if isinstance(content, str): - return True - return all(isinstance(chunk, (TextChunk, ThinkChunk)) for chunk in content) - - -def is_system_content(content: str | list[ContentChunk]) -> TypeGuard[str | list[SystemContentChunk]]: - r"""Narrow aggregated content to the types allowed in a system message. - - String content is always accepted. List content is accepted only when every chunk is a - `TextChunk`, `AudioChunk` or `ThinkChunk`. - - Args: - content: The content to narrow. - - Returns: - Whether the content only contains system-compatible chunk types. - """ - if isinstance(content, str): - return True - return all(isinstance(chunk, (TextChunk, AudioChunk, ThinkChunk)) for chunk in content) def _convert_openai_content_chunks(openai_content_chunks: dict[str, Any]) -> ContentChunk: diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py index 6ca46b2f..f96b78af 100644 --- a/src/mistral_common/protocol/instruct/messages.py +++ b/src/mistral_common/protocol/instruct/messages.py @@ -1,21 +1,22 @@ import warnings from collections.abc import Sequence from enum import Enum -from typing import Any, Literal, TypeVar +from typing import Any, ClassVar, Literal, TypeVar -from pydantic import Field +from pydantic import Field, model_validator from typing_extensions import Annotated, TypeAlias, TypeGuard from mistral_common.base import MistralBase from mistral_common.exceptions import InvalidAssistantMessageException from mistral_common.protocol.instruct.chunk import ( - AssistantContentChunk, + AudioChunk, + AudioURLChunk, + BaseContentChunk, ContentChunk, - SystemContentChunk, + ImageChunk, + ImageURLChunk, TextChunk, ThinkChunk, - ToolContentChunk, - UserContentChunk, _convert_openai_content_chunks, ) from mistral_common.protocol.instruct.tool_calls import ToolCall @@ -27,12 +28,12 @@ ) -def _are_think_chunks(chunks: Sequence[AssistantContentChunk]) -> TypeGuard[list[ThinkChunk]]: +def _are_think_chunks(chunks: Sequence[ContentChunk]) -> TypeGuard[list[ThinkChunk]]: r"""Narrow a chunk list to ThinkChunk list.""" return all(isinstance(c, ThinkChunk) for c in chunks) -def _are_text_chunks(chunks: Sequence[AssistantContentChunk]) -> TypeGuard[list[TextChunk]]: +def _are_text_chunks(chunks: Sequence[ContentChunk]) -> TypeGuard[list[TextChunk]]: r"""Narrow a chunk list to TextChunk list.""" return all(isinstance(c, TextChunk) for c in chunks) @@ -79,6 +80,19 @@ class BaseMessage(MistralBase): role: Literal[Roles.system, Roles.user, Roles.assistant, Roles.tool] + # Allow-list of content chunk types accepted by this message. Must be set by each subclass. + _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] + + @model_validator(mode="after") + def _validate_allowed_content_chunks(self) -> "BaseMessage": + r"""Enforce the per-message content chunk allow-list.""" + content = getattr(self, "content", None) + if isinstance(content, list): + for chunk in content: + if not isinstance(chunk, self._allowed_content_chunks): + raise ValueError(f"{type(chunk).__name__} cannot be used in {self.role} message.") + return self + @staticmethod def _content_to_openai( content: str | Sequence[ContentChunk] | None, @@ -144,7 +158,14 @@ class UserMessage(BaseMessage): """ role: Literal[Roles.user] = Roles.user - content: str | list[UserContentChunk] + content: str | list[ContentChunk] + _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = ( + TextChunk, + ImageChunk, + ImageURLChunk, + AudioChunk, + AudioURLChunk, + ) def to_openai(self) -> dict[str, Any]: r"""Converts the message to the OpenAI format.""" @@ -169,7 +190,8 @@ class SystemMessage(BaseMessage): """ role: Literal[Roles.system] = Roles.system - content: str | list[SystemContentChunk] + content: str | list[ContentChunk] + _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = (TextChunk, AudioChunk, ThinkChunk) def to_openai(self) -> dict[str, Any]: r"""Converts the message to the OpenAI format.""" @@ -197,7 +219,8 @@ class AssistantMessage(BaseMessage): """ role: Literal[Roles.assistant] = Roles.assistant - content: str | list[AssistantContentChunk] | None = None + content: str | list[ContentChunk] | None = None + _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = (TextChunk, ThinkChunk) tool_calls: list[ToolCall] | None = None prefix: bool = False @@ -340,13 +363,23 @@ class ToolMessage(BaseMessage): >>> message = ToolMessage(content="Hello, how can I help you?", tool_call_id="123") """ - content: str | list[ToolContentChunk] + content: str | list[ContentChunk] role: Literal[Roles.tool] = Roles.tool tool_call_id: str | None = None # Deprecated in V3 tokenization name: str | None = None + # Tool messages accept all content chunk types. + _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = ( + TextChunk, + ImageChunk, + ImageURLChunk, + AudioChunk, + AudioURLChunk, + ThinkChunk, + ) + def to_openai(self) -> dict[str, Any]: r"""Converts the message to the OpenAI format.""" assert self.tool_call_id is not None, "tool_call_id must be provided for tool messages." diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index 96f1dddf..c423cda0 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -8,9 +8,6 @@ from mistral_common.protocol.instruct.chunk import ( ContentChunk, TextChunk, - is_assistant_content, - is_system_content, - is_user_content, ) from mistral_common.protocol.instruct.messages import ( UATS, @@ -286,9 +283,6 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag ) weight = message.weight - assert is_assistant_content(content), ( - f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" - ) aggregated_message = self._assistant_message_class( content=content, tool_calls=tool_calls or None, @@ -302,9 +296,6 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType: """Coalesce neighboring blocks of ContentChunks in user messages.""" content = self._aggregate_content_chunks(messages) - assert is_user_content(content), ( - f"Unexpected content chunk types in user message: {[type(c).__name__ for c in content]}" - ) return self._user_message_class(content=content) def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]: @@ -451,9 +442,6 @@ def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessage for message in messages: if isinstance(message, self._system_message_class): content = self._aggregate_content_chunks([message]) - assert is_system_content(content), ( - f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}" - ) aggregated.append(self._system_message_class(content=content)) return aggregated diff --git a/src/mistral_common/tokens/tokenizers/base.py b/src/mistral_common/tokens/tokenizers/base.py index 1fe580fa..06dff715 100644 --- a/src/mistral_common/tokens/tokenizers/base.py +++ b/src/mistral_common/tokens/tokenizers/base.py @@ -9,7 +9,7 @@ from mistral_common.base import MistralBase from mistral_common.protocol.fim.request import FIMRequest -from mistral_common.protocol.instruct.chunk import UserContentChunk +from mistral_common.protocol.instruct.chunk import ContentChunk from mistral_common.protocol.instruct.messages import ( AssistantMessageType, UserMessage, @@ -428,7 +428,7 @@ def encode_user_message( @abstractmethod def encode_user_content( self, - content: str | list[UserContentChunk], + content: str | list[ContentChunk], is_last: bool, system_prompt: str | None = None, force_img_first: bool = False, diff --git a/src/mistral_common/tokens/tokenizers/instruct.py b/src/mistral_common/tokens/tokenizers/instruct.py index 9b044f6c..7d0e5ab8 100644 --- a/src/mistral_common/tokens/tokenizers/instruct.py +++ b/src/mistral_common/tokens/tokenizers/instruct.py @@ -20,8 +20,6 @@ ImageURLChunk, TextChunk, ThinkChunk, - ToolContentChunk, - UserContentChunk, ) from mistral_common.protocol.instruct.messages import ( UATS, @@ -300,7 +298,7 @@ def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list def encode_user_content( self, - content: str | list[UserContentChunk], + content: str | list[ContentChunk], is_last: bool, system_prompt: str | None = None, force_img_first: bool = False, @@ -488,9 +486,15 @@ def _parse_json_content(self, content: str) -> Any: except json.JSONDecodeError: return content - def _parse_tool_content(self, content: str | list[ToolContentChunk]) -> Any: + def _parse_tool_content(self, content: str | list[ContentChunk]) -> Any: if isinstance(content, list): - content = "".join(chunk.text for chunk in content if isinstance(chunk, TextChunk)) + text_parts: list[str] = [] + for chunk in content: + assert isinstance(chunk, TextChunk), ( + f"Tool content only supports text chunks, got {type(chunk).__name__}." + ) + text_parts.append(chunk.text) + content = "".join(text_parts) return self._parse_json_content(content) def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]: @@ -753,7 +757,7 @@ def _encode_content_chunks( def encode_user_content( self, - content: str | list[UserContentChunk], + content: str | list[ContentChunk], is_last: bool, system_prompt: str | None = None, force_img_first: bool = False, @@ -793,13 +797,18 @@ def encode_user_content( assert not content_str, ( f"It is not possible that `content` is non-empty when chunk is of type {type(chunk)}." ) - chunk_tokens, _, chunk_audio = self._encode_content_chunk(chunk) + chunk_tokens, maybe_image, chunk_audio = self._encode_content_chunk(chunk) + assert maybe_image is None, f"Unexpected image for audio chunk {type(chunk).__name__}." audio.append(chunk_audio) elif isinstance(chunk, (ImageChunk, ImageURLChunk)): - chunk_tokens, chunk_image, _ = self._encode_content_chunk(chunk) + chunk_tokens, chunk_image, maybe_audio = self._encode_content_chunk(chunk) + assert maybe_audio is None, f"Unexpected audio for image chunk {type(chunk).__name__}." images.append(chunk_image) else: - chunk_tokens = self._encode_content_chunk(chunk)[0] + chunk_tokens, maybe_image, maybe_audio = self._encode_content_chunk(chunk) + assert maybe_image is None and maybe_audio is None, ( + f"Unexpected image/audio for chunk {type(chunk).__name__}." + ) tokens.extend(chunk_tokens) return tokens, images, audio @@ -891,14 +900,15 @@ def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list tokens = [self.BEGIN_SYSTEM] if isinstance(content := message.content, str): content = [TextChunk(text=content)] - content_tokens, _images, audios = self._encode_content_chunks(content) + content_tokens, images, audios = self._encode_content_chunks(content) + assert not images, f"System messages cannot contain images, got {len(images)}." tokens += content_tokens tokens.append(self.END_SYSTEM) return tokens, audios def encode_user_content( self, - content: str | list[UserContentChunk], + content: str | list[ContentChunk], is_last: bool, system_prompt: str | None = None, force_img_first: bool = False, @@ -995,7 +1005,7 @@ def _encode_instruct_transcription(self, request: TranscriptionRequest) -> Token ) assert self.TRANSCRIBE is not None, f"{self.__class__.__name__} needs to have a TRANSCRIBE token" prefix = self.start() - tokens, _, audio = self.encode_user_message( + tokens, images, audio = self.encode_user_message( UserMessage(content=[AudioChunk(input_audio=request.audio)]), available_tools=[], is_last=True, @@ -1003,6 +1013,7 @@ def _encode_instruct_transcription(self, request: TranscriptionRequest) -> Token system_prompt=None, settings=ModelSettings.none(), ) + assert not images, f"Transcription input cannot contain images, got {len(images)}." tokens = [*prefix, *tokens] if request.language is not None: @@ -1152,7 +1163,12 @@ def encode_assistant_message( if isinstance(message.content, str): curr_tokens = self._encode_normal_content_assistant_message(message) elif isinstance(message.content, list): - curr_tokens += self._encode_content_chunks(message.content)[0] + content_tokens, images, audios = self._encode_content_chunks(message.content) + assert not images and not audios, ( + f"Assistant messages cannot contain images or audios, got {len(images)} images " + f"and {len(audios)} audios." + ) + curr_tokens += content_tokens if message.tool_calls: curr_tokens += self._encode_tool_calls_in_assistant_message(message) if not message.prefix and not continue_message: diff --git a/tests/test_tokenizer_v15.py b/tests/test_tokenizer_v15.py index 2c882cac..c8abe52c 100644 --- a/tests/test_tokenizer_v15.py +++ b/tests/test_tokenizer_v15.py @@ -488,7 +488,7 @@ def test_encode_chat_completion_with_multimodal_tool( expected_text: str, ) -> None: mistral_tokenizer = tokenizer_factory() - chat_request = ChatCompletionRequest( # type: ignore[type-var] + chat_request: ChatCompletionRequest = ChatCompletionRequest( messages=[ UserMessage(content="Use the tool"), AssistantMessage(tool_calls=[ToolCall(id="test12345", function=FunctionCall(name="fn", arguments="{}"))]), @@ -517,7 +517,7 @@ def test_encode_chat_completion_with_multimodal_system( expected_text: str, ) -> None: mistral_tokenizer = tokenizer_factory() - chat_request = ChatCompletionRequest( # type: ignore[type-var] + chat_request: ChatCompletionRequest = ChatCompletionRequest( messages=[ SystemMessage(content=[TextChunk(text="System with content"), content_chunk]), UserMessage(content="Hello"), diff --git a/tests/test_tokenizer_v7_audio.py b/tests/test_tokenizer_v7_audio.py index 10995957..4ef3d47c 100644 --- a/tests/test_tokenizer_v7_audio.py +++ b/tests/test_tokenizer_v7_audio.py @@ -11,8 +11,8 @@ AudioChunk, AudioURL, AudioURLChunk, + ContentChunk, TextChunk, - UserContentChunk, ) from mistral_common.protocol.instruct.messages import ( UATS, @@ -215,7 +215,7 @@ def test_tokenize_user_message(tekkenizer: InstructTokenizerV7, audio_first: boo text_chunk = TextChunk(text="a") num_expected_frames = int(np.ceil(duration * frame_rate)) - chunks: list[UserContentChunk] = [audio_chunk, text_chunk] if audio_first else [text_chunk, audio_chunk] + chunks: list[ContentChunk] = [audio_chunk, text_chunk] if audio_first else [text_chunk, audio_chunk] tokenized = tekkenizer.encode_instruct( InstructRequest( @@ -261,7 +261,7 @@ def test_tokenize_multi_turn(tekkenizer: InstructTokenizerV7) -> None: text_chunk = TextChunk(text="a") num_expected_frames = int(np.ceil(duration * frame_rate)) - chunks: list[UserContentChunk] = [audio_chunk, text_chunk] + chunks: list[ContentChunk] = [audio_chunk, text_chunk] tokenized = tekkenizer.encode_instruct( InstructRequest( From b57956eb3c53523793c32d67032f8e180992cfcd Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Mon, 15 Jun 2026 14:23:57 +0200 Subject: [PATCH 47/47] Fix mypy. --- tests/test_converters.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_converters.py b/tests/test_converters.py index d093d945..19b69a22 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -37,6 +37,7 @@ AudioChunk, AudioURL, AudioURLChunk, + ContentChunk, ImageChunk, ImageURL, ImageURLChunk, @@ -586,7 +587,7 @@ def test_non_leading_think_chunks_construction_ok() -> None: [TextChunk(text="A"), TextChunk(text="B"), ThinkChunk(thinking="End", closed=True)], ], ) -def test_non_leading_think_chunks_to_openai_raises(content: list[TextChunk | ThinkChunk]) -> None: +def test_non_leading_think_chunks_to_openai_raises(content: list[ContentChunk]) -> None: """to_openai raises when ThinkChunks are not leading.""" msg = AssistantMessage(content=content) with pytest.raises(InvalidAssistantMessageException, match="ThinkChunks must be leading"):