From 98d252357931c44c7278bd37426047f557941666 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Thu, 4 Jun 2026 15:14:05 +0200
Subject: [PATCH 01/47] Parametrize chunk and message join separators
---
src/mistral_common/protocol/instruct/normalize.py | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index c96a1940..4ef3f740 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -32,13 +32,14 @@
from mistral_common.tokens.tokenizers.base import InstructRequestType, TokenizerVersion
from mistral_common.tokens.tokenizers.model_settings_builder import ModelSettingsBuilder
-_DEFAULT_JOIN_STR = "\n\n"
+_MSG_JOIN_STR = "\n\n"
+_CHUNK_JOIN_STR = ""
def _aggregate_content_chunks_impl(
contents: list[list[ContentChunk] | str | None],
- msg_join_str: str,
- chunk_join_str: str,
+ msg_join_str: str = _MSG_JOIN_STR,
+ chunk_join_str: str = _CHUNK_JOIN_STR,
) -> list[ContentChunk] | str:
r"""Coalesce TextChunks within the same message and across different messages.
@@ -128,8 +129,8 @@ class InstructRequestNormalizer(
_system_prompt_in_begin: bool = False
_allow_tool_call_and_content: bool = False
- _chunk_join_str: str = _DEFAULT_JOIN_STR
- _msg_join_str: str = _DEFAULT_JOIN_STR
+ _chunk_join_str: str = _CHUNK_JOIN_STR
+ _msg_join_str: str = _MSG_JOIN_STR
def __init__(
self,
From 7aa8189da5469f2b5f752f5a49f9f5c4dda48842 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Thu, 4 Jun 2026 15:17:47 +0200
Subject: [PATCH 02/47] Use TypeGuard for type-safe content narrowing
---
src/mistral_common/protocol/instruct/normalize.py | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index 4ef3f740..ea48db1b 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -297,8 +297,10 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
)
weight = message.weight
- if isinstance(content, str) or _is_assistant_content(content):
+ if isinstance(content, str):
narrowed_content: str | list[TextChunk | ThinkChunk] = content
+ elif _is_assistant_content(content):
+ narrowed_content = content
else:
raise InvalidRequestException(
f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
@@ -317,7 +319,9 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType:
"""Coalesce neighboring blocks of ContentChunks in user messages."""
content = self._aggregate_content_chunks(messages)
- if isinstance(content, str) or _is_user_content(content):
+ if isinstance(content, str):
+ return self._user_message_class(content=content)
+ elif _is_user_content(content):
return self._user_message_class(content=content)
else:
raise InvalidRequestException(
From 0cb2e81cdb63a81f84d90a4c524820f64d4a05bd Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Thu, 4 Jun 2026 17:04:51 +0200
Subject: [PATCH 03/47] Fix chat template intra-message chunk joining to match
normalizer
---
.../chat_templates/template_generator.py | 16 ++++++++++++++--
tests/data/chat_templates/v1.jinja | 16 ++++++++++++++--
tests/data/chat_templates/v11.jinja | 10 +++++++++-
tests/data/chat_templates/v11_audio.jinja | 10 +++++++++-
tests/data/chat_templates/v11_image.jinja | 10 +++++++++-
tests/data/chat_templates/v11_image_think.jinja | 10 +++++++++-
tests/data/chat_templates/v11_think.jinja | 10 +++++++++-
tests/data/chat_templates/v13.jinja | 10 +++++++++-
tests/data/chat_templates/v13_audio.jinja | 10 +++++++++-
tests/data/chat_templates/v13_image.jinja | 10 +++++++++-
tests/data/chat_templates/v13_image_think.jinja | 10 +++++++++-
tests/data/chat_templates/v13_think.jinja | 10 +++++++++-
tests/data/chat_templates/v1_spm.jinja | 16 ++++++++++++++--
tests/data/chat_templates/v2.jinja | 16 ++++++++++++++--
tests/data/chat_templates/v2_spm.jinja | 16 ++++++++++++++--
tests/data/chat_templates/v3.jinja | 16 ++++++++++++++--
tests/data/chat_templates/v3_image.jinja | 16 ++++++++++++++--
tests/data/chat_templates/v3_image_spm.jinja | 16 ++++++++++++++--
tests/data/chat_templates/v3_spm.jinja | 16 ++++++++++++++--
tests/data/chat_templates/v7.jinja | 10 +++++++++-
tests/data/chat_templates/v7_audio.jinja | 10 +++++++++-
tests/data/chat_templates/v7_image.jinja | 10 +++++++++-
tests/data/chat_templates/v7_image_spm.jinja | 10 +++++++++-
tests/data/chat_templates/v7_spm.jinja | 10 +++++++++-
24 files changed, 261 insertions(+), 33 deletions(-)
diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py
index 807ed528..9916cb83 100644
--- a/src/mistral_common/integrations/chat_templates/template_generator.py
+++ b/src/mistral_common/integrations/chat_templates/template_generator.py
@@ -287,13 +287,17 @@ def _generate_system_prompt_handling_pre_v7(config: TemplateConfig) -> list[str]
" {%- if message['content'] is string %}",
" {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}",
" {%- else %}",
+ " {%- set ns_sys_msg = namespace(parts=[]) %}",
" {%- for block in message['content'] %}",
" {%- if block['type'] == 'text' %}",
- " {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}",
+ " {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}",
" {%- else %}",
" {{- raise_exception('Only text chunks are supported in system message content.') }}",
" {%- endif %}",
" {%- endfor %}",
+ " {%- if ns_sys_msg.parts | length > 0 %}",
+ " {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}",
+ " {%- endif %}",
" {%- endif %}",
" {%- else %}",
" {%- set ns_sys.filtered = ns_sys.filtered + [message] %}",
@@ -796,10 +800,15 @@ def _generate_flush_logic(config: TemplateConfig) -> list[str]:
else:
list_content_lines = [
" {%- elif msg['content'] is not none %}",
+ " {%- set ns_msg = namespace(msg_text_parts=[]) %}",
" {%- for block in msg['content'] %}",
" {%- if block['type'] == 'text' %}",
- " {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}",
+ " {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}",
" {%- else %}",
+ " {%- if ns_msg.msg_text_parts | length > 0 %}",
+ " {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}", # noqa: E501
+ " {%- set ns_msg.msg_text_parts = [] %}",
+ " {%- endif %}",
" {%- if ns_c.text_parts | length > 0 %}",
" {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\\n\\n')}] %}", # noqa: E501
" {%- set ns_c.text_parts = [] %}",
@@ -808,6 +817,9 @@ def _generate_flush_logic(config: TemplateConfig) -> list[str]:
" {%- set ns_c.has_non_text = true %}",
" {%- endif %}",
" {%- endfor %}",
+ " {%- if ns_msg.msg_text_parts | length > 0 %}",
+ " {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}",
+ " {%- endif %}",
" {%- endif %}",
]
diff --git a/tests/data/chat_templates/v1.jinja b/tests/data/chat_templates/v1.jinja
index 2b6162c7..340c8963 100644
--- a/tests/data/chat_templates/v1.jinja
+++ b/tests/data/chat_templates/v1.jinja
@@ -13,13 +13,17 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
+ {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
+ {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
+ {%- if ns_sys_msg.parts | length > 0 %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -45,10 +49,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -57,6 +66,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v11.jinja b/tests/data/chat_templates/v11.jinja
index 79943d0d..2a1d7562 100644
--- a/tests/data/chat_templates/v11.jinja
+++ b/tests/data/chat_templates/v11.jinja
@@ -50,10 +50,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -62,6 +67,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v11_audio.jinja b/tests/data/chat_templates/v11_audio.jinja
index e96cdbf9..c95b1dc7 100644
--- a/tests/data/chat_templates/v11_audio.jinja
+++ b/tests/data/chat_templates/v11_audio.jinja
@@ -60,10 +60,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -72,6 +77,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v11_image.jinja b/tests/data/chat_templates/v11_image.jinja
index 164a6763..1944dca4 100644
--- a/tests/data/chat_templates/v11_image.jinja
+++ b/tests/data/chat_templates/v11_image.jinja
@@ -52,10 +52,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -64,6 +69,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v11_image_think.jinja b/tests/data/chat_templates/v11_image_think.jinja
index 950746ef..39683fce 100644
--- a/tests/data/chat_templates/v11_image_think.jinja
+++ b/tests/data/chat_templates/v11_image_think.jinja
@@ -79,10 +79,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -91,6 +96,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v11_think.jinja b/tests/data/chat_templates/v11_think.jinja
index ca99772c..a4820c07 100644
--- a/tests/data/chat_templates/v11_think.jinja
+++ b/tests/data/chat_templates/v11_think.jinja
@@ -77,10 +77,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -89,6 +94,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v13.jinja b/tests/data/chat_templates/v13.jinja
index 38c253d9..14e0ce58 100644
--- a/tests/data/chat_templates/v13.jinja
+++ b/tests/data/chat_templates/v13.jinja
@@ -50,10 +50,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -62,6 +67,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v13_audio.jinja b/tests/data/chat_templates/v13_audio.jinja
index c0a2fdec..5fbb6ccb 100644
--- a/tests/data/chat_templates/v13_audio.jinja
+++ b/tests/data/chat_templates/v13_audio.jinja
@@ -60,10 +60,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -72,6 +77,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v13_image.jinja b/tests/data/chat_templates/v13_image.jinja
index 1b79382a..38d608bf 100644
--- a/tests/data/chat_templates/v13_image.jinja
+++ b/tests/data/chat_templates/v13_image.jinja
@@ -52,10 +52,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -64,6 +69,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v13_image_think.jinja b/tests/data/chat_templates/v13_image_think.jinja
index cb39ae85..cc5079f3 100644
--- a/tests/data/chat_templates/v13_image_think.jinja
+++ b/tests/data/chat_templates/v13_image_think.jinja
@@ -79,10 +79,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -91,6 +96,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v13_think.jinja b/tests/data/chat_templates/v13_think.jinja
index 5313eabe..8606f7c3 100644
--- a/tests/data/chat_templates/v13_think.jinja
+++ b/tests/data/chat_templates/v13_think.jinja
@@ -77,10 +77,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -89,6 +94,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v1_spm.jinja b/tests/data/chat_templates/v1_spm.jinja
index 5fb2e7d0..7fae6926 100644
--- a/tests/data/chat_templates/v1_spm.jinja
+++ b/tests/data/chat_templates/v1_spm.jinja
@@ -13,13 +13,17 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
+ {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
+ {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
+ {%- if ns_sys_msg.parts | length > 0 %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -45,10 +49,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -57,6 +66,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v2.jinja b/tests/data/chat_templates/v2.jinja
index 0ba2a722..cd009d81 100644
--- a/tests/data/chat_templates/v2.jinja
+++ b/tests/data/chat_templates/v2.jinja
@@ -13,13 +13,17 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
+ {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
+ {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
+ {%- if ns_sys_msg.parts | length > 0 %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -71,10 +75,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -83,6 +92,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v2_spm.jinja b/tests/data/chat_templates/v2_spm.jinja
index 647ad3ac..8b9d18ca 100644
--- a/tests/data/chat_templates/v2_spm.jinja
+++ b/tests/data/chat_templates/v2_spm.jinja
@@ -13,13 +13,17 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
+ {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
+ {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
+ {%- if ns_sys_msg.parts | length > 0 %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -71,10 +75,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -83,6 +92,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v3.jinja b/tests/data/chat_templates/v3.jinja
index 8a0bfa7d..6678f208 100644
--- a/tests/data/chat_templates/v3.jinja
+++ b/tests/data/chat_templates/v3.jinja
@@ -13,13 +13,17 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
+ {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
+ {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
+ {%- if ns_sys_msg.parts | length > 0 %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -71,10 +75,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -83,6 +92,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v3_image.jinja b/tests/data/chat_templates/v3_image.jinja
index d00b3529..f62b7b28 100644
--- a/tests/data/chat_templates/v3_image.jinja
+++ b/tests/data/chat_templates/v3_image.jinja
@@ -13,13 +13,17 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
+ {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
+ {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
+ {%- if ns_sys_msg.parts | length > 0 %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -73,10 +77,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -85,6 +94,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v3_image_spm.jinja b/tests/data/chat_templates/v3_image_spm.jinja
index 8cb31b30..eb1cf1ee 100644
--- a/tests/data/chat_templates/v3_image_spm.jinja
+++ b/tests/data/chat_templates/v3_image_spm.jinja
@@ -13,13 +13,17 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
+ {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
+ {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
+ {%- if ns_sys_msg.parts | length > 0 %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -79,10 +83,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -91,6 +100,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v3_spm.jinja b/tests/data/chat_templates/v3_spm.jinja
index 304644e4..f1aef3c2 100644
--- a/tests/data/chat_templates/v3_spm.jinja
+++ b/tests/data/chat_templates/v3_spm.jinja
@@ -13,13 +13,17 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
+ {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
+ {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
+ {%- if ns_sys_msg.parts | length > 0 %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -71,10 +75,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -83,6 +92,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v7.jinja b/tests/data/chat_templates/v7.jinja
index 519991ec..516ea1e0 100644
--- a/tests/data/chat_templates/v7.jinja
+++ b/tests/data/chat_templates/v7.jinja
@@ -50,10 +50,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -62,6 +67,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v7_audio.jinja b/tests/data/chat_templates/v7_audio.jinja
index f4e98518..c3d6cbf8 100644
--- a/tests/data/chat_templates/v7_audio.jinja
+++ b/tests/data/chat_templates/v7_audio.jinja
@@ -60,10 +60,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -72,6 +77,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v7_image.jinja b/tests/data/chat_templates/v7_image.jinja
index e3c372df..0c0a1223 100644
--- a/tests/data/chat_templates/v7_image.jinja
+++ b/tests/data/chat_templates/v7_image.jinja
@@ -52,10 +52,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -64,6 +69,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v7_image_spm.jinja b/tests/data/chat_templates/v7_image_spm.jinja
index 69088620..dda5bef3 100644
--- a/tests/data/chat_templates/v7_image_spm.jinja
+++ b/tests/data/chat_templates/v7_image_spm.jinja
@@ -58,10 +58,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -70,6 +75,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v7_spm.jinja b/tests/data/chat_templates/v7_spm.jinja
index 02875ee6..84d76105 100644
--- a/tests/data/chat_templates/v7_spm.jinja
+++ b/tests/data/chat_templates/v7_spm.jinja
@@ -50,10 +50,15 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
+ {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
+ {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
{%- else %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- set ns_msg.msg_text_parts = [] %}
+ {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -62,6 +67,9 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
+ {%- if ns_msg.msg_text_parts | length > 0 %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
+ {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
From d9d2e634b211052d121f9abe39b06e0bf0ac8513 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Thu, 4 Jun 2026 17:48:17 +0200
Subject: [PATCH 04/47] Restrict intra-message chunk join change to V15 only
---
.../chat_templates/template_generator.py | 17 +++--------------
.../protocol/instruct/normalize.py | 11 +++++------
tests/data/chat_templates/v1.jinja | 16 ++--------------
tests/data/chat_templates/v11.jinja | 10 +---------
tests/data/chat_templates/v11_audio.jinja | 10 +---------
tests/data/chat_templates/v11_image.jinja | 10 +---------
tests/data/chat_templates/v11_image_think.jinja | 10 +---------
tests/data/chat_templates/v11_think.jinja | 10 +---------
tests/data/chat_templates/v13.jinja | 10 +---------
tests/data/chat_templates/v13_audio.jinja | 10 +---------
tests/data/chat_templates/v13_image.jinja | 10 +---------
tests/data/chat_templates/v13_image_think.jinja | 10 +---------
tests/data/chat_templates/v13_think.jinja | 10 +---------
tests/data/chat_templates/v1_spm.jinja | 16 ++--------------
tests/data/chat_templates/v2.jinja | 16 ++--------------
tests/data/chat_templates/v2_spm.jinja | 16 ++--------------
tests/data/chat_templates/v3.jinja | 16 ++--------------
tests/data/chat_templates/v3_image.jinja | 16 ++--------------
tests/data/chat_templates/v3_image_spm.jinja | 16 ++--------------
tests/data/chat_templates/v3_spm.jinja | 16 ++--------------
tests/data/chat_templates/v7.jinja | 10 +---------
tests/data/chat_templates/v7_audio.jinja | 10 +---------
tests/data/chat_templates/v7_image.jinja | 10 +---------
tests/data/chat_templates/v7_image_spm.jinja | 10 +---------
tests/data/chat_templates/v7_spm.jinja | 10 +---------
25 files changed, 39 insertions(+), 267 deletions(-)
diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py
index 9916cb83..d580acb8 100644
--- a/src/mistral_common/integrations/chat_templates/template_generator.py
+++ b/src/mistral_common/integrations/chat_templates/template_generator.py
@@ -287,17 +287,13 @@ def _generate_system_prompt_handling_pre_v7(config: TemplateConfig) -> list[str]
" {%- if message['content'] is string %}",
" {%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}",
" {%- else %}",
- " {%- set ns_sys_msg = namespace(parts=[]) %}",
" {%- for block in message['content'] %}",
" {%- if block['type'] == 'text' %}",
- " {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}",
+ " {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}",
" {%- else %}",
" {{- raise_exception('Only text chunks are supported in system message content.') }}",
" {%- endif %}",
" {%- endfor %}",
- " {%- if ns_sys_msg.parts | length > 0 %}",
- " {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}",
- " {%- endif %}",
" {%- endif %}",
" {%- else %}",
" {%- set ns_sys.filtered = ns_sys.filtered + [message] %}",
@@ -800,15 +796,10 @@ def _generate_flush_logic(config: TemplateConfig) -> list[str]:
else:
list_content_lines = [
" {%- elif msg['content'] is not none %}",
- " {%- set ns_msg = namespace(msg_text_parts=[]) %}",
" {%- for block in msg['content'] %}",
" {%- if block['type'] == 'text' %}",
- " {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}",
+ " {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}",
" {%- else %}",
- " {%- if ns_msg.msg_text_parts | length > 0 %}",
- " {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}", # noqa: E501
- " {%- set ns_msg.msg_text_parts = [] %}",
- " {%- endif %}",
" {%- if ns_c.text_parts | length > 0 %}",
" {%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\\n\\n')}] %}", # noqa: E501
" {%- set ns_c.text_parts = [] %}",
@@ -817,9 +808,7 @@ def _generate_flush_logic(config: TemplateConfig) -> list[str]:
" {%- set ns_c.has_non_text = true %}",
" {%- endif %}",
" {%- endfor %}",
- " {%- if ns_msg.msg_text_parts | length > 0 %}",
- " {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}",
- " {%- endif %}",
+
" {%- endif %}",
]
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index ea48db1b..56325e37 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -32,14 +32,13 @@
from mistral_common.tokens.tokenizers.base import InstructRequestType, TokenizerVersion
from mistral_common.tokens.tokenizers.model_settings_builder import ModelSettingsBuilder
-_MSG_JOIN_STR = "\n\n"
-_CHUNK_JOIN_STR = ""
+_DEFAULT_JOIN_STR = "\n\n"
def _aggregate_content_chunks_impl(
contents: list[list[ContentChunk] | str | None],
- msg_join_str: str = _MSG_JOIN_STR,
- chunk_join_str: str = _CHUNK_JOIN_STR,
+ msg_join_str: str,
+ chunk_join_str: str,
) -> list[ContentChunk] | str:
r"""Coalesce TextChunks within the same message and across different messages.
@@ -129,8 +128,8 @@ class InstructRequestNormalizer(
_system_prompt_in_begin: bool = False
_allow_tool_call_and_content: bool = False
- _chunk_join_str: str = _CHUNK_JOIN_STR
- _msg_join_str: str = _MSG_JOIN_STR
+ _chunk_join_str: str = _DEFAULT_JOIN_STR
+ _msg_join_str: str = _DEFAULT_JOIN_STR
def __init__(
self,
diff --git a/tests/data/chat_templates/v1.jinja b/tests/data/chat_templates/v1.jinja
index 340c8963..2b6162c7 100644
--- a/tests/data/chat_templates/v1.jinja
+++ b/tests/data/chat_templates/v1.jinja
@@ -13,17 +13,13 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
- {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
- {%- if ns_sys_msg.parts | length > 0 %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -49,15 +45,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -66,9 +57,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v11.jinja b/tests/data/chat_templates/v11.jinja
index 2a1d7562..79943d0d 100644
--- a/tests/data/chat_templates/v11.jinja
+++ b/tests/data/chat_templates/v11.jinja
@@ -50,15 +50,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -67,9 +62,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v11_audio.jinja b/tests/data/chat_templates/v11_audio.jinja
index c95b1dc7..e96cdbf9 100644
--- a/tests/data/chat_templates/v11_audio.jinja
+++ b/tests/data/chat_templates/v11_audio.jinja
@@ -60,15 +60,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -77,9 +72,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v11_image.jinja b/tests/data/chat_templates/v11_image.jinja
index 1944dca4..164a6763 100644
--- a/tests/data/chat_templates/v11_image.jinja
+++ b/tests/data/chat_templates/v11_image.jinja
@@ -52,15 +52,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -69,9 +64,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v11_image_think.jinja b/tests/data/chat_templates/v11_image_think.jinja
index 39683fce..950746ef 100644
--- a/tests/data/chat_templates/v11_image_think.jinja
+++ b/tests/data/chat_templates/v11_image_think.jinja
@@ -79,15 +79,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -96,9 +91,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v11_think.jinja b/tests/data/chat_templates/v11_think.jinja
index a4820c07..ca99772c 100644
--- a/tests/data/chat_templates/v11_think.jinja
+++ b/tests/data/chat_templates/v11_think.jinja
@@ -77,15 +77,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -94,9 +89,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v13.jinja b/tests/data/chat_templates/v13.jinja
index 14e0ce58..38c253d9 100644
--- a/tests/data/chat_templates/v13.jinja
+++ b/tests/data/chat_templates/v13.jinja
@@ -50,15 +50,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -67,9 +62,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v13_audio.jinja b/tests/data/chat_templates/v13_audio.jinja
index 5fbb6ccb..c0a2fdec 100644
--- a/tests/data/chat_templates/v13_audio.jinja
+++ b/tests/data/chat_templates/v13_audio.jinja
@@ -60,15 +60,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -77,9 +72,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v13_image.jinja b/tests/data/chat_templates/v13_image.jinja
index 38d608bf..1b79382a 100644
--- a/tests/data/chat_templates/v13_image.jinja
+++ b/tests/data/chat_templates/v13_image.jinja
@@ -52,15 +52,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -69,9 +64,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v13_image_think.jinja b/tests/data/chat_templates/v13_image_think.jinja
index cc5079f3..cb39ae85 100644
--- a/tests/data/chat_templates/v13_image_think.jinja
+++ b/tests/data/chat_templates/v13_image_think.jinja
@@ -79,15 +79,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -96,9 +91,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v13_think.jinja b/tests/data/chat_templates/v13_think.jinja
index 8606f7c3..5313eabe 100644
--- a/tests/data/chat_templates/v13_think.jinja
+++ b/tests/data/chat_templates/v13_think.jinja
@@ -77,15 +77,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -94,9 +89,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v1_spm.jinja b/tests/data/chat_templates/v1_spm.jinja
index 7fae6926..5fb2e7d0 100644
--- a/tests/data/chat_templates/v1_spm.jinja
+++ b/tests/data/chat_templates/v1_spm.jinja
@@ -13,17 +13,13 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
- {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
- {%- if ns_sys_msg.parts | length > 0 %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -49,15 +45,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -66,9 +57,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v2.jinja b/tests/data/chat_templates/v2.jinja
index cd009d81..0ba2a722 100644
--- a/tests/data/chat_templates/v2.jinja
+++ b/tests/data/chat_templates/v2.jinja
@@ -13,17 +13,13 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
- {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
- {%- if ns_sys_msg.parts | length > 0 %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -75,15 +71,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -92,9 +83,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v2_spm.jinja b/tests/data/chat_templates/v2_spm.jinja
index 8b9d18ca..647ad3ac 100644
--- a/tests/data/chat_templates/v2_spm.jinja
+++ b/tests/data/chat_templates/v2_spm.jinja
@@ -13,17 +13,13 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
- {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
- {%- if ns_sys_msg.parts | length > 0 %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -75,15 +71,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -92,9 +83,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v3.jinja b/tests/data/chat_templates/v3.jinja
index 6678f208..8a0bfa7d 100644
--- a/tests/data/chat_templates/v3.jinja
+++ b/tests/data/chat_templates/v3.jinja
@@ -13,17 +13,13 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
- {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
- {%- if ns_sys_msg.parts | length > 0 %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -75,15 +71,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -92,9 +83,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v3_image.jinja b/tests/data/chat_templates/v3_image.jinja
index f62b7b28..d00b3529 100644
--- a/tests/data/chat_templates/v3_image.jinja
+++ b/tests/data/chat_templates/v3_image.jinja
@@ -13,17 +13,13 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
- {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
- {%- if ns_sys_msg.parts | length > 0 %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -77,15 +73,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -94,9 +85,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v3_image_spm.jinja b/tests/data/chat_templates/v3_image_spm.jinja
index eb1cf1ee..8cb31b30 100644
--- a/tests/data/chat_templates/v3_image_spm.jinja
+++ b/tests/data/chat_templates/v3_image_spm.jinja
@@ -13,17 +13,13 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
- {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
- {%- if ns_sys_msg.parts | length > 0 %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -83,15 +79,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -100,9 +91,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v3_spm.jinja b/tests/data/chat_templates/v3_spm.jinja
index f1aef3c2..304644e4 100644
--- a/tests/data/chat_templates/v3_spm.jinja
+++ b/tests/data/chat_templates/v3_spm.jinja
@@ -13,17 +13,13 @@
{%- if message['content'] is string %}
{%- set ns_sys.system_parts = ns_sys.system_parts + [message['content']] %}
{%- else %}
- {%- set ns_sys_msg = namespace(parts=[]) %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_sys_msg.parts = ns_sys_msg.parts + [block['text']] %}
+ {%- set ns_sys.system_parts = ns_sys.system_parts + [block['text']] %}
{%- else %}
{{- raise_exception('Only text chunks are supported in system message content.') }}
{%- endif %}
{%- endfor %}
- {%- if ns_sys_msg.parts | length > 0 %}
- {%- set ns_sys.system_parts = ns_sys.system_parts + [ns_sys_msg.parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- else %}
{%- set ns_sys.filtered = ns_sys.filtered + [message] %}
@@ -75,15 +71,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -92,9 +83,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v7.jinja b/tests/data/chat_templates/v7.jinja
index 516ea1e0..519991ec 100644
--- a/tests/data/chat_templates/v7.jinja
+++ b/tests/data/chat_templates/v7.jinja
@@ -50,15 +50,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -67,9 +62,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v7_audio.jinja b/tests/data/chat_templates/v7_audio.jinja
index c3d6cbf8..f4e98518 100644
--- a/tests/data/chat_templates/v7_audio.jinja
+++ b/tests/data/chat_templates/v7_audio.jinja
@@ -60,15 +60,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -77,9 +72,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v7_image.jinja b/tests/data/chat_templates/v7_image.jinja
index 0c0a1223..e3c372df 100644
--- a/tests/data/chat_templates/v7_image.jinja
+++ b/tests/data/chat_templates/v7_image.jinja
@@ -52,15 +52,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -69,9 +64,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v7_image_spm.jinja b/tests/data/chat_templates/v7_image_spm.jinja
index dda5bef3..69088620 100644
--- a/tests/data/chat_templates/v7_image_spm.jinja
+++ b/tests/data/chat_templates/v7_image_spm.jinja
@@ -58,15 +58,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -75,9 +70,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
diff --git a/tests/data/chat_templates/v7_spm.jinja b/tests/data/chat_templates/v7_spm.jinja
index 84d76105..02875ee6 100644
--- a/tests/data/chat_templates/v7_spm.jinja
+++ b/tests/data/chat_templates/v7_spm.jinja
@@ -50,15 +50,10 @@
{%- if msg['content'] is string %}
{%- set ns_c.text_parts = ns_c.text_parts + [msg['content']] %}
{%- elif msg['content'] is not none %}
- {%- set ns_msg = namespace(msg_text_parts=[]) %}
{%- for block in msg['content'] %}
{%- if block['type'] == 'text' %}
- {%- set ns_msg.msg_text_parts = ns_msg.msg_text_parts + [block['text']] %}
+ {%- set ns_c.text_parts = ns_c.text_parts + [block['text']] %}
{%- else %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- set ns_msg.msg_text_parts = [] %}
- {%- endif %}
{%- if ns_c.text_parts | length > 0 %}
{%- set ns_c.chunks = ns_c.chunks + [{'type': 'text', 'text': ns_c.text_parts | join('\n\n')}] %}
{%- set ns_c.text_parts = [] %}
@@ -67,9 +62,6 @@
{%- set ns_c.has_non_text = true %}
{%- endif %}
{%- endfor %}
- {%- if ns_msg.msg_text_parts | length > 0 %}
- {%- set ns_c.text_parts = ns_c.text_parts + [ns_msg.msg_text_parts | join('')] %}
- {%- endif %}
{%- endif %}
{%- if msg['tool_calls'] is defined and msg['tool_calls'] is not none %}
{%- set ns_c.tool_calls = ns_c.tool_calls + msg['tool_calls'] | list %}
From 1842eb44f079e705cc2f82a08d55154964b9a454 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 5 Jun 2026 10:28:26 +0200
Subject: [PATCH 05/47] Address review nits: remove cur_text_len, combine
branches with or
---
src/mistral_common/protocol/instruct/normalize.py | 8 ++------
1 file changed, 2 insertions(+), 6 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index 56325e37..c96a1940 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -296,10 +296,8 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
)
weight = message.weight
- if isinstance(content, str):
+ if isinstance(content, str) or _is_assistant_content(content):
narrowed_content: str | list[TextChunk | ThinkChunk] = content
- elif _is_assistant_content(content):
- narrowed_content = content
else:
raise InvalidRequestException(
f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
@@ -318,9 +316,7 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType:
"""Coalesce neighboring blocks of ContentChunks in user messages."""
content = self._aggregate_content_chunks(messages)
- if isinstance(content, str):
- return self._user_message_class(content=content)
- elif _is_user_content(content):
+ if isinstance(content, str) or _is_user_content(content):
return self._user_message_class(content=content)
else:
raise InvalidRequestException(
From 16ecf704ff28fbe81942f2617911be20659de4e5 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 5 Jun 2026 16:48:44 +0200
Subject: [PATCH 06/47] Consolidate multimodal ContentChunk support for all
message roles
Add per-role content type aliases: AssistantContentChunk, SystemContentChunk,
ToolContentChunk. Widen all message content fields to use per-role aliases.
Add shared _content_to_openai/_content_from_openai helpers on BaseMessage.
Refactor all to_openai/from_openai to use shared helpers.
Add normalizer _narrow_*_content methods with version-aware validation:
- _narrow_assistant_content: pre-V15 allows TextChunk/ThinkChunk, V15 allows all
- _narrow_tool_content: pre-V15 allows text only, V15 allows all
- _narrow_system_content: V7+ allows text/think/audio, V15 rejects ThinkChunk
Widen encode_tool/assistant/system_message return types to
tuple[list[int], list[np.ndarray], list[Audio]].
Add TemplateConfig properties: tool_supports_multimodal,
assistant_supports_multimodal, system_supports_audio, system_supports_thinking.
Update template generation with dynamic supported_types_desc.
Closes #236, closes #237, closes #238
---
.../experimental/app/routers.py | 4 +-
.../chat_templates/template_generator.py | 70 ++++-
src/mistral_common/protocol/instruct/chunk.py | 8 +-
.../protocol/instruct/messages.py | 115 ++++---
.../protocol/instruct/normalize.py | 174 +++++++++--
.../tokens/tokenizers/instruct.py | 117 +++++---
tests/data/chat_templates/v15.jinja | 4 +-
tests/data/chat_templates/v15_audio.jinja | 8 +-
tests/data/chat_templates/v15_image.jinja | 6 +-
.../data/chat_templates/v15_image_think.jinja | 6 +-
tests/data/chat_templates/v15_think.jinja | 4 +-
tests/experimental/test_app.py | 6 +-
tests/guidance/test_guidance.py | 5 +-
.../chat_templates/fixtures_data.py | 66 ++++
.../chat_templates/test_parity.py | 125 ++++++++
.../transformers/test_core_parity.py | 20 +-
.../chat_templates/unit/test_v15.py | 281 ++++++++++++++++++
tests/test_converters.py | 19 +-
tests/test_normalization.py | 190 ++++++++++++
tests/test_tokenizer_v11.py | 16 +-
tests/test_tokenizer_v13.py | 18 +-
tests/test_tokenizer_v15.py | 206 ++++++++++++-
22 files changed, 1312 insertions(+), 156 deletions(-)
diff --git a/src/mistral_common/experimental/app/routers.py b/src/mistral_common/experimental/app/routers.py
index 73c9c6e3..1c8aabc4 100644
--- a/src/mistral_common/experimental/app/routers.py
+++ b/src/mistral_common/experimental/app/routers.py
@@ -14,7 +14,7 @@
)
from mistral_common.experimental.think import _split_content_and_think_chunks
from mistral_common.experimental.tools import _decode_tool_calls, _split_content_and_tool_calls
-from mistral_common.protocol.instruct.chunk import TextChunk, ThinkChunk
+from mistral_common.protocol.instruct.chunk import AssistantContentChunk, TextChunk, ThinkChunk
from mistral_common.protocol.instruct.messages import AssistantMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, Tokenized, TokenizerVersion
@@ -100,7 +100,7 @@ async def detokenize_to_assistant_message(
else:
content_tokens, tool_calls_tokens = tokens, ()
- content: str | list[TextChunk | ThinkChunk] | None = None
+ content: str | list[AssistantContentChunk] | None = None
if settings.tokenizer.instruct_tokenizer.tokenizer.version >= TokenizerVersion.v13:
assert isinstance(settings.tokenizer.instruct_tokenizer, InstructTokenizerV13)
diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py
index d580acb8..9d8a4373 100644
--- a/src/mistral_common/integrations/chat_templates/template_generator.py
+++ b/src/mistral_common/integrations/chat_templates/template_generator.py
@@ -216,6 +216,35 @@ def validates_assistant_non_empty(self) -> bool:
r"""Whether to validate that assistant messages have non-empty content or tool calls."""
return self.version >= TokenizerVersion.v7 or (self.version >= TokenizerVersion.v3 and not self.spm)
+ @property
+ def tool_supports_multimodal(self) -> bool:
+ r"""Whether tool messages can contain non-text content chunks. V15+."""
+ return self.version >= TokenizerVersion.v15
+
+ @property
+ def assistant_supports_multimodal(self) -> bool:
+ r"""Whether assistant messages can contain non-text content chunks. V15+."""
+ return self.version >= TokenizerVersion.v15
+
+ @property
+ def system_supports_audio(self) -> bool:
+ r"""Whether system messages can contain audio. V15+ with audio support."""
+ return self.audio_support and self.version >= TokenizerVersion.v15
+
+
+def _join_types_desc(parts: list[str]) -> str:
+ r"""Join type names into a human-readable description string.
+
+ Args:
+ parts: List of type names (e.g. ["text", "thinking", "image"]).
+
+ Returns:
+ Formatted string like "text", "text and thinking", or "text, thinking and image".
+ """
+ if len(parts) == 1:
+ return parts[0]
+ return ", ".join(parts[:-1]) + " and " + parts[-1]
+
def _generate_header(config: TemplateConfig) -> str:
r"""Generate template header with default system message.
@@ -873,6 +902,8 @@ def _generate_system_message_handling(config: TemplateConfig) -> str:
if has_extra_types:
if config.system_supports_thinking:
rc_args += ", supported_types_desc='text and thinking'"
+ elif config.system_supports_audio:
+ rc_args += ", supported_types_desc='text and audio'"
else:
rc_args += ", supported_types_desc='text'"
if config.any_thinking_support:
@@ -883,7 +914,7 @@ def _generate_system_message_handling(config: TemplateConfig) -> str:
if config.image_support:
rc_args += ", support_images=false"
if config.audio_support:
- rc_args += ", support_audio=false"
+ rc_args += f", support_audio={'true' if config.system_supports_audio else 'false'}"
lines.append(" {{- render_content(" + rc_args + ") -}}")
lines.append(" {{- '" + _END_SYSTEM + "' -}}")
@@ -1236,16 +1267,20 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str:
has_extra_types = config.any_thinking_support or config.image_support or config.audio_support
rc_call_args = "message['content'], 'assistant message contents'"
if has_extra_types:
+ desc_parts = ["text"]
if config.any_thinking_support:
- rc_call_args += ", supported_types_desc='text and thinking'"
- else:
- rc_call_args += ", supported_types_desc='text'"
+ desc_parts.append("thinking")
+ if config.assistant_supports_multimodal and config.image_support:
+ desc_parts.append("image")
+ if config.assistant_supports_multimodal and config.audio_support:
+ desc_parts.append("audio")
+ rc_call_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'"
if config.any_thinking_support:
rc_call_args += ", support_thinking=true"
if config.image_support:
- rc_call_args += ", support_images=false"
+ rc_call_args += f", support_images={'true' if config.assistant_supports_multimodal else 'false'}"
if config.audio_support:
- rc_call_args += ", support_audio=false"
+ rc_call_args += f", support_audio={'true' if config.assistant_supports_multimodal else 'false'}"
lines.append(" {%- if message['content'] %}")
@@ -1485,9 +1520,26 @@ def _generate_tool_message_handling(config: TemplateConfig) -> str:
+ "' }}" # noqa: E501
)
elif config.uses_simple_tool_results:
- lines.append(
- " {{- '" + _BEGIN_TOOL_RESULTS + "' + message['content']|string + '" + _END_TOOL_RESULTS + "' }}"
- ) # noqa: E501
+ if config.tool_supports_multimodal:
+ tool_rc_args = "message['content'], 'tool message contents'"
+ if config.image_support or config.audio_support:
+ desc_parts = ["text"]
+ if config.image_support:
+ desc_parts.append("image")
+ if config.audio_support:
+ desc_parts.append("audio")
+ tool_rc_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'"
+ if config.image_support:
+ tool_rc_args += ", support_images=true"
+ if config.audio_support:
+ tool_rc_args += ", support_audio=true"
+ lines.append(" {{- '" + _BEGIN_TOOL_RESULTS + "' -}}")
+ lines.append(" {{- render_content(" + tool_rc_args + ") -}}")
+ lines.append(" {{- '" + _END_TOOL_RESULTS + "' }}")
+ else:
+ lines.append(
+ " {{- '" + _BEGIN_TOOL_RESULTS + "' + message['content']|string + '" + _END_TOOL_RESULTS + "' }}"
+ ) # noqa: E501
else:
# v3 non-spm style
lines.extend(_emit_int_float_parsing(" "))
diff --git a/src/mistral_common/protocol/instruct/chunk.py b/src/mistral_common/protocol/instruct/chunk.py
index 2ed61528..430c8c3d 100644
--- a/src/mistral_common/protocol/instruct/chunk.py
+++ b/src/mistral_common/protocol/instruct/chunk.py
@@ -7,7 +7,7 @@
from urllib.parse import urlparse
from pydantic import ConfigDict, Field, ValidationError, field_validator, model_validator
-from typing_extensions import Annotated
+from typing_extensions import Annotated, TypeAlias
from mistral_common.base import MistralBase
from mistral_common.deprecation import warn_once
@@ -455,6 +455,12 @@ def from_openai(cls, openai_chunk: dict[str, Any]) -> "ThinkChunk":
TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk, Field(discriminator="type")
]
+AssistantContentChunk: TypeAlias = ContentChunk
+
+SystemContentChunk = Annotated[TextChunk | AudioChunk | ThinkChunk, Field(discriminator="type")]
+
+ToolContentChunk: TypeAlias = ContentChunk
+
def _convert_openai_content_chunks(openai_content_chunks: dict[str, Any]) -> ContentChunk:
content_type_str = openai_content_chunks.get("type")
diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py
index e82379cb..140a1f3a 100644
--- a/src/mistral_common/protocol/instruct/messages.py
+++ b/src/mistral_common/protocol/instruct/messages.py
@@ -1,16 +1,20 @@
import warnings
+from collections.abc import Sequence
from enum import Enum
-from typing import Any, Literal, TypeGuard, TypeVar
+from typing import Any, Literal, TypeVar
from pydantic import Field
-from typing_extensions import Annotated, TypeAlias
+from typing_extensions import Annotated, TypeAlias, TypeGuard
from mistral_common.base import MistralBase
from mistral_common.exceptions import InvalidAssistantMessageException
from mistral_common.protocol.instruct.chunk import (
+ AssistantContentChunk,
ContentChunk,
+ SystemContentChunk,
TextChunk,
ThinkChunk,
+ ToolContentChunk,
UserContentChunk,
_convert_openai_content_chunks,
)
@@ -23,12 +27,9 @@
)
-def _are_think_chunks(think_chunks: list[ThinkChunk | TextChunk]) -> TypeGuard[list[ThinkChunk]]:
- return all(isinstance(c, ThinkChunk) for c in think_chunks)
-
-
-def _are_text_chunks(think_chunks: list[ThinkChunk | TextChunk]) -> TypeGuard[list[TextChunk]]:
- return all(isinstance(c, TextChunk) for c in think_chunks)
+def _are_think_chunks(chunks: list[ContentChunk]) -> TypeGuard[list[ThinkChunk]]:
+ r"""Narrow a ContentChunk list to ThinkChunk list."""
+ return all(isinstance(c, ThinkChunk) for c in chunks)
class ReasoningFieldFormat(str, Enum):
@@ -73,6 +74,44 @@ class BaseMessage(MistralBase):
role: Literal[Roles.system, Roles.user, Roles.assistant, Roles.tool]
+ @staticmethod
+ def _content_to_openai(
+ content: str | Sequence[ContentChunk] | None,
+ ) -> str | list[dict[str, Any]] | None:
+ r"""Serialize message content to OpenAI format.
+
+ Args:
+ content: String, list of content chunks, or None.
+
+ Returns:
+ String content as-is, list of chunks serialized via each chunk's
+ to_openai(), or None.
+ """
+ if content is None or isinstance(content, str):
+ return content
+ return [chunk.to_openai() for chunk in content]
+
+ @staticmethod
+ def _content_from_openai(
+ raw: str | list[dict[str, Any]] | None,
+ ) -> str | list[ContentChunk] | None:
+ r"""Deserialize content from OpenAI format.
+
+ Args:
+ raw: Raw content from OpenAI message dict.
+
+ Returns:
+ String content as-is, list of deserialized content chunks, or None.
+
+ Raises:
+ ValueError: If content type is unrecognized.
+ """
+ if raw is None or isinstance(raw, str):
+ return raw
+ if isinstance(raw, list):
+ return [_convert_openai_content_chunks(chunk) for chunk in raw]
+ raise ValueError(f"Unknown content type: {type(raw)}")
+
def to_openai(self) -> dict[str, Any]:
r"""Converts the message to the OpenAI format.
@@ -104,20 +143,13 @@ class UserMessage(BaseMessage):
def to_openai(self) -> dict[str, Any]:
r"""Converts the message to the OpenAI format."""
- if isinstance(self.content, str):
- return {"role": self.role, "content": self.content}
- return {"role": self.role, "content": [chunk.to_openai() for chunk in self.content]}
+ return {"role": self.role, "content": self._content_to_openai(self.content)}
@classmethod
def from_openai(cls, openai_message: dict[str, Any]) -> "UserMessage":
r"""Converts the OpenAI message to the Mistral format."""
- if isinstance(openai_message["content"], str):
- return cls.model_validate_ignore_extra(openai_message)
return cls.model_validate(
- {
- "role": openai_message["role"],
- "content": [_convert_openai_content_chunks(chunk) for chunk in openai_message["content"]],
- },
+ {"role": openai_message["role"], "content": cls._content_from_openai(openai_message.get("content"))}
)
@@ -132,16 +164,18 @@ class SystemMessage(BaseMessage):
"""
role: Literal[Roles.system] = Roles.system
- content: str | list[TextChunk | ThinkChunk]
+ content: str | list[SystemContentChunk]
def to_openai(self) -> dict[str, Any]:
r"""Converts the message to the OpenAI format."""
- return self.model_dump()
+ return {"role": self.role, "content": self._content_to_openai(self.content)}
@classmethod
def from_openai(cls, openai_message: dict[str, Any]) -> "SystemMessage":
r"""Converts the OpenAI message to the Mistral format."""
- return cls.model_validate_ignore_extra(openai_message)
+ return cls.model_validate(
+ {"role": openai_message["role"], "content": cls._content_from_openai(openai_message.get("content"))}
+ )
class AssistantMessage(BaseMessage):
@@ -158,7 +192,7 @@ class AssistantMessage(BaseMessage):
"""
role: Literal[Roles.assistant] = Roles.assistant
- content: str | list[TextChunk | ThinkChunk] | None = None
+ content: str | list[AssistantContentChunk] | None = None
tool_calls: list[ToolCall] | None = None
prefix: bool = False
@@ -205,18 +239,18 @@ def to_openai(
match reasoning_field_format:
case None | ReasoningFieldFormat.thinking_chunks:
- out_dict["content"] = [chunk.to_openai() for chunk in self.content]
+ out_dict["content"] = self._content_to_openai(self.content)
case ReasoningFieldFormat.reasoning | ReasoningFieldFormat.reasoning_content:
think_chunks, content_chunks = self.content[: last_think_idx + 1], self.content[last_think_idx + 1 :]
- if not _are_think_chunks(think_chunks) or not _are_text_chunks(content_chunks):
- raise RuntimeError("Impossible, only think or content chunks should have been present.")
+ if not _are_think_chunks(think_chunks):
+ raise RuntimeError("Expected only ThinkChunks in the leading portion.")
if len(think_chunks) > 0:
out_dict[reasoning_field_format.value] = "\n".join(tc.thinking for tc in think_chunks)
- if len(content_chunks) == 1:
+ if len(content_chunks) == 1 and isinstance(content_chunks[0], TextChunk):
out_dict["content"] = content_chunks[0].text
elif content_chunks:
- out_dict["content"] = [chunk.to_openai() for chunk in content_chunks]
+ out_dict["content"] = self._content_to_openai(content_chunks)
case _:
raise ValueError(f"{reasoning_field_format=} is not supported.")
@@ -234,14 +268,7 @@ def from_openai(cls, openai_message: dict[str, Any]) -> "AssistantMessage":
tools_calls.append(ToolCall.from_openai(openai_tool_call))
else:
raise ValueError(f"tool_calls must be a list, got {type(openai_tool_calls)}")
- openai_content = openai_message.get("content", None)
- content: str | list[ContentChunk] | None = None
- if openai_content is None or isinstance(openai_content, str):
- content = openai_content
- elif isinstance(openai_content, list):
- content = [_convert_openai_content_chunks(chunk) for chunk in openai_content]
- else:
- raise ValueError(f"Unknown content type: {type(openai_content)}")
+ content = cls._content_from_openai(openai_message.get("content"))
reasoning_content: str | None = openai_message.get("reasoning_content")
reasoning: str | None = openai_message.get("reasoning")
@@ -308,7 +335,7 @@ class ToolMessage(BaseMessage):
>>> message = ToolMessage(content="Hello, how can I help you?", tool_call_id="123")
"""
- content: str | list[TextChunk]
+ content: str | list[ToolContentChunk]
role: Literal[Roles.tool] = Roles.tool
tool_call_id: str | None = None
@@ -318,12 +345,24 @@ class ToolMessage(BaseMessage):
def to_openai(self) -> dict[str, Any]:
r"""Converts the message to the OpenAI format."""
assert self.tool_call_id is not None, "tool_call_id must be provided for tool messages."
- return self.model_dump(exclude={"name"})
+ return {
+ "role": self.role,
+ "tool_call_id": self.tool_call_id,
+ "content": self._content_to_openai(self.content),
+ }
@classmethod
- def from_openai(cls, messages: dict[str, str | list[dict[str, str | dict[str, Any]]]]) -> "ToolMessage":
+ def from_openai(cls, openai_message: dict[str, Any]) -> "ToolMessage":
r"""Converts the OpenAI message to the Mistral format."""
- tool_message = cls.model_validate_ignore_extra(messages)
+ content = cls._content_from_openai(openai_message.get("content"))
+ tool_message = cls.model_validate(
+ {
+ "role": openai_message["role"],
+ "tool_call_id": openai_message.get("tool_call_id"),
+ "content": content,
+ "name": openai_message.get("name"),
+ }
+ )
assert tool_message.tool_call_id is not None, "tool_call_id must be provided for tool messages."
return tool_message
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index c96a1940..b053b381 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -1,8 +1,8 @@
import json
import warnings
-from typing import Generic, Sequence, TypeGuard
+from typing import Generic, Sequence
-from typing_extensions import assert_never
+from typing_extensions import TypeGuard, assert_never
from mistral_common.exceptions import InvalidRequestException
from mistral_common.protocol.instruct.chunk import (
@@ -35,6 +35,13 @@
_DEFAULT_JOIN_STR = "\n\n"
+def _is_user_content(
+ chunks: list[ContentChunk],
+) -> TypeGuard[list[TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk]]:
+ r"""Narrow ContentChunk list to user-compatible types (no ThinkChunk)."""
+ return all(not isinstance(c, ThinkChunk) for c in chunks)
+
+
def _aggregate_content_chunks_impl(
contents: list[list[ContentChunk] | str | None],
msg_join_str: str,
@@ -98,18 +105,6 @@ def _flush_text() -> None:
return all_content
-def _is_assistant_content(chunks: list[ContentChunk]) -> TypeGuard[list[TextChunk | ThinkChunk]]:
- """Narrow ContentChunk list to assistant-compatible types."""
- return all(isinstance(c, (TextChunk, ThinkChunk)) for c in chunks)
-
-
-def _is_user_content(
- chunks: list[ContentChunk],
-) -> TypeGuard[list[TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk]]:
- """Narrow ContentChunk list to user-compatible types."""
- return all(not isinstance(c, ThinkChunk) for c in chunks)
-
-
class InstructRequestNormalizer(
Generic[UserMessageType, AssistantMessageType, ToolMessageType, SystemMessageType, InstructRequestType]
):
@@ -244,13 +239,17 @@ def _aggregate_system_prompts(self, messages: list[UATS]) -> str | None:
def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
"""Normalize tool messages without aggregation across messages.
- Each tool message's content chunks are aggregated individually and JSON is normalized.
+ Each tool message's content is validated and JSON-normalized.
"""
tool_messages: list[ToolMessageType] = []
for message in messages:
assert isinstance(message, self._tool_message_class), "Expected tool message"
- content = self._aggregate_content_chunks_to_str_same_message(message)
- normalized_content = self._normalize_json_content(content)
+ content = self._aggregate_content_chunks([message])
+ validated = self._narrow_tool_content(content)
+ if isinstance(validated, str):
+ normalized_content: str | list[ContentChunk] = self._normalize_json_content(validated)
+ else:
+ normalized_content = validated
tool_messages.append(
self._tool_message_class(
content=normalized_content, tool_call_id=message.tool_call_id, name=message.name
@@ -266,6 +265,51 @@ def _normalize_tool_call(self, tool_call: ToolCall) -> ToolCall:
id=tool_call.id,
)
+ def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
+ r"""Validate and narrow content chunks for assistant messages.
+
+ Pre-V15 normalizers only allow TextChunk and ThinkChunk.
+
+ Args:
+ content: The aggregated content chunks.
+
+ Returns:
+ The validated and narrowed content.
+
+ Raises:
+ InvalidRequestException: If unsupported chunk types are found.
+ """
+ if isinstance(content, str) or all(isinstance(c, (TextChunk, ThinkChunk)) for c in content):
+ return content
+ raise InvalidRequestException(
+ f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
+ )
+
+ def _narrow_tool_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
+ r"""Validate and narrow content for tool messages.
+
+ Pre-V15 normalizers only allow text content.
+
+ Args:
+ content: The raw or aggregated content.
+
+ Returns:
+ The content as a string.
+
+ Raises:
+ InvalidRequestException: If non-text content chunks are found.
+ """
+ if isinstance(content, str):
+ return content
+ text_parts: list[str] = []
+ for c in content:
+ if not isinstance(c, TextChunk):
+ raise InvalidRequestException(
+ f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}"
+ )
+ text_parts.append(c.text)
+ return "".join(text_parts)
+
def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]:
return []
@@ -296,15 +340,10 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
)
weight = message.weight
- if isinstance(content, str) or _is_assistant_content(content):
- narrowed_content: str | list[TextChunk | ThinkChunk] = content
- else:
- raise InvalidRequestException(
- f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
- )
+ validated_content = self._narrow_assistant_content(content)
aggregated_message = self._assistant_message_class(
- content=narrowed_content,
+ content=validated_content,
tool_calls=tool_calls or None,
prefix=prefix,
)
@@ -318,10 +357,9 @@ def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType:
content = self._aggregate_content_chunks(messages)
if isinstance(content, str) or _is_user_content(content):
return self._user_message_class(content=content)
- else:
- raise InvalidRequestException(
- f"Unexpected content chunk types in user message: {[type(c).__name__ for c in content]}"
- )
+ raise InvalidRequestException(
+ f"Unexpected content chunk types in user message: {[type(c).__name__ for c in content]}"
+ )
def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]:
if role == Roles.tool:
@@ -444,12 +482,28 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I
UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest[UATS, Tool], None
)
+ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
+ r"""Validate content chunks for system messages.
+
+ V7+ accepts all SystemContentChunk types (Pydantic validates at construction).
+ V15 overrides to reject ThinkChunk.
+
+ Args:
+ content: The aggregated content chunks.
+
+ Returns:
+ The validated content.
+ """
+ return content
+
def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]:
- return [
- self._system_message_class(content=self._aggregate_content_chunks([message]))
- for message in messages
- if isinstance(message, self._system_message_class)
- ]
+ aggregated: list[SystemMessageType] = []
+ for message in messages:
+ if isinstance(message, self._system_message_class):
+ content = self._aggregate_content_chunks([message])
+ validated = self._narrow_system_content(content)
+ aggregated.append(self._system_message_class(content=validated))
+ return aggregated
def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]:
if role == Roles.tool:
@@ -555,6 +609,60 @@ class InstructRequestNormalizerV15(InstructRequestNormalizerV13):
_chunk_join_str: str = ""
+ def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
+ r"""V15 accepts all ContentChunk types in assistant messages."""
+ return content
+
+ def _narrow_tool_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
+ r"""V15 accepts all ContentChunk types in tool messages."""
+ if isinstance(content, str):
+ return content
+ text_parts: list[str] = []
+ for c in content:
+ if not isinstance(c, TextChunk):
+ return content
+ text_parts.append(c.text)
+ return "".join(text_parts)
+
+ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
+ r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk."""
+ if isinstance(content, str):
+ return content
+ if any(isinstance(c, ThinkChunk) for c in content):
+ raise InvalidRequestException("ThinkChunk in system message is not supported for V15")
+ return content
+
+ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
+ r"""V15 tool messages preserve non-text content chunks.
+
+ Text-only content is aggregated and JSON-normalized. Mixed content is preserved as-is.
+ """
+ tool_messages: list[ToolMessageType] = []
+ for message in messages:
+ assert isinstance(message, self._tool_message_class), "Expected tool message"
+ content = self._aggregate_content_chunks([message])
+ validated = self._narrow_tool_content(content)
+ if isinstance(validated, str):
+ normalized_content: str | list[ContentChunk] = self._normalize_json_content(validated)
+ else:
+ normalized_content = validated
+ tool_messages.append(
+ self._tool_message_class(
+ content=normalized_content, tool_call_id=message.tool_call_id, name=message.name
+ )
+ )
+
+ # Reorder by tool call order
+ id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)}
+ id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)}
+ tool_messages.sort(
+ key=lambda msg: (
+ id_to_tool_call_idx.get(msg.tool_call_id or "null", float("inf")),
+ id_to_tool_result_idx[msg.tool_call_id],
+ ),
+ )
+ return tool_messages
+
@staticmethod
def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15":
r"""Returns a normalizer for the V15 instruct request.
diff --git a/src/mistral_common/tokens/tokenizers/instruct.py b/src/mistral_common/tokens/tokenizers/instruct.py
index 7992df1c..08de0c95 100644
--- a/src/mistral_common/tokens/tokenizers/instruct.py
+++ b/src/mistral_common/tokens/tokenizers/instruct.py
@@ -104,7 +104,9 @@ def find_first_last_user(request: InstructRequest) -> tuple[int, int]:
return first_user_idx, last_user_idx
@abstractmethod
- def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]:
+ def encode_tool_message(
+ self, message: ToolMessage, is_before_last_user_message: bool
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a tool message.
Raises:
@@ -115,7 +117,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
@abstractmethod
def encode_assistant_message(
self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool
- ) -> list[int]:
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode an assistant message.
Raises:
@@ -192,17 +194,23 @@ def encode_instruct(
images.extend(new_images)
audios.extend(new_audios)
elif isinstance(msg, ToolMessage):
- new_tokens = self.encode_tool_message(msg, msg_idx < last_user_idx)
+ new_tokens, new_images, new_audios = self.encode_tool_message(msg, msg_idx < last_user_idx)
+ images.extend(new_images)
+ audios.extend(new_audios)
elif isinstance(msg, AssistantMessage):
continue_message = request.continue_final_message and (msg_idx == len(request.messages) - 1)
- new_tokens = self.encode_assistant_message(
+ new_tokens, new_images, new_audios = self.encode_assistant_message(
msg, msg_idx < last_user_idx, continue_message=continue_message
)
+ images.extend(new_images)
+ audios.extend(new_audios)
if msg_idx == len(request.messages) - 1:
prefix_ids = new_tokens
elif isinstance(msg, SystemMessage):
- new_tokens = self.encode_system_message(msg)
+ new_tokens, new_images, new_audios = self.encode_system_message(msg)
+ images.extend(new_images)
+ audios.extend(new_audios)
else:
raise TokenizerException(f"Unknown message type {type(msg)}")
@@ -289,7 +297,7 @@ def encode_user_message(
curr_tokens, image, audio = self.encode_user_content(content=message_txt, is_last=False, system_prompt=None)
return curr_tokens, image, audio
- def encode_system_message(self, message: SystemMessage) -> list[int]:
+ def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[np.ndarray], list[Audio]]:
raise NotImplementedError(f"System message encoding not implemented for {self.__class__.__name__}")
def encode_user_content(
@@ -318,7 +326,9 @@ def encode_user_content(
tokens = self.tokenizer.encode(content, bos=False, eos=False)
return tokens, [], []
- def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]:
+ def encode_tool_message(
+ self, message: ToolMessage, is_before_last_user_message: bool
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a tool message.
Raises:
@@ -328,7 +338,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
def encode_assistant_message(
self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool
- ) -> list[int]:
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode an assistant message.
Args:
@@ -338,7 +348,7 @@ def encode_assistant_message(
Only use this if the assistant message is the last message.
Returns:
- The encoded tokens.
+ The encoded tokens, images, and audios.
"""
assert isinstance(message, AssistantMessage), message
if message.tool_calls is not None and len(message.tool_calls) > 0:
@@ -354,7 +364,7 @@ def encode_assistant_message(
raise TokenizerException(f"{message.content} // {message.tool_calls}")
if not message.prefix and not continue_message:
curr_tokens.append(self.tokenizer.eos_id)
- return curr_tokens
+ return curr_tokens, [], []
def encode_think(self, chunk: ThinkChunk) -> list[int]:
r"""Encode a think chunk.
@@ -480,9 +490,9 @@ def _parse_json_content(self, content: str) -> Any:
except json.JSONDecodeError:
return content
- def _parse_tool_content(self, content: str | list[TextChunk]) -> Any:
+ def _parse_tool_content(self, content: str | list[ContentChunk]) -> Any:
if isinstance(content, list):
- content = "".join(chunk.text for chunk in content)
+ content = "".join(chunk.text for chunk in content if isinstance(chunk, TextChunk))
return self._parse_json_content(content)
def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]:
@@ -492,7 +502,9 @@ def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]:
"content": self._parse_tool_content(tool_message.content),
}
- def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]:
+ def encode_tool_message(
+ self, message: ToolMessage, is_before_last_user_message: bool
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a tool message.
Args:
@@ -501,11 +513,11 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
not encoded.
Returns:
- The encoded tokens.
+ The encoded tokens, images, and audios.
"""
if is_before_last_user_message:
# don't tokenize last tool response before last user msg
- return []
+ return [], [], []
# Currently only supports single tool results
tool_result_str = json.dumps([self._prepare_tool_result(message)], ensure_ascii=False)
@@ -514,7 +526,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
*self.tokenizer.encode(tool_result_str, bos=False, eos=False),
self.END_TOOL_RESULTS,
]
- return curr_tokens
+ return curr_tokens, [], []
def _prepare_function_call(self, tool_call: ToolCall) -> dict[str, Any]:
r"""Bit of a hack due to the way function calls are tokenized."""
@@ -550,7 +562,7 @@ def _encode_settings(
def encode_assistant_message(
self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool
- ) -> list[int]:
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode an assistant message.
Args:
@@ -561,7 +573,7 @@ def encode_assistant_message(
Only use this if the assistant message is the last message.
Returns:
- The encoded tokens.
+ The encoded tokens, images, and audios.
"""
if message.tool_calls and message.content:
raise ValueError(f"Cannot have tool calls and content defined in the same assistant message {message}")
@@ -573,7 +585,7 @@ def encode_assistant_message(
if message.tool_calls:
if is_before_last_user_message:
# don't tokenize tool call before last user message
- return []
+ return [], [], []
curr_tokens = self._encode_tool_calls_in_assistant_message(message)
elif message.content:
assert isinstance(message.content, str), "Message content must be a string for tokenizer < V7"
@@ -582,7 +594,7 @@ def encode_assistant_message(
raise TokenizerException(f"Invalid assistant message: {message.content}")
if not message.prefix and not continue_message:
curr_tokens.append(self.tokenizer.eos_id)
- return curr_tokens
+ return curr_tokens, [], []
def _encode_infilling(self, text: str) -> list[int]:
r"""Remove prefix space in the case of SentencePieceTokenizers."""
@@ -652,7 +664,9 @@ def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]:
"call_id": tool_message.tool_call_id,
}
- def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]:
+ def encode_tool_message(
+ self, message: ToolMessage, is_before_last_user_message: bool
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a tool message.
Note:
@@ -665,7 +679,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
not encoded.
Returns:
- The encoded tokens.
+ The encoded tokens, images, and audios.
"""
tool_result_str = json.dumps(self._prepare_tool_result(message), ensure_ascii=False)
curr_tokens = [
@@ -673,11 +687,11 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
*self.tokenizer.encode(tool_result_str, bos=False, eos=False),
self.END_TOOL_RESULTS,
]
- return curr_tokens
+ return curr_tokens, [], []
def encode_assistant_message(
self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool
- ) -> list[int]:
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode an assistant message.
Note:
@@ -691,7 +705,7 @@ def encode_assistant_message(
is_before_last_user_message: Not used.
Returns:
- The encoded tokens.
+ The encoded tokens, images, and audios.
"""
return super().encode_assistant_message(message, False, continue_message)
@@ -867,22 +881,22 @@ def drop(idx: int) -> None:
if to_drop > 0:
raise TokenizerException("Input couldn't fit in truncate_at_max_token")
- def encode_system_message(self, message: SystemMessage) -> list[int]:
+ def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a system message.
Args:
message: The message to encode.
Returns:
- The encoded tokens.
+ The encoded tokens, images, and audios.
"""
-
tokens = [self.BEGIN_SYSTEM]
if isinstance(content := message.content, str):
content = [TextChunk(text=content)]
- tokens += self._encode_content_chunks(content)[0]
+ content_tokens, images, audios = self._encode_content_chunks(content)
+ tokens += content_tokens
tokens.append(self.END_SYSTEM)
- return tokens
+ return tokens, images, audios
def encode_user_content(
self,
@@ -1081,7 +1095,9 @@ def _has_audio(messages: list[UATS]) -> bool:
for message in messages
)
- def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]:
+ def encode_tool_message(
+ self, message: ToolMessage, is_before_last_user_message: bool
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a tool message.
Note:
@@ -1093,7 +1109,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
is_before_last_user_message: Not used.
Returns:
- The encoded tokens.
+ The encoded tokens, images, and audios.
"""
assert message.tool_call_id is not None
assert isinstance(message.content, str), "Message content must be normalized"
@@ -1110,11 +1126,11 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
*tokens,
self.END_TOOL_RESULTS,
]
- return curr_tokens
+ return curr_tokens, [], []
def encode_assistant_message(
self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool
- ) -> list[int]:
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode an assistant message.
Args:
@@ -1124,7 +1140,7 @@ def encode_assistant_message(
Only use this if the assistant message is the last message.
Returns:
- The encoded tokens.
+ The encoded tokens, images, and audios.
"""
if not message.content and not message.tool_calls:
raise TokenizerException(f"Invalid assistant message: {message}")
@@ -1134,17 +1150,22 @@ def encode_assistant_message(
)
curr_tokens: list = []
+ images: list[np.ndarray] = []
+ audios: list[Audio] = []
if message.content:
if isinstance(message.content, str):
curr_tokens = self._encode_normal_content_assistant_message(message)
elif isinstance(message.content, list):
- curr_tokens += self._encode_content_chunks(message.content)[0]
+ content_tokens, new_images, new_audios = self._encode_content_chunks(message.content)
+ curr_tokens += content_tokens
+ images.extend(new_images)
+ audios.extend(new_audios)
if message.tool_calls:
curr_tokens += self._encode_tool_calls_in_assistant_message(message)
if not message.prefix and not continue_message:
curr_tokens.append(self.tokenizer.eos_id)
- return curr_tokens
+ return curr_tokens, images, audios
def _encode_audio_for_speech_request(self, ref_audio: str | bytes | None, voice: str | None) -> Tokenized:
r"""Encode reference audio or voice preset into a Tokenized object.
@@ -1282,28 +1303,32 @@ def _encode_tool_calls_in_assistant_message(self, message: AssistantMessageType)
]
return curr_tokens
- def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]:
+ def encode_tool_message(
+ self, message: ToolMessage, is_before_last_user_message: bool
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a tool message.
Args:
message: The message to encode.
is_before_last_user_message: Not used.
Returns:
- The encoded tokens.
+ The encoded tokens, images, and audios.
"""
assert message.tool_call_id is not None, "Tool call id must be provided for tokenizer >= v13"
- content = message.content
- if not isinstance(content, str):
- content = "".join(chunk.text for chunk in content)
+ if isinstance(message.content, str):
+ content_tokens = self.tokenizer.encode(message.content, bos=False, eos=False)
+ images: list[np.ndarray] = []
+ audios: list[Audio] = []
+ else:
+ content_tokens, images, audios = self._encode_content_chunks(message.content)
- tokens = self.tokenizer.encode(content, bos=False, eos=False)
curr_tokens = [
self.BEGIN_TOOL_RESULTS,
- *tokens,
+ *content_tokens,
self.END_TOOL_RESULTS,
]
- return curr_tokens
+ return curr_tokens, images, audios
def encode_think(self, chunk: ThinkChunk) -> list[int]:
r"""Encode a thinking chunk.
@@ -1368,7 +1393,7 @@ def _encode_settings(
]
return settings_tokens
- def encode_system_message(self, message: SystemMessage) -> list[int]:
+ def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a system message, rejecting ThinkChunk content."""
if isinstance(message.content, list):
if any(isinstance(chunk, ThinkChunk) for chunk in message.content):
diff --git a/tests/data/chat_templates/v15.jinja b/tests/data/chat_templates/v15.jinja
index 26b4f9d2..d1f29add 100644
--- a/tests/data/chat_templates/v15.jinja
+++ b/tests/data/chat_templates/v15.jinja
@@ -182,7 +182,9 @@
{#- Tool messages only supports text content. #}
{%- elif message['role'] == 'tool' %}
- {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}
+ {{- '[TOOL_RESULTS]' -}}
+ {{- render_content(message['content'], 'tool message contents') -}}
+ {{- '[/TOOL_RESULTS]' }}
{#- System messages. #}
{%- elif message['role'] == 'system' %}
diff --git a/tests/data/chat_templates/v15_audio.jinja b/tests/data/chat_templates/v15_audio.jinja
index 68ce5346..381c72cb 100644
--- a/tests/data/chat_templates/v15_audio.jinja
+++ b/tests/data/chat_templates/v15_audio.jinja
@@ -158,7 +158,7 @@
{%- endif %}
{%- if message['content'] %}
- {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_audio=false) -}}
+ {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and audio', support_audio=true) -}}
{%- endif %}
{%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %}
@@ -179,12 +179,14 @@
{#- Tool messages only supports text content. #}
{%- elif message['role'] == 'tool' %}
- {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}
+ {{- '[TOOL_RESULTS]' -}}
+ {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and audio', support_audio=true) -}}
+ {{- '[/TOOL_RESULTS]' }}
{#- System messages. #}
{%- elif message['role'] == 'system' %}
{{- '[SYSTEM_PROMPT]' -}}
- {{- render_content(message['content'], 'system message contents', supported_types_desc='text', support_audio=false) -}}
+ {{- render_content(message['content'], 'system message contents', supported_types_desc='text and audio', support_audio=true) -}}
{{- '[/SYSTEM_PROMPT]' -}}
{#- Raise exception for unsupported roles. #}
diff --git a/tests/data/chat_templates/v15_image.jinja b/tests/data/chat_templates/v15_image.jinja
index bb338233..58146ed7 100644
--- a/tests/data/chat_templates/v15_image.jinja
+++ b/tests/data/chat_templates/v15_image.jinja
@@ -169,7 +169,7 @@
{%- endif %}
{%- if message['content'] %}
- {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_images=false) -}}
+ {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and image', support_images=true) -}}
{%- endif %}
{%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %}
@@ -190,7 +190,9 @@
{#- Tool messages only supports text content. #}
{%- elif message['role'] == 'tool' %}
- {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}
+ {{- '[TOOL_RESULTS]' -}}
+ {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and image', support_images=true) -}}
+ {{- '[/TOOL_RESULTS]' }}
{#- System messages. #}
{%- elif message['role'] == 'system' %}
diff --git a/tests/data/chat_templates/v15_image_think.jinja b/tests/data/chat_templates/v15_image_think.jinja
index 63c61f6b..1460ec9b 100644
--- a/tests/data/chat_templates/v15_image_think.jinja
+++ b/tests/data/chat_templates/v15_image_think.jinja
@@ -196,7 +196,7 @@
{%- endif %}
{%- if message['content'] %}
- {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and thinking', support_thinking=true, support_images=false) -}}
+ {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text, thinking and image', support_thinking=true, support_images=true) -}}
{%- endif %}
{%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %}
@@ -217,7 +217,9 @@
{#- Tool messages only supports text content. #}
{%- elif message['role'] == 'tool' %}
- {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}
+ {{- '[TOOL_RESULTS]' -}}
+ {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and image', support_images=true) -}}
+ {{- '[/TOOL_RESULTS]' }}
{#- System messages. #}
{%- elif message['role'] == 'system' %}
diff --git a/tests/data/chat_templates/v15_think.jinja b/tests/data/chat_templates/v15_think.jinja
index 7df9072a..d91b3909 100644
--- a/tests/data/chat_templates/v15_think.jinja
+++ b/tests/data/chat_templates/v15_think.jinja
@@ -209,7 +209,9 @@
{#- Tool messages only supports text content. #}
{%- elif message['role'] == 'tool' %}
- {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}
+ {{- '[TOOL_RESULTS]' -}}
+ {{- render_content(message['content'], 'tool message contents') -}}
+ {{- '[/TOOL_RESULTS]' }}
{#- System messages. #}
{%- elif message['role'] == 'system' %}
diff --git a/tests/experimental/test_app.py b/tests/experimental/test_app.py
index 9851cb06..27bb9dd8 100644
--- a/tests/experimental/test_app.py
+++ b/tests/experimental/test_app.py
@@ -394,7 +394,9 @@ def test_detokenize_assistant_message(
def test_detokenize_assistant_message_think_chunks(
assistant_message: AssistantMessage, mistral_tokenizer_v13: MistralTokenizer, tekken_v13_client: TestClient
) -> None:
- encoded_tokens = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message(assistant_message, False, False) # type: ignore[attr-defined]
+ encoded_tokens, _, _ = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined]
+ assistant_message, False, False
+ )
response = tekken_v13_client.post("/v1/detokenize/", json=encoded_tokens)
assert response.status_code == 200
@@ -462,7 +464,7 @@ def test_generate(
engine_request: dict | ChatCompletionRequest | OpenAIChatCompletionRequest,
output_assistant_message: AssistantMessage,
) -> None:
- output_tokens = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined]
+ output_tokens, _, _ = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined]
output_assistant_message, False, False
)
if output_assistant_message.tool_calls:
diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py
index 9afe6f33..cac90e4c 100644
--- a/tests/guidance/test_guidance.py
+++ b/tests/guidance/test_guidance.py
@@ -228,16 +228,17 @@ def _encode_content(
tokenizer = instruct_tokenizer.tokenizer
if isinstance(content, str):
- return instruct_tokenizer.encode_assistant_message(
+ result, _, _ = instruct_tokenizer.encode_assistant_message(
AssistantMessage(content=content), is_before_last_user_message=False, continue_message=False
)
+ return result
tool_calls = [x for x in content if isinstance(x, ToolCall)]
content_chunks = [x for x in content if not isinstance(x, ToolCall)]
tokens: list[int] = []
if content_chunks:
- tokens = instruct_tokenizer.encode_assistant_message(
+ tokens, _, _ = instruct_tokenizer.encode_assistant_message(
AssistantMessage(content=content_chunks),
is_before_last_user_message=False,
continue_message=False,
diff --git a/tests/integrations/chat_templates/fixtures_data.py b/tests/integrations/chat_templates/fixtures_data.py
index 6f0e6e55..c45f1b79 100644
--- a/tests/integrations/chat_templates/fixtures_data.py
+++ b/tests/integrations/chat_templates/fixtures_data.py
@@ -952,6 +952,60 @@
]
)
+# -- Multimodal content in non-user messages (v15+) --
+
+REQUEST_TOOL_IMAGE_TRAIN = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ UserMessage(content="What is in this image?"),
+ AssistantMessage(
+ content=None,
+ tool_calls=[
+ ToolCall(
+ id="tl1mg2345",
+ function=FunctionCall(
+ name="tool1",
+ arguments={"location": "San Francisco, CA"}, # type: ignore[arg-type]
+ ),
+ ),
+ ],
+ ),
+ ToolMessage(
+ content=[
+ TextChunk(text="Here is the result."),
+ ImageURLChunk(image_url=_IMAGE_URL),
+ ],
+ tool_call_id="tl1mg2345",
+ ),
+ AssistantMessage(content="The tool returned an image of a red square."),
+ ],
+ tools=_TOOLS,
+)
+
+REQUEST_ASSISTANT_IMAGE_TRAIN = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ UserMessage(content="Generate an image for me."),
+ AssistantMessage(
+ content=[
+ TextChunk(text="Here is the generated image."),
+ ImageURLChunk(image_url=_IMAGE_URL),
+ ],
+ ),
+ ],
+)
+
+REQUEST_SYSTEM_AUDIO_TRAIN = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ SystemMessage(
+ content=[
+ TextChunk(text="You are an audio assistant."),
+ AudioChunk(input_audio=_AUDIO),
+ ],
+ ),
+ UserMessage(content="What was that sound?"),
+ AssistantMessage(content="That was a sine wave tone."),
+ ],
+)
+
def _get_conversations(
tokenizer_version: TokenizerVersion,
@@ -1102,6 +1156,18 @@ def _get_conversations(
]
)
+ # v15+ only: multimodal content in non-user messages (finetuning only)
+ if tokenizer_version >= TokenizerVersion.v15 and validation_mode == ValidationMode.finetuning:
+ if image:
+ conversations.extend(
+ [
+ REQUEST_TOOL_IMAGE_TRAIN,
+ REQUEST_ASSISTANT_IMAGE_TRAIN,
+ ]
+ )
+ if audio:
+ conversations.append(REQUEST_SYSTEM_AUDIO_TRAIN)
+
conversations = [c.model_copy(deep=True) for c in conversations]
if think and tokenizer_version >= TokenizerVersion.v15:
diff --git a/tests/integrations/chat_templates/test_parity.py b/tests/integrations/chat_templates/test_parity.py
index 5394d5cb..0cce8183 100644
--- a/tests/integrations/chat_templates/test_parity.py
+++ b/tests/integrations/chat_templates/test_parity.py
@@ -426,12 +426,133 @@ def test_dynamic_template_comprehensive(config: TestConfig) -> None:
]
)
+ # V15+ multimodal: tool messages with image/audio content
+ if config.version >= TokenizerVersion.v15 and config.image:
+ test_cases.append(
+ {
+ "name": "v15_tool_with_image",
+ "messages": [
+ {"role": "user", "content": "Use tool"},
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [
+ {"id": "abc123def", "function": {"name": "fn", "arguments": "{}"}},
+ ],
+ },
+ {
+ "role": "tool",
+ "content": [
+ {"type": "text", "text": "result"},
+ {"type": "image_url", "image_url": "http://example.com/img.png"},
+ ],
+ "tool_call_id": "abc123def",
+ },
+ {"role": "assistant", "content": "Done"},
+ ],
+ "tools": [
+ {
+ "type": "function",
+ "function": {"name": "fn", "description": "test", "parameters": {}},
+ }
+ ],
+ }
+ )
+
+ if config.version >= TokenizerVersion.v15 and config.audio:
+ test_cases.append(
+ {
+ "name": "v15_tool_with_audio",
+ "messages": [
+ {"role": "user", "content": "Use tool"},
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [
+ {"id": "abc123def", "function": {"name": "fn", "arguments": "{}"}},
+ ],
+ },
+ {
+ "role": "tool",
+ "content": [
+ {"type": "text", "text": "result"},
+ {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
+ ],
+ "tool_call_id": "abc123def",
+ },
+ {"role": "assistant", "content": "Done"},
+ ],
+ "tools": [
+ {
+ "type": "function",
+ "function": {"name": "fn", "description": "test", "parameters": {}},
+ }
+ ],
+ }
+ )
+
+ # V15+ multimodal: assistant messages with image/audio content
+ if config.version >= TokenizerVersion.v15 and config.image:
+ test_cases.append(
+ {
+ "name": "v15_assistant_with_image",
+ "messages": [
+ {"role": "user", "content": "Show me"},
+ {
+ "role": "assistant",
+ "content": [
+ {"type": "text", "text": "Here is the image"},
+ {"type": "image_url", "image_url": "http://example.com/img.png"},
+ ],
+ },
+ ],
+ }
+ )
+
+ if config.version >= TokenizerVersion.v15 and config.audio:
+ test_cases.append(
+ {
+ "name": "v15_assistant_with_audio",
+ "messages": [
+ {"role": "user", "content": "Listen"},
+ {
+ "role": "assistant",
+ "content": [
+ {"type": "text", "text": "Here is audio"},
+ {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
+ ],
+ },
+ ],
+ }
+ )
+
+ # V15+ multimodal: system messages with audio content
+ if config.version >= TokenizerVersion.v15 and config.audio:
+ test_cases.append(
+ {
+ "name": "v15_system_with_audio",
+ "messages": [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "Listen to context"},
+ {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
+ ],
+ },
+ {"role": "user", "content": "Summarize"},
+ {"role": "assistant", "content": "Done"},
+ ],
+ }
+ )
+
skip_names_image = {"with_image", "consecutive_users_with_image", "consecutive_users_multi_image"}
skip_names_audio = {"with_audio", "consecutive_users_multi_audio"}
skip_names_think = {"with_thinking", "consecutive_assistants_think", "consecutive_systems_think"}
# ThinkChunks in system messages are only supported in v13 (not v15+)
skip_names_think_system = {"consecutive_systems_think"}
skip_names_tools = {"with_tools_definition", "with_tool_call_and_result"}
+ skip_names_v15_image = {"v15_tool_with_image", "v15_assistant_with_image"}
+ skip_names_v15_audio = {"v15_tool_with_audio", "v15_assistant_with_audio", "v15_system_with_audio"}
# Not using parametrize here because test_cases are built dynamically based on the
# config (version/image/audio/think) from the outer parametrize. Each sub-case is
@@ -451,6 +572,10 @@ def test_dynamic_template_comprehensive(config: TestConfig) -> None:
continue
if test_name in skip_names_tools and config.version <= TokenizerVersion.v1:
continue
+ if test_name in skip_names_v15_image and (config.version < TokenizerVersion.v15 or not config.image):
+ continue
+ if test_name in skip_names_v15_audio and (config.version < TokenizerVersion.v15 or not config.audio):
+ continue
static_output = render_template(static_template, messages, tools=tools) # type: ignore
dynamic_output = render_template(dynamic_template, messages, tools=tools) # type: ignore
diff --git a/tests/integrations/chat_templates/transformers/test_core_parity.py b/tests/integrations/chat_templates/transformers/test_core_parity.py
index bf5eb23d..ffb47608 100644
--- a/tests/integrations/chat_templates/transformers/test_core_parity.py
+++ b/tests/integrations/chat_templates/transformers/test_core_parity.py
@@ -205,10 +205,16 @@ def test_invalid_chunks(
# Not using parametrize here because invalid_convs depends on the version/image/audio/think
# parameters from the outer parametrize. Each sub-case is identifiable via the TemplateError
# match string which includes the role and allowed chunks.
+ is_v15_plus = config.version >= TokenizerVersion.v15
for conv in invalid_convs:
msg_template = "Only {chunks} chunks are supported in {role} message content."
if conv in sp_invalids:
- chunks = "text and thinking" if config.think and config.version < TokenizerVersion.v15 else "text"
+ if config.think and not is_v15_plus:
+ chunks = "text and thinking"
+ elif config.audio and is_v15_plus:
+ chunks = "text and audio"
+ else:
+ chunks = "text"
role = "system"
elif conv in user_invalids:
chunks = "text"
@@ -218,7 +224,17 @@ def test_invalid_chunks(
chunks += ", input_audio and audio_url"
role = "user"
elif conv in assistant_invalids:
- chunks = "text and thinking" if config.think else "text"
+ desc_parts = ["text"]
+ if config.think:
+ desc_parts.append("thinking")
+ if is_v15_plus and config.image:
+ desc_parts.append("image")
+ if is_v15_plus and config.audio:
+ desc_parts.append("audio")
+ if len(desc_parts) == 1:
+ chunks = "text"
+ else:
+ chunks = ", ".join(desc_parts[:-1]) + " and " + desc_parts[-1]
role = "assistant"
err_msg = msg_template.format(chunks=chunks, role=role)
diff --git a/tests/integrations/chat_templates/unit/test_v15.py b/tests/integrations/chat_templates/unit/test_v15.py
index 1364e5f1..3401a8cf 100644
--- a/tests/integrations/chat_templates/unit/test_v15.py
+++ b/tests/integrations/chat_templates/unit/test_v15.py
@@ -246,3 +246,284 @@ def test_v15_think_template_rejects_think_in_system(self, image: bool) -> None:
with pytest.raises(ValueError, match="Only text chunks are supported in system message contents"):
render_template(template, messages)
+
+
+class TestV15MultimodalContent:
+ def test_v15_tool_message_with_image_content(self) -> None:
+ r"""V15 image template renders tool message with image content using render_content."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v15,
+ image_support=True,
+ audio_support=False,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {"role": "user", "content": "Use tool"},
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [{"id": "test12345", "function": {"name": "fn", "arguments": "{}"}}],
+ },
+ {
+ "role": "tool",
+ "content": [
+ {"type": "text", "text": "result"},
+ {"type": "image_url", "image_url": "http://example.com/img.png"},
+ ],
+ "tool_call_id": "test12345",
+ },
+ {"role": "assistant", "content": "Done"},
+ ]
+
+ tools = [{"type": "function", "function": {"name": "fn", "description": "test", "parameters": {}}}]
+ output = render_template(template, messages, tools=tools, reasoning_effort="none")
+ assert "[TOOL_RESULTS]" in output
+ assert "[IMG]" in output
+ assert "result" in output
+
+ def test_v15_tool_message_with_audio_content(self) -> None:
+ r"""V15 audio template renders tool message with audio content."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v15,
+ image_support=False,
+ audio_support=True,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {"role": "user", "content": "Use tool"},
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [{"id": "test12345", "function": {"name": "fn", "arguments": "{}"}}],
+ },
+ {
+ "role": "tool",
+ "content": [
+ {"type": "text", "text": "result"},
+ {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
+ ],
+ "tool_call_id": "test12345",
+ },
+ {"role": "assistant", "content": "Done"},
+ ]
+
+ tools = [{"type": "function", "function": {"name": "fn", "description": "test", "parameters": {}}}]
+ output = render_template(template, messages, tools=tools, reasoning_effort="none")
+ assert "[TOOL_RESULTS]" in output
+ assert "[AUDIO]" in output
+ assert "result" in output
+
+ def test_v15_assistant_message_with_image_content(self) -> None:
+ r"""V15 image template renders assistant message with image content."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v15,
+ image_support=True,
+ audio_support=False,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {"role": "user", "content": "Show me"},
+ {
+ "role": "assistant",
+ "content": [
+ {"type": "text", "text": "Here is the image"},
+ {"type": "image_url", "image_url": "http://example.com/img.png"},
+ ],
+ },
+ ]
+
+ output = render_template(template, messages, reasoning_effort="none")
+ assert "Here is the image" in output
+ assert "[IMG]" in output
+
+ def test_v15_assistant_message_with_audio_content(self) -> None:
+ r"""V15 audio template renders assistant message with audio content."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v15,
+ image_support=False,
+ audio_support=True,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {"role": "user", "content": "Listen"},
+ {
+ "role": "assistant",
+ "content": [
+ {"type": "text", "text": "Here is audio"},
+ {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
+ ],
+ },
+ ]
+
+ output = render_template(template, messages, reasoning_effort="none")
+ assert "Here is audio" in output
+ assert "[AUDIO]" in output
+
+ def test_v15_system_message_with_audio_content(self) -> None:
+ r"""V15 audio template renders system message with audio content."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v15,
+ image_support=False,
+ audio_support=True,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "Listen to context"},
+ {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
+ ],
+ },
+ {"role": "user", "content": "Summarize"},
+ {"role": "assistant", "content": "Done"},
+ ]
+
+ output = render_template(template, messages, reasoning_effort="none")
+ assert "[SYSTEM_PROMPT]" in output
+ assert "Listen to context" in output
+ assert "[AUDIO]" in output
+
+ def test_pre_v15_image_template_rejects_image_in_assistant(self) -> None:
+ r"""Pre-V15 image template rejects image chunks in assistant message."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v13,
+ image_support=True,
+ audio_support=False,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {"role": "user", "content": "Show me"},
+ {
+ "role": "assistant",
+ "content": [
+ {"type": "text", "text": "Here"},
+ {"type": "image_url", "image_url": "http://example.com/img.png"},
+ ],
+ },
+ ]
+
+ with pytest.raises(ValueError, match="Only text chunks are supported in assistant message contents"):
+ render_template(template, messages)
+
+ def test_pre_v15_audio_template_rejects_audio_in_assistant(self) -> None:
+ r"""Pre-V15 audio template rejects audio chunks in assistant message."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v13,
+ image_support=False,
+ audio_support=True,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {"role": "user", "content": "Listen"},
+ {
+ "role": "assistant",
+ "content": [
+ {"type": "text", "text": "Here"},
+ {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
+ ],
+ },
+ ]
+
+ with pytest.raises(ValueError, match="Only text chunks are supported in assistant message contents"):
+ render_template(template, messages)
+
+ def test_pre_v15_audio_template_rejects_audio_in_system(self) -> None:
+ r"""Pre-V15 audio template rejects audio chunks in system message."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v13,
+ image_support=False,
+ audio_support=True,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "Context"},
+ {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
+ ],
+ },
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi"},
+ ]
+
+ with pytest.raises(ValueError, match="Only text chunks are supported in system message contents"):
+ render_template(template, messages)
+
+ def test_pre_v15_image_template_rejects_image_in_tool(self) -> None:
+ r"""Pre-V15 image template rejects image chunks in tool message content."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v13,
+ image_support=True,
+ audio_support=False,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {"role": "user", "content": "Use tool"},
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [{"id": "test12345", "function": {"name": "fn", "arguments": "{}"}}],
+ },
+ {
+ "role": "tool",
+ "content": [
+ {"type": "text", "text": "result"},
+ {"type": "image_url", "image_url": "http://example.com/img.png"},
+ ],
+ "tool_call_id": "test12345",
+ },
+ {"role": "assistant", "content": "Done"},
+ ]
+
+ tools = [{"type": "function", "function": {"name": "fn", "description": "test", "parameters": {}}}]
+ # Pre-V15 uses message['content']|string which coerces list to string representation
+ # rather than rendering through render_content — no [IMG] token produced
+ output = render_template(template, messages, tools=tools)
+ assert "[IMG]" not in output
diff --git a/tests/test_converters.py b/tests/test_converters.py
index d093d945..e09ad399 100644
--- a/tests/test_converters.py
+++ b/tests/test_converters.py
@@ -588,7 +588,7 @@ def test_non_leading_think_chunks_construction_ok() -> None:
)
def test_non_leading_think_chunks_to_openai_raises(content: list[TextChunk | ThinkChunk]) -> None:
"""to_openai raises when ThinkChunks are not leading."""
- msg = AssistantMessage(content=content)
+ msg = AssistantMessage(content=content) # type: ignore[arg-type]
with pytest.raises(InvalidAssistantMessageException, match="ThinkChunks must be leading"):
msg.to_openai()
@@ -1426,3 +1426,20 @@ def test_from_openai_drops_extra_fields(from_openai_call: Any, expected: Any) ->
def test_direct_construction_still_strict(constructor: Any) -> None:
with pytest.raises(Exception):
constructor()
+
+
+def test_assistant_message_to_openai_reasoning_with_multimodal() -> None:
+ r"""Reasoning format with multimodal content serializes non-think portion as list."""
+ msg = AssistantMessage(
+ content=[
+ ThinkChunk(thinking="Let me think", closed=True),
+ TextChunk(text="Here is the result"),
+ ImageURLChunk(image_url="data:image/png;base64,iVBORw0"),
+ ]
+ )
+ result = msg.to_openai(reasoning_field_format=ReasoningFieldFormat.reasoning)
+ assert result["reasoning"] == "Let me think"
+ assert isinstance(result["content"], list)
+ assert len(result["content"]) == 2
+ assert result["content"][0] == {"type": "text", "text": "Here is the result"}
+ assert result["content"][1]["type"] == "image_url"
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index 1ff09546..b8a55921 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -1,8 +1,11 @@
import json
import pytest
+from pydantic import ValidationError
+from mistral_common.exceptions import InvalidRequestException
from mistral_common.protocol.instruct.chunk import (
+ AudioChunk,
ChunkTypes,
ContentChunk,
ImageURLChunk,
@@ -1121,3 +1124,190 @@ def test_get_normalizer_version_mapping(
normalizer = get_normalizer(version, model_settings_builder)
assert isinstance(normalizer, expected_class)
assert normalizer._model_settings_builder == model_settings_builder
+
+
+class TestToolMessageContentChunk:
+ @pytest.fixture()
+ def normalizer_v15(self) -> InstructRequestNormalizerV15:
+ return InstructRequestNormalizerV15(
+ UserMessage,
+ AssistantMessage,
+ ToolMessage,
+ SystemMessage,
+ InstructRequest,
+ ModelSettingsBuilder(
+ reasoning_effort=EnumBuilder[ReasoningEffort](
+ values=list(ReasoningEffort), accepts_none=False, default=None
+ )
+ ),
+ )
+
+ @pytest.fixture()
+ def normalizer_v13(self) -> InstructRequestNormalizerV13:
+ return InstructRequestNormalizerV13(
+ UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None
+ )
+
+ def test_v15_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ image_chunk = ImageURLChunk(image_url="https://example.com/image.png")
+ request = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
+ ToolMessage(content=[image_chunk], tool_call_id="c1"),
+ ],
+ reasoning_effort=ReasoningEffort.high,
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
+ tool_msg = parsed.messages[2]
+ assert isinstance(tool_msg, ToolMessage)
+ assert isinstance(tool_msg.content, list)
+ assert tool_msg.content == [image_chunk]
+
+ def test_pre_v15_rejects_non_text_tool_content(self) -> None:
+ r"""Pre-V15 normalizer raises InvalidRequestException for non-text tool content."""
+ normalizer = get_normalizer(TokenizerVersion.v13)
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="hi"),
+ AssistantMessage(
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")],
+ ),
+ ToolMessage(
+ content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")],
+ tool_call_id="test12345",
+ ),
+ ]
+ )
+ with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"):
+ normalizer.from_chat_completion_request(request)
+
+ def test_pre_v15_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
+ ToolMessage(content=[TextChunk(text="hello"), TextChunk(text="world")], tool_call_id="c1"),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
+ tool_msg = parsed.messages[2]
+ assert isinstance(tool_msg, ToolMessage)
+ assert isinstance(tool_msg.content, str)
+ assert tool_msg.content == "hello\n\nworld"
+
+ def test_pre_v15_rejects_audio_in_tool_content(self) -> None:
+ r"""Pre-V15 normalizer raises InvalidRequestException for audio tool content."""
+ normalizer = get_normalizer(TokenizerVersion.v13)
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="hi"),
+ AssistantMessage(
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")],
+ ),
+ ToolMessage(
+ content=[AudioChunk(input_audio=b"fake_audio_data")],
+ tool_call_id="test12345",
+ ),
+ ]
+ )
+ with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"):
+ normalizer.from_chat_completion_request(request)
+
+
+class TestAssistantMessageContentChunk:
+ @pytest.fixture()
+ def normalizer_v15(self) -> InstructRequestNormalizerV15:
+ return InstructRequestNormalizerV15(
+ UserMessage,
+ AssistantMessage,
+ ToolMessage,
+ SystemMessage,
+ InstructRequest,
+ ModelSettingsBuilder(
+ reasoning_effort=EnumBuilder[ReasoningEffort](
+ values=list(ReasoningEffort), accepts_none=False, default=None
+ )
+ ),
+ )
+
+ @pytest.fixture()
+ def normalizer_v13(self) -> InstructRequestNormalizerV13:
+ return InstructRequestNormalizerV13(
+ UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None
+ )
+
+ def test_v15_preserves_non_text_assistant_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ image_chunk = ImageURLChunk(image_url="https://example.com/image.png")
+ text_chunk = TextChunk(text="description")
+ request = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[image_chunk, text_chunk]),
+ ],
+ reasoning_effort=ReasoningEffort.high,
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
+ assistant_msg = parsed.messages[1]
+ assert isinstance(assistant_msg, AssistantMessage)
+ assert isinstance(assistant_msg.content, list)
+ assert assistant_msg.content == [image_chunk, TextChunk(text="description")]
+
+ def test_pre_v15_rejects_non_text_assistant_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
+ image_chunk = ImageURLChunk(image_url="https://example.com/image.png")
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[image_chunk]),
+ ],
+ )
+ with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in assistant message"):
+ normalizer_v13.from_chat_completion_request(request)
+
+
+class TestSystemMessageContentChunk:
+ def test_system_message_accepts_audio_chunk(self) -> None:
+ msg = SystemMessage(content=[AudioChunk(input_audio="dGVzdA==")])
+ assert isinstance(msg.content, list)
+ assert len(msg.content) == 1
+ assert isinstance(msg.content[0], AudioChunk)
+
+ def test_system_message_rejects_image_chunk(self) -> None:
+ with pytest.raises(ValidationError):
+ SystemMessage(content=[ImageURLChunk(image_url="https://example.com/image.png")]) # type: ignore[list-item]
+
+ def test_v15_rejects_think_in_system_message(self) -> None:
+ r"""V15 normalizer rejects ThinkChunk in system messages."""
+ normalizer = get_normalizer(
+ TokenizerVersion.v15,
+ model_settings_builder=ModelSettingsBuilder(
+ reasoning_effort=EnumBuilder[ReasoningEffort](
+ values=list(ReasoningEffort), accepts_none=False, default=None
+ )
+ ),
+ )
+ request = mock_chat_completion(
+ messages=[
+ SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]),
+ UserMessage(content="test"),
+ ]
+ )
+ with pytest.raises(InvalidRequestException, match="ThinkChunk"):
+ normalizer.from_chat_completion_request(request)
+
+ def test_v7_normalization_preserves_audio_in_system_message(self) -> None:
+ r"""V7 normalizer preserves AudioChunk in system messages."""
+ normalizer = InstructRequestNormalizerV7.normalizer()
+ request = mock_chat_completion(
+ messages=[
+ SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]),
+ UserMessage(content="test"),
+ ]
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
+ system_msg = parsed.messages[0]
+ assert isinstance(system_msg, SystemMessage)
+ assert isinstance(system_msg.content, list)
+ assert len(system_msg.content) == 2
+ assert isinstance(system_msg.content[0], TextChunk)
+ assert isinstance(system_msg.content[1], AudioChunk)
diff --git a/tests/test_tokenizer_v11.py b/tests/test_tokenizer_v11.py
index 8c3c7330..59bcb466 100644
--- a/tests/test_tokenizer_v11.py
+++ b/tests/test_tokenizer_v11.py
@@ -32,13 +32,15 @@ def test_special_tokens(tekkenizer: InstructTokenizerV11) -> None:
def test_tokenize_assistant_message(tekkenizer: InstructTokenizerV11) -> None:
- tokens = tekkenizer.encode_assistant_message(
+ tokens, images, audios = tekkenizer.encode_assistant_message(
AssistantMessage(
tool_calls=[ToolCall(function=FunctionCall(name="a_a_a", arguments="blabla"))],
),
is_before_last_user_message=False,
continue_message=False,
)
+ assert images == []
+ assert audios == []
assert tokens == [
tekkenizer.TOOL_CALLS,
197,
@@ -61,13 +63,15 @@ def test_tokenize_assistant_message(tekkenizer: InstructTokenizerV11) -> None:
def test_tokenize_assistant_message_continue_message(tekkenizer: InstructTokenizerV11) -> None:
- tokens = tekkenizer.encode_assistant_message(
+ tokens, images, audios = tekkenizer.encode_assistant_message(
AssistantMessage(
content='"blabla"',
),
is_before_last_user_message=False,
continue_message=True,
)
+ assert images == []
+ assert audios == []
assert tokens == [
134,
198,
@@ -95,7 +99,7 @@ def test_tokenize_assistant_message_continue_message(tekkenizer: InstructTokeniz
def test_tokenize_assistant_messages(tekkenizer: InstructTokenizerV11) -> None:
- tokens = tekkenizer.encode_assistant_message(
+ tokens, images, audios = tekkenizer.encode_assistant_message(
AssistantMessage(
tool_calls=[
ToolCall(function=FunctionCall(name="a_a_a", arguments="blabla")),
@@ -105,6 +109,8 @@ def test_tokenize_assistant_messages(tekkenizer: InstructTokenizerV11) -> None:
is_before_last_user_message=False,
continue_message=False,
)
+ assert images == []
+ assert audios == []
assert tokens == [
tekkenizer.TOOL_CALLS,
197,
@@ -135,13 +141,15 @@ def test_tokenize_assistant_messages(tekkenizer: InstructTokenizerV11) -> None:
def test_tokenize_assistant_message_train(tekkenizer: InstructTokenizerV11) -> None:
- tokens = tekkenizer.encode_assistant_message(
+ tokens, images, audios = tekkenizer.encode_assistant_message(
AssistantMessage(
tool_calls=[ToolCall(function=FunctionCall(name="a_a_a", arguments="blabla"), id="ABC")],
),
is_before_last_user_message=True,
continue_message=False,
)
+ assert images == []
+ assert audios == []
assert tokens == [
tekkenizer.TOOL_CALLS,
197,
diff --git a/tests/test_tokenizer_v13.py b/tests/test_tokenizer_v13.py
index 15bb6f3a..cec43105 100644
--- a/tests/test_tokenizer_v13.py
+++ b/tests/test_tokenizer_v13.py
@@ -237,13 +237,17 @@ def test_end_to_end_v13_wrong_order(
def test_encode_tool_message(v13_tekkenizer: InstructTokenizerV13) -> None:
tool_message = ToolMessage(content="R1", tool_call_id="123456789")
assert isinstance(v13_tekkenizer, InstructTokenizerV13)
- encoded = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False)
+ encoded, images, audios = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False)
assert encoded == [7, 182, 149, 8]
+ assert images == []
+ assert audios == []
tool_message = ToolMessage(content=[TextChunk(text="R1"), TextChunk(text="R2")], tool_call_id="123456789")
assert isinstance(v13_tekkenizer, InstructTokenizerV13)
- encoded = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False)
+ encoded, images, audios = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False)
assert encoded == [7, 182, 149, 182, 150, 8]
+ assert images == []
+ assert audios == []
def test_encode_think_chunk(v13_tekkenizer_think: InstructTokenizerV13) -> None:
@@ -295,7 +299,7 @@ def test_tokenize_assistant_message(
v13_tekkenizer_think: InstructTokenizerV13, message: AssistantMessage, expected: str, continue_final_message: bool
) -> None:
if not continue_final_message:
- tokens = v13_tekkenizer_think.encode_assistant_message(
+ tokens, images, audios = v13_tekkenizer_think.encode_assistant_message(
message, is_before_last_user_message=False, continue_message=continue_final_message
)
if not message.prefix:
@@ -310,10 +314,12 @@ def test_tokenize_assistant_message(
message, is_before_last_user_message=False, continue_message=continue_final_message
)
return
- tokens = v13_tekkenizer_think.encode_assistant_message(
+ tokens, images, audios = v13_tekkenizer_think.encode_assistant_message(
message, is_before_last_user_message=False, continue_message=continue_final_message
)
assert v13_tekkenizer_think.decode(tokens, special_token_policy=SpecialTokenPolicy.KEEP) == expected
+ assert images == []
+ assert audios == []
def test_tokenize_assistant_message_error(v13_tekkenizer: InstructTokenizerV13) -> None:
@@ -371,8 +377,10 @@ def test_tokenize_assistant_message_error(v13_tekkenizer: InstructTokenizerV13)
def test_encode_system_message(
v13_tekkenizer_think: InstructTokenizerV13, message: SystemMessage, expected: str
) -> None:
- encoded = v13_tekkenizer_think.encode_system_message(message)
+ encoded, images, audios = v13_tekkenizer_think.encode_system_message(message)
assert v13_tekkenizer_think.decode(encoded, special_token_policy=SpecialTokenPolicy.KEEP) == expected
+ assert images == []
+ assert audios == []
@pytest.mark.parametrize("audio_fixture", ["audio_chunk", "audio_url_chunk"])
diff --git a/tests/test_tokenizer_v15.py b/tests/test_tokenizer_v15.py
index fdd84c33..8b45e69a 100644
--- a/tests/test_tokenizer_v15.py
+++ b/tests/test_tokenizer_v15.py
@@ -1,7 +1,18 @@
+import base64
+from collections.abc import Callable
+from io import BytesIO
+
import pytest
+from PIL import Image
from mistral_common.exceptions import InvalidRequestException, TokenizerException
-from mistral_common.protocol.instruct.chunk import ThinkChunk
+from mistral_common.protocol.instruct.chunk import (
+ AudioChunk,
+ AudioURLChunk,
+ ImageURLChunk,
+ TextChunk,
+ ThinkChunk,
+)
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
ChatMessage,
@@ -18,11 +29,14 @@
)
from mistral_common.protocol.instruct.tool_calls import Function, FunctionCall, Tool, ToolCall
from mistral_common.protocol.instruct.validator import ValidationMode, get_validator
-from mistral_common.tokens.tokenizers.base import TokenizerVersion
+from mistral_common.tokens.tokenizers.audio import AudioConfig, AudioEncoder, AudioSpectrogramConfig, SpecialAudioIDs
+from mistral_common.tokens.tokenizers.base import SpecialTokens, TokenizerVersion
+from mistral_common.tokens.tokenizers.image import ImageConfig, ImageEncoder, SpecialImageIDs
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV15
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.model_settings_builder import EnumBuilder, ModelSettingsBuilder
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
+from tests.fixtures.audio import get_dummy_audio_chunk, get_dummy_audio_url_chunk
from tests.test_tekken import get_special_tokens, quick_vocab
EXPECTED_TEXT_V15: str = (
@@ -294,3 +308,191 @@ def test_encode_chat_completion_continue_final_message() -> None:
eos_id = tokenizer_v15.instruct_tokenizer.tokenizer.eos_id
assert encoded.tokens[-1] != eos_id
+
+
+@pytest.fixture(scope="session")
+def audio_chunk() -> AudioChunk:
+ return get_dummy_audio_chunk()
+
+
+def get_v15_mistral_tokenizer_with_audio() -> MistralTokenizer:
+ r"""Build a V15 MistralTokenizer with audio encoder."""
+ builder = _build_model_settings_builder(tuple(ReasoningEffort))
+ tekkenizer = Tekkenizer(
+ quick_vocab([b"a", b"b", b"c", b"f", b"de"]),
+ special_tokens=get_special_tokens(TokenizerVersion.v15, add_audio=True),
+ pattern=r".+",
+ vocab_size=256 + 100,
+ num_special_tokens=100,
+ version=TokenizerVersion.v15,
+ model_settings_builder=builder,
+ )
+ audio_config = AudioConfig(
+ sampling_rate=24_000,
+ frame_rate=12.5,
+ encoding_config=AudioSpectrogramConfig(
+ num_mel_bins=128,
+ hop_length=160,
+ window_size=400,
+ ),
+ )
+ special_audio_ids = SpecialAudioIDs(
+ audio=tekkenizer.get_special_token(SpecialTokens.audio.value),
+ begin_audio=tekkenizer.get_special_token(SpecialTokens.begin_audio.value),
+ streaming_pad=None,
+ text_to_audio=None,
+ audio_to_text=None,
+ )
+ audio_encoder = AudioEncoder(audio_config, special_audio_ids)
+ instruct_tokenizer = InstructTokenizerV15(tekkenizer, audio_encoder=audio_encoder)
+ request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder)
+ validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test)
+ return MistralTokenizer(
+ instruct_tokenizer=instruct_tokenizer,
+ validator=validator,
+ request_normalizer=request_normalizer,
+ )
+
+
+def _get_dummy_image_url_chunk() -> ImageURLChunk:
+ r"""Build a small base64-encoded image URL chunk for testing."""
+ img = Image.new("RGB", (4, 4), "red")
+ buf = BytesIO()
+ img.save(buf, "PNG")
+ data_url = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
+ return ImageURLChunk(image_url=data_url)
+
+
+def get_v15_mistral_tokenizer_with_image() -> MistralTokenizer:
+ r"""Build a V15 MistralTokenizer with image encoder."""
+ builder = _build_model_settings_builder(tuple(ReasoningEffort))
+ tekkenizer = Tekkenizer(
+ quick_vocab([b"a", b"b", b"c", b"f", b"de"]),
+ special_tokens=get_special_tokens(TokenizerVersion.v15, add_think=True),
+ pattern=r".+",
+ vocab_size=256 + 100,
+ num_special_tokens=100,
+ version=TokenizerVersion.v15,
+ model_settings_builder=builder,
+ )
+ image_config = ImageConfig(image_patch_size=16, max_image_size=1024)
+ special_image_ids = SpecialImageIDs(
+ img=tekkenizer.get_special_token(SpecialTokens.img.value),
+ img_break=tekkenizer.get_special_token(SpecialTokens.img_break.value),
+ img_end=tekkenizer.get_special_token(SpecialTokens.img_end.value),
+ )
+ image_encoder = ImageEncoder(image_config, special_image_ids)
+ instruct_tokenizer = InstructTokenizerV15(tekkenizer, image_encoder=image_encoder)
+ request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder)
+ validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test)
+ return MistralTokenizer(
+ instruct_tokenizer=instruct_tokenizer,
+ validator=validator,
+ request_normalizer=request_normalizer,
+ )
+
+
+# Multimodal content chunks and their corresponding tokenizer factories for parametrized tests.
+_AUDIO_CHUNK_PARAMS = pytest.param(get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, id="audio")
+_AUDIO_URL_CHUNK_PARAMS = pytest.param(
+ get_dummy_audio_url_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, id="audio_url"
+)
+_IMAGE_URL_CHUNK_PARAMS = pytest.param(
+ _get_dummy_image_url_chunk(), 0, 1, get_v15_mistral_tokenizer_with_image, id="image_url"
+)
+
+_ALL_MULTIMODAL_PARAMS = [_AUDIO_CHUNK_PARAMS, _AUDIO_URL_CHUNK_PARAMS, _IMAGE_URL_CHUNK_PARAMS]
+_AUDIO_ONLY_PARAMS = [_AUDIO_CHUNK_PARAMS]
+
+
+@pytest.mark.parametrize(
+ ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"),
+ _ALL_MULTIMODAL_PARAMS,
+)
+def test_encode_chat_completion_with_multimodal_tool(
+ content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk,
+ expected_audios: int,
+ expected_images: int,
+ tokenizer_factory: Callable[[], MistralTokenizer],
+) -> None:
+ mistral_tokenizer = tokenizer_factory()
+ chat_request = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ UserMessage(content="Use the tool"),
+ AssistantMessage(tool_calls=[ToolCall(id="test12345", function=FunctionCall(name="fn", arguments="{}"))]),
+ ToolMessage(
+ content=[TextChunk(text="result"), content_chunk],
+ tool_call_id="test12345",
+ ),
+ ],
+ tools=[Tool(function=Function(name="fn", description="test", parameters={}))],
+ )
+ encoded = mistral_tokenizer.encode_chat_completion(chat_request)
+ assert len(encoded.audios) == expected_audios
+ assert len(encoded.images) == expected_images
+
+
+@pytest.mark.parametrize(
+ ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"),
+ _ALL_MULTIMODAL_PARAMS,
+)
+def test_encode_chat_completion_with_multimodal_assistant(
+ content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk,
+ expected_audios: int,
+ expected_images: int,
+ tokenizer_factory: Callable[[], MistralTokenizer],
+) -> None:
+ mistral_tokenizer = tokenizer_factory()
+ chat_request = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ UserMessage(content="Hello"),
+ AssistantMessage(content=[TextChunk(text="Here is content"), content_chunk]),
+ UserMessage(content="Thanks"),
+ ],
+ )
+ encoded = mistral_tokenizer.encode_chat_completion(chat_request)
+ assert len(encoded.audios) == expected_audios
+ assert len(encoded.images) == expected_images
+
+
+@pytest.mark.parametrize(
+ ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"),
+ _AUDIO_ONLY_PARAMS,
+)
+def test_encode_chat_completion_with_multimodal_system(
+ content_chunk: AudioChunk,
+ expected_audios: int,
+ expected_images: int,
+ tokenizer_factory: Callable[[], MistralTokenizer],
+) -> None:
+ mistral_tokenizer = tokenizer_factory()
+ chat_request = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ SystemMessage(content=[TextChunk(text="System with content"), content_chunk]),
+ UserMessage(content="Hello"),
+ ],
+ )
+ encoded = mistral_tokenizer.encode_chat_completion(chat_request)
+ assert len(encoded.audios) == expected_audios
+ assert len(encoded.images) == expected_images
+
+
+@pytest.mark.parametrize(
+ ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"),
+ _ALL_MULTIMODAL_PARAMS,
+)
+def test_encode_chat_completion_with_multimodal_user(
+ content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk,
+ expected_audios: int,
+ expected_images: int,
+ tokenizer_factory: Callable[[], MistralTokenizer],
+) -> None:
+ mistral_tokenizer = tokenizer_factory()
+ chat_request = ChatCompletionRequest(
+ messages=[
+ UserMessage(content=[TextChunk(text="Here is content"), content_chunk]),
+ ],
+ )
+ encoded = mistral_tokenizer.encode_chat_completion(chat_request)
+ assert len(encoded.audios) == expected_audios
+ assert len(encoded.images) == expected_images
From 4f0735861d48a4871f89cfb7532497241952e509 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 10:24:17 +0200
Subject: [PATCH 07/47] Address PR review: remove assistant multimodal,
simplify normalizer, use positive type guards
---
.../chat_templates/template_generator.py | 14 +----
src/mistral_common/protocol/instruct/chunk.py | 2 +-
.../protocol/instruct/messages.py | 7 +--
.../protocol/instruct/normalize.py | 56 +++----------------
tests/data/chat_templates/v15_audio.jinja | 2 +-
tests/data/chat_templates/v15_image.jinja | 2 +-
.../data/chat_templates/v15_image_think.jinja | 2 +-
.../chat_templates/fixtures_data.py | 13 -----
.../chat_templates/test_parity.py | 39 +------------
.../chat_templates/unit/test_v15.py | 56 -------------------
tests/test_converters.py | 17 ------
tests/test_normalization.py | 50 -----------------
tests/test_tokenizer_v15.py | 23 --------
13 files changed, 19 insertions(+), 264 deletions(-)
diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py
index 9d8a4373..868cf1c2 100644
--- a/src/mistral_common/integrations/chat_templates/template_generator.py
+++ b/src/mistral_common/integrations/chat_templates/template_generator.py
@@ -221,11 +221,6 @@ def tool_supports_multimodal(self) -> bool:
r"""Whether tool messages can contain non-text content chunks. V15+."""
return self.version >= TokenizerVersion.v15
- @property
- def assistant_supports_multimodal(self) -> bool:
- r"""Whether assistant messages can contain non-text content chunks. V15+."""
- return self.version >= TokenizerVersion.v15
-
@property
def system_supports_audio(self) -> bool:
r"""Whether system messages can contain audio. V15+ with audio support."""
@@ -837,7 +832,6 @@ def _generate_flush_logic(config: TemplateConfig) -> list[str]:
" {%- set ns_c.has_non_text = true %}",
" {%- endif %}",
" {%- endfor %}",
-
" {%- endif %}",
]
@@ -1270,17 +1264,13 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str:
desc_parts = ["text"]
if config.any_thinking_support:
desc_parts.append("thinking")
- if config.assistant_supports_multimodal and config.image_support:
- desc_parts.append("image")
- if config.assistant_supports_multimodal and config.audio_support:
- desc_parts.append("audio")
rc_call_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'"
if config.any_thinking_support:
rc_call_args += ", support_thinking=true"
if config.image_support:
- rc_call_args += f", support_images={'true' if config.assistant_supports_multimodal else 'false'}"
+ rc_call_args += ", support_images=false"
if config.audio_support:
- rc_call_args += f", support_audio={'true' if config.assistant_supports_multimodal else 'false'}"
+ rc_call_args += ", support_audio=false"
lines.append(" {%- if message['content'] %}")
diff --git a/src/mistral_common/protocol/instruct/chunk.py b/src/mistral_common/protocol/instruct/chunk.py
index 430c8c3d..3ba496b9 100644
--- a/src/mistral_common/protocol/instruct/chunk.py
+++ b/src/mistral_common/protocol/instruct/chunk.py
@@ -455,7 +455,7 @@ def from_openai(cls, openai_chunk: dict[str, Any]) -> "ThinkChunk":
TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk, Field(discriminator="type")
]
-AssistantContentChunk: TypeAlias = ContentChunk
+AssistantContentChunk = Annotated[TextChunk | ThinkChunk, Field(discriminator="type")]
SystemContentChunk = Annotated[TextChunk | AudioChunk | ThinkChunk, Field(discriminator="type")]
diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py
index 140a1f3a..96aa5fd7 100644
--- a/src/mistral_common/protocol/instruct/messages.py
+++ b/src/mistral_common/protocol/instruct/messages.py
@@ -27,8 +27,8 @@
)
-def _are_think_chunks(chunks: list[ContentChunk]) -> TypeGuard[list[ThinkChunk]]:
- r"""Narrow a ContentChunk list to ThinkChunk list."""
+def _are_think_chunks(chunks: Sequence[TextChunk | ThinkChunk]) -> TypeGuard[list[ThinkChunk]]:
+ r"""Narrow a chunk list to ThinkChunk list."""
return all(isinstance(c, ThinkChunk) for c in chunks)
@@ -358,12 +358,11 @@ def from_openai(cls, openai_message: dict[str, Any]) -> "ToolMessage":
tool_message = cls.model_validate(
{
"role": openai_message["role"],
- "tool_call_id": openai_message.get("tool_call_id"),
+ "tool_call_id": openai_message["tool_call_id"],
"content": content,
"name": openai_message.get("name"),
}
)
- assert tool_message.tool_call_id is not None, "tool_call_id must be provided for tool messages."
return tool_message
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index b053b381..0dd77243 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -1,6 +1,6 @@
import json
import warnings
-from typing import Generic, Sequence
+from typing import Generic, Sequence, cast
from typing_extensions import TypeGuard, assert_never
@@ -38,8 +38,8 @@
def _is_user_content(
chunks: list[ContentChunk],
) -> TypeGuard[list[TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk]]:
- r"""Narrow ContentChunk list to user-compatible types (no ThinkChunk)."""
- return all(not isinstance(c, ThinkChunk) for c in chunks)
+ r"""Narrow ContentChunk list to user-compatible types."""
+ return all(isinstance(c, (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk)) for c in chunks)
def _aggregate_content_chunks_impl(
@@ -265,7 +265,7 @@ def _normalize_tool_call(self, tool_call: ToolCall) -> ToolCall:
id=tool_call.id,
)
- def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
+ def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[TextChunk | ThinkChunk]:
r"""Validate and narrow content chunks for assistant messages.
Pre-V15 normalizers only allow TextChunk and ThinkChunk.
@@ -279,8 +279,10 @@ def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str |
Raises:
InvalidRequestException: If unsupported chunk types are found.
"""
- if isinstance(content, str) or all(isinstance(c, (TextChunk, ThinkChunk)) for c in content):
+ if isinstance(content, str):
return content
+ if all(isinstance(c, (TextChunk, ThinkChunk)) for c in content):
+ return cast(list[TextChunk | ThinkChunk], content)
raise InvalidRequestException(
f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
)
@@ -609,20 +611,9 @@ class InstructRequestNormalizerV15(InstructRequestNormalizerV13):
_chunk_join_str: str = ""
- def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
- r"""V15 accepts all ContentChunk types in assistant messages."""
- return content
-
def _narrow_tool_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
r"""V15 accepts all ContentChunk types in tool messages."""
- if isinstance(content, str):
- return content
- text_parts: list[str] = []
- for c in content:
- if not isinstance(c, TextChunk):
- return content
- text_parts.append(c.text)
- return "".join(text_parts)
+ return content
def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk."""
@@ -632,37 +623,6 @@ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | lis
raise InvalidRequestException("ThinkChunk in system message is not supported for V15")
return content
- def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
- r"""V15 tool messages preserve non-text content chunks.
-
- Text-only content is aggregated and JSON-normalized. Mixed content is preserved as-is.
- """
- tool_messages: list[ToolMessageType] = []
- for message in messages:
- assert isinstance(message, self._tool_message_class), "Expected tool message"
- content = self._aggregate_content_chunks([message])
- validated = self._narrow_tool_content(content)
- if isinstance(validated, str):
- normalized_content: str | list[ContentChunk] = self._normalize_json_content(validated)
- else:
- normalized_content = validated
- tool_messages.append(
- self._tool_message_class(
- content=normalized_content, tool_call_id=message.tool_call_id, name=message.name
- )
- )
-
- # Reorder by tool call order
- id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)}
- id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)}
- tool_messages.sort(
- key=lambda msg: (
- id_to_tool_call_idx.get(msg.tool_call_id or "null", float("inf")),
- id_to_tool_result_idx[msg.tool_call_id],
- ),
- )
- return tool_messages
-
@staticmethod
def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15":
r"""Returns a normalizer for the V15 instruct request.
diff --git a/tests/data/chat_templates/v15_audio.jinja b/tests/data/chat_templates/v15_audio.jinja
index 381c72cb..bbd04595 100644
--- a/tests/data/chat_templates/v15_audio.jinja
+++ b/tests/data/chat_templates/v15_audio.jinja
@@ -158,7 +158,7 @@
{%- endif %}
{%- if message['content'] %}
- {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and audio', support_audio=true) -}}
+ {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_audio=false) -}}
{%- endif %}
{%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %}
diff --git a/tests/data/chat_templates/v15_image.jinja b/tests/data/chat_templates/v15_image.jinja
index 58146ed7..35b872fe 100644
--- a/tests/data/chat_templates/v15_image.jinja
+++ b/tests/data/chat_templates/v15_image.jinja
@@ -169,7 +169,7 @@
{%- endif %}
{%- if message['content'] %}
- {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and image', support_images=true) -}}
+ {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_images=false) -}}
{%- endif %}
{%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %}
diff --git a/tests/data/chat_templates/v15_image_think.jinja b/tests/data/chat_templates/v15_image_think.jinja
index 1460ec9b..5fab24c0 100644
--- a/tests/data/chat_templates/v15_image_think.jinja
+++ b/tests/data/chat_templates/v15_image_think.jinja
@@ -196,7 +196,7 @@
{%- endif %}
{%- if message['content'] %}
- {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text, thinking and image', support_thinking=true, support_images=true) -}}
+ {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and thinking', support_thinking=true, support_images=false) -}}
{%- endif %}
{%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %}
diff --git a/tests/integrations/chat_templates/fixtures_data.py b/tests/integrations/chat_templates/fixtures_data.py
index c45f1b79..8fad3fcf 100644
--- a/tests/integrations/chat_templates/fixtures_data.py
+++ b/tests/integrations/chat_templates/fixtures_data.py
@@ -981,18 +981,6 @@
tools=_TOOLS,
)
-REQUEST_ASSISTANT_IMAGE_TRAIN = ChatCompletionRequest( # type: ignore[type-var]
- messages=[
- UserMessage(content="Generate an image for me."),
- AssistantMessage(
- content=[
- TextChunk(text="Here is the generated image."),
- ImageURLChunk(image_url=_IMAGE_URL),
- ],
- ),
- ],
-)
-
REQUEST_SYSTEM_AUDIO_TRAIN = ChatCompletionRequest( # type: ignore[type-var]
messages=[
SystemMessage(
@@ -1162,7 +1150,6 @@ def _get_conversations(
conversations.extend(
[
REQUEST_TOOL_IMAGE_TRAIN,
- REQUEST_ASSISTANT_IMAGE_TRAIN,
]
)
if audio:
diff --git a/tests/integrations/chat_templates/test_parity.py b/tests/integrations/chat_templates/test_parity.py
index 0cce8183..8c90c97d 100644
--- a/tests/integrations/chat_templates/test_parity.py
+++ b/tests/integrations/chat_templates/test_parity.py
@@ -491,41 +491,6 @@ def test_dynamic_template_comprehensive(config: TestConfig) -> None:
}
)
- # V15+ multimodal: assistant messages with image/audio content
- if config.version >= TokenizerVersion.v15 and config.image:
- test_cases.append(
- {
- "name": "v15_assistant_with_image",
- "messages": [
- {"role": "user", "content": "Show me"},
- {
- "role": "assistant",
- "content": [
- {"type": "text", "text": "Here is the image"},
- {"type": "image_url", "image_url": "http://example.com/img.png"},
- ],
- },
- ],
- }
- )
-
- if config.version >= TokenizerVersion.v15 and config.audio:
- test_cases.append(
- {
- "name": "v15_assistant_with_audio",
- "messages": [
- {"role": "user", "content": "Listen"},
- {
- "role": "assistant",
- "content": [
- {"type": "text", "text": "Here is audio"},
- {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
- ],
- },
- ],
- }
- )
-
# V15+ multimodal: system messages with audio content
if config.version >= TokenizerVersion.v15 and config.audio:
test_cases.append(
@@ -551,8 +516,8 @@ def test_dynamic_template_comprehensive(config: TestConfig) -> None:
# ThinkChunks in system messages are only supported in v13 (not v15+)
skip_names_think_system = {"consecutive_systems_think"}
skip_names_tools = {"with_tools_definition", "with_tool_call_and_result"}
- skip_names_v15_image = {"v15_tool_with_image", "v15_assistant_with_image"}
- skip_names_v15_audio = {"v15_tool_with_audio", "v15_assistant_with_audio", "v15_system_with_audio"}
+ skip_names_v15_image = {"v15_tool_with_image"}
+ skip_names_v15_audio = {"v15_tool_with_audio", "v15_system_with_audio"}
# Not using parametrize here because test_cases are built dynamically based on the
# config (version/image/audio/think) from the outer parametrize. Each sub-case is
diff --git a/tests/integrations/chat_templates/unit/test_v15.py b/tests/integrations/chat_templates/unit/test_v15.py
index 3401a8cf..39c4705a 100644
--- a/tests/integrations/chat_templates/unit/test_v15.py
+++ b/tests/integrations/chat_templates/unit/test_v15.py
@@ -323,62 +323,6 @@ def test_v15_tool_message_with_audio_content(self) -> None:
assert "[AUDIO]" in output
assert "result" in output
- def test_v15_assistant_message_with_image_content(self) -> None:
- r"""V15 image template renders assistant message with image content."""
- template = generate_chat_template(
- spm=False,
- tokenizer_version=TokenizerVersion.v15,
- image_support=True,
- audio_support=False,
- thinking_support=False,
- default_system_prompt=None,
- plain_thinking_support=False,
- use_special_token_variables=True,
- )
-
- messages: list[dict[str, Any]] = [
- {"role": "user", "content": "Show me"},
- {
- "role": "assistant",
- "content": [
- {"type": "text", "text": "Here is the image"},
- {"type": "image_url", "image_url": "http://example.com/img.png"},
- ],
- },
- ]
-
- output = render_template(template, messages, reasoning_effort="none")
- assert "Here is the image" in output
- assert "[IMG]" in output
-
- def test_v15_assistant_message_with_audio_content(self) -> None:
- r"""V15 audio template renders assistant message with audio content."""
- template = generate_chat_template(
- spm=False,
- tokenizer_version=TokenizerVersion.v15,
- image_support=False,
- audio_support=True,
- thinking_support=False,
- default_system_prompt=None,
- plain_thinking_support=False,
- use_special_token_variables=True,
- )
-
- messages: list[dict[str, Any]] = [
- {"role": "user", "content": "Listen"},
- {
- "role": "assistant",
- "content": [
- {"type": "text", "text": "Here is audio"},
- {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
- ],
- },
- ]
-
- output = render_template(template, messages, reasoning_effort="none")
- assert "Here is audio" in output
- assert "[AUDIO]" in output
-
def test_v15_system_message_with_audio_content(self) -> None:
r"""V15 audio template renders system message with audio content."""
template = generate_chat_template(
diff --git a/tests/test_converters.py b/tests/test_converters.py
index e09ad399..3cd546a8 100644
--- a/tests/test_converters.py
+++ b/tests/test_converters.py
@@ -1426,20 +1426,3 @@ def test_from_openai_drops_extra_fields(from_openai_call: Any, expected: Any) ->
def test_direct_construction_still_strict(constructor: Any) -> None:
with pytest.raises(Exception):
constructor()
-
-
-def test_assistant_message_to_openai_reasoning_with_multimodal() -> None:
- r"""Reasoning format with multimodal content serializes non-think portion as list."""
- msg = AssistantMessage(
- content=[
- ThinkChunk(thinking="Let me think", closed=True),
- TextChunk(text="Here is the result"),
- ImageURLChunk(image_url="data:image/png;base64,iVBORw0"),
- ]
- )
- result = msg.to_openai(reasoning_field_format=ReasoningFieldFormat.reasoning)
- assert result["reasoning"] == "Let me think"
- assert isinstance(result["content"], list)
- assert len(result["content"]) == 2
- assert result["content"][0] == {"type": "text", "text": "Here is the result"}
- assert result["content"][1]["type"] == "image_url"
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index b8a55921..b1d88d80 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -1215,56 +1215,6 @@ def test_pre_v15_rejects_audio_in_tool_content(self) -> None:
normalizer.from_chat_completion_request(request)
-class TestAssistantMessageContentChunk:
- @pytest.fixture()
- def normalizer_v15(self) -> InstructRequestNormalizerV15:
- return InstructRequestNormalizerV15(
- UserMessage,
- AssistantMessage,
- ToolMessage,
- SystemMessage,
- InstructRequest,
- ModelSettingsBuilder(
- reasoning_effort=EnumBuilder[ReasoningEffort](
- values=list(ReasoningEffort), accepts_none=False, default=None
- )
- ),
- )
-
- @pytest.fixture()
- def normalizer_v13(self) -> InstructRequestNormalizerV13:
- return InstructRequestNormalizerV13(
- UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None
- )
-
- def test_v15_preserves_non_text_assistant_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
- image_chunk = ImageURLChunk(image_url="https://example.com/image.png")
- text_chunk = TextChunk(text="description")
- request = ChatCompletionRequest( # type: ignore[type-var]
- messages=[
- UserMessage(content="query"),
- AssistantMessage(content=[image_chunk, text_chunk]),
- ],
- reasoning_effort=ReasoningEffort.high,
- )
- parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assistant_msg = parsed.messages[1]
- assert isinstance(assistant_msg, AssistantMessage)
- assert isinstance(assistant_msg.content, list)
- assert assistant_msg.content == [image_chunk, TextChunk(text="description")]
-
- def test_pre_v15_rejects_non_text_assistant_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
- image_chunk = ImageURLChunk(image_url="https://example.com/image.png")
- request = mock_chat_completion(
- messages=[
- UserMessage(content="query"),
- AssistantMessage(content=[image_chunk]),
- ],
- )
- with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in assistant message"):
- normalizer_v13.from_chat_completion_request(request)
-
-
class TestSystemMessageContentChunk:
def test_system_message_accepts_audio_chunk(self) -> None:
msg = SystemMessage(content=[AudioChunk(input_audio="dGVzdA==")])
diff --git a/tests/test_tokenizer_v15.py b/tests/test_tokenizer_v15.py
index 8b45e69a..739ec75a 100644
--- a/tests/test_tokenizer_v15.py
+++ b/tests/test_tokenizer_v15.py
@@ -432,29 +432,6 @@ def test_encode_chat_completion_with_multimodal_tool(
assert len(encoded.images) == expected_images
-@pytest.mark.parametrize(
- ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"),
- _ALL_MULTIMODAL_PARAMS,
-)
-def test_encode_chat_completion_with_multimodal_assistant(
- content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk,
- expected_audios: int,
- expected_images: int,
- tokenizer_factory: Callable[[], MistralTokenizer],
-) -> None:
- mistral_tokenizer = tokenizer_factory()
- chat_request = ChatCompletionRequest( # type: ignore[type-var]
- messages=[
- UserMessage(content="Hello"),
- AssistantMessage(content=[TextChunk(text="Here is content"), content_chunk]),
- UserMessage(content="Thanks"),
- ],
- )
- encoded = mistral_tokenizer.encode_chat_completion(chat_request)
- assert len(encoded.audios) == expected_audios
- assert len(encoded.images) == expected_images
-
-
@pytest.mark.parametrize(
("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"),
_AUDIO_ONLY_PARAMS,
From 823ee67728232eeb57045c06f413841f608ae021 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 10:27:51 +0200
Subject: [PATCH 08/47] Require content in UserMessage and SystemMessage
from_openai
---
src/mistral_common/protocol/instruct/messages.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py
index 96aa5fd7..f8f0c366 100644
--- a/src/mistral_common/protocol/instruct/messages.py
+++ b/src/mistral_common/protocol/instruct/messages.py
@@ -149,7 +149,7 @@ def to_openai(self) -> dict[str, Any]:
def from_openai(cls, openai_message: dict[str, Any]) -> "UserMessage":
r"""Converts the OpenAI message to the Mistral format."""
return cls.model_validate(
- {"role": openai_message["role"], "content": cls._content_from_openai(openai_message.get("content"))}
+ {"role": openai_message["role"], "content": cls._content_from_openai(openai_message["content"])}
)
@@ -174,7 +174,7 @@ def to_openai(self) -> dict[str, Any]:
def from_openai(cls, openai_message: dict[str, Any]) -> "SystemMessage":
r"""Converts the OpenAI message to the Mistral format."""
return cls.model_validate(
- {"role": openai_message["role"], "content": cls._content_from_openai(openai_message.get("content"))}
+ {"role": openai_message["role"], "content": cls._content_from_openai(openai_message["content"])}
)
From 655180c805f6f7b7ab45892873beb033840f9b4a Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 10:42:02 +0200
Subject: [PATCH 09/47] Skip JSON normalization of tool content for V7+
normalizers
---
src/mistral_common/protocol/instruct/normalize.py | 15 +++++++++++++++
1 file changed, 15 insertions(+)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index 0dd77243..40610744 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -498,6 +498,21 @@ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | lis
"""
return content
+ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
+ """Normalize tool messages without JSON normalization.
+
+ V7+ normalizers skip JSON content normalization for tool messages.
+ """
+ tool_messages: list[ToolMessageType] = []
+ for message in messages:
+ assert isinstance(message, self._tool_message_class), "Expected tool message"
+ content = self._aggregate_content_chunks([message])
+ validated = self._narrow_tool_content(content)
+ tool_messages.append(
+ self._tool_message_class(content=validated, tool_call_id=message.tool_call_id, name=message.name)
+ )
+ return tool_messages
+
def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]:
aggregated: list[SystemMessageType] = []
for message in messages:
From ebe26d9e142e05bba26113e1c7a56b249c37b2f3 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 11:03:24 +0200
Subject: [PATCH 10/47] Use TypeGuard for assistant content, simplify tool
content narrowing, add tests
---
.../protocol/instruct/normalize.py | 36 +++++----
tests/test_normalization.py | 75 +++++++++++++++++++
2 files changed, 95 insertions(+), 16 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index 40610744..d199f96b 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -1,6 +1,6 @@
import json
import warnings
-from typing import Generic, Sequence, cast
+from typing import Generic, Sequence
from typing_extensions import TypeGuard, assert_never
@@ -42,6 +42,13 @@ def _is_user_content(
return all(isinstance(c, (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk)) for c in chunks)
+def _is_assistant_content(
+ chunks: list[ContentChunk],
+) -> TypeGuard[list[TextChunk | ThinkChunk]]:
+ r"""Narrow ContentChunk list to assistant-compatible types."""
+ return all(isinstance(c, (TextChunk, ThinkChunk)) for c in chunks)
+
+
def _aggregate_content_chunks_impl(
contents: list[list[ContentChunk] | str | None],
msg_join_str: str,
@@ -268,7 +275,7 @@ def _normalize_tool_call(self, tool_call: ToolCall) -> ToolCall:
def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[TextChunk | ThinkChunk]:
r"""Validate and narrow content chunks for assistant messages.
- Pre-V15 normalizers only allow TextChunk and ThinkChunk.
+ Only TextChunk and ThinkChunk are allowed.
Args:
content: The aggregated content chunks.
@@ -281,36 +288,33 @@ def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str |
"""
if isinstance(content, str):
return content
- if all(isinstance(c, (TextChunk, ThinkChunk)) for c in content):
- return cast(list[TextChunk | ThinkChunk], content)
+ if _is_assistant_content(content):
+ return content
raise InvalidRequestException(
f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
)
def _narrow_tool_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
- r"""Validate and narrow content for tool messages.
+ r"""Validate that tool content is text-only.
- Pre-V15 normalizers only allow text content.
+ Pre-V15 normalizers only allow text content. Since ``_aggregate_content_chunks``
+ already collapses text-only content to ``str``, receiving a list here means
+ non-text chunks are present.
Args:
- content: The raw or aggregated content.
+ content: The aggregated content.
Returns:
- The content as a string.
+ The validated content as a string.
Raises:
InvalidRequestException: If non-text content chunks are found.
"""
if isinstance(content, str):
return content
- text_parts: list[str] = []
- for c in content:
- if not isinstance(c, TextChunk):
- raise InvalidRequestException(
- f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}"
- )
- text_parts.append(c.text)
- return "".join(text_parts)
+ raise InvalidRequestException(
+ f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}"
+ )
def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]:
return []
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index b1d88d80..eb60f180 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -1126,6 +1126,45 @@ def test_get_normalizer_version_mapping(
assert normalizer._model_settings_builder == model_settings_builder
+class TestAssistantContentNarrowing:
+ def test_accepts_text_and_think_chunks(self) -> None:
+ r"""Normalizer accepts TextChunk and ThinkChunk in assistant messages."""
+ normalizer = get_normalizer(TokenizerVersion.v13)
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
+ assistant_msg = parsed.messages[1]
+ assert isinstance(assistant_msg, AssistantMessage)
+ assert isinstance(assistant_msg.content, list)
+ assert len(assistant_msg.content) == 2
+
+ def test_accepts_string_content(self) -> None:
+ r"""Normalizer accepts string content in assistant messages."""
+ normalizer = get_normalizer(
+ TokenizerVersion.v15,
+ model_settings_builder=ModelSettingsBuilder(
+ reasoning_effort=EnumBuilder[ReasoningEffort](
+ values=list(ReasoningEffort), accepts_none=False, default=None
+ )
+ ),
+ )
+ request = ChatCompletionRequest(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content="plain text"),
+ ],
+ reasoning_effort=ReasoningEffort.high,
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
+ assistant_msg = parsed.messages[1]
+ assert isinstance(assistant_msg, AssistantMessage)
+ assert assistant_msg.content == "plain text"
+
+
class TestToolMessageContentChunk:
@pytest.fixture()
def normalizer_v15(self) -> InstructRequestNormalizerV15:
@@ -1214,6 +1253,42 @@ def test_pre_v15_rejects_audio_in_tool_content(self) -> None:
with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"):
normalizer.from_chat_completion_request(request)
+ def test_base_normalizer_json_normalizes_tool_content(self) -> None:
+ r"""Base normalizer (v1-v3) JSON-normalizes tool message content."""
+ normalizer = InstructRequestNormalizer(
+ UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None
+ )
+ messy_json = '{"key" : "value" , "num": 1}'
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
+ ToolMessage(content=messy_json, tool_call_id="c1"),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
+ tool_msg = parsed.messages[2]
+ assert isinstance(tool_msg, ToolMessage)
+ assert tool_msg.content == '{"key": "value", "num": 1}'
+
+ def test_v7_skips_json_normalization_on_tool_content(self) -> None:
+ r"""V7+ normalizers do not JSON-normalize tool message content."""
+ normalizer = InstructRequestNormalizerV7(
+ UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None
+ )
+ messy_json = '{"key" : "value" , "num": 1}'
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
+ ToolMessage(content=messy_json, tool_call_id="c1"),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
+ tool_msg = parsed.messages[2]
+ assert isinstance(tool_msg, ToolMessage)
+ assert tool_msg.content == messy_json
+
class TestSystemMessageContentChunk:
def test_system_message_accepts_audio_chunk(self) -> None:
From fee10213ad45ce381532169e302e9a13cff767b6 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 12:29:00 +0200
Subject: [PATCH 11/47] Remove _narrow_tool_content, inline logic directly
---
.../protocol/instruct/messages.py | 2 +-
.../protocol/instruct/normalize.py | 40 ++++---------------
2 files changed, 8 insertions(+), 34 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py
index f8f0c366..c0b922d3 100644
--- a/src/mistral_common/protocol/instruct/messages.py
+++ b/src/mistral_common/protocol/instruct/messages.py
@@ -354,7 +354,7 @@ def to_openai(self) -> dict[str, Any]:
@classmethod
def from_openai(cls, openai_message: dict[str, Any]) -> "ToolMessage":
r"""Converts the OpenAI message to the Mistral format."""
- content = cls._content_from_openai(openai_message.get("content"))
+ content = cls._content_from_openai(openai_message["content"])
tool_message = cls.model_validate(
{
"role": openai_message["role"],
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index d199f96b..c69fe912 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -252,11 +252,12 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s
for message in messages:
assert isinstance(message, self._tool_message_class), "Expected tool message"
content = self._aggregate_content_chunks([message])
- validated = self._narrow_tool_content(content)
- if isinstance(validated, str):
- normalized_content: str | list[ContentChunk] = self._normalize_json_content(validated)
- else:
- normalized_content = validated
+ if not isinstance(content, str):
+ raise InvalidRequestException(
+ f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}"
+ )
+ normalized_content = self._normalize_json_content(content)
+
tool_messages.append(
self._tool_message_class(
content=normalized_content, tool_call_id=message.tool_call_id, name=message.name
@@ -294,28 +295,6 @@ def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str |
f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
)
- def _narrow_tool_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
- r"""Validate that tool content is text-only.
-
- Pre-V15 normalizers only allow text content. Since ``_aggregate_content_chunks``
- already collapses text-only content to ``str``, receiving a list here means
- non-text chunks are present.
-
- Args:
- content: The aggregated content.
-
- Returns:
- The validated content as a string.
-
- Raises:
- InvalidRequestException: If non-text content chunks are found.
- """
- if isinstance(content, str):
- return content
- raise InvalidRequestException(
- f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}"
- )
-
def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]:
return []
@@ -511,9 +490,8 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s
for message in messages:
assert isinstance(message, self._tool_message_class), "Expected tool message"
content = self._aggregate_content_chunks([message])
- validated = self._narrow_tool_content(content)
tool_messages.append(
- self._tool_message_class(content=validated, tool_call_id=message.tool_call_id, name=message.name)
+ self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name)
)
return tool_messages
@@ -630,10 +608,6 @@ class InstructRequestNormalizerV15(InstructRequestNormalizerV13):
_chunk_join_str: str = ""
- def _narrow_tool_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
- r"""V15 accepts all ContentChunk types in tool messages."""
- return content
-
def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk."""
if isinstance(content, str):
From dc09e05fab29954f0782f6ea6331fb52ff246e11 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 14:31:16 +0200
Subject: [PATCH 12/47] Fix mypy errors and enable V15 assistant multimodal in
chat templates
---
.../chat_templates/template_generator.py | 21 ++++++++++++++-----
tests/test_converters.py | 2 +-
tests/test_normalization.py | 2 +-
3 files changed, 18 insertions(+), 7 deletions(-)
diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py
index 868cf1c2..eaf66a15 100644
--- a/src/mistral_common/integrations/chat_templates/template_generator.py
+++ b/src/mistral_common/integrations/chat_templates/template_generator.py
@@ -1230,10 +1230,17 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str:
"""
lines = []
+ assistant_supports_image = config.image_support and config.version >= TokenizerVersion.v15
+ assistant_supports_audio = config.audio_support and config.version >= TokenizerVersion.v15
+
+ comment_parts = ["text"]
if config.any_thinking_support:
- chunk_types = "text and thinking"
- else:
- chunk_types = "text"
+ comment_parts.append("thinking")
+ if assistant_supports_image:
+ comment_parts.append("image")
+ if assistant_supports_audio:
+ comment_parts.append("audio")
+ chunk_types = _join_types_desc(comment_parts)
comment = f"{{#- Assistant messages supports {chunk_types} content. #}}"
lines.append("")
@@ -1264,13 +1271,17 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str:
desc_parts = ["text"]
if config.any_thinking_support:
desc_parts.append("thinking")
+ if assistant_supports_image:
+ desc_parts.append("image")
+ if assistant_supports_audio:
+ desc_parts.append("audio")
rc_call_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'"
if config.any_thinking_support:
rc_call_args += ", support_thinking=true"
if config.image_support:
- rc_call_args += ", support_images=false"
+ rc_call_args += f", support_images={'true' if assistant_supports_image else 'false'}"
if config.audio_support:
- rc_call_args += ", support_audio=false"
+ rc_call_args += f", support_audio={'true' if assistant_supports_audio else 'false'}"
lines.append(" {%- if message['content'] %}")
diff --git a/tests/test_converters.py b/tests/test_converters.py
index 3cd546a8..d093d945 100644
--- a/tests/test_converters.py
+++ b/tests/test_converters.py
@@ -588,7 +588,7 @@ def test_non_leading_think_chunks_construction_ok() -> None:
)
def test_non_leading_think_chunks_to_openai_raises(content: list[TextChunk | ThinkChunk]) -> None:
"""to_openai raises when ThinkChunks are not leading."""
- msg = AssistantMessage(content=content) # type: ignore[arg-type]
+ msg = AssistantMessage(content=content)
with pytest.raises(InvalidAssistantMessageException, match="ThinkChunks must be leading"):
msg.to_openai()
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index eb60f180..001fd34a 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -1152,7 +1152,7 @@ def test_accepts_string_content(self) -> None:
)
),
)
- request = ChatCompletionRequest(
+ request = ChatCompletionRequest[ChatMessage](
messages=[
UserMessage(content="query"),
AssistantMessage(content="plain text"),
From db684eb1846bcaefa23b652e4f5d0395b3a562d4 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 14:34:57 +0200
Subject: [PATCH 13/47] Restore V7 tool content validation, add V15 override
accepting all chunks
---
.../protocol/instruct/normalize.py | 27 ++++++++++++++++++-
1 file changed, 26 insertions(+), 1 deletion(-)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index c69fe912..ff742266 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -484,12 +484,17 @@ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | lis
def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
"""Normalize tool messages without JSON normalization.
- V7+ normalizers skip JSON content normalization for tool messages.
+ V7+ normalizers skip JSON content normalization for tool messages but still
+ reject non-text content chunks.
"""
tool_messages: list[ToolMessageType] = []
for message in messages:
assert isinstance(message, self._tool_message_class), "Expected tool message"
content = self._aggregate_content_chunks([message])
+ if not isinstance(content, str):
+ raise InvalidRequestException(
+ f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}"
+ )
tool_messages.append(
self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name)
)
@@ -608,6 +613,26 @@ class InstructRequestNormalizerV15(InstructRequestNormalizerV13):
_chunk_join_str: str = ""
+ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
+ r"""V15 accepts all ContentChunk types in tool messages."""
+ tool_messages: list[ToolMessageType] = []
+ for message in messages:
+ assert isinstance(message, self._tool_message_class), "Expected tool message"
+ content = self._aggregate_content_chunks([message])
+ tool_messages.append(
+ self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name)
+ )
+ # Reorder tool messages based on the tool call order (from V13).
+ id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)}
+ id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)}
+ tool_messages.sort(
+ key=lambda msg: (
+ id_to_tool_call_idx.get(msg.tool_call_id or "null", float("inf")),
+ id_to_tool_result_idx[msg.tool_call_id],
+ ),
+ )
+ return tool_messages
+
def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk."""
if isinstance(content, str):
From affebf53ad64332e2f9848e77c450942b796723c Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 14:39:09 +0200
Subject: [PATCH 14/47] Regenerate V15 golden chat templates with assistant
multimodal support
---
tests/data/chat_templates/v15_audio.jinja | 4 ++--
tests/data/chat_templates/v15_image.jinja | 4 ++--
tests/data/chat_templates/v15_image_think.jinja | 4 ++--
3 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/tests/data/chat_templates/v15_audio.jinja b/tests/data/chat_templates/v15_audio.jinja
index bbd04595..0d843e90 100644
--- a/tests/data/chat_templates/v15_audio.jinja
+++ b/tests/data/chat_templates/v15_audio.jinja
@@ -151,14 +151,14 @@
{{- render_content(message['content'], 'user message content', supported_types_desc='text, input_audio and audio_url', support_audio=true) -}}
{{- '[/INST]' }}
- {#- Assistant messages supports text content. #}
+ {#- Assistant messages supports text and audio content. #}
{%- elif message['role'] == 'assistant' %}
{%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %}
{{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }}
{%- endif %}
{%- if message['content'] %}
- {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_audio=false) -}}
+ {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and audio', support_audio=true) -}}
{%- endif %}
{%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %}
diff --git a/tests/data/chat_templates/v15_image.jinja b/tests/data/chat_templates/v15_image.jinja
index 35b872fe..bd3eefa3 100644
--- a/tests/data/chat_templates/v15_image.jinja
+++ b/tests/data/chat_templates/v15_image.jinja
@@ -162,14 +162,14 @@
{{- render_content(user_content, 'user message content', supported_types_desc='text, image and image_url', support_images=true) -}}
{{- '[/INST]' }}
- {#- Assistant messages supports text content. #}
+ {#- Assistant messages supports text and image content. #}
{%- elif message['role'] == 'assistant' %}
{%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %}
{{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }}
{%- endif %}
{%- if message['content'] %}
- {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_images=false) -}}
+ {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and image', support_images=true) -}}
{%- endif %}
{%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %}
diff --git a/tests/data/chat_templates/v15_image_think.jinja b/tests/data/chat_templates/v15_image_think.jinja
index 5fab24c0..570cb7b8 100644
--- a/tests/data/chat_templates/v15_image_think.jinja
+++ b/tests/data/chat_templates/v15_image_think.jinja
@@ -189,14 +189,14 @@
{{- render_content(user_content, 'user message content', supported_types_desc='text, image and image_url', support_thinking=false, support_images=true) -}}
{{- '[/INST]' }}
- {#- Assistant messages supports text and thinking content. #}
+ {#- Assistant messages supports text, thinking and image content. #}
{%- elif message['role'] == 'assistant' %}
{%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %}
{{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }}
{%- endif %}
{%- if message['content'] %}
- {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and thinking', support_thinking=true, support_images=false) -}}
+ {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text, thinking and image', support_thinking=true, support_images=true) -}}
{%- endif %}
{%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %}
From a8b65ccdb1834940362910cbf3cd93115d112745 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 14:50:04 +0200
Subject: [PATCH 15/47] Restore _are_text_chunks guard, remove assistant
image/audio from templates, add V15 tool sort test
---
.../chat_templates/template_generator.py | 15 ++--------
.../protocol/instruct/messages.py | 9 ++++--
.../protocol/instruct/normalize.py | 20 +++++--------
tests/data/chat_templates/v15_audio.jinja | 4 +--
tests/data/chat_templates/v15_image.jinja | 4 +--
.../data/chat_templates/v15_image_think.jinja | 4 +--
.../transformers/test_core_parity.py | 4 ---
tests/test_normalization.py | 29 +++++++++++++++++++
8 files changed, 52 insertions(+), 37 deletions(-)
diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py
index eaf66a15..10cf02fe 100644
--- a/src/mistral_common/integrations/chat_templates/template_generator.py
+++ b/src/mistral_common/integrations/chat_templates/template_generator.py
@@ -1230,16 +1230,9 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str:
"""
lines = []
- assistant_supports_image = config.image_support and config.version >= TokenizerVersion.v15
- assistant_supports_audio = config.audio_support and config.version >= TokenizerVersion.v15
-
comment_parts = ["text"]
if config.any_thinking_support:
comment_parts.append("thinking")
- if assistant_supports_image:
- comment_parts.append("image")
- if assistant_supports_audio:
- comment_parts.append("audio")
chunk_types = _join_types_desc(comment_parts)
comment = f"{{#- Assistant messages supports {chunk_types} content. #}}"
@@ -1271,17 +1264,13 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str:
desc_parts = ["text"]
if config.any_thinking_support:
desc_parts.append("thinking")
- if assistant_supports_image:
- desc_parts.append("image")
- if assistant_supports_audio:
- desc_parts.append("audio")
rc_call_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'"
if config.any_thinking_support:
rc_call_args += ", support_thinking=true"
if config.image_support:
- rc_call_args += f", support_images={'true' if assistant_supports_image else 'false'}"
+ rc_call_args += ", support_images=false"
if config.audio_support:
- rc_call_args += f", support_audio={'true' if assistant_supports_audio else 'false'}"
+ rc_call_args += ", support_audio=false"
lines.append(" {%- if message['content'] %}")
diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py
index c0b922d3..3b9c8c02 100644
--- a/src/mistral_common/protocol/instruct/messages.py
+++ b/src/mistral_common/protocol/instruct/messages.py
@@ -32,6 +32,11 @@ def _are_think_chunks(chunks: Sequence[TextChunk | ThinkChunk]) -> TypeGuard[lis
return all(isinstance(c, ThinkChunk) for c in chunks)
+def _are_text_chunks(chunks: Sequence[TextChunk | ThinkChunk]) -> TypeGuard[list[TextChunk]]:
+ r"""Narrow a chunk list to TextChunk list."""
+ return all(isinstance(c, TextChunk) for c in chunks)
+
+
class ReasoningFieldFormat(str, Enum):
r"""How to serialize leading `ThinkChunk` in `AssistantMessage.to_openai()`.
@@ -242,8 +247,8 @@ def to_openai(
out_dict["content"] = self._content_to_openai(self.content)
case ReasoningFieldFormat.reasoning | ReasoningFieldFormat.reasoning_content:
think_chunks, content_chunks = self.content[: last_think_idx + 1], self.content[last_think_idx + 1 :]
- if not _are_think_chunks(think_chunks):
- raise RuntimeError("Expected only ThinkChunks in the leading portion.")
+ if not _are_think_chunks(think_chunks) or not _are_text_chunks(content_chunks):
+ raise RuntimeError("Impossible, only think or content chunks should have been present.")
if len(think_chunks) > 0:
out_dict[reasoning_field_format.value] = "\n".join(tc.thinking for tc in think_chunks)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index ff742266..810e1a3b 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -588,17 +588,21 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I
UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest[UATS, Tool], None
)
- def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
- tool_messages: list[ToolMessageType] = super()._aggregate_tool_messages(messages, latest_call_ids)
+ @staticmethod
+ def _inplace_sort_tool_messages(tool_messages: list[ToolMessageType], latest_call_ids: list[str]) -> None:
id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)}
id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)}
# First order by tool call idx and then by tool result idx
- tool_messages.sort(
+ return tool_messages.sort(
key=lambda msg: (
id_to_tool_call_idx.get(msg.tool_call_id or "null", float("inf")),
id_to_tool_result_idx[msg.tool_call_id],
),
)
+
+ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
+ tool_messages: list[ToolMessageType] = super()._aggregate_tool_messages(messages, latest_call_ids)
+ self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids)
return tool_messages
@@ -622,15 +626,7 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s
tool_messages.append(
self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name)
)
- # Reorder tool messages based on the tool call order (from V13).
- id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)}
- id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)}
- tool_messages.sort(
- key=lambda msg: (
- id_to_tool_call_idx.get(msg.tool_call_id or "null", float("inf")),
- id_to_tool_result_idx[msg.tool_call_id],
- ),
- )
+ self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids)
return tool_messages
def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
diff --git a/tests/data/chat_templates/v15_audio.jinja b/tests/data/chat_templates/v15_audio.jinja
index 0d843e90..bbd04595 100644
--- a/tests/data/chat_templates/v15_audio.jinja
+++ b/tests/data/chat_templates/v15_audio.jinja
@@ -151,14 +151,14 @@
{{- render_content(message['content'], 'user message content', supported_types_desc='text, input_audio and audio_url', support_audio=true) -}}
{{- '[/INST]' }}
- {#- Assistant messages supports text and audio content. #}
+ {#- Assistant messages supports text content. #}
{%- elif message['role'] == 'assistant' %}
{%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %}
{{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }}
{%- endif %}
{%- if message['content'] %}
- {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and audio', support_audio=true) -}}
+ {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_audio=false) -}}
{%- endif %}
{%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %}
diff --git a/tests/data/chat_templates/v15_image.jinja b/tests/data/chat_templates/v15_image.jinja
index bd3eefa3..35b872fe 100644
--- a/tests/data/chat_templates/v15_image.jinja
+++ b/tests/data/chat_templates/v15_image.jinja
@@ -162,14 +162,14 @@
{{- render_content(user_content, 'user message content', supported_types_desc='text, image and image_url', support_images=true) -}}
{{- '[/INST]' }}
- {#- Assistant messages supports text and image content. #}
+ {#- Assistant messages supports text content. #}
{%- elif message['role'] == 'assistant' %}
{%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %}
{{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }}
{%- endif %}
{%- if message['content'] %}
- {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and image', support_images=true) -}}
+ {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text', support_images=false) -}}
{%- endif %}
{%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %}
diff --git a/tests/data/chat_templates/v15_image_think.jinja b/tests/data/chat_templates/v15_image_think.jinja
index 570cb7b8..5fab24c0 100644
--- a/tests/data/chat_templates/v15_image_think.jinja
+++ b/tests/data/chat_templates/v15_image_think.jinja
@@ -189,14 +189,14 @@
{{- render_content(user_content, 'user message content', supported_types_desc='text, image and image_url', support_thinking=false, support_images=true) -}}
{{- '[/INST]' }}
- {#- Assistant messages supports text, thinking and image content. #}
+ {#- Assistant messages supports text and thinking content. #}
{%- elif message['role'] == 'assistant' %}
{%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %}
{{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }}
{%- endif %}
{%- if message['content'] %}
- {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text, thinking and image', support_thinking=true, support_images=true) -}}
+ {{- render_content(message['content'], 'assistant message contents', supported_types_desc='text and thinking', support_thinking=true, support_images=false) -}}
{%- endif %}
{%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %}
diff --git a/tests/integrations/chat_templates/transformers/test_core_parity.py b/tests/integrations/chat_templates/transformers/test_core_parity.py
index ffb47608..c579e53c 100644
--- a/tests/integrations/chat_templates/transformers/test_core_parity.py
+++ b/tests/integrations/chat_templates/transformers/test_core_parity.py
@@ -227,10 +227,6 @@ def test_invalid_chunks(
desc_parts = ["text"]
if config.think:
desc_parts.append("thinking")
- if is_v15_plus and config.image:
- desc_parts.append("image")
- if is_v15_plus and config.audio:
- desc_parts.append("audio")
if len(desc_parts) == 1:
chunks = "text"
else:
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index 001fd34a..6b396e06 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -1203,6 +1203,35 @@ def test_v15_preserves_non_text_tool_content(self, normalizer_v15: InstructReque
assert isinstance(tool_msg.content, list)
assert tool_msg.content == [image_chunk]
+ def test_v15_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ image_chunk_1 = ImageURLChunk(image_url="https://example.com/img1.png")
+ image_chunk_2 = ImageURLChunk(image_url="https://example.com/img2.png")
+ request = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ tool_calls=[
+ ToolCall(function=FunctionCall(name="fn1", arguments="{}"), id="c1"),
+ ToolCall(function=FunctionCall(name="fn2", arguments="{}"), id="c2"),
+ ]
+ ),
+ ToolMessage(content=[image_chunk_2], tool_call_id="c2"),
+ ToolMessage(content=[image_chunk_1], tool_call_id="c1"),
+ ],
+ reasoning_effort=ReasoningEffort.high,
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
+
+ tool_msg_1 = parsed.messages[2]
+ assert isinstance(tool_msg_1, ToolMessage)
+ assert tool_msg_1.tool_call_id == "c1"
+ assert tool_msg_1.content == [image_chunk_1]
+
+ tool_msg_2 = parsed.messages[3]
+ assert isinstance(tool_msg_2, ToolMessage)
+ assert tool_msg_2.tool_call_id == "c2"
+ assert tool_msg_2.content == [image_chunk_2]
+
def test_pre_v15_rejects_non_text_tool_content(self) -> None:
r"""Pre-V15 normalizer raises InvalidRequestException for non-text tool content."""
normalizer = get_normalizer(TokenizerVersion.v13)
From d16b7ef1edcb1a0beab99010072b9b7920ff2c8e Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 15:42:22 +0200
Subject: [PATCH 16/47] Apply suggestions from code review
Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
---
src/mistral_common/protocol/instruct/messages.py | 2 +-
src/mistral_common/protocol/instruct/normalize.py | 4 +---
2 files changed, 2 insertions(+), 4 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py
index 3b9c8c02..2b922dfb 100644
--- a/src/mistral_common/protocol/instruct/messages.py
+++ b/src/mistral_common/protocol/instruct/messages.py
@@ -252,7 +252,7 @@ def to_openai(
if len(think_chunks) > 0:
out_dict[reasoning_field_format.value] = "\n".join(tc.thinking for tc in think_chunks)
- if len(content_chunks) == 1 and isinstance(content_chunks[0], TextChunk):
+ if len(content_chunks) == 1:
out_dict["content"] = content_chunks[0].text
elif content_chunks:
out_dict["content"] = self._content_to_openai(content_chunks)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index 810e1a3b..45969660 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -287,9 +287,7 @@ def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str |
Raises:
InvalidRequestException: If unsupported chunk types are found.
"""
- if isinstance(content, str):
- return content
- if _is_assistant_content(content):
+ if isinstance(content, str) or _is_assistant_content(content):
return content
raise InvalidRequestException(
f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
From a701ae337797ac2266ee2fe5793202382ec7f346 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 15:45:55 +0200
Subject: [PATCH 17/47] Use content chunk type aliases in TypeGuard and
narrowing function signatures
---
.../protocol/instruct/messages.py | 4 +-
.../protocol/instruct/normalize.py | 38 ++++++++++++++-----
2 files changed, 30 insertions(+), 12 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py
index 2b922dfb..6ca46b2f 100644
--- a/src/mistral_common/protocol/instruct/messages.py
+++ b/src/mistral_common/protocol/instruct/messages.py
@@ -27,12 +27,12 @@
)
-def _are_think_chunks(chunks: Sequence[TextChunk | ThinkChunk]) -> TypeGuard[list[ThinkChunk]]:
+def _are_think_chunks(chunks: Sequence[AssistantContentChunk]) -> TypeGuard[list[ThinkChunk]]:
r"""Narrow a chunk list to ThinkChunk list."""
return all(isinstance(c, ThinkChunk) for c in chunks)
-def _are_text_chunks(chunks: Sequence[TextChunk | ThinkChunk]) -> TypeGuard[list[TextChunk]]:
+def _are_text_chunks(chunks: Sequence[AssistantContentChunk]) -> TypeGuard[list[TextChunk]]:
r"""Narrow a chunk list to TextChunk list."""
return all(isinstance(c, TextChunk) for c in chunks)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index 45969660..b774b102 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -6,13 +6,16 @@
from mistral_common.exceptions import InvalidRequestException
from mistral_common.protocol.instruct.chunk import (
+ AssistantContentChunk,
AudioChunk,
AudioURLChunk,
ContentChunk,
ImageChunk,
ImageURLChunk,
+ SystemContentChunk,
TextChunk,
ThinkChunk,
+ UserContentChunk,
)
from mistral_common.protocol.instruct.messages import (
UATS,
@@ -37,18 +40,25 @@
def _is_user_content(
chunks: list[ContentChunk],
-) -> TypeGuard[list[TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk]]:
+) -> TypeGuard[list[UserContentChunk]]:
r"""Narrow ContentChunk list to user-compatible types."""
return all(isinstance(c, (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk)) for c in chunks)
def _is_assistant_content(
chunks: list[ContentChunk],
-) -> TypeGuard[list[TextChunk | ThinkChunk]]:
+) -> TypeGuard[list[AssistantContentChunk]]:
r"""Narrow ContentChunk list to assistant-compatible types."""
return all(isinstance(c, (TextChunk, ThinkChunk)) for c in chunks)
+def _is_system_content(
+ chunks: list[ContentChunk],
+) -> TypeGuard[list[SystemContentChunk]]:
+ r"""Narrow ContentChunk list to system-compatible types."""
+ return all(isinstance(c, (TextChunk, AudioChunk, ThinkChunk)) for c in chunks)
+
+
def _aggregate_content_chunks_impl(
contents: list[list[ContentChunk] | str | None],
msg_join_str: str,
@@ -273,7 +283,7 @@ def _normalize_tool_call(self, tool_call: ToolCall) -> ToolCall:
id=tool_call.id,
)
- def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[TextChunk | ThinkChunk]:
+ def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[AssistantContentChunk]:
r"""Validate and narrow content chunks for assistant messages.
Only TextChunk and ThinkChunk are allowed.
@@ -465,7 +475,7 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I
UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest[UATS, Tool], None
)
- def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
+ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]:
r"""Validate content chunks for system messages.
V7+ accepts all SystemContentChunk types (Pydantic validates at construction).
@@ -476,8 +486,15 @@ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | lis
Returns:
The validated content.
+
+ Raises:
+ InvalidRequestException: If unsupported chunk types are found.
"""
- return content
+ if isinstance(content, str) or _is_system_content(content):
+ return content
+ raise InvalidRequestException(
+ f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}"
+ )
def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
"""Normalize tool messages without JSON normalization.
@@ -627,13 +644,14 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s
self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids)
return tool_messages
- def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[ContentChunk]:
+ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]:
r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk."""
- if isinstance(content, str):
- return content
- if any(isinstance(c, ThinkChunk) for c in content):
+ validated = super()._narrow_system_content(content)
+ if isinstance(validated, str):
+ return validated
+ if any(isinstance(c, ThinkChunk) for c in validated):
raise InvalidRequestException("ThinkChunk in system message is not supported for V15")
- return content
+ return validated
@staticmethod
def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15":
From a974af19add4131686bb2b5a2dc86e0931915742 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 16:07:50 +0200
Subject: [PATCH 18/47] Revert encode_assistant_message to list[int], reject
audio in V7-V13 system messages
---
.../protocol/instruct/normalize.py | 19 ++++++----
.../tokens/tokenizers/instruct.py | 37 ++++++++-----------
tests/experimental/test_app.py | 4 +-
tests/guidance/test_guidance.py | 4 +-
tests/test_normalization.py | 26 ++++++++++++-
tests/test_tokenizer_v11.py | 16 ++------
tests/test_tokenizer_v13.py | 6 +--
7 files changed, 61 insertions(+), 51 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index b774b102..b84b4f5b 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -478,8 +478,8 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I
def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]:
r"""Validate content chunks for system messages.
- V7+ accepts all SystemContentChunk types (Pydantic validates at construction).
- V15 overrides to reject ThinkChunk.
+ V7-V13 accepts text and thinking chunks. V15 overrides to allow audio
+ and reject thinking.
Args:
content: The aggregated content chunks.
@@ -491,6 +491,8 @@ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | lis
InvalidRequestException: If unsupported chunk types are found.
"""
if isinstance(content, str) or _is_system_content(content):
+ if isinstance(content, list) and any(isinstance(c, (AudioChunk, AudioURLChunk)) for c in content):
+ raise InvalidRequestException("Audio chunks in system messages are only supported from V15")
return content
raise InvalidRequestException(
f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}"
@@ -646,12 +648,15 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s
def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]:
r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk."""
- validated = super()._narrow_system_content(content)
- if isinstance(validated, str):
- return validated
- if any(isinstance(c, ThinkChunk) for c in validated):
+ if isinstance(content, str):
+ return content
+ if not _is_system_content(content):
+ raise InvalidRequestException(
+ f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}"
+ )
+ if any(isinstance(c, ThinkChunk) for c in content):
raise InvalidRequestException("ThinkChunk in system message is not supported for V15")
- return validated
+ return content
@staticmethod
def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15":
diff --git a/src/mistral_common/tokens/tokenizers/instruct.py b/src/mistral_common/tokens/tokenizers/instruct.py
index 08de0c95..3ec6730e 100644
--- a/src/mistral_common/tokens/tokenizers/instruct.py
+++ b/src/mistral_common/tokens/tokenizers/instruct.py
@@ -117,7 +117,7 @@ def encode_tool_message(
@abstractmethod
def encode_assistant_message(
self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool
- ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
+ ) -> list[int]:
r"""Encode an assistant message.
Raises:
@@ -200,11 +200,9 @@ def encode_instruct(
elif isinstance(msg, AssistantMessage):
continue_message = request.continue_final_message and (msg_idx == len(request.messages) - 1)
- new_tokens, new_images, new_audios = self.encode_assistant_message(
+ new_tokens = self.encode_assistant_message(
msg, msg_idx < last_user_idx, continue_message=continue_message
)
- images.extend(new_images)
- audios.extend(new_audios)
if msg_idx == len(request.messages) - 1:
prefix_ids = new_tokens
elif isinstance(msg, SystemMessage):
@@ -338,7 +336,7 @@ def encode_tool_message(
def encode_assistant_message(
self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool
- ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
+ ) -> list[int]:
r"""Encode an assistant message.
Args:
@@ -348,7 +346,7 @@ def encode_assistant_message(
Only use this if the assistant message is the last message.
Returns:
- The encoded tokens, images, and audios.
+ The encoded tokens.
"""
assert isinstance(message, AssistantMessage), message
if message.tool_calls is not None and len(message.tool_calls) > 0:
@@ -364,7 +362,7 @@ def encode_assistant_message(
raise TokenizerException(f"{message.content} // {message.tool_calls}")
if not message.prefix and not continue_message:
curr_tokens.append(self.tokenizer.eos_id)
- return curr_tokens, [], []
+ return curr_tokens
def encode_think(self, chunk: ThinkChunk) -> list[int]:
r"""Encode a think chunk.
@@ -562,7 +560,7 @@ def _encode_settings(
def encode_assistant_message(
self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool
- ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
+ ) -> list[int]:
r"""Encode an assistant message.
Args:
@@ -573,7 +571,7 @@ def encode_assistant_message(
Only use this if the assistant message is the last message.
Returns:
- The encoded tokens, images, and audios.
+ The encoded tokens.
"""
if message.tool_calls and message.content:
raise ValueError(f"Cannot have tool calls and content defined in the same assistant message {message}")
@@ -585,7 +583,7 @@ def encode_assistant_message(
if message.tool_calls:
if is_before_last_user_message:
# don't tokenize tool call before last user message
- return [], [], []
+ return []
curr_tokens = self._encode_tool_calls_in_assistant_message(message)
elif message.content:
assert isinstance(message.content, str), "Message content must be a string for tokenizer < V7"
@@ -594,7 +592,7 @@ def encode_assistant_message(
raise TokenizerException(f"Invalid assistant message: {message.content}")
if not message.prefix and not continue_message:
curr_tokens.append(self.tokenizer.eos_id)
- return curr_tokens, [], []
+ return curr_tokens
def _encode_infilling(self, text: str) -> list[int]:
r"""Remove prefix space in the case of SentencePieceTokenizers."""
@@ -691,7 +689,7 @@ def encode_tool_message(
def encode_assistant_message(
self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool
- ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
+ ) -> list[int]:
r"""Encode an assistant message.
Note:
@@ -705,7 +703,7 @@ def encode_assistant_message(
is_before_last_user_message: Not used.
Returns:
- The encoded tokens, images, and audios.
+ The encoded tokens.
"""
return super().encode_assistant_message(message, False, continue_message)
@@ -1130,7 +1128,7 @@ def encode_tool_message(
def encode_assistant_message(
self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool
- ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
+ ) -> list[int]:
r"""Encode an assistant message.
Args:
@@ -1140,7 +1138,7 @@ def encode_assistant_message(
Only use this if the assistant message is the last message.
Returns:
- The encoded tokens, images, and audios.
+ The encoded tokens.
"""
if not message.content and not message.tool_calls:
raise TokenizerException(f"Invalid assistant message: {message}")
@@ -1150,22 +1148,17 @@ def encode_assistant_message(
)
curr_tokens: list = []
- images: list[np.ndarray] = []
- audios: list[Audio] = []
if message.content:
if isinstance(message.content, str):
curr_tokens = self._encode_normal_content_assistant_message(message)
elif isinstance(message.content, list):
- content_tokens, new_images, new_audios = self._encode_content_chunks(message.content)
- curr_tokens += content_tokens
- images.extend(new_images)
- audios.extend(new_audios)
+ curr_tokens += self._encode_content_chunks(message.content)[0]
if message.tool_calls:
curr_tokens += self._encode_tool_calls_in_assistant_message(message)
if not message.prefix and not continue_message:
curr_tokens.append(self.tokenizer.eos_id)
- return curr_tokens, images, audios
+ return curr_tokens
def _encode_audio_for_speech_request(self, ref_audio: str | bytes | None, voice: str | None) -> Tokenized:
r"""Encode reference audio or voice preset into a Tokenized object.
diff --git a/tests/experimental/test_app.py b/tests/experimental/test_app.py
index 27bb9dd8..5b38d2c6 100644
--- a/tests/experimental/test_app.py
+++ b/tests/experimental/test_app.py
@@ -394,7 +394,7 @@ def test_detokenize_assistant_message(
def test_detokenize_assistant_message_think_chunks(
assistant_message: AssistantMessage, mistral_tokenizer_v13: MistralTokenizer, tekken_v13_client: TestClient
) -> None:
- encoded_tokens, _, _ = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined]
+ encoded_tokens = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined]
assistant_message, False, False
)
@@ -464,7 +464,7 @@ def test_generate(
engine_request: dict | ChatCompletionRequest | OpenAIChatCompletionRequest,
output_assistant_message: AssistantMessage,
) -> None:
- output_tokens, _, _ = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined]
+ output_tokens = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined]
output_assistant_message, False, False
)
if output_assistant_message.tool_calls:
diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py
index cac90e4c..a5c7a8a4 100644
--- a/tests/guidance/test_guidance.py
+++ b/tests/guidance/test_guidance.py
@@ -228,7 +228,7 @@ def _encode_content(
tokenizer = instruct_tokenizer.tokenizer
if isinstance(content, str):
- result, _, _ = instruct_tokenizer.encode_assistant_message(
+ result = instruct_tokenizer.encode_assistant_message(
AssistantMessage(content=content), is_before_last_user_message=False, continue_message=False
)
return result
@@ -238,7 +238,7 @@ def _encode_content(
tokens: list[int] = []
if content_chunks:
- tokens, _, _ = instruct_tokenizer.encode_assistant_message(
+ tokens = instruct_tokenizer.encode_assistant_message(
AssistantMessage(content=content_chunks),
is_before_last_user_message=False,
continue_message=False,
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index 6b396e06..eb6de338 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -1349,8 +1349,8 @@ def test_v15_rejects_think_in_system_message(self) -> None:
with pytest.raises(InvalidRequestException, match="ThinkChunk"):
normalizer.from_chat_completion_request(request)
- def test_v7_normalization_preserves_audio_in_system_message(self) -> None:
- r"""V7 normalizer preserves AudioChunk in system messages."""
+ def test_v7_rejects_audio_in_system_content(self) -> None:
+ r"""V7 normalizer rejects AudioChunk in system messages."""
normalizer = InstructRequestNormalizerV7.normalizer()
request = mock_chat_completion(
messages=[
@@ -1358,6 +1358,28 @@ def test_v7_normalization_preserves_audio_in_system_message(self) -> None:
UserMessage(content="test"),
]
)
+ with pytest.raises(
+ InvalidRequestException, match="Audio chunks in system messages are only supported from V15"
+ ):
+ normalizer.from_chat_completion_request(request)
+
+ def test_v15_preserves_audio_in_system_message(self) -> None:
+ r"""V15 normalizer preserves AudioChunk in system messages."""
+ normalizer = get_normalizer(
+ TokenizerVersion.v15,
+ model_settings_builder=ModelSettingsBuilder(
+ reasoning_effort=EnumBuilder[ReasoningEffort](
+ values=list(ReasoningEffort), accepts_none=False, default=None
+ )
+ ),
+ )
+ request = ChatCompletionRequest[ChatMessage](
+ messages=[
+ SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]),
+ UserMessage(content="test"),
+ ],
+ reasoning_effort=ReasoningEffort.high,
+ )
parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
system_msg = parsed.messages[0]
assert isinstance(system_msg, SystemMessage)
diff --git a/tests/test_tokenizer_v11.py b/tests/test_tokenizer_v11.py
index 59bcb466..8c3c7330 100644
--- a/tests/test_tokenizer_v11.py
+++ b/tests/test_tokenizer_v11.py
@@ -32,15 +32,13 @@ def test_special_tokens(tekkenizer: InstructTokenizerV11) -> None:
def test_tokenize_assistant_message(tekkenizer: InstructTokenizerV11) -> None:
- tokens, images, audios = tekkenizer.encode_assistant_message(
+ tokens = tekkenizer.encode_assistant_message(
AssistantMessage(
tool_calls=[ToolCall(function=FunctionCall(name="a_a_a", arguments="blabla"))],
),
is_before_last_user_message=False,
continue_message=False,
)
- assert images == []
- assert audios == []
assert tokens == [
tekkenizer.TOOL_CALLS,
197,
@@ -63,15 +61,13 @@ def test_tokenize_assistant_message(tekkenizer: InstructTokenizerV11) -> None:
def test_tokenize_assistant_message_continue_message(tekkenizer: InstructTokenizerV11) -> None:
- tokens, images, audios = tekkenizer.encode_assistant_message(
+ tokens = tekkenizer.encode_assistant_message(
AssistantMessage(
content='"blabla"',
),
is_before_last_user_message=False,
continue_message=True,
)
- assert images == []
- assert audios == []
assert tokens == [
134,
198,
@@ -99,7 +95,7 @@ def test_tokenize_assistant_message_continue_message(tekkenizer: InstructTokeniz
def test_tokenize_assistant_messages(tekkenizer: InstructTokenizerV11) -> None:
- tokens, images, audios = tekkenizer.encode_assistant_message(
+ tokens = tekkenizer.encode_assistant_message(
AssistantMessage(
tool_calls=[
ToolCall(function=FunctionCall(name="a_a_a", arguments="blabla")),
@@ -109,8 +105,6 @@ def test_tokenize_assistant_messages(tekkenizer: InstructTokenizerV11) -> None:
is_before_last_user_message=False,
continue_message=False,
)
- assert images == []
- assert audios == []
assert tokens == [
tekkenizer.TOOL_CALLS,
197,
@@ -141,15 +135,13 @@ def test_tokenize_assistant_messages(tekkenizer: InstructTokenizerV11) -> None:
def test_tokenize_assistant_message_train(tekkenizer: InstructTokenizerV11) -> None:
- tokens, images, audios = tekkenizer.encode_assistant_message(
+ tokens = tekkenizer.encode_assistant_message(
AssistantMessage(
tool_calls=[ToolCall(function=FunctionCall(name="a_a_a", arguments="blabla"), id="ABC")],
),
is_before_last_user_message=True,
continue_message=False,
)
- assert images == []
- assert audios == []
assert tokens == [
tekkenizer.TOOL_CALLS,
197,
diff --git a/tests/test_tokenizer_v13.py b/tests/test_tokenizer_v13.py
index cec43105..1d6412ad 100644
--- a/tests/test_tokenizer_v13.py
+++ b/tests/test_tokenizer_v13.py
@@ -299,7 +299,7 @@ def test_tokenize_assistant_message(
v13_tekkenizer_think: InstructTokenizerV13, message: AssistantMessage, expected: str, continue_final_message: bool
) -> None:
if not continue_final_message:
- tokens, images, audios = v13_tekkenizer_think.encode_assistant_message(
+ tokens = v13_tekkenizer_think.encode_assistant_message(
message, is_before_last_user_message=False, continue_message=continue_final_message
)
if not message.prefix:
@@ -314,12 +314,10 @@ def test_tokenize_assistant_message(
message, is_before_last_user_message=False, continue_message=continue_final_message
)
return
- tokens, images, audios = v13_tekkenizer_think.encode_assistant_message(
+ tokens = v13_tekkenizer_think.encode_assistant_message(
message, is_before_last_user_message=False, continue_message=continue_final_message
)
assert v13_tekkenizer_think.decode(tokens, special_token_policy=SpecialTokenPolicy.KEEP) == expected
- assert images == []
- assert audios == []
def test_tokenize_assistant_message_error(v13_tekkenizer: InstructTokenizerV13) -> None:
From 2cf685e17969e192a5a6af672701368082a9cf14 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 16:14:23 +0200
Subject: [PATCH 19/47] Revert V7 audio rejection: audio supported in system
messages from V7+
---
.../protocol/instruct/normalize.py | 19 +++++++------------
tests/test_normalization.py | 15 +++++++++------
2 files changed, 16 insertions(+), 18 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index b84b4f5b..38cb4d89 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -478,8 +478,8 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I
def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]:
r"""Validate content chunks for system messages.
- V7-V13 accepts text and thinking chunks. V15 overrides to allow audio
- and reject thinking.
+ V7+ accepts all SystemContentChunk types (text, audio, thinking).
+ V15 overrides to reject ThinkChunk.
Args:
content: The aggregated content chunks.
@@ -491,8 +491,6 @@ def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | lis
InvalidRequestException: If unsupported chunk types are found.
"""
if isinstance(content, str) or _is_system_content(content):
- if isinstance(content, list) and any(isinstance(c, (AudioChunk, AudioURLChunk)) for c in content):
- raise InvalidRequestException("Audio chunks in system messages are only supported from V15")
return content
raise InvalidRequestException(
f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}"
@@ -648,15 +646,12 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s
def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]:
r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk."""
- if isinstance(content, str):
- return content
- if not _is_system_content(content):
- raise InvalidRequestException(
- f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}"
- )
- if any(isinstance(c, ThinkChunk) for c in content):
+ validated = super()._narrow_system_content(content)
+ if isinstance(validated, str):
+ return validated
+ if any(isinstance(c, ThinkChunk) for c in validated):
raise InvalidRequestException("ThinkChunk in system message is not supported for V15")
- return content
+ return validated
@staticmethod
def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15":
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index eb6de338..b68ea020 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -1349,8 +1349,8 @@ def test_v15_rejects_think_in_system_message(self) -> None:
with pytest.raises(InvalidRequestException, match="ThinkChunk"):
normalizer.from_chat_completion_request(request)
- def test_v7_rejects_audio_in_system_content(self) -> None:
- r"""V7 normalizer rejects AudioChunk in system messages."""
+ def test_v7_preserves_audio_in_system_message(self) -> None:
+ r"""V7 normalizer preserves AudioChunk in system messages."""
normalizer = InstructRequestNormalizerV7.normalizer()
request = mock_chat_completion(
messages=[
@@ -1358,10 +1358,13 @@ def test_v7_rejects_audio_in_system_content(self) -> None:
UserMessage(content="test"),
]
)
- with pytest.raises(
- InvalidRequestException, match="Audio chunks in system messages are only supported from V15"
- ):
- normalizer.from_chat_completion_request(request)
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
+ system_msg = parsed.messages[0]
+ assert isinstance(system_msg, SystemMessage)
+ assert isinstance(system_msg.content, list)
+ assert len(system_msg.content) == 2
+ assert isinstance(system_msg.content[0], TextChunk)
+ assert isinstance(system_msg.content[1], AudioChunk)
def test_v15_preserves_audio_in_system_message(self) -> None:
r"""V15 normalizer preserves AudioChunk in system messages."""
From 953873e6a79b9df49067deaafab345e2426ef161 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 17:23:11 +0200
Subject: [PATCH 20/47] Add pre-V7 system message rejection tests for audio and
think chunks
---
tests/test_normalization.py | 30 ++++++++++++++++++++++++++++++
1 file changed, 30 insertions(+)
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index b68ea020..2c446d83 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -1366,6 +1366,36 @@ def test_v7_preserves_audio_in_system_message(self) -> None:
assert isinstance(system_msg.content[0], TextChunk)
assert isinstance(system_msg.content[1], AudioChunk)
+ def test_pre_v7_rejects_audio_in_system_message(self) -> None:
+ r"""Pre-V7 normalizer rejects AudioChunk in system messages."""
+ normalizer = InstructRequestNormalizer(
+ UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None
+ )
+ request = mock_chat_completion(
+ messages=[
+ SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]),
+ UserMessage(content="query"),
+ AssistantMessage(content="answer"),
+ ]
+ )
+ with pytest.raises(AssertionError):
+ normalizer.from_chat_completion_request(request)
+
+ def test_pre_v7_rejects_think_in_system_message(self) -> None:
+ r"""Pre-V7 normalizer rejects ThinkChunk in system messages."""
+ normalizer = InstructRequestNormalizer(
+ UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None
+ )
+ request = mock_chat_completion(
+ messages=[
+ SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]),
+ UserMessage(content="query"),
+ AssistantMessage(content="answer"),
+ ]
+ )
+ with pytest.raises(AssertionError):
+ normalizer.from_chat_completion_request(request)
+
def test_v15_preserves_audio_in_system_message(self) -> None:
r"""V15 normalizer preserves AudioChunk in system messages."""
normalizer = get_normalizer(
From e73d56c21e6142df5803c8ffcd7610db7348043e Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 20:01:50 +0200
Subject: [PATCH 21/47] Fix review nits: stale comment, misleading return, test
parametrize
---
.../chat_templates/template_generator.py | 5 +++-
src/mistral_common/protocol/instruct/chunk.py | 2 +-
.../protocol/instruct/normalize.py | 2 +-
tests/data/chat_templates/v15.jinja | 2 +-
tests/data/chat_templates/v15_audio.jinja | 2 +-
tests/data/chat_templates/v15_image.jinja | 2 +-
.../data/chat_templates/v15_image_think.jinja | 2 +-
tests/data/chat_templates/v15_think.jinja | 2 +-
tests/test_normalization.py | 28 +++++++------------
9 files changed, 21 insertions(+), 26 deletions(-)
diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py
index 10cf02fe..b9dc662a 100644
--- a/src/mistral_common/integrations/chat_templates/template_generator.py
+++ b/src/mistral_common/integrations/chat_templates/template_generator.py
@@ -1449,7 +1449,10 @@ def _generate_tool_message_handling(config: TemplateConfig) -> str:
lines.append(" {#- Tool messages supports int, float or text content. #}")
lines.append(" {%- elif message['role'] == 'tool' and ns.index > ns.max_idx_user %}")
else:
- lines.append(" {#- Tool messages only supports text content. #}")
+ if config.tool_supports_multimodal:
+ lines.append(" {#- Tool messages (multimodal). #}")
+ else:
+ lines.append(" {#- Tool messages only supports text content. #}")
lines.append(" {%- elif message['role'] == 'tool' %}")
if config.uses_spm_space_tracking:
diff --git a/src/mistral_common/protocol/instruct/chunk.py b/src/mistral_common/protocol/instruct/chunk.py
index 3ba496b9..786c38e6 100644
--- a/src/mistral_common/protocol/instruct/chunk.py
+++ b/src/mistral_common/protocol/instruct/chunk.py
@@ -459,7 +459,7 @@ def from_openai(cls, openai_chunk: dict[str, Any]) -> "ThinkChunk":
SystemContentChunk = Annotated[TextChunk | AudioChunk | ThinkChunk, Field(discriminator="type")]
-ToolContentChunk: TypeAlias = ContentChunk
+ToolContentChunk: TypeAlias = ContentChunk # Accepts all content chunk types (no restriction on tool messages).
def _convert_openai_content_chunks(openai_content_chunks: dict[str, Any]) -> ContentChunk:
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index 38cb4d89..4863d7bb 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -608,7 +608,7 @@ def _inplace_sort_tool_messages(tool_messages: list[ToolMessageType], latest_cal
id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)}
id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)}
# First order by tool call idx and then by tool result idx
- return tool_messages.sort(
+ tool_messages.sort(
key=lambda msg: (
id_to_tool_call_idx.get(msg.tool_call_id or "null", float("inf")),
id_to_tool_result_idx[msg.tool_call_id],
diff --git a/tests/data/chat_templates/v15.jinja b/tests/data/chat_templates/v15.jinja
index d1f29add..e4df96ef 100644
--- a/tests/data/chat_templates/v15.jinja
+++ b/tests/data/chat_templates/v15.jinja
@@ -180,7 +180,7 @@
{{- eos_token }}
- {#- Tool messages only supports text content. #}
+ {#- Tool messages (multimodal). #}
{%- elif message['role'] == 'tool' %}
{{- '[TOOL_RESULTS]' -}}
{{- render_content(message['content'], 'tool message contents') -}}
diff --git a/tests/data/chat_templates/v15_audio.jinja b/tests/data/chat_templates/v15_audio.jinja
index bbd04595..f3c90f1c 100644
--- a/tests/data/chat_templates/v15_audio.jinja
+++ b/tests/data/chat_templates/v15_audio.jinja
@@ -177,7 +177,7 @@
{{- eos_token }}
- {#- Tool messages only supports text content. #}
+ {#- Tool messages (multimodal). #}
{%- elif message['role'] == 'tool' %}
{{- '[TOOL_RESULTS]' -}}
{{- render_content(message['content'], 'tool message contents', supported_types_desc='text and audio', support_audio=true) -}}
diff --git a/tests/data/chat_templates/v15_image.jinja b/tests/data/chat_templates/v15_image.jinja
index 35b872fe..bfbfdd17 100644
--- a/tests/data/chat_templates/v15_image.jinja
+++ b/tests/data/chat_templates/v15_image.jinja
@@ -188,7 +188,7 @@
{{- eos_token }}
- {#- Tool messages only supports text content. #}
+ {#- Tool messages (multimodal). #}
{%- elif message['role'] == 'tool' %}
{{- '[TOOL_RESULTS]' -}}
{{- render_content(message['content'], 'tool message contents', supported_types_desc='text and image', support_images=true) -}}
diff --git a/tests/data/chat_templates/v15_image_think.jinja b/tests/data/chat_templates/v15_image_think.jinja
index 5fab24c0..e1f26e8f 100644
--- a/tests/data/chat_templates/v15_image_think.jinja
+++ b/tests/data/chat_templates/v15_image_think.jinja
@@ -215,7 +215,7 @@
{{- eos_token }}
- {#- Tool messages only supports text content. #}
+ {#- Tool messages (multimodal). #}
{%- elif message['role'] == 'tool' %}
{{- '[TOOL_RESULTS]' -}}
{{- render_content(message['content'], 'tool message contents', supported_types_desc='text and image', support_images=true) -}}
diff --git a/tests/data/chat_templates/v15_think.jinja b/tests/data/chat_templates/v15_think.jinja
index d91b3909..40b8a7fc 100644
--- a/tests/data/chat_templates/v15_think.jinja
+++ b/tests/data/chat_templates/v15_think.jinja
@@ -207,7 +207,7 @@
{{- eos_token }}
- {#- Tool messages only supports text content. #}
+ {#- Tool messages (multimodal). #}
{%- elif message['role'] == 'tool' %}
{{- '[TOOL_RESULTS]' -}}
{{- render_content(message['content'], 'tool message contents') -}}
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index 2c446d83..cef30eb9 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -1366,29 +1366,21 @@ def test_v7_preserves_audio_in_system_message(self) -> None:
assert isinstance(system_msg.content[0], TextChunk)
assert isinstance(system_msg.content[1], AudioChunk)
- def test_pre_v7_rejects_audio_in_system_message(self) -> None:
- r"""Pre-V7 normalizer rejects AudioChunk in system messages."""
- normalizer = InstructRequestNormalizer(
- UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None
- )
- request = mock_chat_completion(
- messages=[
- SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]),
- UserMessage(content="query"),
- AssistantMessage(content="answer"),
- ]
- )
- with pytest.raises(AssertionError):
- normalizer.from_chat_completion_request(request)
-
- def test_pre_v7_rejects_think_in_system_message(self) -> None:
- r"""Pre-V7 normalizer rejects ThinkChunk in system messages."""
+ @pytest.mark.parametrize(
+ "chunk",
+ [
+ pytest.param(AudioChunk(input_audio=b"fake_audio_data"), id="audio"),
+ pytest.param(ThinkChunk(thinking="thinking", closed=True), id="think"),
+ ],
+ )
+ def test_pre_v7_rejects_non_text_in_system_message(self, chunk: AudioChunk | ThinkChunk) -> None:
+ r"""Pre-V7 normalizer rejects non-text chunks in system messages."""
normalizer = InstructRequestNormalizer(
UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None
)
request = mock_chat_completion(
messages=[
- SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]),
+ SystemMessage(content=[TextChunk(text="hello"), chunk]),
UserMessage(content="query"),
AssistantMessage(content="answer"),
]
From 354e244e4d381be0c7a7adb49873f14b32a3bc10 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 21:20:30 +0200
Subject: [PATCH 22/47] Narrow encode_system_message return type to exclude
images
---
src/mistral_common/tokens/tokenizers/instruct.py | 15 +++++++--------
tests/experimental/test_app.py | 4 +---
tests/guidance/test_guidance.py | 3 +--
tests/test_tokenizer_v13.py | 3 +--
4 files changed, 10 insertions(+), 15 deletions(-)
diff --git a/src/mistral_common/tokens/tokenizers/instruct.py b/src/mistral_common/tokens/tokenizers/instruct.py
index 3ec6730e..a939ffb7 100644
--- a/src/mistral_common/tokens/tokenizers/instruct.py
+++ b/src/mistral_common/tokens/tokenizers/instruct.py
@@ -206,8 +206,7 @@ def encode_instruct(
if msg_idx == len(request.messages) - 1:
prefix_ids = new_tokens
elif isinstance(msg, SystemMessage):
- new_tokens, new_images, new_audios = self.encode_system_message(msg)
- images.extend(new_images)
+ new_tokens, new_audios = self.encode_system_message(msg)
audios.extend(new_audios)
else:
raise TokenizerException(f"Unknown message type {type(msg)}")
@@ -295,7 +294,7 @@ def encode_user_message(
curr_tokens, image, audio = self.encode_user_content(content=message_txt, is_last=False, system_prompt=None)
return curr_tokens, image, audio
- def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[np.ndarray], list[Audio]]:
+ def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[Audio]]:
raise NotImplementedError(f"System message encoding not implemented for {self.__class__.__name__}")
def encode_user_content(
@@ -879,22 +878,22 @@ def drop(idx: int) -> None:
if to_drop > 0:
raise TokenizerException("Input couldn't fit in truncate_at_max_token")
- def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[np.ndarray], list[Audio]]:
+ def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[Audio]]:
r"""Encode a system message.
Args:
message: The message to encode.
Returns:
- The encoded tokens, images, and audios.
+ The encoded tokens and audios.
"""
tokens = [self.BEGIN_SYSTEM]
if isinstance(content := message.content, str):
content = [TextChunk(text=content)]
- content_tokens, images, audios = self._encode_content_chunks(content)
+ content_tokens, _images, audios = self._encode_content_chunks(content)
tokens += content_tokens
tokens.append(self.END_SYSTEM)
- return tokens, images, audios
+ return tokens, audios
def encode_user_content(
self,
@@ -1386,7 +1385,7 @@ def _encode_settings(
]
return settings_tokens
- def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[np.ndarray], list[Audio]]:
+ def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[Audio]]:
r"""Encode a system message, rejecting ThinkChunk content."""
if isinstance(message.content, list):
if any(isinstance(chunk, ThinkChunk) for chunk in message.content):
diff --git a/tests/experimental/test_app.py b/tests/experimental/test_app.py
index 5b38d2c6..9851cb06 100644
--- a/tests/experimental/test_app.py
+++ b/tests/experimental/test_app.py
@@ -394,9 +394,7 @@ def test_detokenize_assistant_message(
def test_detokenize_assistant_message_think_chunks(
assistant_message: AssistantMessage, mistral_tokenizer_v13: MistralTokenizer, tekken_v13_client: TestClient
) -> None:
- encoded_tokens = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message( # type: ignore[attr-defined]
- assistant_message, False, False
- )
+ encoded_tokens = mistral_tokenizer_v13.instruct_tokenizer.encode_assistant_message(assistant_message, False, False) # type: ignore[attr-defined]
response = tekken_v13_client.post("/v1/detokenize/", json=encoded_tokens)
assert response.status_code == 200
diff --git a/tests/guidance/test_guidance.py b/tests/guidance/test_guidance.py
index a5c7a8a4..9afe6f33 100644
--- a/tests/guidance/test_guidance.py
+++ b/tests/guidance/test_guidance.py
@@ -228,10 +228,9 @@ def _encode_content(
tokenizer = instruct_tokenizer.tokenizer
if isinstance(content, str):
- result = instruct_tokenizer.encode_assistant_message(
+ return instruct_tokenizer.encode_assistant_message(
AssistantMessage(content=content), is_before_last_user_message=False, continue_message=False
)
- return result
tool_calls = [x for x in content if isinstance(x, ToolCall)]
content_chunks = [x for x in content if not isinstance(x, ToolCall)]
diff --git a/tests/test_tokenizer_v13.py b/tests/test_tokenizer_v13.py
index 1d6412ad..64380ac3 100644
--- a/tests/test_tokenizer_v13.py
+++ b/tests/test_tokenizer_v13.py
@@ -375,9 +375,8 @@ def test_tokenize_assistant_message_error(v13_tekkenizer: InstructTokenizerV13)
def test_encode_system_message(
v13_tekkenizer_think: InstructTokenizerV13, message: SystemMessage, expected: str
) -> None:
- encoded, images, audios = v13_tekkenizer_think.encode_system_message(message)
+ encoded, audios = v13_tekkenizer_think.encode_system_message(message)
assert v13_tekkenizer_think.decode(encoded, special_token_policy=SpecialTokenPolicy.KEEP) == expected
- assert images == []
assert audios == []
From 113acb0f02ab238c8c7d7dadb9943e27aee162d4 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 21:30:24 +0200
Subject: [PATCH 23/47] Add REQUEST_TOOL_AUDIO_TRAIN fixture for tool message
with audio
---
.../chat_templates/fixtures_data.py | 28 +++++++++++++++++++
1 file changed, 28 insertions(+)
diff --git a/tests/integrations/chat_templates/fixtures_data.py b/tests/integrations/chat_templates/fixtures_data.py
index 8fad3fcf..b396f457 100644
--- a/tests/integrations/chat_templates/fixtures_data.py
+++ b/tests/integrations/chat_templates/fixtures_data.py
@@ -981,6 +981,33 @@
tools=_TOOLS,
)
+REQUEST_TOOL_AUDIO_TRAIN = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ UserMessage(content="What does this sound like?"),
+ AssistantMessage(
+ content=None,
+ tool_calls=[
+ ToolCall(
+ id="tl1mg2345",
+ function=FunctionCall(
+ name="tool1",
+ arguments={"location": "San Francisco, CA"}, # type: ignore[arg-type]
+ ),
+ ),
+ ],
+ ),
+ ToolMessage(
+ content=[
+ TextChunk(text="Here is the audio result."),
+ AudioChunk(input_audio=_AUDIO),
+ ],
+ tool_call_id="tl1mg2345",
+ ),
+ AssistantMessage(content="The tool returned audio of a sine wave."),
+ ],
+ tools=_TOOLS,
+)
+
REQUEST_SYSTEM_AUDIO_TRAIN = ChatCompletionRequest( # type: ignore[type-var]
messages=[
SystemMessage(
@@ -1154,6 +1181,7 @@ def _get_conversations(
)
if audio:
conversations.append(REQUEST_SYSTEM_AUDIO_TRAIN)
+ conversations.append(REQUEST_TOOL_AUDIO_TRAIN)
conversations = [c.model_copy(deep=True) for c in conversations]
From baf29e6b6315a4c93b902be01db9bbd5c28325ac Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 21:30:26 +0200
Subject: [PATCH 24/47] Expand TestAssistantContentNarrowing to cover V7, V13,
V15 with rejection tests
---
tests/test_normalization.py | 56 +++++++++++++++++++++++++++----------
1 file changed, 42 insertions(+), 14 deletions(-)
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index cef30eb9..4d066b86 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -1127,10 +1127,34 @@ def test_get_normalizer_version_mapping(
class TestAssistantContentNarrowing:
- def test_accepts_text_and_think_chunks(self) -> None:
+ @pytest.fixture(params=[TokenizerVersion.v7, TokenizerVersion.v13, TokenizerVersion.v15])
+ def normalizer(self, request: pytest.FixtureRequest) -> InstructRequestNormalizer:
+ r"""Normalizer fixture parametrized across V7, V13, V15."""
+ version = request.param
+ if version == TokenizerVersion.v15:
+ return get_normalizer(
+ version,
+ model_settings_builder=ModelSettingsBuilder(
+ reasoning_effort=EnumBuilder[ReasoningEffort](
+ values=list(ReasoningEffort), accepts_none=False, default=None
+ )
+ ),
+ )
+ return get_normalizer(version)
+
+ @staticmethod
+ def _make_request(
+ normalizer: InstructRequestNormalizer, messages: list[ChatMessage]
+ ) -> ChatCompletionRequest[ChatMessage]:
+ r"""Build a ChatCompletionRequest, adding reasoning_effort for V15."""
+ if isinstance(normalizer, InstructRequestNormalizerV15):
+ return ChatCompletionRequest(messages=messages, reasoning_effort=ReasoningEffort.high)
+ return mock_chat_completion(messages=messages)
+
+ def test_accepts_text_and_think_chunks(self, normalizer: InstructRequestNormalizer) -> None:
r"""Normalizer accepts TextChunk and ThinkChunk in assistant messages."""
- normalizer = get_normalizer(TokenizerVersion.v13)
- request = mock_chat_completion(
+ request = self._make_request(
+ normalizer,
messages=[
UserMessage(content="query"),
AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
@@ -1142,28 +1166,32 @@ def test_accepts_text_and_think_chunks(self) -> None:
assert isinstance(assistant_msg.content, list)
assert len(assistant_msg.content) == 2
- def test_accepts_string_content(self) -> None:
+ def test_accepts_string_content(self, normalizer: InstructRequestNormalizer) -> None:
r"""Normalizer accepts string content in assistant messages."""
- normalizer = get_normalizer(
- TokenizerVersion.v15,
- model_settings_builder=ModelSettingsBuilder(
- reasoning_effort=EnumBuilder[ReasoningEffort](
- values=list(ReasoningEffort), accepts_none=False, default=None
- )
- ),
- )
- request = ChatCompletionRequest[ChatMessage](
+ request = self._make_request(
+ normalizer,
messages=[
UserMessage(content="query"),
AssistantMessage(content="plain text"),
],
- reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
assistant_msg = parsed.messages[1]
assert isinstance(assistant_msg, AssistantMessage)
assert assistant_msg.content == "plain text"
+ def test_pydantic_rejects_image_in_assistant(self) -> None:
+ with pytest.raises(ValidationError):
+ AssistantMessage(
+ content=[TextChunk(text="answer"), ImageURLChunk(image_url="https://example.com/img.png")] # type: ignore[list-item]
+ )
+
+ def test_pydantic_rejects_audio_in_assistant(self) -> None:
+ with pytest.raises(ValidationError):
+ AssistantMessage(
+ content=[TextChunk(text="answer"), AudioChunk(input_audio=b"fake_audio_data")] # type: ignore[list-item]
+ )
+
class TestToolMessageContentChunk:
@pytest.fixture()
From 3c1228d321079ad36f827cd681e860e526d70be1 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 23:09:02 +0200
Subject: [PATCH 25/47] Refactor test_parity to use _get_conversations with
to_openai conversion
---
.../chat_templates/test_parity.py | 537 ++----------------
1 file changed, 51 insertions(+), 486 deletions(-)
diff --git a/tests/integrations/chat_templates/test_parity.py b/tests/integrations/chat_templates/test_parity.py
index 8c90c97d..1bd6feed 100644
--- a/tests/integrations/chat_templates/test_parity.py
+++ b/tests/integrations/chat_templates/test_parity.py
@@ -3,8 +3,11 @@
import pytest
from mistral_common.integrations.chat_templates.template_generator import build_chat_template
+from mistral_common.protocol.instruct.request import ChatCompletionRequest
+from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import TokenizerVersion
from tests.integrations.chat_templates.conftest import ALL_CONFIGS, _config_id
+from tests.integrations.chat_templates.fixtures_data import _get_conversations
from tests.integrations.chat_templates.helpers import (
TestConfig,
_load_golden_template,
@@ -13,6 +16,43 @@
)
+def _request_to_render_args(request: ChatCompletionRequest) -> dict[str, Any]:
+ r"""Convert a ChatCompletionRequest to render_template kwargs.
+
+ Enriches tool messages with the ``name`` field resolved from the preceding
+ assistant's ``tool_calls``, which the v2 Jinja template requires but
+ ``ToolMessage.to_openai()`` omits.
+ """
+ openai = request.to_openai()
+ messages = openai["messages"]
+
+ # Build a mapping from tool_call_id -> function name so tool messages
+ # can carry the ``name`` field expected by v2 templates.
+ tool_call_names: dict[str, str] = {}
+ for msg in messages:
+ for tc in msg.get("tool_calls", []):
+ tc_id = tc.get("id")
+ fn_name = tc.get("function", {}).get("name")
+ if tc_id and fn_name:
+ tool_call_names[tc_id] = fn_name
+
+ for msg in messages:
+ if msg.get("role") == "tool" and "name" not in msg:
+ tc_id = msg.get("tool_call_id")
+ if tc_id and tc_id in tool_call_names:
+ msg["name"] = tool_call_names[tc_id]
+
+ kwargs: dict[str, Any] = {
+ "messages": messages,
+ }
+ if "tools" in openai and openai["tools"]:
+ kwargs["tools"] = openai["tools"]
+ reasoning = openai.get("reasoning_effort")
+ if reasoning is not None:
+ kwargs["reasoning_effort"] = reasoning
+ return kwargs
+
+
@pytest.mark.parametrize(
"config",
[c for c in ALL_CONFIGS if not c.plain_think],
@@ -64,489 +104,14 @@ def test_dynamic_template_comprehensive(config: TestConfig) -> None:
static_template = _load_golden_template(template_config)
dynamic_template = build_chat_template(template_config)
- # Test cases
- test_cases = [
- # Simple one-turn
- {
- "name": "one_turn",
- "messages": [
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi there!"},
- ],
- },
- # With system message
- {
- "name": "with_system",
- "messages": [
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi there!"},
- ],
- },
- # Multi-turn
- {
- "name": "multi_turn",
- "messages": [
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi there!"},
- {"role": "user", "content": "How are you?"},
- {"role": "assistant", "content": "I'm doing well!"},
- ],
- },
- # Content as list of chunks
- {
- "name": "content_chunks",
- "messages": [
- {"role": "user", "content": [{"type": "text", "text": "Hello"}]},
- {"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]},
- ],
- },
- ]
-
- # Tool call scenarios (v2+ only)
- if config.version > TokenizerVersion.v1:
- test_cases.extend(
- [
- {
- "name": "with_tools_definition",
- "messages": [
- {"role": "user", "content": "What's the weather?"},
- {"role": "assistant", "content": "It's sunny."},
- ],
- "tools": [
- {
- "type": "function",
- "function": {
- "name": "get_weather",
- "description": "Get weather",
- "parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
- },
- }
- ],
- },
- {
- "name": "with_tool_call_and_result",
- "messages": [
- {"role": "user", "content": "What's the weather in Paris?"},
- {
- "role": "assistant",
- "content": "",
- "tool_calls": [
- {
- "id": "abc123def",
- "type": "function",
- "function": {"name": "get_weather", "arguments": '{"city": "Paris"}'},
- }
- ],
- },
- {"role": "tool", "content": '{"temp": 20}', "tool_call_id": "abc123def", "name": "get_weather"},
- {"role": "assistant", "content": "It's 20 degrees in Paris."},
- ],
- "tools": [
- {
- "type": "function",
- "function": {
- "name": "get_weather",
- "description": "Get weather",
- "parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
- },
- }
- ],
- },
- ]
- )
-
- # Add image test case if image support
- if config.image:
- test_cases.append(
- {
- "name": "with_image",
- "messages": [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "What is this?"},
- {"type": "image_url", "image_url": "http://example.com/image.png"},
- ],
- },
- {"role": "assistant", "content": "It's an image."},
- ],
- }
- )
-
- # Add audio test case if audio support
- if config.audio:
- test_cases.append(
- {
- "name": "with_audio",
- "messages": [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "What is this?"},
- {"type": "audio_url", "audio_url": "http://example.com/audio.mp3"},
- ],
- },
- {"role": "assistant", "content": "It's an audio file."},
- ],
- }
- )
-
- # Add thinking test case if thinking support
- if config.think:
- test_cases.append(
- {
- "name": "with_thinking",
- "messages": [
- {"role": "user", "content": "Solve this problem"},
- {
- "role": "assistant",
- "content": [
- {"type": "thinking", "thinking": "Let me think..."},
- {"type": "text", "text": "The answer is 42."},
- ],
- },
- ],
- }
- )
-
- # Add message aggregation test cases
- test_cases.extend(
- [
- {
- "name": "consecutive_users",
- "messages": [
- {"role": "user", "content": "Hello"},
- {"role": "user", "content": "World"},
- {"role": "assistant", "content": "Hi there"},
- ],
- },
- {
- "name": "consecutive_users_with_system",
- "messages": [
- {"role": "system", "content": "You are helpful."},
- {"role": "user", "content": "Hello"},
- {"role": "user", "content": "World"},
- {"role": "assistant", "content": "Hi there"},
- ],
- },
- {
- "name": "consecutive_assistants",
- "messages": [
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi"},
- {"role": "assistant", "content": "How can I help?"},
- {"role": "user", "content": "Thanks"},
- {"role": "assistant", "content": "Welcome"},
- ],
- },
- {
- "name": "multiple_systems",
- "messages": [
- {"role": "system", "content": "System 1."},
- {"role": "system", "content": "System 2."},
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi"},
- ],
- },
- {
- "name": "mid_conv_system",
- "messages": [
- {"role": "user", "content": "Hello"},
- {"role": "system", "content": "New instruction."},
- {"role": "assistant", "content": "Got it"},
- ],
- },
- {
- "name": "mid_conv_system_with_consecutive_users",
- "messages": [
- {"role": "system", "content": "Be helpful."},
- {"role": "user", "content": "Hello"},
- {"role": "user", "content": "World"},
- {"role": "system", "content": "Now be concise."},
- {"role": "assistant", "content": "Got it"},
- ],
- },
- ]
- )
-
- # Multi-chunk aggregation test cases
- test_cases.extend(
- [
- {
- "name": "consecutive_users_text_chunks",
- "messages": [
- {"role": "user", "content": "First as string"},
- {"role": "user", "content": [{"type": "text", "text": "Second as chunk"}]},
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Third part A"},
- {"type": "text", "text": "Third part B"},
- ],
- },
- {"role": "assistant", "content": "Response"},
- ],
- },
- {
- "name": "system_text_chunks",
- "messages": [
- {
- "role": "system",
- "content": [
- {"type": "text", "text": "You are helpful."},
- {"type": "text", "text": "Be concise."},
- ],
- },
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi"},
- ],
- },
- ]
- )
-
- if config.image:
- test_cases.extend(
- [
- {
- "name": "consecutive_users_with_image",
- "messages": [
- {"role": "user", "content": "What is this?"},
- {
- "role": "user",
- "content": [
- {"type": "image_url", "image_url": "http://example.com/image.png"},
- {"type": "text", "text": "Describe it"},
- ],
- },
- {"role": "assistant", "content": "It's an image."},
- ],
- },
- {
- "name": "consecutive_users_multi_image",
- "messages": [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Describe this"},
- {"type": "image_url", "image_url": "http://example.com/a.png"},
- {"type": "text", "text": "What color?"},
- ],
- },
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Also this"},
- {"type": "image_url", "image_url": "http://example.com/b.png"},
- {"type": "text", "text": "What shape?"},
- ],
- },
- {"role": "assistant", "content": "Both are red squares."},
- ],
- },
- ]
- )
-
- if config.audio:
- test_cases.append(
- {
- "name": "consecutive_users_multi_audio",
- "messages": [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Listen"},
- {"type": "audio_url", "audio_url": "http://example.com/a.wav"},
- {"type": "text", "text": "What language?"},
- ],
- },
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "And this"},
- {"type": "audio_url", "audio_url": "http://example.com/b.wav"},
- {"type": "text", "text": "Transcribe it"},
- ],
- },
- {"role": "assistant", "content": "Both are in English."},
- ],
- }
- )
-
- if config.think:
- test_cases.extend(
- [
- {
- "name": "consecutive_assistants_think",
- "messages": [
- {"role": "user", "content": "Solve this"},
- {
- "role": "assistant",
- "content": [
- {"type": "text", "text": "Hmm."},
- {"type": "thinking", "thinking": "Let me think..."},
- {"type": "text", "text": "I need more context."},
- ],
- },
- {
- "role": "assistant",
- "content": [
- {"type": "text", "text": "OK."},
- {"type": "thinking", "thinking": "Now I understand."},
- {"type": "text", "text": "The answer is 42."},
- ],
- },
- {"role": "user", "content": "Thanks"},
- {"role": "assistant", "content": "You're welcome"},
- ],
- },
- {
- "name": "consecutive_systems_think",
- "messages": [
- {
- "role": "system",
- "content": [
- {"type": "text", "text": "Rule A"},
- {"type": "text", "text": "Rule B"},
- {"type": "thinking", "thinking": "Think 1"},
- ],
- },
- {
- "role": "system",
- "content": [
- {"type": "thinking", "thinking": "Think 2"},
- {"type": "text", "text": "Rule C"},
- {"type": "text", "text": "Rule D"},
- ],
- },
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi"},
- ],
- },
- ]
- )
-
- # V15+ multimodal: tool messages with image/audio content
- if config.version >= TokenizerVersion.v15 and config.image:
- test_cases.append(
- {
- "name": "v15_tool_with_image",
- "messages": [
- {"role": "user", "content": "Use tool"},
- {
- "role": "assistant",
- "content": "",
- "tool_calls": [
- {"id": "abc123def", "function": {"name": "fn", "arguments": "{}"}},
- ],
- },
- {
- "role": "tool",
- "content": [
- {"type": "text", "text": "result"},
- {"type": "image_url", "image_url": "http://example.com/img.png"},
- ],
- "tool_call_id": "abc123def",
- },
- {"role": "assistant", "content": "Done"},
- ],
- "tools": [
- {
- "type": "function",
- "function": {"name": "fn", "description": "test", "parameters": {}},
- }
- ],
- }
- )
-
- if config.version >= TokenizerVersion.v15 and config.audio:
- test_cases.append(
- {
- "name": "v15_tool_with_audio",
- "messages": [
- {"role": "user", "content": "Use tool"},
- {
- "role": "assistant",
- "content": "",
- "tool_calls": [
- {"id": "abc123def", "function": {"name": "fn", "arguments": "{}"}},
- ],
- },
- {
- "role": "tool",
- "content": [
- {"type": "text", "text": "result"},
- {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
- ],
- "tool_call_id": "abc123def",
- },
- {"role": "assistant", "content": "Done"},
- ],
- "tools": [
- {
- "type": "function",
- "function": {"name": "fn", "description": "test", "parameters": {}},
- }
- ],
- }
- )
-
- # V15+ multimodal: system messages with audio content
- if config.version >= TokenizerVersion.v15 and config.audio:
- test_cases.append(
- {
- "name": "v15_system_with_audio",
- "messages": [
- {
- "role": "system",
- "content": [
- {"type": "text", "text": "Listen to context"},
- {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
- ],
- },
- {"role": "user", "content": "Summarize"},
- {"role": "assistant", "content": "Done"},
- ],
- }
- )
-
- skip_names_image = {"with_image", "consecutive_users_with_image", "consecutive_users_multi_image"}
- skip_names_audio = {"with_audio", "consecutive_users_multi_audio"}
- skip_names_think = {"with_thinking", "consecutive_assistants_think", "consecutive_systems_think"}
- # ThinkChunks in system messages are only supported in v13 (not v15+)
- skip_names_think_system = {"consecutive_systems_think"}
- skip_names_tools = {"with_tools_definition", "with_tool_call_and_result"}
- skip_names_v15_image = {"v15_tool_with_image"}
- skip_names_v15_audio = {"v15_tool_with_audio", "v15_system_with_audio"}
-
- # Not using parametrize here because test_cases are built dynamically based on the
- # config (version/image/audio/think) from the outer parametrize. Each sub-case is
- # identifiable via test_name in the assertion failure message.
- for test_case in test_cases:
- test_name = test_case["name"]
- messages = test_case["messages"]
- tools = test_case.get("tools")
-
- if test_name in skip_names_image and not config.image:
- continue
- if test_name in skip_names_audio and not config.audio:
- continue
- if test_name in skip_names_think and not config.think:
- continue
- if test_name in skip_names_think_system and config.version >= TokenizerVersion.v15:
- continue
- if test_name in skip_names_tools and config.version <= TokenizerVersion.v1:
- continue
- if test_name in skip_names_v15_image and (config.version < TokenizerVersion.v15 or not config.image):
- continue
- if test_name in skip_names_v15_audio and (config.version < TokenizerVersion.v15 or not config.audio):
- continue
-
- static_output = render_template(static_template, messages, tools=tools) # type: ignore
- dynamic_output = render_template(dynamic_template, messages, tools=tools) # type: ignore
-
- assert static_output == dynamic_output, (
- f"Output mismatch for {config}, case={test_name}\n\n"
- f"Static output: {static_output}\n"
- f"Dynamic output: {dynamic_output}"
- )
+ for mode in (ValidationMode.finetuning, ValidationMode.test):
+ conversations = _get_conversations(config.version, mode, config.image, config.audio, config.think)
+ for idx, request in enumerate(conversations):
+ render_args = _request_to_render_args(request)
+ static_output = render_template(static_template, **render_args)
+ dynamic_output = render_template(dynamic_template, **render_args)
+ assert static_output == dynamic_output, (
+ f"Output mismatch for {config}, mode={mode}, conversation={idx}\n\n"
+ f"Static output: {static_output}\n"
+ f"Dynamic output: {dynamic_output}"
+ )
From 4ac82b7629eb4094c3e5e41f387d602b6d7f87dc Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Tue, 9 Jun 2026 23:09:03 +0200
Subject: [PATCH 26/47] Reorganize V15 tokenizer test fixtures, add text
assertions
---
tests/test_tokenizer_v15.py | 331 ++++++++++++++++++++++--------------
1 file changed, 206 insertions(+), 125 deletions(-)
diff --git a/tests/test_tokenizer_v15.py b/tests/test_tokenizer_v15.py
index 739ec75a..968128f9 100644
--- a/tests/test_tokenizer_v15.py
+++ b/tests/test_tokenizer_v15.py
@@ -63,6 +63,44 @@
r"[INST]U2[/INST]"
)
+EXPECTED_TEXT_TOOL_AUDIO: str = (
+ r""
+ r'[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "fn",'
+ r' "description": "test", "parameters": {}}}]'
+ r'[/AVAILABLE_TOOLS][MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]'
+ r"[INST]Use the tool[/INST]"
+ r"[TOOL_CALLS]fn[ARGS]{}"
+ r"[TOOL_RESULTS]result[BEGIN_AUDIO][AUDIO][AUDIO][/TOOL_RESULTS]"
+)
+
+EXPECTED_TEXT_TOOL_IMAGE: str = (
+ r""
+ r'[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "fn",'
+ r' "description": "test", "parameters": {}}}]'
+ r'[/AVAILABLE_TOOLS][MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]'
+ r"[INST]Use the tool[/INST]"
+ r"[TOOL_CALLS]fn[ARGS]{}"
+ r"[TOOL_RESULTS]result[IMG][IMG_END][/TOOL_RESULTS]"
+)
+
+EXPECTED_TEXT_SYSTEM_AUDIO: str = (
+ r"[SYSTEM_PROMPT]System with content[BEGIN_AUDIO][AUDIO][AUDIO][/SYSTEM_PROMPT]"
+ r'[MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]'
+ r"[INST]Hello[/INST]"
+)
+
+EXPECTED_TEXT_USER_AUDIO: str = (
+ r""
+ r'[MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]'
+ r"[INST]Here is content[BEGIN_AUDIO][AUDIO][AUDIO][/INST]"
+)
+
+EXPECTED_TEXT_USER_IMAGE: str = (
+ r""
+ r'[MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]'
+ r"[INST][IMG][IMG_END]Here is content[/INST]"
+)
+
def _build_v15_tekkenizer(model_settings_builder: ModelSettingsBuilder | None) -> Tekkenizer:
r"""Build a v15 Tekkenizer with the given model settings builder.
@@ -103,6 +141,107 @@ def get_v15_mistral_tokenizer(
)
+def _build_model_settings_builder(
+ allowed_reasoning_effort: tuple[str, ...] | None,
+) -> ModelSettingsBuilder:
+ """Build a ModelSettingsBuilder from allowed reasoning effort values.
+
+ When `allowed_reasoning_effort` is `None`, returns `ModelSettingsBuilder.none()`
+ (all fields ignored). This matches the behavior of `Tekkenizer.from_file` when no
+ `model_settings_builder` key is present in the JSON.
+ """
+ if allowed_reasoning_effort is None:
+ return ModelSettingsBuilder.none()
+ if not allowed_reasoning_effort:
+ return ModelSettingsBuilder(
+ reasoning_effort=EnumBuilder[ReasoningEffort](values=[], accepts_none=True, default=None)
+ )
+ return ModelSettingsBuilder(
+ reasoning_effort=EnumBuilder[ReasoningEffort](
+ values=[ReasoningEffort(v) for v in allowed_reasoning_effort],
+ accepts_none=True,
+ default=ReasoningEffort(allowed_reasoning_effort[0]) if allowed_reasoning_effort else None,
+ )
+ )
+
+
+def _get_dummy_image_url_chunk() -> ImageURLChunk:
+ r"""Build a small base64-encoded image URL chunk for testing."""
+ img = Image.new("RGB", (4, 4), "red")
+ buf = BytesIO()
+ img.save(buf, "PNG")
+ data_url = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
+ return ImageURLChunk(image_url=data_url)
+
+
+def get_v15_mistral_tokenizer_with_audio() -> MistralTokenizer:
+ r"""Build a V15 MistralTokenizer with audio encoder."""
+ builder = _build_model_settings_builder(tuple(ReasoningEffort))
+ tekkenizer = Tekkenizer(
+ quick_vocab([b"a", b"b", b"c", b"f", b"de"]),
+ special_tokens=get_special_tokens(TokenizerVersion.v15, add_audio=True),
+ pattern=r".+",
+ vocab_size=256 + 100,
+ num_special_tokens=100,
+ version=TokenizerVersion.v15,
+ model_settings_builder=builder,
+ )
+ audio_config = AudioConfig(
+ sampling_rate=24_000,
+ frame_rate=12.5,
+ encoding_config=AudioSpectrogramConfig(
+ num_mel_bins=128,
+ hop_length=160,
+ window_size=400,
+ ),
+ )
+ special_audio_ids = SpecialAudioIDs(
+ audio=tekkenizer.get_special_token(SpecialTokens.audio.value),
+ begin_audio=tekkenizer.get_special_token(SpecialTokens.begin_audio.value),
+ streaming_pad=None,
+ text_to_audio=None,
+ audio_to_text=None,
+ )
+ audio_encoder = AudioEncoder(audio_config, special_audio_ids)
+ instruct_tokenizer = InstructTokenizerV15(tekkenizer, audio_encoder=audio_encoder)
+ request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder)
+ validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test)
+ return MistralTokenizer(
+ instruct_tokenizer=instruct_tokenizer,
+ validator=validator,
+ request_normalizer=request_normalizer,
+ )
+
+
+def get_v15_mistral_tokenizer_with_image() -> MistralTokenizer:
+ r"""Build a V15 MistralTokenizer with image encoder."""
+ builder = _build_model_settings_builder(tuple(ReasoningEffort))
+ tekkenizer = Tekkenizer(
+ quick_vocab([b"a", b"b", b"c", b"f", b"de"]),
+ special_tokens=get_special_tokens(TokenizerVersion.v15, add_think=True),
+ pattern=r".+",
+ vocab_size=256 + 100,
+ num_special_tokens=100,
+ version=TokenizerVersion.v15,
+ model_settings_builder=builder,
+ )
+ image_config = ImageConfig(image_patch_size=16, max_image_size=1024)
+ special_image_ids = SpecialImageIDs(
+ img=tekkenizer.get_special_token(SpecialTokens.img.value),
+ img_break=tekkenizer.get_special_token(SpecialTokens.img_break.value),
+ img_end=tekkenizer.get_special_token(SpecialTokens.img_end.value),
+ )
+ image_encoder = ImageEncoder(image_config, special_image_ids)
+ instruct_tokenizer = InstructTokenizerV15(tekkenizer, image_encoder=image_encoder)
+ request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder)
+ validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test)
+ return MistralTokenizer(
+ instruct_tokenizer=instruct_tokenizer,
+ validator=validator,
+ request_normalizer=request_normalizer,
+ )
+
+
@pytest.fixture(scope="session")
def v15_tekkenizer() -> InstructTokenizerV15:
return get_v15_tekkenizer(_build_model_settings_builder(tuple(ReasoningEffort)))
@@ -153,6 +292,61 @@ def messages() -> list[ChatMessage]:
]
+@pytest.fixture(scope="session")
+def audio_chunk() -> AudioChunk:
+ return get_dummy_audio_chunk()
+
+
+# Multimodal content chunks and their corresponding tokenizer factories for parametrized tests.
+_TOOL_MULTIMODAL_PARAMS = [
+ pytest.param(
+ get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, EXPECTED_TEXT_TOOL_AUDIO, id="audio"
+ ),
+ pytest.param(
+ get_dummy_audio_url_chunk(),
+ 1,
+ 0,
+ get_v15_mistral_tokenizer_with_audio,
+ EXPECTED_TEXT_TOOL_AUDIO,
+ id="audio_url",
+ ),
+ pytest.param(
+ _get_dummy_image_url_chunk(),
+ 0,
+ 1,
+ get_v15_mistral_tokenizer_with_image,
+ EXPECTED_TEXT_TOOL_IMAGE,
+ id="image_url",
+ ),
+]
+_SYSTEM_MULTIMODAL_PARAMS = [
+ pytest.param(
+ get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, EXPECTED_TEXT_SYSTEM_AUDIO, id="audio"
+ ),
+]
+_USER_MULTIMODAL_PARAMS = [
+ pytest.param(
+ get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, EXPECTED_TEXT_USER_AUDIO, id="audio"
+ ),
+ pytest.param(
+ get_dummy_audio_url_chunk(),
+ 1,
+ 0,
+ get_v15_mistral_tokenizer_with_audio,
+ EXPECTED_TEXT_USER_AUDIO,
+ id="audio_url",
+ ),
+ pytest.param(
+ _get_dummy_image_url_chunk(),
+ 0,
+ 1,
+ get_v15_mistral_tokenizer_with_image,
+ EXPECTED_TEXT_USER_IMAGE,
+ id="image_url",
+ ),
+]
+
+
def test_tools_and_reasoning_effort(
v15_tekkenizer: InstructTokenizerV15, available_tools: list[Tool], messages: list[ChatMessage]
) -> None:
@@ -191,30 +385,6 @@ def test_system_think_chunk_raises_v15(v15_tekkenizer: InstructTokenizerV15) ->
v15_tekkenizer.encode_instruct(request)
-def _build_model_settings_builder(
- allowed_reasoning_effort: tuple[str, ...] | None,
-) -> ModelSettingsBuilder:
- """Build a ModelSettingsBuilder from allowed reasoning effort values.
-
- When `allowed_reasoning_effort` is `None`, returns `ModelSettingsBuilder.none()`
- (all fields ignored). This matches the behavior of `Tekkenizer.from_file` when no
- `model_settings_builder` key is present in the JSON.
- """
- if allowed_reasoning_effort is None:
- return ModelSettingsBuilder.none()
- if not allowed_reasoning_effort:
- return ModelSettingsBuilder(
- reasoning_effort=EnumBuilder[ReasoningEffort](values=[], accepts_none=True, default=None)
- )
- return ModelSettingsBuilder(
- reasoning_effort=EnumBuilder[ReasoningEffort](
- values=[ReasoningEffort(v) for v in allowed_reasoning_effort],
- accepts_none=True,
- default=ReasoningEffort(allowed_reasoning_effort[0]) if allowed_reasoning_effort else None,
- )
- )
-
-
@pytest.mark.parametrize(
("reasoning_effort", "allowed_reasoning_effort", "raises", "match"),
[
@@ -310,110 +480,16 @@ def test_encode_chat_completion_continue_final_message() -> None:
assert encoded.tokens[-1] != eos_id
-@pytest.fixture(scope="session")
-def audio_chunk() -> AudioChunk:
- return get_dummy_audio_chunk()
-
-
-def get_v15_mistral_tokenizer_with_audio() -> MistralTokenizer:
- r"""Build a V15 MistralTokenizer with audio encoder."""
- builder = _build_model_settings_builder(tuple(ReasoningEffort))
- tekkenizer = Tekkenizer(
- quick_vocab([b"a", b"b", b"c", b"f", b"de"]),
- special_tokens=get_special_tokens(TokenizerVersion.v15, add_audio=True),
- pattern=r".+",
- vocab_size=256 + 100,
- num_special_tokens=100,
- version=TokenizerVersion.v15,
- model_settings_builder=builder,
- )
- audio_config = AudioConfig(
- sampling_rate=24_000,
- frame_rate=12.5,
- encoding_config=AudioSpectrogramConfig(
- num_mel_bins=128,
- hop_length=160,
- window_size=400,
- ),
- )
- special_audio_ids = SpecialAudioIDs(
- audio=tekkenizer.get_special_token(SpecialTokens.audio.value),
- begin_audio=tekkenizer.get_special_token(SpecialTokens.begin_audio.value),
- streaming_pad=None,
- text_to_audio=None,
- audio_to_text=None,
- )
- audio_encoder = AudioEncoder(audio_config, special_audio_ids)
- instruct_tokenizer = InstructTokenizerV15(tekkenizer, audio_encoder=audio_encoder)
- request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder)
- validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test)
- return MistralTokenizer(
- instruct_tokenizer=instruct_tokenizer,
- validator=validator,
- request_normalizer=request_normalizer,
- )
-
-
-def _get_dummy_image_url_chunk() -> ImageURLChunk:
- r"""Build a small base64-encoded image URL chunk for testing."""
- img = Image.new("RGB", (4, 4), "red")
- buf = BytesIO()
- img.save(buf, "PNG")
- data_url = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
- return ImageURLChunk(image_url=data_url)
-
-
-def get_v15_mistral_tokenizer_with_image() -> MistralTokenizer:
- r"""Build a V15 MistralTokenizer with image encoder."""
- builder = _build_model_settings_builder(tuple(ReasoningEffort))
- tekkenizer = Tekkenizer(
- quick_vocab([b"a", b"b", b"c", b"f", b"de"]),
- special_tokens=get_special_tokens(TokenizerVersion.v15, add_think=True),
- pattern=r".+",
- vocab_size=256 + 100,
- num_special_tokens=100,
- version=TokenizerVersion.v15,
- model_settings_builder=builder,
- )
- image_config = ImageConfig(image_patch_size=16, max_image_size=1024)
- special_image_ids = SpecialImageIDs(
- img=tekkenizer.get_special_token(SpecialTokens.img.value),
- img_break=tekkenizer.get_special_token(SpecialTokens.img_break.value),
- img_end=tekkenizer.get_special_token(SpecialTokens.img_end.value),
- )
- image_encoder = ImageEncoder(image_config, special_image_ids)
- instruct_tokenizer = InstructTokenizerV15(tekkenizer, image_encoder=image_encoder)
- request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder)
- validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test)
- return MistralTokenizer(
- instruct_tokenizer=instruct_tokenizer,
- validator=validator,
- request_normalizer=request_normalizer,
- )
-
-
-# Multimodal content chunks and their corresponding tokenizer factories for parametrized tests.
-_AUDIO_CHUNK_PARAMS = pytest.param(get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, id="audio")
-_AUDIO_URL_CHUNK_PARAMS = pytest.param(
- get_dummy_audio_url_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, id="audio_url"
-)
-_IMAGE_URL_CHUNK_PARAMS = pytest.param(
- _get_dummy_image_url_chunk(), 0, 1, get_v15_mistral_tokenizer_with_image, id="image_url"
-)
-
-_ALL_MULTIMODAL_PARAMS = [_AUDIO_CHUNK_PARAMS, _AUDIO_URL_CHUNK_PARAMS, _IMAGE_URL_CHUNK_PARAMS]
-_AUDIO_ONLY_PARAMS = [_AUDIO_CHUNK_PARAMS]
-
-
@pytest.mark.parametrize(
- ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"),
- _ALL_MULTIMODAL_PARAMS,
+ ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory", "expected_text"),
+ _TOOL_MULTIMODAL_PARAMS,
)
def test_encode_chat_completion_with_multimodal_tool(
content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk,
expected_audios: int,
expected_images: int,
tokenizer_factory: Callable[[], MistralTokenizer],
+ expected_text: str,
) -> None:
mistral_tokenizer = tokenizer_factory()
chat_request = ChatCompletionRequest( # type: ignore[type-var]
@@ -428,19 +504,21 @@ def test_encode_chat_completion_with_multimodal_tool(
tools=[Tool(function=Function(name="fn", description="test", parameters={}))],
)
encoded = mistral_tokenizer.encode_chat_completion(chat_request)
+ assert encoded.text == expected_text, encoded.text
assert len(encoded.audios) == expected_audios
assert len(encoded.images) == expected_images
@pytest.mark.parametrize(
- ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"),
- _AUDIO_ONLY_PARAMS,
+ ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory", "expected_text"),
+ _SYSTEM_MULTIMODAL_PARAMS,
)
def test_encode_chat_completion_with_multimodal_system(
content_chunk: AudioChunk,
expected_audios: int,
expected_images: int,
tokenizer_factory: Callable[[], MistralTokenizer],
+ expected_text: str,
) -> None:
mistral_tokenizer = tokenizer_factory()
chat_request = ChatCompletionRequest( # type: ignore[type-var]
@@ -450,19 +528,21 @@ def test_encode_chat_completion_with_multimodal_system(
],
)
encoded = mistral_tokenizer.encode_chat_completion(chat_request)
+ assert encoded.text == expected_text, encoded.text
assert len(encoded.audios) == expected_audios
assert len(encoded.images) == expected_images
@pytest.mark.parametrize(
- ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory"),
- _ALL_MULTIMODAL_PARAMS,
+ ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory", "expected_text"),
+ _USER_MULTIMODAL_PARAMS,
)
def test_encode_chat_completion_with_multimodal_user(
content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk,
expected_audios: int,
expected_images: int,
tokenizer_factory: Callable[[], MistralTokenizer],
+ expected_text: str,
) -> None:
mistral_tokenizer = tokenizer_factory()
chat_request = ChatCompletionRequest(
@@ -471,5 +551,6 @@ def test_encode_chat_completion_with_multimodal_user(
],
)
encoded = mistral_tokenizer.encode_chat_completion(chat_request)
+ assert encoded.text == expected_text, encoded.text
assert len(encoded.audios) == expected_audios
assert len(encoded.images) == expected_images
From 1035be10f136d9b08ba01a7ae5c4df58f25781ad Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Wed, 10 Jun 2026 00:01:46 +0200
Subject: [PATCH 27/47] Use ToolContentChunk type alias in _parse_tool_content
signature
---
src/mistral_common/tokens/tokenizers/instruct.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/src/mistral_common/tokens/tokenizers/instruct.py b/src/mistral_common/tokens/tokenizers/instruct.py
index a939ffb7..9b044f6c 100644
--- a/src/mistral_common/tokens/tokenizers/instruct.py
+++ b/src/mistral_common/tokens/tokenizers/instruct.py
@@ -20,6 +20,7 @@
ImageURLChunk,
TextChunk,
ThinkChunk,
+ ToolContentChunk,
UserContentChunk,
)
from mistral_common.protocol.instruct.messages import (
@@ -487,7 +488,7 @@ def _parse_json_content(self, content: str) -> Any:
except json.JSONDecodeError:
return content
- def _parse_tool_content(self, content: str | list[ContentChunk]) -> Any:
+ def _parse_tool_content(self, content: str | list[ToolContentChunk]) -> Any:
if isinstance(content, list):
content = "".join(chunk.text for chunk in content if isinstance(chunk, TextChunk))
return self._parse_json_content(content)
From 729050d24ae03693a27eace9d4a7fa24f32a1faa Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Wed, 10 Jun 2026 00:03:03 +0200
Subject: [PATCH 28/47] Use walrus operator for reasoning_effort, fix double
backticks
---
tests/integrations/chat_templates/hf_utils.py | 4 ++--
tests/integrations/chat_templates/test_parity.py | 11 +++++------
2 files changed, 7 insertions(+), 8 deletions(-)
diff --git a/tests/integrations/chat_templates/hf_utils.py b/tests/integrations/chat_templates/hf_utils.py
index 9e8133d2..6d071b3d 100644
--- a/tests/integrations/chat_templates/hf_utils.py
+++ b/tests/integrations/chat_templates/hf_utils.py
@@ -1,7 +1,7 @@
r"""HuggingFace transformers/tokenizers utilities for chat template tests.
-This module contains helpers that depend on ``transformers`` and
-``tokenizers``. General-purpose helpers live in ``helpers.py``.
+This module contains helpers that depend on `transformers` and
+`tokenizers`. General-purpose helpers live in `helpers.py`.
"""
import json
diff --git a/tests/integrations/chat_templates/test_parity.py b/tests/integrations/chat_templates/test_parity.py
index 1bd6feed..56d79c5e 100644
--- a/tests/integrations/chat_templates/test_parity.py
+++ b/tests/integrations/chat_templates/test_parity.py
@@ -19,15 +19,15 @@
def _request_to_render_args(request: ChatCompletionRequest) -> dict[str, Any]:
r"""Convert a ChatCompletionRequest to render_template kwargs.
- Enriches tool messages with the ``name`` field resolved from the preceding
- assistant's ``tool_calls``, which the v2 Jinja template requires but
- ``ToolMessage.to_openai()`` omits.
+ Enriches tool messages with the `name` field resolved from the preceding
+ assistant's `tool_calls`, which the v2 Jinja template requires but
+ `ToolMessage.to_openai()` omits.
"""
openai = request.to_openai()
messages = openai["messages"]
# Build a mapping from tool_call_id -> function name so tool messages
- # can carry the ``name`` field expected by v2 templates.
+ # can carry the `name` field expected by v2 templates.
tool_call_names: dict[str, str] = {}
for msg in messages:
for tc in msg.get("tool_calls", []):
@@ -47,8 +47,7 @@ def _request_to_render_args(request: ChatCompletionRequest) -> dict[str, Any]:
}
if "tools" in openai and openai["tools"]:
kwargs["tools"] = openai["tools"]
- reasoning = openai.get("reasoning_effort")
- if reasoning is not None:
+ if (reasoning := openai.get("reasoning_effort")) is not None:
kwargs["reasoning_effort"] = reasoning
return kwargs
From 69ab7294ed00d37215fae51bc6435090b754bc28 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Wed, 10 Jun 2026 00:07:02 +0200
Subject: [PATCH 29/47] Refactor normalizer tests into version-specific
classes, add error match strings
---
tests/test_normalization.py | 497 ++++++++++++++++++------------------
1 file changed, 242 insertions(+), 255 deletions(-)
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index 4d066b86..bd4f5fe8 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -426,6 +426,71 @@ def test_continue_final_message_forwarded(self, normalizer: InstructRequestNorma
result: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
assert result.continue_final_message is True
+ def test_pydantic_rejects_image_in_assistant(self) -> None:
+ r"""Pydantic rejects ImageURLChunk in assistant message content."""
+ with pytest.raises(ValidationError, match="union_tag_invalid"):
+ AssistantMessage(
+ content=[TextChunk(text="answer"), ImageURLChunk(image_url="https://example.com/img.png")] # type: ignore[list-item]
+ )
+
+ def test_pydantic_rejects_audio_in_assistant(self) -> None:
+ r"""Pydantic rejects AudioChunk in assistant message content."""
+ with pytest.raises(ValidationError, match="union_tag_invalid"):
+ AssistantMessage(
+ content=[TextChunk(text="answer"), AudioChunk(input_audio=b"fake_audio_data")] # type: ignore[list-item]
+ )
+
+ def test_system_message_accepts_audio_chunk(self) -> None:
+ r"""SystemMessage Pydantic model accepts AudioChunk in content."""
+ msg = SystemMessage(content=[AudioChunk(input_audio="dGVzdA==")])
+ assert isinstance(msg.content, list)
+ assert len(msg.content) == 1
+ assert isinstance(msg.content[0], AudioChunk)
+
+ def test_system_message_rejects_image_chunk(self) -> None:
+ r"""SystemMessage Pydantic model rejects ImageURLChunk in content."""
+ with pytest.raises(ValidationError, match="union_tag_invalid"):
+ SystemMessage(content=[ImageURLChunk(image_url="https://example.com/image.png")]) # type: ignore[list-item]
+
+ def test_rejects_audio_in_system_message(self, normalizer: InstructRequestNormalizer) -> None:
+ r"""Pre-V7 normalizer rejects AudioChunk in system messages."""
+ request = mock_chat_completion(
+ messages=[
+ SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]),
+ UserMessage(content="query"),
+ AssistantMessage(content="answer"),
+ ]
+ )
+ with pytest.raises(AssertionError, match="AudioChunk"):
+ normalizer.from_chat_completion_request(request)
+
+ def test_rejects_think_in_system_message(self, normalizer: InstructRequestNormalizer) -> None:
+ r"""Pre-V7 normalizer rejects ThinkChunk in system messages."""
+ request = mock_chat_completion(
+ messages=[
+ SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]),
+ UserMessage(content="query"),
+ AssistantMessage(content="answer"),
+ ]
+ )
+ with pytest.raises(AssertionError, match="ThinkChunk"):
+ normalizer.from_chat_completion_request(request)
+
+ def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalizer) -> None:
+ r"""Base normalizer (v1-v3) JSON-normalizes tool message content."""
+ messy_json = '{"key" : "value" , "num": 1}'
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
+ ToolMessage(content=messy_json, tool_call_id="c1"),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
+ tool_msg = parsed.messages[2]
+ assert isinstance(tool_msg, ToolMessage)
+ assert tool_msg.content == '{"key": "value", "num": 1}'
+
class TestChatCompletionRequestNormalizationV7:
@pytest.fixture(autouse=True)
@@ -634,6 +699,64 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma
TextChunk(text="C\n\nD"),
]
+ def test_accepts_text_and_think_chunks(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
+ r"""V7 normalizer accepts TextChunk and ThinkChunk in assistant messages."""
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
+ assistant_msg = parsed.messages[1]
+ assert isinstance(assistant_msg, AssistantMessage)
+ assert isinstance(assistant_msg.content, list)
+ assert len(assistant_msg.content) == 2
+
+ def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
+ r"""V7 normalizer accepts string content in assistant messages."""
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content="plain text"),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
+ assistant_msg = parsed.messages[1]
+ assert isinstance(assistant_msg, AssistantMessage)
+ assert assistant_msg.content == "plain text"
+
+ def test_skips_json_normalization_on_tool_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
+ r"""V7+ normalizers do not JSON-normalize tool message content."""
+ messy_json = '{"key" : "value" , "num": 1}'
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
+ ToolMessage(content=messy_json, tool_call_id="c1"),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
+ tool_msg = parsed.messages[2]
+ assert isinstance(tool_msg, ToolMessage)
+ assert tool_msg.content == messy_json
+
+ def test_preserves_audio_in_system_message(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
+ r"""V7 normalizer preserves AudioChunk in system messages."""
+ request = mock_chat_completion(
+ messages=[
+ SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]),
+ UserMessage(content="test"),
+ ]
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
+ system_msg = parsed.messages[0]
+ assert isinstance(system_msg, SystemMessage)
+ assert isinstance(system_msg.content, list)
+ assert len(system_msg.content) == 2
+ assert isinstance(system_msg.content[0], TextChunk)
+ assert isinstance(system_msg.content[1], AudioChunk)
+
class TestFineTuningNormalizer:
@pytest.fixture(autouse=True)
@@ -978,6 +1101,82 @@ def test_continue_final_message_forwarded(self, normalizer_v13: InstructRequestN
result: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
assert result.continue_final_message is True
+ def test_accepts_text_and_think_chunks(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
+ r"""V13 normalizer accepts TextChunk and ThinkChunk in assistant messages."""
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
+ assistant_msg = parsed.messages[1]
+ assert isinstance(assistant_msg, AssistantMessage)
+ assert isinstance(assistant_msg.content, list)
+ assert len(assistant_msg.content) == 2
+
+ def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
+ r"""V13 normalizer accepts string content in assistant messages."""
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content="plain text"),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
+ assistant_msg = parsed.messages[1]
+ assert isinstance(assistant_msg, AssistantMessage)
+ assert assistant_msg.content == "plain text"
+
+ def test_rejects_non_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
+ r"""V13 normalizer raises InvalidRequestException for non-text tool content."""
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="hi"),
+ AssistantMessage(
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")],
+ ),
+ ToolMessage(
+ content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")],
+ tool_call_id="test12345",
+ ),
+ ]
+ )
+ with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"):
+ normalizer_v13.from_chat_completion_request(request)
+
+ def test_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
+ r"""V13 normalizer aggregates TextChunks in tool messages to a string."""
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
+ ToolMessage(content=[TextChunk(text="hello"), TextChunk(text="world")], tool_call_id="c1"),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
+ tool_msg = parsed.messages[2]
+ assert isinstance(tool_msg, ToolMessage)
+ assert isinstance(tool_msg.content, str)
+ assert tool_msg.content == "hello\n\nworld"
+
+ def test_rejects_audio_in_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
+ r"""V13 normalizer raises InvalidRequestException for audio tool content."""
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="hi"),
+ AssistantMessage(
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")],
+ ),
+ ToolMessage(
+ content=[AudioChunk(input_audio=b"fake_audio_data")],
+ tool_call_id="test12345",
+ ),
+ ]
+ )
+ with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"):
+ normalizer_v13.from_chat_completion_request(request)
+
class TestChatCompletionRequestNormalizationV15:
@pytest.fixture(autouse=True)
@@ -1097,125 +1296,37 @@ def test_v15_tool_message_text_chunks_joined_without_separator(
assert isinstance(tool_msg, ToolMessage)
assert tool_msg.content == "XY"
-
-@pytest.mark.parametrize(
- "version,expected_class,model_settings_builder",
- [
- (TokenizerVersion.v1, InstructRequestNormalizer, None),
- (TokenizerVersion.v2, InstructRequestNormalizer, None),
- (TokenizerVersion.v3, InstructRequestNormalizer, None),
- (TokenizerVersion.v7, InstructRequestNormalizerV7, None),
- (TokenizerVersion.v11, InstructRequestNormalizerV7, None),
- (TokenizerVersion.v13, InstructRequestNormalizerV13, None),
- (
- TokenizerVersion.v15,
- InstructRequestNormalizerV15,
- ModelSettingsBuilder(
- reasoning_effort=EnumBuilder[ReasoningEffort](
- values=list(ReasoningEffort), accepts_none=False, default=None
- )
- ),
- ),
- ],
-)
-def test_get_normalizer_version_mapping(
- version: TokenizerVersion, expected_class: type, model_settings_builder: ModelSettingsBuilder
-) -> None:
- normalizer = get_normalizer(version, model_settings_builder)
- assert isinstance(normalizer, expected_class)
- assert normalizer._model_settings_builder == model_settings_builder
-
-
-class TestAssistantContentNarrowing:
- @pytest.fixture(params=[TokenizerVersion.v7, TokenizerVersion.v13, TokenizerVersion.v15])
- def normalizer(self, request: pytest.FixtureRequest) -> InstructRequestNormalizer:
- r"""Normalizer fixture parametrized across V7, V13, V15."""
- version = request.param
- if version == TokenizerVersion.v15:
- return get_normalizer(
- version,
- model_settings_builder=ModelSettingsBuilder(
- reasoning_effort=EnumBuilder[ReasoningEffort](
- values=list(ReasoningEffort), accepts_none=False, default=None
- )
- ),
- )
- return get_normalizer(version)
-
- @staticmethod
- def _make_request(
- normalizer: InstructRequestNormalizer, messages: list[ChatMessage]
- ) -> ChatCompletionRequest[ChatMessage]:
- r"""Build a ChatCompletionRequest, adding reasoning_effort for V15."""
- if isinstance(normalizer, InstructRequestNormalizerV15):
- return ChatCompletionRequest(messages=messages, reasoning_effort=ReasoningEffort.high)
- return mock_chat_completion(messages=messages)
-
- def test_accepts_text_and_think_chunks(self, normalizer: InstructRequestNormalizer) -> None:
- r"""Normalizer accepts TextChunk and ThinkChunk in assistant messages."""
- request = self._make_request(
- normalizer,
+ def test_accepts_text_and_think_chunks(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ r"""V15 normalizer accepts TextChunk and ThinkChunk in assistant messages."""
+ request = ChatCompletionRequest[ChatMessage](
messages=[
UserMessage(content="query"),
AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
],
+ reasoning_effort=ReasoningEffort.high,
)
- parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
assistant_msg = parsed.messages[1]
assert isinstance(assistant_msg, AssistantMessage)
assert isinstance(assistant_msg.content, list)
assert len(assistant_msg.content) == 2
- def test_accepts_string_content(self, normalizer: InstructRequestNormalizer) -> None:
- r"""Normalizer accepts string content in assistant messages."""
- request = self._make_request(
- normalizer,
+ def test_accepts_string_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ r"""V15 normalizer accepts string content in assistant messages."""
+ request = ChatCompletionRequest[ChatMessage](
messages=[
UserMessage(content="query"),
AssistantMessage(content="plain text"),
],
+ reasoning_effort=ReasoningEffort.high,
)
- parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
assistant_msg = parsed.messages[1]
assert isinstance(assistant_msg, AssistantMessage)
assert assistant_msg.content == "plain text"
- def test_pydantic_rejects_image_in_assistant(self) -> None:
- with pytest.raises(ValidationError):
- AssistantMessage(
- content=[TextChunk(text="answer"), ImageURLChunk(image_url="https://example.com/img.png")] # type: ignore[list-item]
- )
-
- def test_pydantic_rejects_audio_in_assistant(self) -> None:
- with pytest.raises(ValidationError):
- AssistantMessage(
- content=[TextChunk(text="answer"), AudioChunk(input_audio=b"fake_audio_data")] # type: ignore[list-item]
- )
-
-
-class TestToolMessageContentChunk:
- @pytest.fixture()
- def normalizer_v15(self) -> InstructRequestNormalizerV15:
- return InstructRequestNormalizerV15(
- UserMessage,
- AssistantMessage,
- ToolMessage,
- SystemMessage,
- InstructRequest,
- ModelSettingsBuilder(
- reasoning_effort=EnumBuilder[ReasoningEffort](
- values=list(ReasoningEffort), accepts_none=False, default=None
- )
- ),
- )
-
- @pytest.fixture()
- def normalizer_v13(self) -> InstructRequestNormalizerV13:
- return InstructRequestNormalizerV13(
- UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None
- )
-
- def test_v15_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ def test_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ r"""V15 normalizer preserves non-text chunks in tool messages."""
image_chunk = ImageURLChunk(image_url="https://example.com/image.png")
request = ChatCompletionRequest( # type: ignore[type-var]
messages=[
@@ -1231,7 +1342,8 @@ def test_v15_preserves_non_text_tool_content(self, normalizer_v15: InstructReque
assert isinstance(tool_msg.content, list)
assert tool_msg.content == [image_chunk]
- def test_v15_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ r"""V15 normalizer sorts multimodal tool messages by tool call order."""
image_chunk_1 = ImageURLChunk(image_url="https://example.com/img1.png")
image_chunk_2 = ImageURLChunk(image_url="https://example.com/img2.png")
request = ChatCompletionRequest( # type: ignore[type-var]
@@ -1260,114 +1372,8 @@ def test_v15_sorts_multimodal_tool_messages(self, normalizer_v15: InstructReques
assert tool_msg_2.tool_call_id == "c2"
assert tool_msg_2.content == [image_chunk_2]
- def test_pre_v15_rejects_non_text_tool_content(self) -> None:
- r"""Pre-V15 normalizer raises InvalidRequestException for non-text tool content."""
- normalizer = get_normalizer(TokenizerVersion.v13)
- request = mock_chat_completion(
- messages=[
- UserMessage(content="hi"),
- AssistantMessage(
- tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")],
- ),
- ToolMessage(
- content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")],
- tool_call_id="test12345",
- ),
- ]
- )
- with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"):
- normalizer.from_chat_completion_request(request)
-
- def test_pre_v15_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
- request = mock_chat_completion(
- messages=[
- UserMessage(content="query"),
- AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
- ToolMessage(content=[TextChunk(text="hello"), TextChunk(text="world")], tool_call_id="c1"),
- ],
- )
- parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
- tool_msg = parsed.messages[2]
- assert isinstance(tool_msg, ToolMessage)
- assert isinstance(tool_msg.content, str)
- assert tool_msg.content == "hello\n\nworld"
-
- def test_pre_v15_rejects_audio_in_tool_content(self) -> None:
- r"""Pre-V15 normalizer raises InvalidRequestException for audio tool content."""
- normalizer = get_normalizer(TokenizerVersion.v13)
- request = mock_chat_completion(
- messages=[
- UserMessage(content="hi"),
- AssistantMessage(
- tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")],
- ),
- ToolMessage(
- content=[AudioChunk(input_audio=b"fake_audio_data")],
- tool_call_id="test12345",
- ),
- ]
- )
- with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"):
- normalizer.from_chat_completion_request(request)
-
- def test_base_normalizer_json_normalizes_tool_content(self) -> None:
- r"""Base normalizer (v1-v3) JSON-normalizes tool message content."""
- normalizer = InstructRequestNormalizer(
- UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None
- )
- messy_json = '{"key" : "value" , "num": 1}'
- request = mock_chat_completion(
- messages=[
- UserMessage(content="query"),
- AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
- ToolMessage(content=messy_json, tool_call_id="c1"),
- ],
- )
- parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
- tool_msg = parsed.messages[2]
- assert isinstance(tool_msg, ToolMessage)
- assert tool_msg.content == '{"key": "value", "num": 1}'
-
- def test_v7_skips_json_normalization_on_tool_content(self) -> None:
- r"""V7+ normalizers do not JSON-normalize tool message content."""
- normalizer = InstructRequestNormalizerV7(
- UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None
- )
- messy_json = '{"key" : "value" , "num": 1}'
- request = mock_chat_completion(
- messages=[
- UserMessage(content="query"),
- AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
- ToolMessage(content=messy_json, tool_call_id="c1"),
- ],
- )
- parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
- tool_msg = parsed.messages[2]
- assert isinstance(tool_msg, ToolMessage)
- assert tool_msg.content == messy_json
-
-
-class TestSystemMessageContentChunk:
- def test_system_message_accepts_audio_chunk(self) -> None:
- msg = SystemMessage(content=[AudioChunk(input_audio="dGVzdA==")])
- assert isinstance(msg.content, list)
- assert len(msg.content) == 1
- assert isinstance(msg.content[0], AudioChunk)
-
- def test_system_message_rejects_image_chunk(self) -> None:
- with pytest.raises(ValidationError):
- SystemMessage(content=[ImageURLChunk(image_url="https://example.com/image.png")]) # type: ignore[list-item]
-
- def test_v15_rejects_think_in_system_message(self) -> None:
+ def test_rejects_think_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 normalizer rejects ThinkChunk in system messages."""
- normalizer = get_normalizer(
- TokenizerVersion.v15,
- model_settings_builder=ModelSettingsBuilder(
- reasoning_effort=EnumBuilder[ReasoningEffort](
- values=list(ReasoningEffort), accepts_none=False, default=None
- )
- ),
- )
request = mock_chat_completion(
messages=[
SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]),
@@ -1375,18 +1381,18 @@ def test_v15_rejects_think_in_system_message(self) -> None:
]
)
with pytest.raises(InvalidRequestException, match="ThinkChunk"):
- normalizer.from_chat_completion_request(request)
+ normalizer_v15.from_chat_completion_request(request)
- def test_v7_preserves_audio_in_system_message(self) -> None:
- r"""V7 normalizer preserves AudioChunk in system messages."""
- normalizer = InstructRequestNormalizerV7.normalizer()
- request = mock_chat_completion(
+ def test_preserves_audio_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ r"""V15 normalizer preserves AudioChunk in system messages."""
+ request = ChatCompletionRequest[ChatMessage](
messages=[
SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]),
UserMessage(content="test"),
- ]
+ ],
+ reasoning_effort=ReasoningEffort.high,
)
- parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
system_msg = parsed.messages[0]
assert isinstance(system_msg, SystemMessage)
assert isinstance(system_msg.content, list)
@@ -1394,49 +1400,30 @@ def test_v7_preserves_audio_in_system_message(self) -> None:
assert isinstance(system_msg.content[0], TextChunk)
assert isinstance(system_msg.content[1], AudioChunk)
- @pytest.mark.parametrize(
- "chunk",
- [
- pytest.param(AudioChunk(input_audio=b"fake_audio_data"), id="audio"),
- pytest.param(ThinkChunk(thinking="thinking", closed=True), id="think"),
- ],
- )
- def test_pre_v7_rejects_non_text_in_system_message(self, chunk: AudioChunk | ThinkChunk) -> None:
- r"""Pre-V7 normalizer rejects non-text chunks in system messages."""
- normalizer = InstructRequestNormalizer(
- UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest, None
- )
- request = mock_chat_completion(
- messages=[
- SystemMessage(content=[TextChunk(text="hello"), chunk]),
- UserMessage(content="query"),
- AssistantMessage(content="answer"),
- ]
- )
- with pytest.raises(AssertionError):
- normalizer.from_chat_completion_request(request)
- def test_v15_preserves_audio_in_system_message(self) -> None:
- r"""V15 normalizer preserves AudioChunk in system messages."""
- normalizer = get_normalizer(
+@pytest.mark.parametrize(
+ "version,expected_class,model_settings_builder",
+ [
+ (TokenizerVersion.v1, InstructRequestNormalizer, None),
+ (TokenizerVersion.v2, InstructRequestNormalizer, None),
+ (TokenizerVersion.v3, InstructRequestNormalizer, None),
+ (TokenizerVersion.v7, InstructRequestNormalizerV7, None),
+ (TokenizerVersion.v11, InstructRequestNormalizerV7, None),
+ (TokenizerVersion.v13, InstructRequestNormalizerV13, None),
+ (
TokenizerVersion.v15,
- model_settings_builder=ModelSettingsBuilder(
+ InstructRequestNormalizerV15,
+ ModelSettingsBuilder(
reasoning_effort=EnumBuilder[ReasoningEffort](
values=list(ReasoningEffort), accepts_none=False, default=None
)
),
- )
- request = ChatCompletionRequest[ChatMessage](
- messages=[
- SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]),
- UserMessage(content="test"),
- ],
- reasoning_effort=ReasoningEffort.high,
- )
- parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
- system_msg = parsed.messages[0]
- assert isinstance(system_msg, SystemMessage)
- assert isinstance(system_msg.content, list)
- assert len(system_msg.content) == 2
- assert isinstance(system_msg.content[0], TextChunk)
- assert isinstance(system_msg.content[1], AudioChunk)
+ ),
+ ],
+)
+def test_get_normalizer_version_mapping(
+ version: TokenizerVersion, expected_class: type, model_settings_builder: ModelSettingsBuilder
+) -> None:
+ normalizer = get_normalizer(version, model_settings_builder)
+ assert isinstance(normalizer, expected_class)
+ assert normalizer._model_settings_builder == model_settings_builder
From 21860dce198598924d78674a96343d83f6dd4b8d Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Wed, 10 Jun 2026 10:10:09 +0200
Subject: [PATCH 30/47] Reject ThinkChunk in assistant messages for pre-v11
normalizers, remove Pydantic tests
---
.../protocol/instruct/normalize.py | 17 ++-
tests/test_normalization.py | 132 ++++++------------
2 files changed, 55 insertions(+), 94 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index 4863d7bb..8cd5f904 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -284,20 +284,21 @@ def _normalize_tool_call(self, tool_call: ToolCall) -> ToolCall:
)
def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[AssistantContentChunk]:
- r"""Validate and narrow content chunks for assistant messages.
+ r"""Validate content chunks for assistant messages.
- Only TextChunk and ThinkChunk are allowed.
+ Pre-v11 only allows text content in assistant messages.
+ V11+ overrides to also accept ThinkChunk.
Args:
content: The aggregated content chunks.
Returns:
- The validated and narrowed content.
+ The validated content.
Raises:
InvalidRequestException: If unsupported chunk types are found.
"""
- if isinstance(content, str) or _is_assistant_content(content):
+ if isinstance(content, str):
return content
raise InvalidRequestException(
f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
@@ -615,6 +616,14 @@ def _inplace_sort_tool_messages(tool_messages: list[ToolMessageType], latest_cal
),
)
+ def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[AssistantContentChunk]:
+ r"""V11+ accepts TextChunk and ThinkChunk in assistant messages."""
+ if isinstance(content, str) or _is_assistant_content(content):
+ return content
+ raise InvalidRequestException(
+ f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
+ )
+
def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
tool_messages: list[ToolMessageType] = super()._aggregate_tool_messages(messages, latest_call_ids)
self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids)
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index bd4f5fe8..2e25b931 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -1,7 +1,6 @@
import json
import pytest
-from pydantic import ValidationError
from mistral_common.exceptions import InvalidRequestException
from mistral_common.protocol.instruct.chunk import (
@@ -165,41 +164,6 @@ def check_merge(
assert isinstance(x, (UserMessage, AssistantMessage))
assert x.content == expected
- def check_merge_chunks(
- self,
- roles: list[str],
- expected_roles: list[str],
- expected_content: list[list[ContentChunk] | str],
- normalizer: InstructRequestNormalizer,
- ) -> None:
- letter_to_cls: dict[str, ChatMessage] = {
- "s": SystemMessage(content="s"),
- "u": UserMessage(content="u"),
- "a": AssistantMessage(content="a"),
- "a2": AssistantMessage(
- content=[
- ThinkChunk(thinking="t1"),
- ThinkChunk(thinking="t2"),
- TextChunk(text="a1"),
- TextChunk(text="a2"),
- TextChunk(text="a3"),
- ]
- ),
- }
-
- chat_completion_request = mock_chat_completion(
- messages=[letter_to_cls[r] for r in roles],
- )
- parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- assert len(parsed_request.messages) == len(expected_roles)
- assert [message.role for message in parsed_request.messages] == [
- letter_to_cls[role].role for role in expected_roles
- ]
- assert len(expected_content) == len(parsed_request.messages)
- for x, expected in zip(parsed_request.messages, expected_content):
- assert isinstance(x, (UserMessage, AssistantMessage))
- assert x.content == expected
-
def test_message_aggregation(self, normalizer: InstructRequestNormalizer) -> None:
self.check_merge(["s", "s", "s", "u"], ["u"], ["u"], normalizer)
self.check_merge(["s", "s", "s", "u", "u"], ["u"], ["u\n\nu"], normalizer)
@@ -219,21 +183,6 @@ def test_message_aggregation(self, normalizer: InstructRequestNormalizer) -> Non
normalizer,
)
- self.check_merge_chunks(
- ["u", "a2", "u"],
- ["u", "a", "u"],
- [
- "u",
- [
- ThinkChunk(thinking="t1"),
- ThinkChunk(thinking="t2"),
- TextChunk(text="a1\n\na2\n\na3"),
- ],
- "u",
- ],
- normalizer,
- )
-
def test_tool_chunk_aggregation(self, normalizer: InstructRequestNormalizer) -> None:
messages = [
ToolMessage(content="C", tool_call_id="1"),
@@ -426,32 +375,6 @@ def test_continue_final_message_forwarded(self, normalizer: InstructRequestNorma
result: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
assert result.continue_final_message is True
- def test_pydantic_rejects_image_in_assistant(self) -> None:
- r"""Pydantic rejects ImageURLChunk in assistant message content."""
- with pytest.raises(ValidationError, match="union_tag_invalid"):
- AssistantMessage(
- content=[TextChunk(text="answer"), ImageURLChunk(image_url="https://example.com/img.png")] # type: ignore[list-item]
- )
-
- def test_pydantic_rejects_audio_in_assistant(self) -> None:
- r"""Pydantic rejects AudioChunk in assistant message content."""
- with pytest.raises(ValidationError, match="union_tag_invalid"):
- AssistantMessage(
- content=[TextChunk(text="answer"), AudioChunk(input_audio=b"fake_audio_data")] # type: ignore[list-item]
- )
-
- def test_system_message_accepts_audio_chunk(self) -> None:
- r"""SystemMessage Pydantic model accepts AudioChunk in content."""
- msg = SystemMessage(content=[AudioChunk(input_audio="dGVzdA==")])
- assert isinstance(msg.content, list)
- assert len(msg.content) == 1
- assert isinstance(msg.content[0], AudioChunk)
-
- def test_system_message_rejects_image_chunk(self) -> None:
- r"""SystemMessage Pydantic model rejects ImageURLChunk in content."""
- with pytest.raises(ValidationError, match="union_tag_invalid"):
- SystemMessage(content=[ImageURLChunk(image_url="https://example.com/image.png")]) # type: ignore[list-item]
-
def test_rejects_audio_in_system_message(self, normalizer: InstructRequestNormalizer) -> None:
r"""Pre-V7 normalizer rejects AudioChunk in system messages."""
request = mock_chat_completion(
@@ -491,6 +414,17 @@ def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalize
assert isinstance(tool_msg, ToolMessage)
assert tool_msg.content == '{"key": "value", "num": 1}'
+ def test_rejects_think_in_assistant(self, normalizer: InstructRequestNormalizer) -> None:
+ r"""Pre-v11 normalizer rejects ThinkChunk in assistant messages."""
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
+ ]
+ )
+ with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in assistant message"):
+ normalizer.from_chat_completion_request(request)
+
class TestChatCompletionRequestNormalizationV7:
@pytest.fixture(autouse=True)
@@ -681,7 +615,6 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma
AssistantMessage(content=[TextChunk(text="")]),
AssistantMessage(
content=[
- ThinkChunk(thinking="T"),
TextChunk(text="C"),
TextChunk(text="D"),
]
@@ -693,25 +626,18 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma
)
first_message = parsed_request.messages[0]
assert isinstance(first_message, AssistantMessage)
- assert first_message.content == [
- TextChunk(text="A\n\nB"),
- ThinkChunk(thinking="T"),
- TextChunk(text="C\n\nD"),
- ]
+ assert first_message.content == "A\n\nB\n\nC\n\nD"
- def test_accepts_text_and_think_chunks(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
- r"""V7 normalizer accepts TextChunk and ThinkChunk in assistant messages."""
+ def test_rejects_think_in_assistant(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
+ r"""V7 normalizer rejects ThinkChunk in assistant messages (pre-v11)."""
request = mock_chat_completion(
messages=[
UserMessage(content="query"),
AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
],
)
- parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
- assistant_msg = parsed.messages[1]
- assert isinstance(assistant_msg, AssistantMessage)
- assert isinstance(assistant_msg.content, list)
- assert len(assistant_msg.content) == 2
+ with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in assistant message"):
+ normalizer_v7.from_chat_completion_request(request)
def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
r"""V7 normalizer accepts string content in assistant messages."""
@@ -1128,6 +1054,32 @@ def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV
assert isinstance(assistant_msg, AssistantMessage)
assert assistant_msg.content == "plain text"
+ def test_assistant_think_chunk_aggregation(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
+ r"""V13 normalizer preserves ThinkChunks in assistant message aggregation."""
+ chat_completion_request = mock_chat_completion(
+ messages=[
+ AssistantMessage(content="A"),
+ AssistantMessage(content=[TextChunk(text="B")]),
+ AssistantMessage(
+ content=[
+ ThinkChunk(thinking="T"),
+ TextChunk(text="C"),
+ TextChunk(text="D"),
+ ]
+ ),
+ ]
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
+ chat_completion_request
+ )
+ first_message = parsed.messages[0]
+ assert isinstance(first_message, AssistantMessage)
+ assert first_message.content == [
+ TextChunk(text="A\n\nB"),
+ ThinkChunk(thinking="T"),
+ TextChunk(text="C\n\nD"),
+ ]
+
def test_rejects_non_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
r"""V13 normalizer raises InvalidRequestException for non-text tool content."""
request = mock_chat_completion(
From 53057e257a73dc24779a8ee1e4bdb512e9b15005 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Wed, 10 Jun 2026 10:31:33 +0200
Subject: [PATCH 31/47] Refactor normalizer tests to use Pydantic model
equality assertions
---
tests/test_normalization.py | 198 ++++++++++++------------------------
1 file changed, 64 insertions(+), 134 deletions(-)
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index 2e25b931..27a7bdf5 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -100,9 +100,7 @@ def test_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> N
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == ""
+ assert parsed_request.messages[0] == UserMessage(content="")
assert parsed_request.system_prompt == "S"
def test_assistant_content_with_tool_calls(self, normalizer: InstructRequestNormalizer) -> None:
@@ -132,9 +130,7 @@ def test_assistant_system_user_adds_user(self, normalizer: InstructRequestNormal
assert len(parsed_request.messages) == 3 # 1 user message added, system message removed
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == ""
+ assert parsed_request.messages[0] == UserMessage(content="")
assert parsed_request.system_prompt == "S"
def check_merge(
@@ -219,9 +215,7 @@ def test_normalize_chunks(self, normalizer: InstructRequestNormalizer) -> None:
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == "foo\n\nchunk\n\nfoo\n\nchunk"
+ assert parsed_request.messages[0] == UserMessage(content="foo\n\nchunk\n\nfoo\n\nchunk")
def test_many_chunks_in_user_message(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -237,9 +231,7 @@ def test_many_chunks_in_user_message(self, normalizer: InstructRequestNormalizer
],
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == "foo\n\nchunk1\n\nchunk2\n\nchunk3"
+ assert parsed_request.messages[0] == UserMessage(content="foo\n\nchunk1\n\nchunk2\n\nchunk3")
def test_ignore_middle_empty_text_chunks(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -261,12 +253,8 @@ def test_ignore_middle_empty_text_chunks(self, normalizer: InstructRequestNormal
]
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == "U\n\nV"
- second_message = parsed_request.messages[1]
- assert isinstance(second_message, AssistantMessage)
- assert second_message.content == "A\n\nB"
+ assert parsed_request.messages[0] == UserMessage(content="U\n\nV")
+ assert parsed_request.messages[1] == AssistantMessage(content="A\n\nB")
def test_safety_prompt_aggregated(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = ChatCompletionRequest[ChatMessage](
@@ -280,9 +268,7 @@ def test_safety_prompt_aggregated(self, normalizer: InstructRequestNormalizer) -
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == "user"
+ assert parsed_request.messages[0] == UserMessage(content="user")
assert parsed_request.system_prompt == "system"
def test_normalize_tools(self, normalizer: InstructRequestNormalizer) -> None:
@@ -410,9 +396,7 @@ def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalize
],
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
- tool_msg = parsed.messages[2]
- assert isinstance(tool_msg, ToolMessage)
- assert tool_msg.content == '{"key": "value", "num": 1}'
+ assert parsed.messages[2] == ToolMessage(content='{"key": "value", "num": 1}', tool_call_id="c1")
def test_rejects_think_in_assistant(self, normalizer: InstructRequestNormalizer) -> None:
r"""Pre-v11 normalizer rejects ThinkChunk in assistant messages."""
@@ -444,14 +428,8 @@ def test_system_assistant_user_v7(self, normalizer_v7: InstructRequestNormalizer
parsed_request: InstructRequest = normalizer_v7.from_chat_completion_request(chat_completion_request)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, SystemMessage)
- assert first_message.content == "S"
-
- second_message = parsed_request.messages[1]
- assert isinstance(second_message, AssistantMessage)
- assert second_message.content == "A"
-
+ assert parsed_request.messages[0] == SystemMessage(content="S")
+ assert parsed_request.messages[1] == AssistantMessage(content="A")
assert parsed_request.system_prompt is None
def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNormalizer) -> None:
@@ -468,13 +446,8 @@ def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNorma
assert len(parsed_request.messages) == 2
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, AssistantMessage)
- assert first_message.content == "A"
-
- second_message = parsed_request.messages[1]
- assert isinstance(second_message, SystemMessage)
- assert second_message.content == "S"
+ assert parsed_request.messages[0] == AssistantMessage(content="A")
+ assert parsed_request.messages[1] == SystemMessage(content="S")
def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -487,10 +460,9 @@ def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestN
)
normalized_chat_req = normalizer_v7.from_chat_completion_request(chat_completion_request)
- assert normalized_chat_req.messages[0].content == "A", normalized_chat_req.messages[0].content
- assert len(normalized_chat_req.messages[0].tool_calls) == 1, normalized_chat_req.messages[0].tool_calls
- assert normalized_chat_req.messages[0].tool_calls[0].function.name == "tool1", (
- normalized_chat_req.messages[0].tool_calls[0].function.name
+ assert normalized_chat_req.messages[0] == AssistantMessage(
+ content="A",
+ tool_calls=[ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "11"}'))],
)
def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructRequestNormalizer) -> None:
@@ -563,13 +535,8 @@ def test_only_empty_text_chunks(self, normalizer_v7: InstructRequestNormalizerV7
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
chat_completion_request
)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == ""
- second_message = parsed_request.messages[1]
- assert isinstance(second_message, AssistantMessage)
- # Empty string content is passed through directly
- assert second_message.content == ""
+ assert parsed_request.messages[0] == UserMessage(content="")
+ assert parsed_request.messages[1] == AssistantMessage(content="")
def test_complex_user_aggregation(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
"""Complex multi-user-message aggregation with mixed str, chunks, empty, and non-text chunks."""
@@ -596,13 +563,13 @@ def test_complex_user_aggregation(self, normalizer_v7: InstructRequestNormalizer
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
chat_completion_request
)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == [
- TextChunk(text="A\n\nB\n\nC\n\nD"),
- ImageURLChunk(image_url="E"),
- TextChunk(text="G\n\nH"),
- ]
+ assert parsed_request.messages[0] == UserMessage(
+ content=[
+ TextChunk(text="A\n\nB\n\nC\n\nD"),
+ ImageURLChunk(image_url="E"),
+ TextChunk(text="G\n\nH"),
+ ]
+ )
def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
"""Complex multi-assistant-message aggregation with mixed str, chunks, and empty content."""
@@ -624,9 +591,7 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
chat_completion_request
)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, AssistantMessage)
- assert first_message.content == "A\n\nB\n\nC\n\nD"
+ assert parsed_request.messages[0] == AssistantMessage(content="A\n\nB\n\nC\n\nD")
def test_rejects_think_in_assistant(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
r"""V7 normalizer rejects ThinkChunk in assistant messages (pre-v11)."""
@@ -648,9 +613,7 @@ def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7
],
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
- assistant_msg = parsed.messages[1]
- assert isinstance(assistant_msg, AssistantMessage)
- assert assistant_msg.content == "plain text"
+ assert parsed.messages[1] == AssistantMessage(content="plain text")
def test_skips_json_normalization_on_tool_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
r"""V7+ normalizers do not JSON-normalize tool message content."""
@@ -663,9 +626,7 @@ def test_skips_json_normalization_on_tool_content(self, normalizer_v7: InstructR
],
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
- tool_msg = parsed.messages[2]
- assert isinstance(tool_msg, ToolMessage)
- assert tool_msg.content == messy_json
+ assert parsed.messages[2] == ToolMessage(content=messy_json, tool_call_id="c1")
def test_preserves_audio_in_system_message(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
r"""V7 normalizer preserves AudioChunk in system messages."""
@@ -676,12 +637,12 @@ def test_preserves_audio_in_system_message(self, normalizer_v7: InstructRequestN
]
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
- system_msg = parsed.messages[0]
- assert isinstance(system_msg, SystemMessage)
- assert isinstance(system_msg.content, list)
- assert len(system_msg.content) == 2
- assert isinstance(system_msg.content[0], TextChunk)
- assert isinstance(system_msg.content[1], AudioChunk)
+ assert parsed.messages[0] == SystemMessage(
+ content=[
+ TextChunk(text="hello"),
+ AudioChunk(input_audio=b"fake_audio_data"),
+ ]
+ )
class TestFineTuningNormalizer:
@@ -1036,10 +997,9 @@ def test_accepts_text_and_think_chunks(self, normalizer_v13: InstructRequestNorm
],
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
- assistant_msg = parsed.messages[1]
- assert isinstance(assistant_msg, AssistantMessage)
- assert isinstance(assistant_msg.content, list)
- assert len(assistant_msg.content) == 2
+ assert parsed.messages[1] == AssistantMessage(
+ content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]
+ )
def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
r"""V13 normalizer accepts string content in assistant messages."""
@@ -1050,9 +1010,7 @@ def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV
],
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
- assistant_msg = parsed.messages[1]
- assert isinstance(assistant_msg, AssistantMessage)
- assert assistant_msg.content == "plain text"
+ assert parsed.messages[1] == AssistantMessage(content="plain text")
def test_assistant_think_chunk_aggregation(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
r"""V13 normalizer preserves ThinkChunks in assistant message aggregation."""
@@ -1072,13 +1030,13 @@ def test_assistant_think_chunk_aggregation(self, normalizer_v13: InstructRequest
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- first_message = parsed.messages[0]
- assert isinstance(first_message, AssistantMessage)
- assert first_message.content == [
- TextChunk(text="A\n\nB"),
- ThinkChunk(thinking="T"),
- TextChunk(text="C\n\nD"),
- ]
+ assert parsed.messages[0] == AssistantMessage(
+ content=[
+ TextChunk(text="A\n\nB"),
+ ThinkChunk(thinking="T"),
+ TextChunk(text="C\n\nD"),
+ ]
+ )
def test_rejects_non_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
r"""V13 normalizer raises InvalidRequestException for non-text tool content."""
@@ -1107,10 +1065,7 @@ def test_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNorma
],
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
- tool_msg = parsed.messages[2]
- assert isinstance(tool_msg, ToolMessage)
- assert isinstance(tool_msg.content, str)
- assert tool_msg.content == "hello\n\nworld"
+ assert parsed.messages[2] == ToolMessage(content="hello\n\nworld", tool_call_id="c1")
def test_rejects_audio_in_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
r"""V13 normalizer raises InvalidRequestException for audio tool content."""
@@ -1179,12 +1134,8 @@ def test_v15_intra_message_chunks_joined_without_separator(
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- user_msg = parsed.messages[0]
- assert isinstance(user_msg, UserMessage)
- assert user_msg.content == "AB"
- assistant_msg = parsed.messages[1]
- assert isinstance(assistant_msg, AssistantMessage)
- assert assistant_msg.content == "CD"
+ assert parsed.messages[0] == UserMessage(content="AB")
+ assert parsed.messages[1] == AssistantMessage(content="CD")
def test_v15_inter_message_join_still_uses_separator(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 still joins text across different messages with '\n\n'."""
@@ -1197,9 +1148,7 @@ def test_v15_inter_message_join_still_uses_separator(self, normalizer_v15: Instr
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- user_msg = parsed.messages[0]
- assert isinstance(user_msg, UserMessage)
- assert user_msg.content == "First\n\nSecond"
+ assert parsed.messages[0] == UserMessage(content="First\n\nSecond")
def test_v15_mixed_intra_and_inter_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 combines intra-message ('') and inter-message ('\n\n') joining."""
@@ -1212,9 +1161,7 @@ def test_v15_mixed_intra_and_inter_message(self, normalizer_v15: InstructRequest
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- user_msg = parsed.messages[0]
- assert isinstance(user_msg, UserMessage)
- assert user_msg.content == "AB\n\nCD"
+ assert parsed.messages[0] == UserMessage(content="AB\n\nCD")
def test_v15_mixed_intra_and_inter_assistant_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 combines intra-message ('') and inter-message ('\n\n') joining for assistant messages."""
@@ -1227,9 +1174,7 @@ def test_v15_mixed_intra_and_inter_assistant_messages(self, normalizer_v15: Inst
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assistant_msg = parsed.messages[1]
- assert isinstance(assistant_msg, AssistantMessage)
- assert assistant_msg.content == "AB\n\nCD"
+ assert parsed.messages[1] == AssistantMessage(content="AB\n\nCD")
def test_v15_tool_message_text_chunks_joined_without_separator(
self, normalizer_v15: InstructRequestNormalizerV15
@@ -1244,9 +1189,7 @@ def test_v15_tool_message_text_chunks_joined_without_separator(
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- tool_msg = parsed.messages[2]
- assert isinstance(tool_msg, ToolMessage)
- assert tool_msg.content == "XY"
+ assert parsed.messages[2] == ToolMessage(content="XY", tool_call_id="c1")
def test_accepts_text_and_think_chunks(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 normalizer accepts TextChunk and ThinkChunk in assistant messages."""
@@ -1258,10 +1201,9 @@ def test_accepts_text_and_think_chunks(self, normalizer_v15: InstructRequestNorm
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assistant_msg = parsed.messages[1]
- assert isinstance(assistant_msg, AssistantMessage)
- assert isinstance(assistant_msg.content, list)
- assert len(assistant_msg.content) == 2
+ assert parsed.messages[1] == AssistantMessage(
+ content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]
+ )
def test_accepts_string_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 normalizer accepts string content in assistant messages."""
@@ -1273,9 +1215,7 @@ def test_accepts_string_content(self, normalizer_v15: InstructRequestNormalizerV
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assistant_msg = parsed.messages[1]
- assert isinstance(assistant_msg, AssistantMessage)
- assert assistant_msg.content == "plain text"
+ assert parsed.messages[1] == AssistantMessage(content="plain text")
def test_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 normalizer preserves non-text chunks in tool messages."""
@@ -1289,10 +1229,7 @@ def test_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNo
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- tool_msg = parsed.messages[2]
- assert isinstance(tool_msg, ToolMessage)
- assert isinstance(tool_msg.content, list)
- assert tool_msg.content == [image_chunk]
+ assert parsed.messages[2] == ToolMessage(content=[image_chunk], tool_call_id="c1")
def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 normalizer sorts multimodal tool messages by tool call order."""
@@ -1314,15 +1251,8 @@ def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNor
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- tool_msg_1 = parsed.messages[2]
- assert isinstance(tool_msg_1, ToolMessage)
- assert tool_msg_1.tool_call_id == "c1"
- assert tool_msg_1.content == [image_chunk_1]
-
- tool_msg_2 = parsed.messages[3]
- assert isinstance(tool_msg_2, ToolMessage)
- assert tool_msg_2.tool_call_id == "c2"
- assert tool_msg_2.content == [image_chunk_2]
+ assert parsed.messages[2] == ToolMessage(content=[image_chunk_1], tool_call_id="c1")
+ assert parsed.messages[3] == ToolMessage(content=[image_chunk_2], tool_call_id="c2")
def test_rejects_think_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 normalizer rejects ThinkChunk in system messages."""
@@ -1345,12 +1275,12 @@ def test_preserves_audio_in_system_message(self, normalizer_v15: InstructRequest
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- system_msg = parsed.messages[0]
- assert isinstance(system_msg, SystemMessage)
- assert isinstance(system_msg.content, list)
- assert len(system_msg.content) == 2
- assert isinstance(system_msg.content[0], TextChunk)
- assert isinstance(system_msg.content[1], AudioChunk)
+ assert parsed.messages[0] == SystemMessage(
+ content=[
+ TextChunk(text="hello"),
+ AudioChunk(input_audio=b"fake_audio_data"),
+ ]
+ )
@pytest.mark.parametrize(
From 66af65b3be0f7e23082a84f11c6d7194ba807c0a Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Wed, 10 Jun 2026 11:21:43 +0200
Subject: [PATCH 32/47] Add back intra-message ThinkChunk aggregation test to
V13 class
---
tests/test_normalization.py | 36 ++++++++++++++++++++++++++++++++++--
1 file changed, 34 insertions(+), 2 deletions(-)
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index 27a7bdf5..fe7041b7 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -1012,8 +1012,10 @@ def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
assert parsed.messages[1] == AssistantMessage(content="plain text")
- def test_assistant_think_chunk_aggregation(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
- r"""V13 normalizer preserves ThinkChunks in assistant message aggregation."""
+ def test_assistant_think_chunk_inter_message_aggregation(
+ self, normalizer_v13: InstructRequestNormalizerV13
+ ) -> None:
+ r"""V13 normalizer preserves ThinkChunks across multiple assistant messages."""
chat_completion_request = mock_chat_completion(
messages=[
AssistantMessage(content="A"),
@@ -1038,6 +1040,36 @@ def test_assistant_think_chunk_aggregation(self, normalizer_v13: InstructRequest
]
)
+ def test_assistant_think_chunk_intra_message_aggregation(
+ self, normalizer_v13: InstructRequestNormalizerV13
+ ) -> None:
+ r"""V13 normalizer coalesces TextChunks and preserves multiple ThinkChunks within a single message."""
+ chat_completion_request = mock_chat_completion(
+ messages=[
+ UserMessage(content="u"),
+ AssistantMessage(
+ content=[
+ ThinkChunk(thinking="t1"),
+ ThinkChunk(thinking="t2"),
+ TextChunk(text="a1"),
+ TextChunk(text="a2"),
+ TextChunk(text="a3"),
+ ]
+ ),
+ UserMessage(content="u"),
+ ]
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
+ chat_completion_request
+ )
+ assert parsed.messages[1] == AssistantMessage(
+ content=[
+ ThinkChunk(thinking="t1"),
+ ThinkChunk(thinking="t2"),
+ TextChunk(text="a1\n\na2\n\na3"),
+ ]
+ )
+
def test_rejects_non_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
r"""V13 normalizer raises InvalidRequestException for non-text tool content."""
request = mock_chat_completion(
From 864d127ea2828d58d5bd099f782684232309a359 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Wed, 10 Jun 2026 11:35:22 +0200
Subject: [PATCH 33/47] Assert full InstructRequest output in normalizer tests
---
tests/test_normalization.py | 680 +++++++++++++++++++++++-------------
1 file changed, 432 insertions(+), 248 deletions(-)
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index fe7041b7..26f7d8be 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -6,7 +6,6 @@
from mistral_common.protocol.instruct.chunk import (
AudioChunk,
ChunkTypes,
- ContentChunk,
ImageURLChunk,
TextChunk,
ThinkChunk,
@@ -64,7 +63,10 @@ def test_user_system_user(self, normalizer: InstructRequestNormalizer) -> None:
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- assert parsed_request.system_prompt == "S"
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="U"), UserMessage(content="U")],
+ system_prompt="S",
+ )
def test_multiple_system(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -76,7 +78,10 @@ def test_multiple_system(self, normalizer: InstructRequestNormalizer) -> None:
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- assert parsed_request.system_prompt == "S\n\nS\n\nS"
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="")],
+ system_prompt="S\n\nS\n\nS",
+ )
def test_single_system(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -86,8 +91,10 @@ def test_single_system(self, normalizer: InstructRequestNormalizer) -> None:
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
-
- assert parsed_request.system_prompt == "S"
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="")],
+ system_prompt="S",
+ )
def test_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -99,9 +106,10 @@ def test_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> N
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
-
- assert parsed_request.messages[0] == UserMessage(content="")
- assert parsed_request.system_prompt == "S"
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content=""), AssistantMessage(content="A"), UserMessage(content="U")],
+ system_prompt="S",
+ )
def test_assistant_content_with_tool_calls(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -125,58 +133,89 @@ def test_assistant_system_user_adds_user(self, normalizer: InstructRequestNormal
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
-
- assert parsed_request.system_prompt == "S"
-
- assert len(parsed_request.messages) == 3 # 1 user message added, system message removed
-
- assert parsed_request.messages[0] == UserMessage(content="")
- assert parsed_request.system_prompt == "S"
-
- def check_merge(
- self,
- roles: list[str],
- expected_roles: list[str],
- expected_content: list[list[ContentChunk] | str],
- normalizer: InstructRequestNormalizer,
- ) -> None:
- letter_to_cls: dict[str, ChatMessage] = {
- "s": SystemMessage(content="s"),
- "u": UserMessage(content="u"),
- "a": AssistantMessage(content="a"),
- "u2": UserMessage(content="u2"),
- }
-
- chat_completion_request = mock_chat_completion(
- messages=[letter_to_cls[r] for r in roles],
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content=""), AssistantMessage(content="A"), UserMessage(content="U")],
+ system_prompt="S",
+ )
+
+ def test_message_aggregation_system_then_user(self, normalizer: InstructRequestNormalizer) -> None:
+ parsed = normalizer.from_chat_completion_request(
+ mock_chat_completion(
+ [
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ UserMessage(content="u"),
+ ]
+ )
+ )
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="u")], system_prompt="s\n\ns\n\ns"
)
- parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- assert len(parsed_request.messages) == len(expected_roles)
- assert [message.role for message in parsed_request.messages] == [
- letter_to_cls[role].role for role in expected_roles
- ]
- assert len(expected_content) == len(parsed_request.messages)
- for x, expected in zip(parsed_request.messages, expected_content):
- assert isinstance(x, (UserMessage, AssistantMessage))
- assert x.content == expected
-
- def test_message_aggregation(self, normalizer: InstructRequestNormalizer) -> None:
- self.check_merge(["s", "s", "s", "u"], ["u"], ["u"], normalizer)
- self.check_merge(["s", "s", "s", "u", "u"], ["u"], ["u\n\nu"], normalizer)
- self.check_merge(["s", "s", "s", "u", "u", "s", "a", "u"], ["u", "a", "u"], ["u\n\nu", "a", "u"], normalizer)
- self.check_merge(
- ["s", "s", "s", "u", "u", "a", "a", "u"],
- ["u", "a", "u"],
- ["u\n\nu", "a\n\na", "u"],
- normalizer,
+ def test_message_aggregation_system_then_users(self, normalizer: InstructRequestNormalizer) -> None:
+ parsed = normalizer.from_chat_completion_request(
+ mock_chat_completion(
+ [
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ UserMessage(content="u"),
+ UserMessage(content="u"),
+ ]
+ )
+ )
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="u\n\nu")], system_prompt="s\n\ns\n\ns"
+ )
+
+ def test_message_aggregation_mixed_with_middle_system(self, normalizer: InstructRequestNormalizer) -> None:
+ parsed = normalizer.from_chat_completion_request(
+ mock_chat_completion(
+ [
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ UserMessage(content="u"),
+ UserMessage(content="u"),
+ SystemMessage(content="s"),
+ AssistantMessage(content="a"),
+ UserMessage(content="u"),
+ ]
+ )
+ )
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="u\n\nu"), AssistantMessage(content="a"), UserMessage(content="u")],
+ system_prompt="s\n\ns\n\ns\n\ns",
+ )
+
+ def test_message_aggregation_consecutive_assistants(self, normalizer: InstructRequestNormalizer) -> None:
+ parsed = normalizer.from_chat_completion_request(
+ mock_chat_completion(
+ [
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ UserMessage(content="u"),
+ UserMessage(content="u"),
+ AssistantMessage(content="a"),
+ AssistantMessage(content="a"),
+ UserMessage(content="u"),
+ ]
+ )
+ )
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="u\n\nu"), AssistantMessage(content="a\n\na"), UserMessage(content="u")],
+ system_prompt="s\n\ns\n\ns",
)
- self.check_merge(
- ["s", "a", "u"],
- ["u", "a", "u"],
- ["", "a", "u"],
- normalizer,
+ def test_message_aggregation_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> None:
+ parsed = normalizer.from_chat_completion_request(
+ mock_chat_completion([SystemMessage(content="s"), AssistantMessage(content="a"), UserMessage(content="u")])
+ )
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content=""), AssistantMessage(content="a"), UserMessage(content="u")],
+ system_prompt="s",
)
def test_tool_chunk_aggregation(self, normalizer: InstructRequestNormalizer) -> None:
@@ -214,8 +253,9 @@ def test_normalize_chunks(self, normalizer: InstructRequestNormalizer) -> None:
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
-
- assert parsed_request.messages[0] == UserMessage(content="foo\n\nchunk\n\nfoo\n\nchunk")
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="foo\n\nchunk\n\nfoo\n\nchunk")]
+ )
def test_many_chunks_in_user_message(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -231,7 +271,9 @@ def test_many_chunks_in_user_message(self, normalizer: InstructRequestNormalizer
],
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- assert parsed_request.messages[0] == UserMessage(content="foo\n\nchunk1\n\nchunk2\n\nchunk3")
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="foo\n\nchunk1\n\nchunk2\n\nchunk3")]
+ )
def test_ignore_middle_empty_text_chunks(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -253,8 +295,9 @@ def test_ignore_middle_empty_text_chunks(self, normalizer: InstructRequestNormal
]
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- assert parsed_request.messages[0] == UserMessage(content="U\n\nV")
- assert parsed_request.messages[1] == AssistantMessage(content="A\n\nB")
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="U\n\nV"), AssistantMessage(content="A\n\nB")],
+ )
def test_safety_prompt_aggregated(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = ChatCompletionRequest[ChatMessage](
@@ -268,8 +311,10 @@ def test_safety_prompt_aggregated(self, normalizer: InstructRequestNormalizer) -
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- assert parsed_request.messages[0] == UserMessage(content="user")
- assert parsed_request.system_prompt == "system"
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="user")],
+ system_prompt="system",
+ )
def test_normalize_tools(self, normalizer: InstructRequestNormalizer) -> None:
"""
@@ -348,18 +393,19 @@ def test_assert_parsed_settings(
normalizer: InstructRequestNormalizer,
) -> None:
chat_completion_request = ChatCompletionRequest(messages=[UserMessage(content="B")])
- parsed_request: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(
- chat_completion_request
- )
- assert parsed_request.settings == ModelSettings.none()
+ parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
+ assert parsed_request == InstructRequest[ChatMessage, Tool](messages=[UserMessage(content="B")])
def test_continue_final_message_forwarded(self, normalizer: InstructRequestNormalizer) -> None:
request = ChatCompletionRequest[ChatMessage](
messages=[UserMessage(content="a"), AssistantMessage(content="b")],
continue_final_message=True,
)
- result: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
- assert result.continue_final_message is True
+ result = normalizer.from_chat_completion_request(request)
+ assert result == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="a"), AssistantMessage(content="b")],
+ continue_final_message=True,
+ )
def test_rejects_audio_in_system_message(self, normalizer: InstructRequestNormalizer) -> None:
r"""Pre-V7 normalizer rejects AudioChunk in system messages."""
@@ -395,8 +441,17 @@ def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalize
ToolMessage(content=messy_json, tool_call_id="c1"),
],
)
- parsed: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
- assert parsed.messages[2] == ToolMessage(content='{"key": "value", "num": 1}', tool_call_id="c1")
+ parsed = normalizer.from_chat_completion_request(request)
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ content="",
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")],
+ ),
+ ToolMessage(content='{"key": "value", "num": 1}', tool_call_id="c1"),
+ ],
+ )
def test_rejects_think_in_assistant(self, normalizer: InstructRequestNormalizer) -> None:
r"""Pre-v11 normalizer rejects ThinkChunk in assistant messages."""
@@ -426,13 +481,14 @@ def test_system_assistant_user_v7(self, normalizer_v7: InstructRequestNormalizer
]
)
- parsed_request: InstructRequest = normalizer_v7.from_chat_completion_request(chat_completion_request)
-
- assert parsed_request.messages[0] == SystemMessage(content="S")
- assert parsed_request.messages[1] == AssistantMessage(content="A")
- assert parsed_request.system_prompt is None
+ parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
+ chat_completion_request
+ )
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[SystemMessage(content="S"), AssistantMessage(content="A"), UserMessage(content="U")],
+ )
- def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNormalizer) -> None:
+ def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
chat_completion_request = mock_chat_completion(
messages=[
AssistantMessage(content="A"),
@@ -440,16 +496,14 @@ def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNorma
]
)
- parsed_request = normalizer_v7.from_chat_completion_request(chat_completion_request)
-
- assert parsed_request.system_prompt is None
-
- assert len(parsed_request.messages) == 2
-
- assert parsed_request.messages[0] == AssistantMessage(content="A")
- assert parsed_request.messages[1] == SystemMessage(content="S")
+ parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
+ chat_completion_request
+ )
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[AssistantMessage(content="A"), SystemMessage(content="S")],
+ )
- def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestNormalizer) -> None:
+ def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
chat_completion_request = mock_chat_completion(
messages=[
AssistantMessage(
@@ -458,27 +512,28 @@ def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestN
)
]
)
- normalized_chat_req = normalizer_v7.from_chat_completion_request(chat_completion_request)
-
- assert normalized_chat_req.messages[0] == AssistantMessage(
- content="A",
- tool_calls=[ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "11"}'))],
+ normalized_chat_req: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
+ chat_completion_request
+ )
+ assert normalized_chat_req == InstructRequest[ChatMessage, Tool](
+ messages=[
+ AssistantMessage(
+ content="A",
+ tool_calls=[ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "11"}'))],
+ ),
+ ],
)
- def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructRequestNormalizer) -> None:
+ def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
chat_completion_request = mock_chat_completion(
messages=[
UserMessage(content="A1"),
- AssistantMessage(
- content="B1",
- ),
+ AssistantMessage(content="B1"),
AssistantMessage(
content="B2",
tool_calls=[ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "1"}'))],
),
- AssistantMessage(
- content="B3",
- ),
+ AssistantMessage(content="B3"),
AssistantMessage(
content="B4",
tool_calls=[
@@ -486,24 +541,27 @@ def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructReq
ToolCall(function=FunctionCall(name="tool22", arguments='{"input": "22"}')),
],
),
+ AssistantMessage(content="B5"),
+ UserMessage(content="C1"),
+ ]
+ )
+ normalized_chat_req: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
+ chat_completion_request
+ )
+ assert normalized_chat_req == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="A1"),
AssistantMessage(
- content="B5",
+ content="B1\n\nB2\n\nB3\n\nB4\n\nB5",
+ tool_calls=[
+ ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "1"}')),
+ ToolCall(function=FunctionCall(name="tool21", arguments='{"input": "21"}')),
+ ToolCall(function=FunctionCall(name="tool22", arguments='{"input": "22"}')),
+ ],
),
UserMessage(content="C1"),
- ]
+ ],
)
- normalized_chat_req = normalizer_v7.from_chat_completion_request(chat_completion_request)
-
- assert normalized_chat_req.messages[0].content == "A1"
- assert normalized_chat_req.messages[1].content.split("\n\n") == [f"B{i}" for i in range(1, 6)]
-
- tool_calls = normalized_chat_req.messages[1].tool_calls
-
- assert len(tool_calls) == 3
-
- tool_key = ["1", "21", "22"]
- assert all([t.function.name == f"tool{tool_key[i]}" for i, t in enumerate(tool_calls)])
- assert all([json.loads(t.function.arguments)["input"] == tool_key[i] for i, t in enumerate(tool_calls)])
def test_assert_parsed_settings(
self,
@@ -513,7 +571,7 @@ def test_assert_parsed_settings(
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.settings == ModelSettings.none()
+ assert parsed_request == InstructRequest[ChatMessage, Tool](messages=[UserMessage(content="B")])
def test_continue_final_message_forwarded(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
request = ChatCompletionRequest[ChatMessage](
@@ -521,7 +579,10 @@ def test_continue_final_message_forwarded(self, normalizer_v7: InstructRequestNo
continue_final_message=True,
)
result: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
- assert result.continue_final_message is True
+ assert result == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="a"), AssistantMessage(content="b")],
+ continue_final_message=True,
+ )
@pytest.mark.parametrize("num_empty", [0, 1, 2])
def test_only_empty_text_chunks(self, normalizer_v7: InstructRequestNormalizerV7, num_empty: int) -> None:
@@ -535,8 +596,9 @@ def test_only_empty_text_chunks(self, normalizer_v7: InstructRequestNormalizerV7
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages[0] == UserMessage(content="")
- assert parsed_request.messages[1] == AssistantMessage(content="")
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content=""), AssistantMessage(content="")],
+ )
def test_complex_user_aggregation(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
"""Complex multi-user-message aggregation with mixed str, chunks, empty, and non-text chunks."""
@@ -563,12 +625,16 @@ def test_complex_user_aggregation(self, normalizer_v7: InstructRequestNormalizer
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages[0] == UserMessage(
- content=[
- TextChunk(text="A\n\nB\n\nC\n\nD"),
- ImageURLChunk(image_url="E"),
- TextChunk(text="G\n\nH"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(
+ content=[
+ TextChunk(text="A\n\nB\n\nC\n\nD"),
+ ImageURLChunk(image_url="E"),
+ TextChunk(text="G\n\nH"),
+ ]
+ ),
+ ],
)
def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
@@ -591,7 +657,9 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages[0] == AssistantMessage(content="A\n\nB\n\nC\n\nD")
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[AssistantMessage(content="A\n\nB\n\nC\n\nD")],
+ )
def test_rejects_think_in_assistant(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
r"""V7 normalizer rejects ThinkChunk in assistant messages (pre-v11)."""
@@ -613,7 +681,9 @@ def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7
],
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
- assert parsed.messages[1] == AssistantMessage(content="plain text")
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="query"), AssistantMessage(content="plain text")],
+ )
def test_skips_json_normalization_on_tool_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
r"""V7+ normalizers do not JSON-normalize tool message content."""
@@ -626,7 +696,16 @@ def test_skips_json_normalization_on_tool_content(self, normalizer_v7: InstructR
],
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
- assert parsed.messages[2] == ToolMessage(content=messy_json, tool_call_id="c1")
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ content="",
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")],
+ ),
+ ToolMessage(content=messy_json, tool_call_id="c1"),
+ ],
+ )
def test_preserves_audio_in_system_message(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
r"""V7 normalizer preserves AudioChunk in system messages."""
@@ -637,11 +716,13 @@ def test_preserves_audio_in_system_message(self, normalizer_v7: InstructRequestN
]
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
- assert parsed.messages[0] == SystemMessage(
- content=[
- TextChunk(text="hello"),
- AudioChunk(input_audio=b"fake_audio_data"),
- ]
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ SystemMessage(
+ content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")],
+ ),
+ UserMessage(content="test"),
+ ],
)
@@ -733,18 +814,20 @@ def test_no_reorder_tool_messages(self, normalizer_v13: InstructRequestNormalize
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- UserMessage(content="A"),
- AssistantMessage(
- content="B",
- tool_calls=[
- ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
- ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
- ],
- ),
- ToolMessage(content="C", tool_call_id="1"),
- ToolMessage(content="D", tool_call_id="2"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="A"),
+ AssistantMessage(
+ content="B",
+ tool_calls=[
+ ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
+ ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
+ ],
+ ),
+ ToolMessage(content="C", tool_call_id="1"),
+ ToolMessage(content="D", tool_call_id="2"),
+ ],
+ )
def test_reorder_last_tool_messages(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
chat_completion_request: ChatCompletionRequest = self._mock_chat_completion(
@@ -764,18 +847,20 @@ def test_reorder_last_tool_messages(self, normalizer_v13: InstructRequestNormali
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- UserMessage(content="A"),
- AssistantMessage(
- content="B",
- tool_calls=[
- ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
- ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
- ],
- ),
- ToolMessage(content="C", tool_call_id="1"),
- ToolMessage(content="D", tool_call_id="2"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="A"),
+ AssistantMessage(
+ content="B",
+ tool_calls=[
+ ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
+ ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
+ ],
+ ),
+ ToolMessage(content="C", tool_call_id="1"),
+ ToolMessage(content="D", tool_call_id="2"),
+ ],
+ )
def test_reorder_internal_tool_messages(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
chat_completion_request: ChatCompletionRequest = self._mock_chat_completion(
@@ -796,19 +881,21 @@ def test_reorder_internal_tool_messages(self, normalizer_v13: InstructRequestNor
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- UserMessage(content="A"),
- AssistantMessage(
- content="B",
- tool_calls=[
- ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
- ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
- ],
- ),
- ToolMessage(content="C", tool_call_id="1"),
- ToolMessage(content="D", tool_call_id="2"),
- AssistantMessage(content="E"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="A"),
+ AssistantMessage(
+ content="B",
+ tool_calls=[
+ ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
+ ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
+ ],
+ ),
+ ToolMessage(content="C", tool_call_id="1"),
+ ToolMessage(content="D", tool_call_id="2"),
+ AssistantMessage(content="E"),
+ ],
+ )
def test_reorder_extra_tool_messages(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
chat_completion_request: ChatCompletionRequest = self._mock_chat_completion(
@@ -827,17 +914,19 @@ def test_reorder_extra_tool_messages(self, normalizer_v13: InstructRequestNormal
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- UserMessage(content="A"),
- AssistantMessage(
- content="B",
- tool_calls=[
- ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
- ],
- ),
- ToolMessage(content="C", tool_call_id="1"),
- ToolMessage(content="D", tool_call_id="2"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="A"),
+ AssistantMessage(
+ content="B",
+ tool_calls=[
+ ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
+ ],
+ ),
+ ToolMessage(content="C", tool_call_id="1"),
+ ToolMessage(content="D", tool_call_id="2"),
+ ],
+ )
def test_reorder_only_from_latest_assistant_message(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
chat_completion_request: ChatCompletionRequest = self._mock_chat_completion(
@@ -866,27 +955,29 @@ def test_reorder_only_from_latest_assistant_message(self, normalizer_v13: Instru
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- UserMessage(content="A"),
- AssistantMessage(
- content="B",
- tool_calls=[
- ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
- ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
- ],
- ),
- ToolMessage(content="C", tool_call_id="1"),
- ToolMessage(content="D", tool_call_id="2"),
- AssistantMessage(
- content="E",
- tool_calls=[
- ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
- ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
- ],
- ),
- ToolMessage(content="D", tool_call_id="2"),
- ToolMessage(content="C", tool_call_id="1"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="A"),
+ AssistantMessage(
+ content="B",
+ tool_calls=[
+ ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
+ ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
+ ],
+ ),
+ ToolMessage(content="C", tool_call_id="1"),
+ ToolMessage(content="D", tool_call_id="2"),
+ AssistantMessage(
+ content="E",
+ tool_calls=[
+ ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
+ ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
+ ],
+ ),
+ ToolMessage(content="D", tool_call_id="2"),
+ ToolMessage(content="C", tool_call_id="1"),
+ ],
+ )
@pytest.mark.parametrize(
["system_message", "expected_system_message"],
@@ -932,7 +1023,9 @@ def test_aggregate_system_prompt_content(
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [expected_system_message, UserMessage(content="B")]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[expected_system_message, UserMessage(content="B")]
+ )
def test_system_messages_no_aggregation(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
"""Consecutive system messages are NOT aggregated into one in V7+."""
@@ -947,12 +1040,14 @@ def test_system_messages_no_aggregation(self, normalizer_v13: InstructRequestNor
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- SystemMessage(content="A"),
- SystemMessage(content="B"),
- UserMessage(content="C"),
- SystemMessage(content="D"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ SystemMessage(content="A"),
+ SystemMessage(content="B"),
+ UserMessage(content="C"),
+ SystemMessage(content="D"),
+ ],
+ )
def test_system_messages_normalization(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
"""System message chunks within the same message are aggregated with no separator."""
@@ -965,10 +1060,12 @@ def test_system_messages_normalization(self, normalizer_v13: InstructRequestNorm
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- SystemMessage(content="A\n\nB"),
- SystemMessage(content="C"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ SystemMessage(content="A\n\nB"),
+ SystemMessage(content="C"),
+ ],
+ )
def test_assert_parsed_settings(
self,
@@ -978,7 +1075,7 @@ def test_assert_parsed_settings(
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.settings == ModelSettings.none()
+ assert parsed_request == InstructRequest[ChatMessage, Tool](messages=[UserMessage(content="B")])
def test_continue_final_message_forwarded(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
request = ChatCompletionRequest[ChatMessage](
@@ -986,7 +1083,10 @@ def test_continue_final_message_forwarded(self, normalizer_v13: InstructRequestN
continue_final_message=True,
)
result: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
- assert result.continue_final_message is True
+ assert result == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="a"), AssistantMessage(content="b")],
+ continue_final_message=True,
+ )
def test_accepts_text_and_think_chunks(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
r"""V13 normalizer accepts TextChunk and ThinkChunk in assistant messages."""
@@ -997,8 +1097,11 @@ def test_accepts_text_and_think_chunks(self, normalizer_v13: InstructRequestNorm
],
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
- assert parsed.messages[1] == AssistantMessage(
- content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
+ ],
)
def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
@@ -1010,7 +1113,9 @@ def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV
],
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
- assert parsed.messages[1] == AssistantMessage(content="plain text")
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="query"), AssistantMessage(content="plain text")],
+ )
def test_assistant_think_chunk_inter_message_aggregation(
self, normalizer_v13: InstructRequestNormalizerV13
@@ -1032,12 +1137,16 @@ def test_assistant_think_chunk_inter_message_aggregation(
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed.messages[0] == AssistantMessage(
- content=[
- TextChunk(text="A\n\nB"),
- ThinkChunk(thinking="T"),
- TextChunk(text="C\n\nD"),
- ]
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ AssistantMessage(
+ content=[
+ TextChunk(text="A\n\nB"),
+ ThinkChunk(thinking="T"),
+ TextChunk(text="C\n\nD"),
+ ]
+ ),
+ ],
)
def test_assistant_think_chunk_intra_message_aggregation(
@@ -1062,12 +1171,18 @@ def test_assistant_think_chunk_intra_message_aggregation(
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed.messages[1] == AssistantMessage(
- content=[
- ThinkChunk(thinking="t1"),
- ThinkChunk(thinking="t2"),
- TextChunk(text="a1\n\na2\n\na3"),
- ]
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="u"),
+ AssistantMessage(
+ content=[
+ ThinkChunk(thinking="t1"),
+ ThinkChunk(thinking="t2"),
+ TextChunk(text="a1\n\na2\n\na3"),
+ ]
+ ),
+ UserMessage(content="u"),
+ ],
)
def test_rejects_non_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
@@ -1097,7 +1212,16 @@ def test_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNorma
],
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
- assert parsed.messages[2] == ToolMessage(content="hello\n\nworld", tool_call_id="c1")
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ content="",
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")],
+ ),
+ ToolMessage(content="hello\n\nworld", tool_call_id="c1"),
+ ],
+ )
def test_rejects_audio_in_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
r"""V13 normalizer raises InvalidRequestException for audio tool content."""
@@ -1143,7 +1267,10 @@ def test_assert_parsed_settings(
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.settings == ModelSettings(reasoning_effort=reasoning_effort)
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="B")],
+ settings=ModelSettings(reasoning_effort=reasoning_effort),
+ )
def test_continue_final_message_forwarded(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
request = ChatCompletionRequest[ChatMessage](
@@ -1152,7 +1279,11 @@ def test_continue_final_message_forwarded(self, normalizer_v15: InstructRequestN
reasoning_effort=ReasoningEffort.high,
)
result: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assert result.continue_final_message is True
+ assert result == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="a"), AssistantMessage(content="b")],
+ continue_final_message=True,
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_v15_intra_message_chunks_joined_without_separator(
self, normalizer_v15: InstructRequestNormalizerV15
@@ -1166,8 +1297,10 @@ def test_v15_intra_message_chunks_joined_without_separator(
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assert parsed.messages[0] == UserMessage(content="AB")
- assert parsed.messages[1] == AssistantMessage(content="CD")
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="AB"), AssistantMessage(content="CD")],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_v15_inter_message_join_still_uses_separator(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 still joins text across different messages with '\n\n'."""
@@ -1180,7 +1313,10 @@ def test_v15_inter_message_join_still_uses_separator(self, normalizer_v15: Instr
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assert parsed.messages[0] == UserMessage(content="First\n\nSecond")
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="First\n\nSecond"), AssistantMessage(content="Reply")],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_v15_mixed_intra_and_inter_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 combines intra-message ('') and inter-message ('\n\n') joining."""
@@ -1193,7 +1329,10 @@ def test_v15_mixed_intra_and_inter_message(self, normalizer_v15: InstructRequest
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assert parsed.messages[0] == UserMessage(content="AB\n\nCD")
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="AB\n\nCD"), AssistantMessage(content="Reply")],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_v15_mixed_intra_and_inter_assistant_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 combines intra-message ('') and inter-message ('\n\n') joining for assistant messages."""
@@ -1206,7 +1345,10 @@ def test_v15_mixed_intra_and_inter_assistant_messages(self, normalizer_v15: Inst
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assert parsed.messages[1] == AssistantMessage(content="AB\n\nCD")
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="Hello"), AssistantMessage(content="AB\n\nCD")],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_v15_tool_message_text_chunks_joined_without_separator(
self, normalizer_v15: InstructRequestNormalizerV15
@@ -1221,7 +1363,17 @@ def test_v15_tool_message_text_chunks_joined_without_separator(
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assert parsed.messages[2] == ToolMessage(content="XY", tool_call_id="c1")
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ content="",
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")],
+ ),
+ ToolMessage(content="XY", tool_call_id="c1"),
+ ],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_accepts_text_and_think_chunks(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 normalizer accepts TextChunk and ThinkChunk in assistant messages."""
@@ -1233,8 +1385,12 @@ def test_accepts_text_and_think_chunks(self, normalizer_v15: InstructRequestNorm
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assert parsed.messages[1] == AssistantMessage(
- content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
+ ],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
)
def test_accepts_string_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
@@ -1247,7 +1403,10 @@ def test_accepts_string_content(self, normalizer_v15: InstructRequestNormalizerV
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assert parsed.messages[1] == AssistantMessage(content="plain text")
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="query"), AssistantMessage(content="plain text")],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 normalizer preserves non-text chunks in tool messages."""
@@ -1261,7 +1420,17 @@ def test_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNo
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assert parsed.messages[2] == ToolMessage(content=[image_chunk], tool_call_id="c1")
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ content="",
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")],
+ ),
+ ToolMessage(content=[image_chunk], tool_call_id="c1"),
+ ],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 normalizer sorts multimodal tool messages by tool call order."""
@@ -1282,9 +1451,21 @@ def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNor
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
-
- assert parsed.messages[2] == ToolMessage(content=[image_chunk_1], tool_call_id="c1")
- assert parsed.messages[3] == ToolMessage(content=[image_chunk_2], tool_call_id="c2")
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ content="",
+ tool_calls=[
+ ToolCall(function=FunctionCall(name="fn1", arguments="{}"), id="c1"),
+ ToolCall(function=FunctionCall(name="fn2", arguments="{}"), id="c2"),
+ ],
+ ),
+ ToolMessage(content=[image_chunk_1], tool_call_id="c1"),
+ ToolMessage(content=[image_chunk_2], tool_call_id="c2"),
+ ],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_rejects_think_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 normalizer rejects ThinkChunk in system messages."""
@@ -1307,11 +1488,14 @@ def test_preserves_audio_in_system_message(self, normalizer_v15: InstructRequest
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assert parsed.messages[0] == SystemMessage(
- content=[
- TextChunk(text="hello"),
- AudioChunk(input_audio=b"fake_audio_data"),
- ]
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ SystemMessage(
+ content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")],
+ ),
+ UserMessage(content="test"),
+ ],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
)
From 75f78e95add60c2f6e8a53ecc132a26e064d6c4d Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 12 Jun 2026 18:06:29 +0200
Subject: [PATCH 34/47] Remove redundant empty reasoning-effort branch in test
helper
---
tests/test_tokenizer_v15.py | 4 ----
1 file changed, 4 deletions(-)
diff --git a/tests/test_tokenizer_v15.py b/tests/test_tokenizer_v15.py
index 968128f9..2c882cac 100644
--- a/tests/test_tokenizer_v15.py
+++ b/tests/test_tokenizer_v15.py
@@ -152,10 +152,6 @@ def _build_model_settings_builder(
"""
if allowed_reasoning_effort is None:
return ModelSettingsBuilder.none()
- if not allowed_reasoning_effort:
- return ModelSettingsBuilder(
- reasoning_effort=EnumBuilder[ReasoningEffort](values=[], accepts_none=True, default=None)
- )
return ModelSettingsBuilder(
reasoning_effort=EnumBuilder[ReasoningEffort](
values=[ReasoningEffort(v) for v in allowed_reasoning_effort],
From 093f44eb530156747af26492b016b20ed544e89b Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 12 Jun 2026 18:18:28 +0200
Subject: [PATCH 35/47] Add content-chunk type validation to base request
validator
---
.../protocol/instruct/validator.py | 49 ++++++++++++-
tests/validation/test_chat_validation.py | 70 ++++++++++++++++++-
2 files changed, 116 insertions(+), 3 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/validator.py b/src/mistral_common/protocol/instruct/validator.py
index f121e49e..5902422c 100644
--- a/src/mistral_common/protocol/instruct/validator.py
+++ b/src/mistral_common/protocol/instruct/validator.py
@@ -1,4 +1,5 @@
import re
+from collections.abc import Sequence
from enum import Enum
from typing import Generic
@@ -14,8 +15,15 @@
InvalidToolException,
InvalidToolMessageException,
InvalidToolSchemaException,
+ InvalidUserMessageException,
+ MistralCommonException,
+)
+from mistral_common.protocol.instruct.chunk import (
+ AudioChunk,
+ AudioURLChunk,
+ ContentChunk,
+ TextChunk,
)
-from mistral_common.protocol.instruct.chunk import AudioChunk, AudioURLChunk
from mistral_common.protocol.instruct.messages import (
UATS,
AssistantMessage,
@@ -38,6 +46,23 @@
from mistral_common.tokens.tokenizers.base import TokenizerVersion
+def _validate_content_chunk_types(
+ content: "str | Sequence[ContentChunk] | None",
+ allowed: tuple[type, ...],
+ role: str,
+ exception_cls: type[MistralCommonException],
+) -> None:
+ r"""Raise if any chunk in a list content is not an instance of `allowed`.
+
+ String or None content is always accepted (covered elsewhere).
+ """
+ if not isinstance(content, list):
+ return
+ invalid = sorted({type(chunk).__name__ for chunk in content if not isinstance(chunk, allowed)})
+ if invalid:
+ raise exception_cls(f"Unexpected content chunk types in {role} message: {invalid}")
+
+
class ValidationMode(str, Enum):
r"""Enum for the validation mode.
@@ -153,13 +178,30 @@ def _validate_tools(self, tools: list[Tool]) -> None:
self._validate_function(tool.function)
def _validate_user_message(self, message: UserMessageType) -> None:
- pass
+ self._validate_user_content_chunks(message.content)
+
+ def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ r"""v1/v2 user messages accept text content only (image >= v3, audio >= v7)."""
+ _validate_content_chunk_types(content, (TextChunk,), "user", InvalidUserMessageException)
+
+ def _validate_assistant_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ r"""Pre-v11 assistant messages accept text content only."""
+ _validate_content_chunk_types(content, (TextChunk,), "assistant", InvalidAssistantMessageException)
+
+ def _validate_system_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ r"""v1-v3 system messages accept text content only."""
+ _validate_content_chunk_types(content, (TextChunk,), "system", InvalidSystemPromptException)
+
+ def _validate_tool_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ r"""Pre-v15 tool messages accept text content only."""
+ _validate_content_chunk_types(content, (TextChunk,), "tool", InvalidToolMessageException)
def _validate_tool_message(self, message: ToolMessageType) -> None:
"""
Checks:
- The tool name is valid
"""
+ self._validate_tool_content_chunks(message.content)
if message.name is not None:
if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", message.name):
raise InvalidToolMessageException(
@@ -174,6 +216,7 @@ def _validate_system_message(self, message: SystemMessageType) -> None:
"""
if message.content is None:
raise InvalidSystemPromptException("System prompt must have content")
+ self._validate_system_content_chunks(message.content)
def _validate_function_call(self, function_call: FunctionCall) -> None:
"""
@@ -201,6 +244,8 @@ def _validate_assistant_message(self, message: AssistantMessageType, is_last_mes
- That the tool calls are valid
"""
+ self._validate_assistant_content_chunks(message.content)
+
# Validate that the message has either text or tool_calls
# but not both and not neither.
if (not self._allow_tool_call_and_content) and (bool(message.content) == bool(message.tool_calls)):
diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py
index f4c8653c..d5161d8a 100644
--- a/tests/validation/test_chat_validation.py
+++ b/tests/validation/test_chat_validation.py
@@ -4,8 +4,16 @@
InvalidAssistantMessageException,
InvalidMessageStructureException,
InvalidRequestException,
+ InvalidSystemPromptException,
+ InvalidToolMessageException,
+ InvalidUserMessageException,
+)
+from mistral_common.protocol.instruct.chunk import (
+ AudioChunk,
+ AudioURLChunk,
+ TextChunk,
+ ThinkChunk,
)
-from mistral_common.protocol.instruct.chunk import AudioChunk, AudioURLChunk
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
SystemMessage,
@@ -621,3 +629,63 @@ def test_build_settings_v15_reasoning_effort(
) -> None:
request = ChatCompletionRequest(messages=[UserMessage(content="Hello")], reasoning_effort=reasoning_effort)
validator_v15._validate_model_settings(request)
+
+
+class TestBaseValidatorContentChunks:
+ @pytest.fixture
+ def base_validator(self) -> MistralRequestValidator:
+ return MistralRequestValidator(ValidationMode.serving)
+
+ def test_rejects_think_in_assistant(self, base_validator: MistralRequestValidator) -> None:
+ messages = [
+ UserMessage(content="hi"),
+ AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]),
+ UserMessage(content="hi again"),
+ ]
+ with pytest.raises(
+ InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message"
+ ):
+ base_validator.validate_messages(messages, continue_final_message=False)
+
+ def test_rejects_audio_in_system(self, base_validator: MistralRequestValidator) -> None:
+ messages = [
+ SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]),
+ UserMessage(content="hi"),
+ ]
+ with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"):
+ base_validator.validate_messages(messages, continue_final_message=False)
+
+ def test_rejects_image_in_tool(self, base_validator: MistralRequestValidator) -> None:
+ from mistral_common.protocol.instruct.chunk import ImageURLChunk
+
+ messages = [
+ UserMessage(content="hi"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
+ ToolMessage(
+ content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")],
+ tool_call_id="test12345",
+ ),
+ ]
+ with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"):
+ base_validator.validate_messages(messages, continue_final_message=False)
+
+ def test_rejects_image_in_user(self, base_validator: MistralRequestValidator) -> None:
+ from mistral_common.protocol.instruct.chunk import ImageURLChunk
+
+ messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])]
+ with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
+ base_validator.validate_messages(messages, continue_final_message=False)
+
+ def test_rejects_audio_in_user(self, base_validator: MistralRequestValidator) -> None:
+ messages = [UserMessage(content=[AudioChunk(input_audio=b"fake")])]
+ with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
+ base_validator.validate_messages(messages, continue_final_message=False)
+
+ def test_accepts_multiple_text_chunks(self, base_validator: MistralRequestValidator) -> None:
+ messages = [
+ SystemMessage(content=[TextChunk(text="a"), TextChunk(text="b")]),
+ UserMessage(content="hi"),
+ AssistantMessage(content=[TextChunk(text="c"), TextChunk(text="d")]),
+ UserMessage(content="hi again"),
+ ]
+ base_validator.validate_messages(messages, continue_final_message=False)
From 091e1f41bc8406d72f294d7fde0fab285ff7acea Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 12 Jun 2026 18:21:47 +0200
Subject: [PATCH 36/47] Add per-version content-chunk validation overrides
---
.../protocol/instruct/validator.py | 37 ++++++++
tests/validation/test_chat_validation.py | 93 +++++++++++++++++++
2 files changed, 130 insertions(+)
diff --git a/src/mistral_common/protocol/instruct/validator.py b/src/mistral_common/protocol/instruct/validator.py
index 5902422c..6a271aa7 100644
--- a/src/mistral_common/protocol/instruct/validator.py
+++ b/src/mistral_common/protocol/instruct/validator.py
@@ -22,7 +22,10 @@
AudioChunk,
AudioURLChunk,
ContentChunk,
+ ImageChunk,
+ ImageURLChunk,
TextChunk,
+ ThinkChunk,
)
from mistral_common.protocol.instruct.messages import (
UATS,
@@ -404,12 +407,19 @@ class MistralRequestValidatorV3(MistralRequestValidator):
>>> validator = MistralRequestValidatorV3()
"""
+ def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ r"""v3 user messages accept text and image chunks (audio >= v7)."""
+ _validate_content_chunk_types(
+ content, (TextChunk, ImageChunk, ImageURLChunk), "user", InvalidUserMessageException
+ )
+
def _validate_tool_message(self, message: ToolMessageType) -> None:
"""
Checks:
- The tool name is valid
- Tool call id is valid
"""
+ self._validate_tool_content_chunks(message.content)
if message.name is not None:
if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", message.name):
raise InvalidToolMessageException(
@@ -468,6 +478,21 @@ class MistralRequestValidatorV5(MistralRequestValidatorV3):
_allow_tool_call_and_content: bool = True
+ def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ r"""v7+ user messages accept text, image and audio chunks."""
+ _validate_content_chunk_types(
+ content,
+ (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk),
+ "user",
+ InvalidUserMessageException,
+ )
+
+ def _validate_system_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ r"""v7+ system messages accept text, audio and thinking chunks."""
+ _validate_content_chunk_types(
+ content, (TextChunk, AudioChunk, ThinkChunk), "system", InvalidSystemPromptException
+ )
+
def _validate_system_prompt_and_audio(self, messages: list[UATS]) -> None:
r"""Validates that system prompts and audio chunks are not used together in v5."""
@@ -552,12 +577,24 @@ def _validate_tool_calls_followed_by_tool_messages(self, messages: list[UATS]) -
elif len(expected_tool_ids) < len(observed_tool_ids) and self._mode == ValidationMode.finetuning:
raise InvalidMessageStructureException("More tool responses than tool calls")
+ def _validate_assistant_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ r"""v11+ assistant messages accept text and thinking chunks."""
+ _validate_content_chunk_types(content, (TextChunk, ThinkChunk), "assistant", InvalidAssistantMessageException)
+
def _validate_system_prompt_and_audio(self, messages: list[UATS]) -> None:
r"""Allows system prompts and audio chunks to coexist in v13."""
return
class MistralRequestValidatorV15(MistralRequestValidatorV13):
+ def _validate_system_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ r"""v15 system messages accept text and audio but reject thinking chunks."""
+ _validate_content_chunk_types(content, (TextChunk, AudioChunk), "system", InvalidSystemPromptException)
+
+ def _validate_tool_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ r"""v15 tool messages accept all content chunk types."""
+ return
+
def _validate_model_settings(self, request: ChatCompletionRequest) -> None:
pass
diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py
index d5161d8a..17f88e5b 100644
--- a/tests/validation/test_chat_validation.py
+++ b/tests/validation/test_chat_validation.py
@@ -689,3 +689,96 @@ def test_accepts_multiple_text_chunks(self, base_validator: MistralRequestValida
UserMessage(content="hi again"),
]
base_validator.validate_messages(messages, continue_final_message=False)
+
+
+class TestVersionedValidatorContentChunks:
+ def test_v3_allows_image_rejects_audio_in_user(self) -> None:
+ from mistral_common.protocol.instruct.chunk import ImageURLChunk
+
+ validator = MistralRequestValidatorV3(ValidationMode.serving)
+ ok = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])]
+ validator.validate_messages(ok, continue_final_message=False)
+ bad = [UserMessage(content=[AudioChunk(input_audio=b"fake")])]
+ with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
+ validator.validate_messages(bad, continue_final_message=False)
+
+ def test_v5_allows_image_and_audio_in_user(self) -> None:
+ from mistral_common.protocol.instruct.chunk import ImageURLChunk
+
+ validator = MistralRequestValidatorV5(ValidationMode.serving)
+ messages = [
+ UserMessage(
+ content=[
+ ImageURLChunk(image_url="data:image/png;base64,iVBORw0"),
+ AudioChunk(input_audio=b"fake"),
+ ]
+ )
+ ]
+ validator.validate_messages(messages, continue_final_message=False)
+
+ def test_v5_allows_audio_and_think_in_system(self) -> None:
+ validator = MistralRequestValidatorV5(ValidationMode.serving)
+ messages = [
+ SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake"), ThinkChunk(thinking="t")]),
+ UserMessage(content="hi"),
+ ]
+ validator.validate_messages(messages, continue_final_message=False)
+
+ def test_v5_still_rejects_think_in_assistant(self) -> None:
+ validator = MistralRequestValidatorV5(ValidationMode.serving)
+ messages = [
+ UserMessage(content="hi"),
+ AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]),
+ UserMessage(content="next"),
+ ]
+ with pytest.raises(InvalidAssistantMessageException, match="Unexpected content chunk types"):
+ validator.validate_messages(messages, continue_final_message=False)
+
+ def test_v13_allows_think_in_assistant(self) -> None:
+ validator = MistralRequestValidatorV13(ValidationMode.serving)
+ messages = [
+ UserMessage(content="hi"),
+ AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]),
+ UserMessage(content="next"),
+ ]
+ validator.validate_messages(messages, continue_final_message=False)
+
+ def test_v13_still_rejects_image_in_tool(self) -> None:
+ from mistral_common.protocol.instruct.chunk import ImageURLChunk
+
+ validator = MistralRequestValidatorV13(ValidationMode.serving)
+ messages = [
+ UserMessage(content="hi"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
+ ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="test12345"),
+ ]
+ with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"):
+ validator.validate_messages(messages, continue_final_message=False)
+
+ def test_v15_rejects_think_in_system(self) -> None:
+ validator = MistralRequestValidatorV15(ValidationMode.serving)
+ messages = [
+ SystemMessage(content=[TextChunk(text="x"), ThinkChunk(thinking="t")]),
+ UserMessage(content="hi"),
+ ]
+ with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"):
+ validator.validate_messages(messages, continue_final_message=False)
+
+ def test_v15_allows_audio_in_system(self) -> None:
+ validator = MistralRequestValidatorV15(ValidationMode.serving)
+ messages = [
+ SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]),
+ UserMessage(content="hi"),
+ ]
+ validator.validate_messages(messages, continue_final_message=False)
+
+ def test_v15_allows_image_in_tool(self) -> None:
+ from mistral_common.protocol.instruct.chunk import ImageURLChunk
+
+ validator = MistralRequestValidatorV15(ValidationMode.serving)
+ messages = [
+ UserMessage(content="hi"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
+ ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="test12345"),
+ ]
+ validator.validate_messages(messages, continue_final_message=False)
From b21094399a7ddc1809b3c4fdcdab0551c19f9bd9 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 12 Jun 2026 18:31:14 +0200
Subject: [PATCH 37/47] Clean up content-chunk validator hooks per review
---
.../protocol/instruct/validator.py | 32 ++++++++-----------
tests/validation/test_chat_validation.py | 32 +++++++++++--------
2 files changed, 32 insertions(+), 32 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/validator.py b/src/mistral_common/protocol/instruct/validator.py
index 6a271aa7..e15d0935 100644
--- a/src/mistral_common/protocol/instruct/validator.py
+++ b/src/mistral_common/protocol/instruct/validator.py
@@ -50,8 +50,8 @@
def _validate_content_chunk_types(
- content: "str | Sequence[ContentChunk] | None",
- allowed: tuple[type, ...],
+ content: str | Sequence[ContentChunk] | None,
+ allowed: tuple[type[ContentChunk], ...],
role: str,
exception_cls: type[MistralCommonException],
) -> None:
@@ -183,19 +183,19 @@ def _validate_tools(self, tools: list[Tool]) -> None:
def _validate_user_message(self, message: UserMessageType) -> None:
self._validate_user_content_chunks(message.content)
- def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ def _validate_user_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
r"""v1/v2 user messages accept text content only (image >= v3, audio >= v7)."""
_validate_content_chunk_types(content, (TextChunk,), "user", InvalidUserMessageException)
- def _validate_assistant_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ def _validate_assistant_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
r"""Pre-v11 assistant messages accept text content only."""
_validate_content_chunk_types(content, (TextChunk,), "assistant", InvalidAssistantMessageException)
- def _validate_system_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ def _validate_system_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
r"""v1-v3 system messages accept text content only."""
_validate_content_chunk_types(content, (TextChunk,), "system", InvalidSystemPromptException)
- def _validate_tool_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ def _validate_tool_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
r"""Pre-v15 tool messages accept text content only."""
_validate_content_chunk_types(content, (TextChunk,), "tool", InvalidToolMessageException)
@@ -407,7 +407,7 @@ class MistralRequestValidatorV3(MistralRequestValidator):
>>> validator = MistralRequestValidatorV3()
"""
- def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ def _validate_user_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
r"""v3 user messages accept text and image chunks (audio >= v7)."""
_validate_content_chunk_types(
content, (TextChunk, ImageChunk, ImageURLChunk), "user", InvalidUserMessageException
@@ -419,13 +419,7 @@ def _validate_tool_message(self, message: ToolMessageType) -> None:
- The tool name is valid
- Tool call id is valid
"""
- self._validate_tool_content_chunks(message.content)
- if message.name is not None:
- if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", message.name):
- raise InvalidToolMessageException(
- f"Function name was {message.name} but must be a-z, A-Z, 0-9, "
- "or contain underscores and dashes, with a maximum length of 64."
- )
+ super()._validate_tool_message(message)
if message.tool_call_id is None:
raise InvalidRequestException("Tool call id has to be defined.")
@@ -478,7 +472,7 @@ class MistralRequestValidatorV5(MistralRequestValidatorV3):
_allow_tool_call_and_content: bool = True
- def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ def _validate_user_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
r"""v7+ user messages accept text, image and audio chunks."""
_validate_content_chunk_types(
content,
@@ -487,7 +481,7 @@ def _validate_user_content_chunks(self, content: "str | Sequence[ContentChunk] |
InvalidUserMessageException,
)
- def _validate_system_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ def _validate_system_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
r"""v7+ system messages accept text, audio and thinking chunks."""
_validate_content_chunk_types(
content, (TextChunk, AudioChunk, ThinkChunk), "system", InvalidSystemPromptException
@@ -577,7 +571,7 @@ def _validate_tool_calls_followed_by_tool_messages(self, messages: list[UATS]) -
elif len(expected_tool_ids) < len(observed_tool_ids) and self._mode == ValidationMode.finetuning:
raise InvalidMessageStructureException("More tool responses than tool calls")
- def _validate_assistant_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ def _validate_assistant_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
r"""v11+ assistant messages accept text and thinking chunks."""
_validate_content_chunk_types(content, (TextChunk, ThinkChunk), "assistant", InvalidAssistantMessageException)
@@ -587,11 +581,11 @@ def _validate_system_prompt_and_audio(self, messages: list[UATS]) -> None:
class MistralRequestValidatorV15(MistralRequestValidatorV13):
- def _validate_system_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ def _validate_system_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
r"""v15 system messages accept text and audio but reject thinking chunks."""
_validate_content_chunk_types(content, (TextChunk, AudioChunk), "system", InvalidSystemPromptException)
- def _validate_tool_content_chunks(self, content: "str | Sequence[ContentChunk] | None") -> None:
+ def _validate_tool_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
r"""v15 tool messages accept all content chunk types."""
return
diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py
index 17f88e5b..74a5a01a 100644
--- a/tests/validation/test_chat_validation.py
+++ b/tests/validation/test_chat_validation.py
@@ -11,6 +11,7 @@
from mistral_common.protocol.instruct.chunk import (
AudioChunk,
AudioURLChunk,
+ ImageURLChunk,
TextChunk,
ThinkChunk,
)
@@ -656,8 +657,6 @@ def test_rejects_audio_in_system(self, base_validator: MistralRequestValidator)
base_validator.validate_messages(messages, continue_final_message=False)
def test_rejects_image_in_tool(self, base_validator: MistralRequestValidator) -> None:
- from mistral_common.protocol.instruct.chunk import ImageURLChunk
-
messages = [
UserMessage(content="hi"),
AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
@@ -670,8 +669,6 @@ def test_rejects_image_in_tool(self, base_validator: MistralRequestValidator) ->
base_validator.validate_messages(messages, continue_final_message=False)
def test_rejects_image_in_user(self, base_validator: MistralRequestValidator) -> None:
- from mistral_common.protocol.instruct.chunk import ImageURLChunk
-
messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])]
with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
base_validator.validate_messages(messages, continue_final_message=False)
@@ -681,6 +678,21 @@ def test_rejects_audio_in_user(self, base_validator: MistralRequestValidator) ->
with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
base_validator.validate_messages(messages, continue_final_message=False)
+ def test_reports_sorted_unique_invalid_types(self, base_validator: MistralRequestValidator) -> None:
+ messages = [
+ UserMessage(
+ content=[
+ AudioChunk(input_audio=b"fake"),
+ ImageURLChunk(image_url="data:image/png;base64,iVBORw0"),
+ ]
+ )
+ ]
+ with pytest.raises(
+ InvalidUserMessageException,
+ match=r"Unexpected content chunk types in user message: \['AudioChunk', 'ImageURLChunk'\]",
+ ):
+ base_validator.validate_messages(messages, continue_final_message=False)
+
def test_accepts_multiple_text_chunks(self, base_validator: MistralRequestValidator) -> None:
messages = [
SystemMessage(content=[TextChunk(text="a"), TextChunk(text="b")]),
@@ -693,8 +705,6 @@ def test_accepts_multiple_text_chunks(self, base_validator: MistralRequestValida
class TestVersionedValidatorContentChunks:
def test_v3_allows_image_rejects_audio_in_user(self) -> None:
- from mistral_common.protocol.instruct.chunk import ImageURLChunk
-
validator = MistralRequestValidatorV3(ValidationMode.serving)
ok = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])]
validator.validate_messages(ok, continue_final_message=False)
@@ -703,8 +713,6 @@ def test_v3_allows_image_rejects_audio_in_user(self) -> None:
validator.validate_messages(bad, continue_final_message=False)
def test_v5_allows_image_and_audio_in_user(self) -> None:
- from mistral_common.protocol.instruct.chunk import ImageURLChunk
-
validator = MistralRequestValidatorV5(ValidationMode.serving)
messages = [
UserMessage(
@@ -731,7 +739,9 @@ def test_v5_still_rejects_think_in_assistant(self) -> None:
AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]),
UserMessage(content="next"),
]
- with pytest.raises(InvalidAssistantMessageException, match="Unexpected content chunk types"):
+ with pytest.raises(
+ InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message"
+ ):
validator.validate_messages(messages, continue_final_message=False)
def test_v13_allows_think_in_assistant(self) -> None:
@@ -744,8 +754,6 @@ def test_v13_allows_think_in_assistant(self) -> None:
validator.validate_messages(messages, continue_final_message=False)
def test_v13_still_rejects_image_in_tool(self) -> None:
- from mistral_common.protocol.instruct.chunk import ImageURLChunk
-
validator = MistralRequestValidatorV13(ValidationMode.serving)
messages = [
UserMessage(content="hi"),
@@ -773,8 +781,6 @@ def test_v15_allows_audio_in_system(self) -> None:
validator.validate_messages(messages, continue_final_message=False)
def test_v15_allows_image_in_tool(self) -> None:
- from mistral_common.protocol.instruct.chunk import ImageURLChunk
-
validator = MistralRequestValidatorV15(ValidationMode.serving)
messages = [
UserMessage(content="hi"),
From 43b15e66ef22e2caac6a284282fdf15dbbbcc7eb Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 12 Jun 2026 18:33:54 +0200
Subject: [PATCH 38/47] Move content-chunk validation out of the normalizer
---
.../protocol/instruct/normalize.py | 114 ++----------------
1 file changed, 8 insertions(+), 106 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index 8cd5f904..d6785b34 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -1,20 +1,15 @@
import json
import warnings
-from typing import Generic, Sequence
+from typing import Generic, Sequence, cast
-from typing_extensions import TypeGuard, assert_never
+from typing_extensions import assert_never
from mistral_common.exceptions import InvalidRequestException
from mistral_common.protocol.instruct.chunk import (
AssistantContentChunk,
- AudioChunk,
- AudioURLChunk,
ContentChunk,
- ImageChunk,
- ImageURLChunk,
SystemContentChunk,
TextChunk,
- ThinkChunk,
UserContentChunk,
)
from mistral_common.protocol.instruct.messages import (
@@ -38,27 +33,6 @@
_DEFAULT_JOIN_STR = "\n\n"
-def _is_user_content(
- chunks: list[ContentChunk],
-) -> TypeGuard[list[UserContentChunk]]:
- r"""Narrow ContentChunk list to user-compatible types."""
- return all(isinstance(c, (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk)) for c in chunks)
-
-
-def _is_assistant_content(
- chunks: list[ContentChunk],
-) -> TypeGuard[list[AssistantContentChunk]]:
- r"""Narrow ContentChunk list to assistant-compatible types."""
- return all(isinstance(c, (TextChunk, ThinkChunk)) for c in chunks)
-
-
-def _is_system_content(
- chunks: list[ContentChunk],
-) -> TypeGuard[list[SystemContentChunk]]:
- r"""Narrow ContentChunk list to system-compatible types."""
- return all(isinstance(c, (TextChunk, AudioChunk, ThinkChunk)) for c in chunks)
-
-
def _aggregate_content_chunks_impl(
contents: list[list[ContentChunk] | str | None],
msg_join_str: str,
@@ -262,10 +236,7 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s
for message in messages:
assert isinstance(message, self._tool_message_class), "Expected tool message"
content = self._aggregate_content_chunks([message])
- if not isinstance(content, str):
- raise InvalidRequestException(
- f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}"
- )
+ assert isinstance(content, str), f"Unexpected content chunk types in tool message: {content}"
normalized_content = self._normalize_json_content(content)
tool_messages.append(
@@ -283,27 +254,6 @@ def _normalize_tool_call(self, tool_call: ToolCall) -> ToolCall:
id=tool_call.id,
)
- def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[AssistantContentChunk]:
- r"""Validate content chunks for assistant messages.
-
- Pre-v11 only allows text content in assistant messages.
- V11+ overrides to also accept ThinkChunk.
-
- Args:
- content: The aggregated content chunks.
-
- Returns:
- The validated content.
-
- Raises:
- InvalidRequestException: If unsupported chunk types are found.
- """
- if isinstance(content, str):
- return content
- raise InvalidRequestException(
- f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
- )
-
def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]:
return []
@@ -334,10 +284,8 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
)
weight = message.weight
- validated_content = self._narrow_assistant_content(content)
-
aggregated_message = self._assistant_message_class(
- content=validated_content,
+ content=cast("str | list[AssistantContentChunk]", content),
tool_calls=tool_calls or None,
prefix=prefix,
)
@@ -349,11 +297,7 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType:
"""Coalesce neighboring blocks of ContentChunks in user messages."""
content = self._aggregate_content_chunks(messages)
- if isinstance(content, str) or _is_user_content(content):
- return self._user_message_class(content=content)
- raise InvalidRequestException(
- f"Unexpected content chunk types in user message: {[type(c).__name__ for c in content]}"
- )
+ return self._user_message_class(content=cast("str | list[UserContentChunk]", content))
def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]:
if role == Roles.tool:
@@ -476,27 +420,6 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I
UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest[UATS, Tool], None
)
- def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]:
- r"""Validate content chunks for system messages.
-
- V7+ accepts all SystemContentChunk types (text, audio, thinking).
- V15 overrides to reject ThinkChunk.
-
- Args:
- content: The aggregated content chunks.
-
- Returns:
- The validated content.
-
- Raises:
- InvalidRequestException: If unsupported chunk types are found.
- """
- if isinstance(content, str) or _is_system_content(content):
- return content
- raise InvalidRequestException(
- f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}"
- )
-
def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
"""Normalize tool messages without JSON normalization.
@@ -507,10 +430,7 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s
for message in messages:
assert isinstance(message, self._tool_message_class), "Expected tool message"
content = self._aggregate_content_chunks([message])
- if not isinstance(content, str):
- raise InvalidRequestException(
- f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}"
- )
+ assert isinstance(content, str), f"Unexpected content chunk types in tool message: {content}"
tool_messages.append(
self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name)
)
@@ -521,8 +441,7 @@ def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessage
for message in messages:
if isinstance(message, self._system_message_class):
content = self._aggregate_content_chunks([message])
- validated = self._narrow_system_content(content)
- aggregated.append(self._system_message_class(content=validated))
+ aggregated.append(self._system_message_class(content=cast("str | list[SystemContentChunk]", content)))
return aggregated
def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]:
@@ -616,14 +535,6 @@ def _inplace_sort_tool_messages(tool_messages: list[ToolMessageType], latest_cal
),
)
- def _narrow_assistant_content(self, content: list[ContentChunk] | str) -> str | list[AssistantContentChunk]:
- r"""V11+ accepts TextChunk and ThinkChunk in assistant messages."""
- if isinstance(content, str) or _is_assistant_content(content):
- return content
- raise InvalidRequestException(
- f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
- )
-
def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
tool_messages: list[ToolMessageType] = super()._aggregate_tool_messages(messages, latest_call_ids)
self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids)
@@ -642,7 +553,7 @@ class InstructRequestNormalizerV15(InstructRequestNormalizerV13):
_chunk_join_str: str = ""
def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
- r"""V15 accepts all ContentChunk types in tool messages."""
+ r"""V15 keeps all aggregated tool content (validation handled by the validator)."""
tool_messages: list[ToolMessageType] = []
for message in messages:
assert isinstance(message, self._tool_message_class), "Expected tool message"
@@ -653,15 +564,6 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s
self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids)
return tool_messages
- def _narrow_system_content(self, content: list[ContentChunk] | str) -> str | list[SystemContentChunk]:
- r"""V15 system messages allow TextChunk and AudioChunk but reject ThinkChunk."""
- validated = super()._narrow_system_content(content)
- if isinstance(validated, str):
- return validated
- if any(isinstance(c, ThinkChunk) for c in validated):
- raise InvalidRequestException("ThinkChunk in system message is not supported for V15")
- return validated
-
@staticmethod
def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15":
r"""Returns a normalizer for the V15 instruct request.
From f995ddd6771ce494fb709e0601bd3b3bc79f01af Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 12 Jun 2026 18:36:37 +0200
Subject: [PATCH 39/47] Reconcile chunk-type tests across Pydantic, normalizer
and validator layers
---
tests/test_messages.py | 41 ++++++++++++++++++++++
tests/test_normalization.py | 68 +++++++------------------------------
2 files changed, 54 insertions(+), 55 deletions(-)
create mode 100644 tests/test_messages.py
diff --git a/tests/test_messages.py b/tests/test_messages.py
new file mode 100644
index 00000000..ca6f6486
--- /dev/null
+++ b/tests/test_messages.py
@@ -0,0 +1,41 @@
+import pytest
+from pydantic import ValidationError
+
+from mistral_common.protocol.instruct.chunk import (
+ AudioChunk,
+ ImageURLChunk,
+ TextChunk,
+ ThinkChunk,
+)
+from mistral_common.protocol.instruct.messages import (
+ AssistantMessage,
+ SystemMessage,
+ UserMessage,
+)
+
+
+class TestMessageContentChunkUnions:
+ def test_assistant_rejects_image(self) -> None:
+ with pytest.raises(ValidationError):
+ AssistantMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])
+
+ def test_assistant_rejects_audio(self) -> None:
+ with pytest.raises(ValidationError):
+ AssistantMessage(content=[AudioChunk(input_audio=b"fake")])
+
+ def test_assistant_accepts_text_and_think(self) -> None:
+ AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")])
+
+ def test_system_rejects_image(self) -> None:
+ with pytest.raises(ValidationError):
+ SystemMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])
+
+ def test_system_accepts_audio(self) -> None:
+ SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")])
+
+ def test_user_rejects_think(self) -> None:
+ with pytest.raises(ValidationError):
+ UserMessage(content=[ThinkChunk(thinking="r")])
+
+ def test_user_accepts_image(self) -> None:
+ UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index 26f7d8be..6fa9d19a 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -2,7 +2,6 @@
import pytest
-from mistral_common.exceptions import InvalidRequestException
from mistral_common.protocol.instruct.chunk import (
AudioChunk,
ChunkTypes,
@@ -453,16 +452,18 @@ def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalize
],
)
- def test_rejects_think_in_assistant(self, normalizer: InstructRequestNormalizer) -> None:
- r"""Pre-v11 normalizer rejects ThinkChunk in assistant messages."""
+ def test_passes_think_in_assistant_through(self, normalizer: InstructRequestNormalizer) -> None:
+ r"""Normalizer no longer gates version rules; the validator rejects pre-v11 think."""
request = mock_chat_completion(
messages=[
UserMessage(content="query"),
AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
]
)
- with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in assistant message"):
- normalizer.from_chat_completion_request(request)
+ parsed = normalizer.from_chat_completion_request(request)
+ assert parsed.messages[-1] == AssistantMessage(
+ content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]
+ )
class TestChatCompletionRequestNormalizationV7:
@@ -661,16 +662,18 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma
messages=[AssistantMessage(content="A\n\nB\n\nC\n\nD")],
)
- def test_rejects_think_in_assistant(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
- r"""V7 normalizer rejects ThinkChunk in assistant messages (pre-v11)."""
+ def test_passes_think_in_assistant_through(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
+ r"""Normalizer no longer gates version rules; the validator rejects pre-v11 think."""
request = mock_chat_completion(
messages=[
UserMessage(content="query"),
AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
- ],
+ ]
+ )
+ parsed = normalizer_v7.from_chat_completion_request(request)
+ assert parsed.messages[-1] == AssistantMessage(
+ content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]
)
- with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in assistant message"):
- normalizer_v7.from_chat_completion_request(request)
def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
r"""V7 normalizer accepts string content in assistant messages."""
@@ -1185,23 +1188,6 @@ def test_assistant_think_chunk_intra_message_aggregation(
],
)
- def test_rejects_non_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
- r"""V13 normalizer raises InvalidRequestException for non-text tool content."""
- request = mock_chat_completion(
- messages=[
- UserMessage(content="hi"),
- AssistantMessage(
- tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")],
- ),
- ToolMessage(
- content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")],
- tool_call_id="test12345",
- ),
- ]
- )
- with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"):
- normalizer_v13.from_chat_completion_request(request)
-
def test_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
r"""V13 normalizer aggregates TextChunks in tool messages to a string."""
request = mock_chat_completion(
@@ -1223,23 +1209,6 @@ def test_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNorma
],
)
- def test_rejects_audio_in_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
- r"""V13 normalizer raises InvalidRequestException for audio tool content."""
- request = mock_chat_completion(
- messages=[
- UserMessage(content="hi"),
- AssistantMessage(
- tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")],
- ),
- ToolMessage(
- content=[AudioChunk(input_audio=b"fake_audio_data")],
- tool_call_id="test12345",
- ),
- ]
- )
- with pytest.raises(InvalidRequestException, match="Unexpected content chunk types in tool message"):
- normalizer_v13.from_chat_completion_request(request)
-
class TestChatCompletionRequestNormalizationV15:
@pytest.fixture(autouse=True)
@@ -1467,17 +1436,6 @@ def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNor
settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
)
- def test_rejects_think_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
- r"""V15 normalizer rejects ThinkChunk in system messages."""
- request = mock_chat_completion(
- messages=[
- SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]),
- UserMessage(content="test"),
- ]
- )
- with pytest.raises(InvalidRequestException, match="ThinkChunk"):
- normalizer_v15.from_chat_completion_request(request)
-
def test_preserves_audio_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 normalizer preserves AudioChunk in system messages."""
request = ChatCompletionRequest[ChatMessage](
From a14f8551a48e960f4a105d1b51c22ab9e86b6886 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 12 Jun 2026 18:42:59 +0200
Subject: [PATCH 40/47] Polish normalizer docstrings and tighten chunk-type
tests per review
---
src/mistral_common/protocol/instruct/normalize.py | 14 +++++++++-----
tests/test_messages.py | 4 ++++
tests/test_normalization.py | 14 ++++++++++----
3 files changed, 23 insertions(+), 9 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index d6785b34..988e25ec 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -230,13 +230,15 @@ def _aggregate_system_prompts(self, messages: list[UATS]) -> str | None:
def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
"""Normalize tool messages without aggregation across messages.
- Each tool message's content is validated and JSON-normalized.
+ Each tool message's content is JSON-normalized; chunk types are guaranteed by the validator.
"""
tool_messages: list[ToolMessageType] = []
for message in messages:
assert isinstance(message, self._tool_message_class), "Expected tool message"
content = self._aggregate_content_chunks([message])
- assert isinstance(content, str), f"Unexpected content chunk types in tool message: {content}"
+ assert isinstance(content, str), (
+ f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}"
+ )
normalized_content = self._normalize_json_content(content)
tool_messages.append(
@@ -423,14 +425,16 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I
def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
"""Normalize tool messages without JSON normalization.
- V7+ normalizers skip JSON content normalization for tool messages but still
- reject non-text content chunks.
+ V7+ normalizers skip JSON content normalization for tool messages (chunk-type validation is
+ handled by the validator).
"""
tool_messages: list[ToolMessageType] = []
for message in messages:
assert isinstance(message, self._tool_message_class), "Expected tool message"
content = self._aggregate_content_chunks([message])
- assert isinstance(content, str), f"Unexpected content chunk types in tool message: {content}"
+ assert isinstance(content, str), (
+ f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}"
+ )
tool_messages.append(
self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name)
)
diff --git a/tests/test_messages.py b/tests/test_messages.py
index ca6f6486..4b8a3f6c 100644
--- a/tests/test_messages.py
+++ b/tests/test_messages.py
@@ -10,6 +10,7 @@
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
SystemMessage,
+ ToolMessage,
UserMessage,
)
@@ -39,3 +40,6 @@ def test_user_rejects_think(self) -> None:
def test_user_accepts_image(self) -> None:
UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])
+
+ def test_tool_accepts_arbitrary_chunks(self) -> None:
+ ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="c1")
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index 6fa9d19a..2edad20c 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -461,8 +461,11 @@ def test_passes_think_in_assistant_through(self, normalizer: InstructRequestNorm
]
)
parsed = normalizer.from_chat_completion_request(request)
- assert parsed.messages[-1] == AssistantMessage(
- content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
+ ],
)
@@ -671,8 +674,11 @@ def test_passes_think_in_assistant_through(self, normalizer_v7: InstructRequestN
]
)
parsed = normalizer_v7.from_chat_completion_request(request)
- assert parsed.messages[-1] == AssistantMessage(
- content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
+ ],
)
def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
From 0385e802c33eec2b3b8339a3cf6e405cb3aaa586 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 12 Jun 2026 19:25:38 +0200
Subject: [PATCH 41/47] Replace normalizer casts with content-chunk TypeGuards
in chunk.py
---
src/mistral_common/protocol/instruct/chunk.py | 55 ++++++++++++++++++-
.../protocol/instruct/normalize.py | 17 +++---
2 files changed, 64 insertions(+), 8 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/chunk.py b/src/mistral_common/protocol/instruct/chunk.py
index 786c38e6..fda7b995 100644
--- a/src/mistral_common/protocol/instruct/chunk.py
+++ b/src/mistral_common/protocol/instruct/chunk.py
@@ -7,7 +7,7 @@
from urllib.parse import urlparse
from pydantic import ConfigDict, Field, ValidationError, field_validator, model_validator
-from typing_extensions import Annotated, TypeAlias
+from typing_extensions import Annotated, TypeAlias, TypeGuard
from mistral_common.base import MistralBase
from mistral_common.deprecation import warn_once
@@ -462,6 +462,59 @@ def from_openai(cls, openai_chunk: dict[str, Any]) -> "ThinkChunk":
ToolContentChunk: TypeAlias = ContentChunk # Accepts all content chunk types (no restriction on tool messages).
+def is_user_content(content: str | list[ContentChunk]) -> TypeGuard[str | list[UserContentChunk]]:
+ r"""Narrow aggregated content to the types allowed in a user message.
+
+ String content is always accepted. List content is accepted only when every chunk is a
+ `TextChunk`, `ImageChunk`, `ImageURLChunk`, `AudioChunk` or `AudioURLChunk`.
+
+ Args:
+ content: The content to narrow.
+
+ Returns:
+ Whether the content only contains user-compatible chunk types.
+ """
+ if isinstance(content, str):
+ return True
+ return all(
+ isinstance(chunk, (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk)) for chunk in content
+ )
+
+
+def is_assistant_content(content: str | list[ContentChunk]) -> TypeGuard[str | list[AssistantContentChunk]]:
+ r"""Narrow aggregated content to the types allowed in an assistant message.
+
+ String content is always accepted. List content is accepted only when every chunk is a
+ `TextChunk` or `ThinkChunk`.
+
+ Args:
+ content: The content to narrow.
+
+ Returns:
+ Whether the content only contains assistant-compatible chunk types.
+ """
+ if isinstance(content, str):
+ return True
+ return all(isinstance(chunk, (TextChunk, ThinkChunk)) for chunk in content)
+
+
+def is_system_content(content: str | list[ContentChunk]) -> TypeGuard[str | list[SystemContentChunk]]:
+ r"""Narrow aggregated content to the types allowed in a system message.
+
+ String content is always accepted. List content is accepted only when every chunk is a
+ `TextChunk`, `AudioChunk` or `ThinkChunk`.
+
+ Args:
+ content: The content to narrow.
+
+ Returns:
+ Whether the content only contains system-compatible chunk types.
+ """
+ if isinstance(content, str):
+ return True
+ return all(isinstance(chunk, (TextChunk, AudioChunk, ThinkChunk)) for chunk in content)
+
+
def _convert_openai_content_chunks(openai_content_chunks: dict[str, Any]) -> ContentChunk:
content_type_str = openai_content_chunks.get("type")
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index 988e25ec..6ed7b736 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -1,16 +1,16 @@
import json
import warnings
-from typing import Generic, Sequence, cast
+from typing import Generic, Sequence
from typing_extensions import assert_never
from mistral_common.exceptions import InvalidRequestException
from mistral_common.protocol.instruct.chunk import (
- AssistantContentChunk,
ContentChunk,
- SystemContentChunk,
TextChunk,
- UserContentChunk,
+ is_assistant_content,
+ is_system_content,
+ is_user_content,
)
from mistral_common.protocol.instruct.messages import (
UATS,
@@ -286,8 +286,9 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
)
weight = message.weight
+ assert is_assistant_content(content)
aggregated_message = self._assistant_message_class(
- content=cast("str | list[AssistantContentChunk]", content),
+ content=content,
tool_calls=tool_calls or None,
prefix=prefix,
)
@@ -299,7 +300,8 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType:
"""Coalesce neighboring blocks of ContentChunks in user messages."""
content = self._aggregate_content_chunks(messages)
- return self._user_message_class(content=cast("str | list[UserContentChunk]", content))
+ assert is_user_content(content)
+ return self._user_message_class(content=content)
def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]:
if role == Roles.tool:
@@ -445,7 +447,8 @@ def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessage
for message in messages:
if isinstance(message, self._system_message_class):
content = self._aggregate_content_chunks([message])
- aggregated.append(self._system_message_class(content=cast("str | list[SystemContentChunk]", content)))
+ assert is_system_content(content)
+ aggregated.append(self._system_message_class(content=content))
return aggregated
def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]:
From c36ddbffe254d9434fb1ee117e5ff87ec6f26580 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 12 Jun 2026 19:25:45 +0200
Subject: [PATCH 42/47] Reorganize chunk-type tests into per-version classes;
drop normalizer validator-concern tests
---
tests/test_normalization.py | 56 -----
tests/validation/test_chat_validation.py | 271 +++++++++++------------
2 files changed, 134 insertions(+), 193 deletions(-)
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index 2edad20c..70145f16 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -406,30 +406,6 @@ def test_continue_final_message_forwarded(self, normalizer: InstructRequestNorma
continue_final_message=True,
)
- def test_rejects_audio_in_system_message(self, normalizer: InstructRequestNormalizer) -> None:
- r"""Pre-V7 normalizer rejects AudioChunk in system messages."""
- request = mock_chat_completion(
- messages=[
- SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]),
- UserMessage(content="query"),
- AssistantMessage(content="answer"),
- ]
- )
- with pytest.raises(AssertionError, match="AudioChunk"):
- normalizer.from_chat_completion_request(request)
-
- def test_rejects_think_in_system_message(self, normalizer: InstructRequestNormalizer) -> None:
- r"""Pre-V7 normalizer rejects ThinkChunk in system messages."""
- request = mock_chat_completion(
- messages=[
- SystemMessage(content=[TextChunk(text="hello"), ThinkChunk(thinking="thinking", closed=True)]),
- UserMessage(content="query"),
- AssistantMessage(content="answer"),
- ]
- )
- with pytest.raises(AssertionError, match="ThinkChunk"):
- normalizer.from_chat_completion_request(request)
-
def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalizer) -> None:
r"""Base normalizer (v1-v3) JSON-normalizes tool message content."""
messy_json = '{"key" : "value" , "num": 1}'
@@ -452,22 +428,6 @@ def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalize
],
)
- def test_passes_think_in_assistant_through(self, normalizer: InstructRequestNormalizer) -> None:
- r"""Normalizer no longer gates version rules; the validator rejects pre-v11 think."""
- request = mock_chat_completion(
- messages=[
- UserMessage(content="query"),
- AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
- ]
- )
- parsed = normalizer.from_chat_completion_request(request)
- assert parsed == InstructRequest[ChatMessage, Tool](
- messages=[
- UserMessage(content="query"),
- AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
- ],
- )
-
class TestChatCompletionRequestNormalizationV7:
@pytest.fixture(autouse=True)
@@ -665,22 +625,6 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma
messages=[AssistantMessage(content="A\n\nB\n\nC\n\nD")],
)
- def test_passes_think_in_assistant_through(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
- r"""Normalizer no longer gates version rules; the validator rejects pre-v11 think."""
- request = mock_chat_completion(
- messages=[
- UserMessage(content="query"),
- AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
- ]
- )
- parsed = normalizer_v7.from_chat_completion_request(request)
- assert parsed == InstructRequest[ChatMessage, Tool](
- messages=[
- UserMessage(content="query"),
- AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
- ],
- )
-
def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
r"""V7 normalizer accepts string content in assistant messages."""
request = mock_chat_completion(
diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py
index 74a5a01a..2f1913ca 100644
--- a/tests/validation/test_chat_validation.py
+++ b/tests/validation/test_chat_validation.py
@@ -58,6 +58,16 @@ def validator(request: pytest.FixtureRequest) -> MistralRequestValidator:
return request.param # type: ignore
+@pytest.fixture
+def validator_base() -> MistralRequestValidator:
+ return MistralRequestValidator(ValidationMode.serving)
+
+
+@pytest.fixture
+def validator_v3() -> MistralRequestValidatorV3:
+ return MistralRequestValidatorV3(ValidationMode.serving)
+
+
@pytest.fixture
def validator_v5() -> MistralRequestValidatorV5:
return MistralRequestValidatorV5(ValidationMode.serving)
@@ -362,6 +372,82 @@ def test_build_settings_raises_error(self, validator: MistralRequestValidator) -
with pytest.raises(InvalidRequestException, match="reasoning_effort='none' is not supported for this model"):
validator._validate_model_settings(request)
+ def test_rejects_think_in_assistant_content_chunks(self, validator_base: MistralRequestValidator) -> None:
+ messages = [
+ UserMessage(content="hi"),
+ AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]),
+ UserMessage(content="hi again"),
+ ]
+ with pytest.raises(
+ InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message"
+ ):
+ validator_base.validate_messages(messages, continue_final_message=False)
+
+ def test_rejects_audio_in_system_content_chunks(self, validator_base: MistralRequestValidator) -> None:
+ messages = [
+ SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]),
+ UserMessage(content="hi"),
+ ]
+ with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"):
+ validator_base.validate_messages(messages, continue_final_message=False)
+
+ def test_rejects_image_in_tool_content_chunks(self, validator_base: MistralRequestValidator) -> None:
+ messages = [
+ UserMessage(content="hi"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
+ ToolMessage(
+ content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")],
+ tool_call_id="test12345",
+ ),
+ ]
+ with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"):
+ validator_base.validate_messages(messages, continue_final_message=False)
+
+ def test_rejects_image_in_user_content_chunks(self, validator_base: MistralRequestValidator) -> None:
+ messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])]
+ with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
+ validator_base.validate_messages(messages, continue_final_message=False)
+
+ def test_rejects_audio_in_user_content_chunks(self, validator_base: MistralRequestValidator) -> None:
+ messages = [UserMessage(content=[AudioChunk(input_audio=b"fake")])]
+ with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
+ validator_base.validate_messages(messages, continue_final_message=False)
+
+ def test_reports_sorted_unique_invalid_chunk_types(self, validator_base: MistralRequestValidator) -> None:
+ messages = [
+ UserMessage(
+ content=[
+ AudioChunk(input_audio=b"fake"),
+ ImageURLChunk(image_url="data:image/png;base64,iVBORw0"),
+ ]
+ )
+ ]
+ with pytest.raises(
+ InvalidUserMessageException,
+ match=r"Unexpected content chunk types in user message: \['AudioChunk', 'ImageURLChunk'\]",
+ ):
+ validator_base.validate_messages(messages, continue_final_message=False)
+
+ def test_accepts_multiple_text_content_chunks(self, validator_base: MistralRequestValidator) -> None:
+ messages = [
+ SystemMessage(content=[TextChunk(text="a"), TextChunk(text="b")]),
+ UserMessage(content="hi"),
+ AssistantMessage(content=[TextChunk(text="c"), TextChunk(text="d")]),
+ UserMessage(content="hi again"),
+ ]
+ validator_base.validate_messages(messages, continue_final_message=False)
+
+
+class TestChatValidationV3:
+ def test_allows_image_in_user_content_chunks(self, validator_v3: MistralRequestValidatorV3) -> None:
+ messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])]
+ validator_v3.validate_messages(messages, continue_final_message=False)
+
+ def test_rejects_audio_in_user_content_chunks(self, validator_v3: MistralRequestValidatorV3) -> None:
+ messages = [UserMessage(content=[AudioChunk(input_audio=b"fake")])]
+ with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
+ validator_v3.validate_messages(messages, continue_final_message=False)
+
class TestChatValidationV5:
@pytest.mark.parametrize("audio_fixture", ["audio_chunk", "audio_url_chunk"])
@@ -412,6 +498,35 @@ def test_build_settings_raises_error(self, validator: MistralRequestValidator) -
with pytest.raises(InvalidRequestException, match="reasoning_effort='none' is not supported for this model"):
validator._validate_model_settings(request)
+ def test_allows_image_and_audio_in_user_content_chunks(self, validator_v5: MistralRequestValidatorV5) -> None:
+ messages = [
+ UserMessage(
+ content=[
+ ImageURLChunk(image_url="data:image/png;base64,iVBORw0"),
+ AudioChunk(input_audio=b"fake"),
+ ]
+ )
+ ]
+ validator_v5.validate_messages(messages, continue_final_message=False)
+
+ def test_allows_audio_and_think_in_system_content_chunks(self, validator_v5: MistralRequestValidatorV5) -> None:
+ messages = [
+ SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake"), ThinkChunk(thinking="t")]),
+ UserMessage(content="hi"),
+ ]
+ validator_v5.validate_messages(messages, continue_final_message=False)
+
+ def test_rejects_think_in_assistant_content_chunks(self, validator_v5: MistralRequestValidatorV5) -> None:
+ messages = [
+ UserMessage(content="hi"),
+ AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]),
+ UserMessage(content="next"),
+ ]
+ with pytest.raises(
+ InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message"
+ ):
+ validator_v5.validate_messages(messages, continue_final_message=False)
+
class TestChatValidationV13:
def test_right_number_results_invalid_id(self, validator_v13: MistralRequestValidatorV13) -> None:
@@ -622,169 +737,51 @@ def test_audio_with_system_prompt_raises_ok(
continue_final_message=False,
)
-
-class TestChatValidationV15:
- @pytest.mark.parametrize("reasoning_effort", [*list(ReasoningEffort), None])
- def test_build_settings_v15_reasoning_effort(
- self, reasoning_effort: ReasoningEffort | None, validator_v15: MistralRequestValidatorV15
- ) -> None:
- request = ChatCompletionRequest(messages=[UserMessage(content="Hello")], reasoning_effort=reasoning_effort)
- validator_v15._validate_model_settings(request)
-
-
-class TestBaseValidatorContentChunks:
- @pytest.fixture
- def base_validator(self) -> MistralRequestValidator:
- return MistralRequestValidator(ValidationMode.serving)
-
- def test_rejects_think_in_assistant(self, base_validator: MistralRequestValidator) -> None:
- messages = [
- UserMessage(content="hi"),
- AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]),
- UserMessage(content="hi again"),
- ]
- with pytest.raises(
- InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message"
- ):
- base_validator.validate_messages(messages, continue_final_message=False)
-
- def test_rejects_audio_in_system(self, base_validator: MistralRequestValidator) -> None:
- messages = [
- SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]),
- UserMessage(content="hi"),
- ]
- with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"):
- base_validator.validate_messages(messages, continue_final_message=False)
-
- def test_rejects_image_in_tool(self, base_validator: MistralRequestValidator) -> None:
- messages = [
- UserMessage(content="hi"),
- AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
- ToolMessage(
- content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")],
- tool_call_id="test12345",
- ),
- ]
- with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"):
- base_validator.validate_messages(messages, continue_final_message=False)
-
- def test_rejects_image_in_user(self, base_validator: MistralRequestValidator) -> None:
- messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])]
- with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
- base_validator.validate_messages(messages, continue_final_message=False)
-
- def test_rejects_audio_in_user(self, base_validator: MistralRequestValidator) -> None:
- messages = [UserMessage(content=[AudioChunk(input_audio=b"fake")])]
- with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
- base_validator.validate_messages(messages, continue_final_message=False)
-
- def test_reports_sorted_unique_invalid_types(self, base_validator: MistralRequestValidator) -> None:
- messages = [
- UserMessage(
- content=[
- AudioChunk(input_audio=b"fake"),
- ImageURLChunk(image_url="data:image/png;base64,iVBORw0"),
- ]
- )
- ]
- with pytest.raises(
- InvalidUserMessageException,
- match=r"Unexpected content chunk types in user message: \['AudioChunk', 'ImageURLChunk'\]",
- ):
- base_validator.validate_messages(messages, continue_final_message=False)
-
- def test_accepts_multiple_text_chunks(self, base_validator: MistralRequestValidator) -> None:
- messages = [
- SystemMessage(content=[TextChunk(text="a"), TextChunk(text="b")]),
- UserMessage(content="hi"),
- AssistantMessage(content=[TextChunk(text="c"), TextChunk(text="d")]),
- UserMessage(content="hi again"),
- ]
- base_validator.validate_messages(messages, continue_final_message=False)
-
-
-class TestVersionedValidatorContentChunks:
- def test_v3_allows_image_rejects_audio_in_user(self) -> None:
- validator = MistralRequestValidatorV3(ValidationMode.serving)
- ok = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])]
- validator.validate_messages(ok, continue_final_message=False)
- bad = [UserMessage(content=[AudioChunk(input_audio=b"fake")])]
- with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
- validator.validate_messages(bad, continue_final_message=False)
-
- def test_v5_allows_image_and_audio_in_user(self) -> None:
- validator = MistralRequestValidatorV5(ValidationMode.serving)
- messages = [
- UserMessage(
- content=[
- ImageURLChunk(image_url="data:image/png;base64,iVBORw0"),
- AudioChunk(input_audio=b"fake"),
- ]
- )
- ]
- validator.validate_messages(messages, continue_final_message=False)
-
- def test_v5_allows_audio_and_think_in_system(self) -> None:
- validator = MistralRequestValidatorV5(ValidationMode.serving)
- messages = [
- SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake"), ThinkChunk(thinking="t")]),
- UserMessage(content="hi"),
- ]
- validator.validate_messages(messages, continue_final_message=False)
-
- def test_v5_still_rejects_think_in_assistant(self) -> None:
- validator = MistralRequestValidatorV5(ValidationMode.serving)
- messages = [
- UserMessage(content="hi"),
- AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]),
- UserMessage(content="next"),
- ]
- with pytest.raises(
- InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message"
- ):
- validator.validate_messages(messages, continue_final_message=False)
-
- def test_v13_allows_think_in_assistant(self) -> None:
- validator = MistralRequestValidatorV13(ValidationMode.serving)
+ def test_allows_think_in_assistant_content_chunks(self, validator_v13: MistralRequestValidatorV13) -> None:
messages = [
UserMessage(content="hi"),
AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]),
UserMessage(content="next"),
]
- validator.validate_messages(messages, continue_final_message=False)
+ validator_v13.validate_messages(messages, continue_final_message=False)
- def test_v13_still_rejects_image_in_tool(self) -> None:
- validator = MistralRequestValidatorV13(ValidationMode.serving)
+ def test_rejects_image_in_tool_content_chunks(self, validator_v13: MistralRequestValidatorV13) -> None:
messages = [
UserMessage(content="hi"),
AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="test12345"),
]
with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"):
- validator.validate_messages(messages, continue_final_message=False)
+ validator_v13.validate_messages(messages, continue_final_message=False)
+
+
+class TestChatValidationV15:
+ @pytest.mark.parametrize("reasoning_effort", [*list(ReasoningEffort), None])
+ def test_build_settings_v15_reasoning_effort(
+ self, reasoning_effort: ReasoningEffort | None, validator_v15: MistralRequestValidatorV15
+ ) -> None:
+ request = ChatCompletionRequest(messages=[UserMessage(content="Hello")], reasoning_effort=reasoning_effort)
+ validator_v15._validate_model_settings(request)
- def test_v15_rejects_think_in_system(self) -> None:
- validator = MistralRequestValidatorV15(ValidationMode.serving)
+ def test_rejects_think_in_system_content_chunks(self, validator_v15: MistralRequestValidatorV15) -> None:
messages = [
SystemMessage(content=[TextChunk(text="x"), ThinkChunk(thinking="t")]),
UserMessage(content="hi"),
]
with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"):
- validator.validate_messages(messages, continue_final_message=False)
+ validator_v15.validate_messages(messages, continue_final_message=False)
- def test_v15_allows_audio_in_system(self) -> None:
- validator = MistralRequestValidatorV15(ValidationMode.serving)
+ def test_allows_audio_in_system_content_chunks(self, validator_v15: MistralRequestValidatorV15) -> None:
messages = [
SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]),
UserMessage(content="hi"),
]
- validator.validate_messages(messages, continue_final_message=False)
+ validator_v15.validate_messages(messages, continue_final_message=False)
- def test_v15_allows_image_in_tool(self) -> None:
- validator = MistralRequestValidatorV15(ValidationMode.serving)
+ def test_allows_image_in_tool_content_chunks(self, validator_v15: MistralRequestValidatorV15) -> None:
messages = [
UserMessage(content="hi"),
AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="test12345"),
]
- validator.validate_messages(messages, continue_final_message=False)
+ validator_v15.validate_messages(messages, continue_final_message=False)
From 216e85dacc518e8073e9ce695027578d10202e7d Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 12 Jun 2026 19:57:06 +0200
Subject: [PATCH 43/47] Add informative messages to normalizer content
TypeGuard asserts
---
src/mistral_common/protocol/instruct/normalize.py | 12 +++++++++---
1 file changed, 9 insertions(+), 3 deletions(-)
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index 6ed7b736..96f1dddf 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -286,7 +286,9 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
)
weight = message.weight
- assert is_assistant_content(content)
+ assert is_assistant_content(content), (
+ f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
+ )
aggregated_message = self._assistant_message_class(
content=content,
tool_calls=tool_calls or None,
@@ -300,7 +302,9 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType:
"""Coalesce neighboring blocks of ContentChunks in user messages."""
content = self._aggregate_content_chunks(messages)
- assert is_user_content(content)
+ assert is_user_content(content), (
+ f"Unexpected content chunk types in user message: {[type(c).__name__ for c in content]}"
+ )
return self._user_message_class(content=content)
def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]:
@@ -447,7 +451,9 @@ def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessage
for message in messages:
if isinstance(message, self._system_message_class):
content = self._aggregate_content_chunks([message])
- assert is_system_content(content)
+ assert is_system_content(content), (
+ f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}"
+ )
aggregated.append(self._system_message_class(content=content))
return aggregated
From 5bfde189cbe47bc9fad9c481f58b2ae02d67c027 Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 12 Jun 2026 19:57:07 +0200
Subject: [PATCH 44/47] Restructure chunk-type tests into exhaustive
allow/disallow methods
---
tests/fixtures/chunks.py | 35 ++++
tests/test_messages.py | 59 +++---
tests/validation/test_chat_validation.py | 234 ++++++++++-------------
3 files changed, 165 insertions(+), 163 deletions(-)
create mode 100644 tests/fixtures/chunks.py
diff --git a/tests/fixtures/chunks.py b/tests/fixtures/chunks.py
new file mode 100644
index 00000000..72b3f85d
--- /dev/null
+++ b/tests/fixtures/chunks.py
@@ -0,0 +1,35 @@
+from PIL import Image
+
+from mistral_common.protocol.instruct.chunk import (
+ ContentChunk,
+ ImageChunk,
+ ImageURLChunk,
+ TextChunk,
+ ThinkChunk,
+)
+from tests.fixtures.audio import get_dummy_audio_chunk, get_dummy_audio_url_chunk
+
+
+def get_content_chunk(name: str) -> ContentChunk:
+ r"""Return a single instance of the requested content chunk type.
+
+ Args:
+ name: One of "text", "image", "image_url", "audio", "audio_url" or "think".
+
+ Returns:
+ A content chunk of the requested type.
+ """
+ chunks: dict[str, ContentChunk] = {
+ "text": TextChunk(text="hello"),
+ "image": ImageChunk(image=Image.new("RGB", (4, 4), "red")),
+ "image_url": ImageURLChunk(image_url="data:image/png;base64,iVBORw0"),
+ "audio": get_dummy_audio_chunk(),
+ "audio_url": get_dummy_audio_url_chunk(),
+ "think": ThinkChunk(thinking="reasoning"),
+ }
+ return chunks[name]
+
+
+def get_content_chunks(names: tuple[str, ...]) -> list[ContentChunk]:
+ r"""Return a list of content chunks for the requested type names."""
+ return [get_content_chunk(name) for name in names]
diff --git a/tests/test_messages.py b/tests/test_messages.py
index 4b8a3f6c..44aba528 100644
--- a/tests/test_messages.py
+++ b/tests/test_messages.py
@@ -1,45 +1,44 @@
import pytest
from pydantic import ValidationError
-from mistral_common.protocol.instruct.chunk import (
- AudioChunk,
- ImageURLChunk,
- TextChunk,
- ThinkChunk,
-)
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
SystemMessage,
ToolMessage,
UserMessage,
)
+from tests.fixtures.chunks import get_content_chunks as _chunks
class TestMessageContentChunkUnions:
- def test_assistant_rejects_image(self) -> None:
- with pytest.raises(ValidationError):
- AssistantMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])
-
- def test_assistant_rejects_audio(self) -> None:
- with pytest.raises(ValidationError):
- AssistantMessage(content=[AudioChunk(input_audio=b"fake")])
-
- def test_assistant_accepts_text_and_think(self) -> None:
- AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")])
+ r"""Pydantic-level (version-independent) content-chunk unions for each message role."""
- def test_system_rejects_image(self) -> None:
- with pytest.raises(ValidationError):
- SystemMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])
-
- def test_system_accepts_audio(self) -> None:
- SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")])
+ def test_user_allows_text_image_audio(self) -> None:
+ UserMessage(content=_chunks(("text", "image", "image_url", "audio", "audio_url")))
def test_user_rejects_think(self) -> None:
- with pytest.raises(ValidationError):
- UserMessage(content=[ThinkChunk(thinking="r")])
-
- def test_user_accepts_image(self) -> None:
- UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])
-
- def test_tool_accepts_arbitrary_chunks(self) -> None:
- ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="c1")
+ for name in ("think",):
+ with pytest.raises(ValidationError):
+ UserMessage(content=_chunks((name,)))
+
+ def test_assistant_allows_text_and_think(self) -> None:
+ AssistantMessage(content=_chunks(("text", "think")))
+
+ def test_assistant_rejects_image_and_audio(self) -> None:
+ for name in ("image", "image_url", "audio", "audio_url"):
+ with pytest.raises(ValidationError):
+ AssistantMessage(content=_chunks((name,)))
+
+ def test_system_allows_text_audio_think(self) -> None:
+ SystemMessage(content=_chunks(("text", "audio", "think")))
+
+ def test_system_rejects_image_and_audio_url(self) -> None:
+ for name in ("image", "image_url", "audio_url"):
+ with pytest.raises(ValidationError):
+ SystemMessage(content=_chunks((name,)))
+
+ def test_tool_allows_all_chunk_types(self) -> None:
+ ToolMessage(
+ content=_chunks(("text", "image", "image_url", "audio", "audio_url", "think")),
+ tool_call_id="c1",
+ )
diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py
index 2f1913ca..90212bf8 100644
--- a/tests/validation/test_chat_validation.py
+++ b/tests/validation/test_chat_validation.py
@@ -11,9 +11,7 @@
from mistral_common.protocol.instruct.chunk import (
AudioChunk,
AudioURLChunk,
- ImageURLChunk,
- TextChunk,
- ThinkChunk,
+ ContentChunk,
)
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
@@ -35,6 +33,29 @@
ValidationMode,
)
from tests.fixtures.audio import get_dummy_audio_chunk, get_dummy_audio_url_chunk
+from tests.fixtures.chunks import get_content_chunks
+
+_Messages = list[UserMessage | AssistantMessage | SystemMessage | ToolMessage]
+
+
+def _user_convo(content: "str | list[ContentChunk]") -> _Messages:
+ return [UserMessage(content=content)]
+
+
+def _assistant_convo(content: "str | list[ContentChunk]") -> _Messages:
+ return [UserMessage(content="hi"), AssistantMessage(content=content), UserMessage(content="next")]
+
+
+def _system_convo(content: "str | list[ContentChunk]") -> _Messages:
+ return [SystemMessage(content=content), UserMessage(content="hi")]
+
+
+def _tool_convo(content: "str | list[ContentChunk]") -> _Messages:
+ return [
+ UserMessage(content="hi"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
+ ToolMessage(content=content, tool_call_id="test12345"),
+ ]
@pytest.fixture(scope="module")
@@ -372,81 +393,56 @@ def test_build_settings_raises_error(self, validator: MistralRequestValidator) -
with pytest.raises(InvalidRequestException, match="reasoning_effort='none' is not supported for this model"):
validator._validate_model_settings(request)
- def test_rejects_think_in_assistant_content_chunks(self, validator_base: MistralRequestValidator) -> None:
- messages = [
- UserMessage(content="hi"),
- AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]),
- UserMessage(content="hi again"),
- ]
- with pytest.raises(
- InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message"
- ):
- validator_base.validate_messages(messages, continue_final_message=False)
-
- def test_rejects_audio_in_system_content_chunks(self, validator_base: MistralRequestValidator) -> None:
- messages = [
- SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]),
- UserMessage(content="hi"),
- ]
- with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"):
- validator_base.validate_messages(messages, continue_final_message=False)
-
- def test_rejects_image_in_tool_content_chunks(self, validator_base: MistralRequestValidator) -> None:
- messages = [
- UserMessage(content="hi"),
- AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
- ToolMessage(
- content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")],
- tool_call_id="test12345",
- ),
- ]
- with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"):
- validator_base.validate_messages(messages, continue_final_message=False)
-
- def test_rejects_image_in_user_content_chunks(self, validator_base: MistralRequestValidator) -> None:
- messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])]
- with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
- validator_base.validate_messages(messages, continue_final_message=False)
-
- def test_rejects_audio_in_user_content_chunks(self, validator_base: MistralRequestValidator) -> None:
- messages = [UserMessage(content=[AudioChunk(input_audio=b"fake")])]
- with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
- validator_base.validate_messages(messages, continue_final_message=False)
+ def test_allows_text_content_chunks(self, validator_base: MistralRequestValidator) -> None:
+ validator_base.validate_messages(_user_convo(get_content_chunks(("text",))), continue_final_message=False)
+ validator_base.validate_messages(_assistant_convo(get_content_chunks(("text",))), continue_final_message=False)
+ validator_base.validate_messages(_system_convo(get_content_chunks(("text",))), continue_final_message=False)
+ validator_base.validate_messages(_tool_convo(get_content_chunks(("text",))), continue_final_message=False)
+
+ def test_rejects_non_text_in_user(self, validator_base: MistralRequestValidator) -> None:
+ for name in ("image", "image_url", "audio", "audio_url"):
+ with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
+ validator_base.validate_messages(_user_convo(get_content_chunks((name,))), continue_final_message=False)
+
+ def test_rejects_non_text_in_assistant(self, validator_base: MistralRequestValidator) -> None:
+ for name in ("think",):
+ with pytest.raises(
+ InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message"
+ ):
+ validator_base.validate_messages(
+ _assistant_convo(get_content_chunks((name,))), continue_final_message=False
+ )
+
+ def test_rejects_non_text_in_system(self, validator_base: MistralRequestValidator) -> None:
+ for name in ("audio", "think"):
+ with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"):
+ validator_base.validate_messages(
+ _system_convo(get_content_chunks((name,))), continue_final_message=False
+ )
+
+ def test_rejects_non_text_in_tool(self, validator_base: MistralRequestValidator) -> None:
+ for name in ("image", "image_url", "audio", "audio_url", "think"):
+ with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"):
+ validator_base.validate_messages(_tool_convo(get_content_chunks((name,))), continue_final_message=False)
def test_reports_sorted_unique_invalid_chunk_types(self, validator_base: MistralRequestValidator) -> None:
- messages = [
- UserMessage(
- content=[
- AudioChunk(input_audio=b"fake"),
- ImageURLChunk(image_url="data:image/png;base64,iVBORw0"),
- ]
- )
- ]
+ content = get_content_chunks(("audio", "image_url"))
with pytest.raises(
InvalidUserMessageException,
match=r"Unexpected content chunk types in user message: \['AudioChunk', 'ImageURLChunk'\]",
):
- validator_base.validate_messages(messages, continue_final_message=False)
-
- def test_accepts_multiple_text_content_chunks(self, validator_base: MistralRequestValidator) -> None:
- messages = [
- SystemMessage(content=[TextChunk(text="a"), TextChunk(text="b")]),
- UserMessage(content="hi"),
- AssistantMessage(content=[TextChunk(text="c"), TextChunk(text="d")]),
- UserMessage(content="hi again"),
- ]
- validator_base.validate_messages(messages, continue_final_message=False)
+ validator_base.validate_messages(_user_convo(content), continue_final_message=False)
class TestChatValidationV3:
- def test_allows_image_in_user_content_chunks(self, validator_v3: MistralRequestValidatorV3) -> None:
- messages = [UserMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")])]
- validator_v3.validate_messages(messages, continue_final_message=False)
+ def test_allows_text_and_image_in_user(self, validator_v3: MistralRequestValidatorV3) -> None:
+ content = get_content_chunks(("text", "image", "image_url"))
+ validator_v3.validate_messages(_user_convo(content), continue_final_message=False)
- def test_rejects_audio_in_user_content_chunks(self, validator_v3: MistralRequestValidatorV3) -> None:
- messages = [UserMessage(content=[AudioChunk(input_audio=b"fake")])]
- with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
- validator_v3.validate_messages(messages, continue_final_message=False)
+ def test_rejects_audio_in_user(self, validator_v3: MistralRequestValidatorV3) -> None:
+ for name in ("audio", "audio_url"):
+ with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
+ validator_v3.validate_messages(_user_convo(get_content_chunks((name,))), continue_final_message=False)
class TestChatValidationV5:
@@ -498,34 +494,22 @@ def test_build_settings_raises_error(self, validator: MistralRequestValidator) -
with pytest.raises(InvalidRequestException, match="reasoning_effort='none' is not supported for this model"):
validator._validate_model_settings(request)
- def test_allows_image_and_audio_in_user_content_chunks(self, validator_v5: MistralRequestValidatorV5) -> None:
- messages = [
- UserMessage(
- content=[
- ImageURLChunk(image_url="data:image/png;base64,iVBORw0"),
- AudioChunk(input_audio=b"fake"),
- ]
- )
- ]
- validator_v5.validate_messages(messages, continue_final_message=False)
-
- def test_allows_audio_and_think_in_system_content_chunks(self, validator_v5: MistralRequestValidatorV5) -> None:
- messages = [
- SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake"), ThinkChunk(thinking="t")]),
- UserMessage(content="hi"),
- ]
- validator_v5.validate_messages(messages, continue_final_message=False)
-
- def test_rejects_think_in_assistant_content_chunks(self, validator_v5: MistralRequestValidatorV5) -> None:
- messages = [
- UserMessage(content="hi"),
- AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]),
- UserMessage(content="next"),
- ]
- with pytest.raises(
- InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message"
- ):
- validator_v5.validate_messages(messages, continue_final_message=False)
+ def test_allows_text_image_audio_in_user(self, validator_v5: MistralRequestValidatorV5) -> None:
+ content = get_content_chunks(("text", "image", "image_url", "audio", "audio_url"))
+ validator_v5.validate_messages(_user_convo(content), continue_final_message=False)
+
+ def test_allows_text_audio_think_in_system(self, validator_v5: MistralRequestValidatorV5) -> None:
+ content = get_content_chunks(("text", "audio", "think"))
+ validator_v5.validate_messages(_system_convo(content), continue_final_message=False)
+
+ def test_rejects_think_in_assistant(self, validator_v5: MistralRequestValidatorV5) -> None:
+ for name in ("think",):
+ with pytest.raises(
+ InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message"
+ ):
+ validator_v5.validate_messages(
+ _assistant_convo(get_content_chunks((name,))), continue_final_message=False
+ )
class TestChatValidationV13:
@@ -737,22 +721,14 @@ def test_audio_with_system_prompt_raises_ok(
continue_final_message=False,
)
- def test_allows_think_in_assistant_content_chunks(self, validator_v13: MistralRequestValidatorV13) -> None:
- messages = [
- UserMessage(content="hi"),
- AssistantMessage(content=[ThinkChunk(thinking="r"), TextChunk(text="a")]),
- UserMessage(content="next"),
- ]
- validator_v13.validate_messages(messages, continue_final_message=False)
+ def test_allows_text_and_think_in_assistant(self, validator_v13: MistralRequestValidatorV13) -> None:
+ content = get_content_chunks(("text", "think"))
+ validator_v13.validate_messages(_assistant_convo(content), continue_final_message=False)
- def test_rejects_image_in_tool_content_chunks(self, validator_v13: MistralRequestValidatorV13) -> None:
- messages = [
- UserMessage(content="hi"),
- AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
- ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="test12345"),
- ]
- with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"):
- validator_v13.validate_messages(messages, continue_final_message=False)
+ def test_rejects_non_text_in_tool(self, validator_v13: MistralRequestValidatorV13) -> None:
+ for name in ("image", "image_url", "audio", "audio_url", "think"):
+ with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"):
+ validator_v13.validate_messages(_tool_convo(get_content_chunks((name,))), continue_final_message=False)
class TestChatValidationV15:
@@ -763,25 +739,17 @@ def test_build_settings_v15_reasoning_effort(
request = ChatCompletionRequest(messages=[UserMessage(content="Hello")], reasoning_effort=reasoning_effort)
validator_v15._validate_model_settings(request)
- def test_rejects_think_in_system_content_chunks(self, validator_v15: MistralRequestValidatorV15) -> None:
- messages = [
- SystemMessage(content=[TextChunk(text="x"), ThinkChunk(thinking="t")]),
- UserMessage(content="hi"),
- ]
- with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"):
- validator_v15.validate_messages(messages, continue_final_message=False)
-
- def test_allows_audio_in_system_content_chunks(self, validator_v15: MistralRequestValidatorV15) -> None:
- messages = [
- SystemMessage(content=[TextChunk(text="x"), AudioChunk(input_audio=b"fake")]),
- UserMessage(content="hi"),
- ]
- validator_v15.validate_messages(messages, continue_final_message=False)
-
- def test_allows_image_in_tool_content_chunks(self, validator_v15: MistralRequestValidatorV15) -> None:
- messages = [
- UserMessage(content="hi"),
- AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
- ToolMessage(content=[ImageURLChunk(image_url="data:image/png;base64,iVBORw0")], tool_call_id="test12345"),
- ]
- validator_v15.validate_messages(messages, continue_final_message=False)
+ def test_allows_text_and_audio_in_system(self, validator_v15: MistralRequestValidatorV15) -> None:
+ content = get_content_chunks(("text", "audio"))
+ validator_v15.validate_messages(_system_convo(content), continue_final_message=False)
+
+ def test_rejects_think_in_system(self, validator_v15: MistralRequestValidatorV15) -> None:
+ for name in ("think",):
+ with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"):
+ validator_v15.validate_messages(
+ _system_convo(get_content_chunks((name,))), continue_final_message=False
+ )
+
+ def test_allows_all_chunk_types_in_tool(self, validator_v15: MistralRequestValidatorV15) -> None:
+ content = get_content_chunks(("text", "image", "image_url", "audio", "audio_url", "think"))
+ validator_v15.validate_messages(_tool_convo(content), continue_final_message=False)
From 34d57d9437d7f2ae612fca6129e2bad77a794bce Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Fri, 12 Jun 2026 20:24:02 +0200
Subject: [PATCH 45/47] Type test chunk helpers as Any to satisfy mypy on
message construction
---
tests/fixtures/chunks.py | 11 +++++++++--
tests/validation/test_chat_validation.py | 11 ++++++-----
2 files changed, 15 insertions(+), 7 deletions(-)
diff --git a/tests/fixtures/chunks.py b/tests/fixtures/chunks.py
index 72b3f85d..289765f5 100644
--- a/tests/fixtures/chunks.py
+++ b/tests/fixtures/chunks.py
@@ -1,3 +1,5 @@
+from typing import Any
+
from PIL import Image
from mistral_common.protocol.instruct.chunk import (
@@ -30,6 +32,11 @@ def get_content_chunk(name: str) -> ContentChunk:
return chunks[name]
-def get_content_chunks(names: tuple[str, ...]) -> list[ContentChunk]:
- r"""Return a list of content chunks for the requested type names."""
+def get_content_chunks(names: tuple[str, ...]) -> list[Any]:
+ r"""Return a list of content chunks for the requested type names.
+
+ The element type is intentionally `Any` so the chunks can be fed to any message role
+ constructor (including deliberately invalid combinations) in tests, leaving the runtime
+ Pydantic validation to enforce the actual rules.
+ """
return [get_content_chunk(name) for name in names]
diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py
index 90212bf8..c036e2a8 100644
--- a/tests/validation/test_chat_validation.py
+++ b/tests/validation/test_chat_validation.py
@@ -1,3 +1,5 @@
+from typing import Any
+
import pytest
from mistral_common.exceptions import (
@@ -11,7 +13,6 @@
from mistral_common.protocol.instruct.chunk import (
AudioChunk,
AudioURLChunk,
- ContentChunk,
)
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
@@ -38,19 +39,19 @@
_Messages = list[UserMessage | AssistantMessage | SystemMessage | ToolMessage]
-def _user_convo(content: "str | list[ContentChunk]") -> _Messages:
+def _user_convo(content: "str | list[Any]") -> _Messages:
return [UserMessage(content=content)]
-def _assistant_convo(content: "str | list[ContentChunk]") -> _Messages:
+def _assistant_convo(content: "str | list[Any]") -> _Messages:
return [UserMessage(content="hi"), AssistantMessage(content=content), UserMessage(content="next")]
-def _system_convo(content: "str | list[ContentChunk]") -> _Messages:
+def _system_convo(content: "str | list[Any]") -> _Messages:
return [SystemMessage(content=content), UserMessage(content="hi")]
-def _tool_convo(content: "str | list[ContentChunk]") -> _Messages:
+def _tool_convo(content: "str | list[Any]") -> _Messages:
return [
UserMessage(content="hi"),
AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
From 46bd99f139a71efd391fb816ad9a47a2dee4132a Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Mon, 15 Jun 2026 13:24:36 +0200
Subject: [PATCH 46/47] Centralize content-chunk allow-lists on message classes
---
.../experimental/app/routers.py | 4 +-
src/mistral_common/protocol/instruct/chunk.py | 64 +------------------
.../protocol/instruct/messages.py | 57 +++++++++++++----
.../protocol/instruct/normalize.py | 12 ----
src/mistral_common/tokens/tokenizers/base.py | 4 +-
.../tokens/tokenizers/instruct.py | 42 ++++++++----
tests/test_tokenizer_v15.py | 4 +-
tests/test_tokenizer_v7_audio.py | 6 +-
8 files changed, 84 insertions(+), 109 deletions(-)
diff --git a/src/mistral_common/experimental/app/routers.py b/src/mistral_common/experimental/app/routers.py
index 1c8aabc4..40103e70 100644
--- a/src/mistral_common/experimental/app/routers.py
+++ b/src/mistral_common/experimental/app/routers.py
@@ -14,7 +14,7 @@
)
from mistral_common.experimental.think import _split_content_and_think_chunks
from mistral_common.experimental.tools import _decode_tool_calls, _split_content_and_tool_calls
-from mistral_common.protocol.instruct.chunk import AssistantContentChunk, TextChunk, ThinkChunk
+from mistral_common.protocol.instruct.chunk import ContentChunk, TextChunk, ThinkChunk
from mistral_common.protocol.instruct.messages import AssistantMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, Tokenized, TokenizerVersion
@@ -100,7 +100,7 @@ async def detokenize_to_assistant_message(
else:
content_tokens, tool_calls_tokens = tokens, ()
- content: str | list[AssistantContentChunk] | None = None
+ content: str | list[ContentChunk] | None = None
if settings.tokenizer.instruct_tokenizer.tokenizer.version >= TokenizerVersion.v13:
assert isinstance(settings.tokenizer.instruct_tokenizer, InstructTokenizerV13)
diff --git a/src/mistral_common/protocol/instruct/chunk.py b/src/mistral_common/protocol/instruct/chunk.py
index fda7b995..91fcb679 100644
--- a/src/mistral_common/protocol/instruct/chunk.py
+++ b/src/mistral_common/protocol/instruct/chunk.py
@@ -7,7 +7,7 @@
from urllib.parse import urlparse
from pydantic import ConfigDict, Field, ValidationError, field_validator, model_validator
-from typing_extensions import Annotated, TypeAlias, TypeGuard
+from typing_extensions import Annotated
from mistral_common.base import MistralBase
from mistral_common.deprecation import warn_once
@@ -451,68 +451,6 @@ def from_openai(cls, openai_chunk: dict[str, Any]) -> "ThinkChunk":
ContentChunk = Annotated[
TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk | ThinkChunk, Field(discriminator="type")
]
-UserContentChunk = Annotated[
- TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk, Field(discriminator="type")
-]
-
-AssistantContentChunk = Annotated[TextChunk | ThinkChunk, Field(discriminator="type")]
-
-SystemContentChunk = Annotated[TextChunk | AudioChunk | ThinkChunk, Field(discriminator="type")]
-
-ToolContentChunk: TypeAlias = ContentChunk # Accepts all content chunk types (no restriction on tool messages).
-
-
-def is_user_content(content: str | list[ContentChunk]) -> TypeGuard[str | list[UserContentChunk]]:
- r"""Narrow aggregated content to the types allowed in a user message.
-
- String content is always accepted. List content is accepted only when every chunk is a
- `TextChunk`, `ImageChunk`, `ImageURLChunk`, `AudioChunk` or `AudioURLChunk`.
-
- Args:
- content: The content to narrow.
-
- Returns:
- Whether the content only contains user-compatible chunk types.
- """
- if isinstance(content, str):
- return True
- return all(
- isinstance(chunk, (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk)) for chunk in content
- )
-
-
-def is_assistant_content(content: str | list[ContentChunk]) -> TypeGuard[str | list[AssistantContentChunk]]:
- r"""Narrow aggregated content to the types allowed in an assistant message.
-
- String content is always accepted. List content is accepted only when every chunk is a
- `TextChunk` or `ThinkChunk`.
-
- Args:
- content: The content to narrow.
-
- Returns:
- Whether the content only contains assistant-compatible chunk types.
- """
- if isinstance(content, str):
- return True
- return all(isinstance(chunk, (TextChunk, ThinkChunk)) for chunk in content)
-
-
-def is_system_content(content: str | list[ContentChunk]) -> TypeGuard[str | list[SystemContentChunk]]:
- r"""Narrow aggregated content to the types allowed in a system message.
-
- String content is always accepted. List content is accepted only when every chunk is a
- `TextChunk`, `AudioChunk` or `ThinkChunk`.
-
- Args:
- content: The content to narrow.
-
- Returns:
- Whether the content only contains system-compatible chunk types.
- """
- if isinstance(content, str):
- return True
- return all(isinstance(chunk, (TextChunk, AudioChunk, ThinkChunk)) for chunk in content)
def _convert_openai_content_chunks(openai_content_chunks: dict[str, Any]) -> ContentChunk:
diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py
index 6ca46b2f..f96b78af 100644
--- a/src/mistral_common/protocol/instruct/messages.py
+++ b/src/mistral_common/protocol/instruct/messages.py
@@ -1,21 +1,22 @@
import warnings
from collections.abc import Sequence
from enum import Enum
-from typing import Any, Literal, TypeVar
+from typing import Any, ClassVar, Literal, TypeVar
-from pydantic import Field
+from pydantic import Field, model_validator
from typing_extensions import Annotated, TypeAlias, TypeGuard
from mistral_common.base import MistralBase
from mistral_common.exceptions import InvalidAssistantMessageException
from mistral_common.protocol.instruct.chunk import (
- AssistantContentChunk,
+ AudioChunk,
+ AudioURLChunk,
+ BaseContentChunk,
ContentChunk,
- SystemContentChunk,
+ ImageChunk,
+ ImageURLChunk,
TextChunk,
ThinkChunk,
- ToolContentChunk,
- UserContentChunk,
_convert_openai_content_chunks,
)
from mistral_common.protocol.instruct.tool_calls import ToolCall
@@ -27,12 +28,12 @@
)
-def _are_think_chunks(chunks: Sequence[AssistantContentChunk]) -> TypeGuard[list[ThinkChunk]]:
+def _are_think_chunks(chunks: Sequence[ContentChunk]) -> TypeGuard[list[ThinkChunk]]:
r"""Narrow a chunk list to ThinkChunk list."""
return all(isinstance(c, ThinkChunk) for c in chunks)
-def _are_text_chunks(chunks: Sequence[AssistantContentChunk]) -> TypeGuard[list[TextChunk]]:
+def _are_text_chunks(chunks: Sequence[ContentChunk]) -> TypeGuard[list[TextChunk]]:
r"""Narrow a chunk list to TextChunk list."""
return all(isinstance(c, TextChunk) for c in chunks)
@@ -79,6 +80,19 @@ class BaseMessage(MistralBase):
role: Literal[Roles.system, Roles.user, Roles.assistant, Roles.tool]
+ # Allow-list of content chunk types accepted by this message. Must be set by each subclass.
+ _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]]
+
+ @model_validator(mode="after")
+ def _validate_allowed_content_chunks(self) -> "BaseMessage":
+ r"""Enforce the per-message content chunk allow-list."""
+ content = getattr(self, "content", None)
+ if isinstance(content, list):
+ for chunk in content:
+ if not isinstance(chunk, self._allowed_content_chunks):
+ raise ValueError(f"{type(chunk).__name__} cannot be used in {self.role} message.")
+ return self
+
@staticmethod
def _content_to_openai(
content: str | Sequence[ContentChunk] | None,
@@ -144,7 +158,14 @@ class UserMessage(BaseMessage):
"""
role: Literal[Roles.user] = Roles.user
- content: str | list[UserContentChunk]
+ content: str | list[ContentChunk]
+ _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = (
+ TextChunk,
+ ImageChunk,
+ ImageURLChunk,
+ AudioChunk,
+ AudioURLChunk,
+ )
def to_openai(self) -> dict[str, Any]:
r"""Converts the message to the OpenAI format."""
@@ -169,7 +190,8 @@ class SystemMessage(BaseMessage):
"""
role: Literal[Roles.system] = Roles.system
- content: str | list[SystemContentChunk]
+ content: str | list[ContentChunk]
+ _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = (TextChunk, AudioChunk, ThinkChunk)
def to_openai(self) -> dict[str, Any]:
r"""Converts the message to the OpenAI format."""
@@ -197,7 +219,8 @@ class AssistantMessage(BaseMessage):
"""
role: Literal[Roles.assistant] = Roles.assistant
- content: str | list[AssistantContentChunk] | None = None
+ content: str | list[ContentChunk] | None = None
+ _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = (TextChunk, ThinkChunk)
tool_calls: list[ToolCall] | None = None
prefix: bool = False
@@ -340,13 +363,23 @@ class ToolMessage(BaseMessage):
>>> message = ToolMessage(content="Hello, how can I help you?", tool_call_id="123")
"""
- content: str | list[ToolContentChunk]
+ content: str | list[ContentChunk]
role: Literal[Roles.tool] = Roles.tool
tool_call_id: str | None = None
# Deprecated in V3 tokenization
name: str | None = None
+ # Tool messages accept all content chunk types.
+ _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = (
+ TextChunk,
+ ImageChunk,
+ ImageURLChunk,
+ AudioChunk,
+ AudioURLChunk,
+ ThinkChunk,
+ )
+
def to_openai(self) -> dict[str, Any]:
r"""Converts the message to the OpenAI format."""
assert self.tool_call_id is not None, "tool_call_id must be provided for tool messages."
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index 96f1dddf..c423cda0 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -8,9 +8,6 @@
from mistral_common.protocol.instruct.chunk import (
ContentChunk,
TextChunk,
- is_assistant_content,
- is_system_content,
- is_user_content,
)
from mistral_common.protocol.instruct.messages import (
UATS,
@@ -286,9 +283,6 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
)
weight = message.weight
- assert is_assistant_content(content), (
- f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
- )
aggregated_message = self._assistant_message_class(
content=content,
tool_calls=tool_calls or None,
@@ -302,9 +296,6 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType:
"""Coalesce neighboring blocks of ContentChunks in user messages."""
content = self._aggregate_content_chunks(messages)
- assert is_user_content(content), (
- f"Unexpected content chunk types in user message: {[type(c).__name__ for c in content]}"
- )
return self._user_message_class(content=content)
def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]:
@@ -451,9 +442,6 @@ def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessage
for message in messages:
if isinstance(message, self._system_message_class):
content = self._aggregate_content_chunks([message])
- assert is_system_content(content), (
- f"Unexpected content chunk types in system message: {[type(c).__name__ for c in content]}"
- )
aggregated.append(self._system_message_class(content=content))
return aggregated
diff --git a/src/mistral_common/tokens/tokenizers/base.py b/src/mistral_common/tokens/tokenizers/base.py
index 1fe580fa..06dff715 100644
--- a/src/mistral_common/tokens/tokenizers/base.py
+++ b/src/mistral_common/tokens/tokenizers/base.py
@@ -9,7 +9,7 @@
from mistral_common.base import MistralBase
from mistral_common.protocol.fim.request import FIMRequest
-from mistral_common.protocol.instruct.chunk import UserContentChunk
+from mistral_common.protocol.instruct.chunk import ContentChunk
from mistral_common.protocol.instruct.messages import (
AssistantMessageType,
UserMessage,
@@ -428,7 +428,7 @@ def encode_user_message(
@abstractmethod
def encode_user_content(
self,
- content: str | list[UserContentChunk],
+ content: str | list[ContentChunk],
is_last: bool,
system_prompt: str | None = None,
force_img_first: bool = False,
diff --git a/src/mistral_common/tokens/tokenizers/instruct.py b/src/mistral_common/tokens/tokenizers/instruct.py
index 9b044f6c..7d0e5ab8 100644
--- a/src/mistral_common/tokens/tokenizers/instruct.py
+++ b/src/mistral_common/tokens/tokenizers/instruct.py
@@ -20,8 +20,6 @@
ImageURLChunk,
TextChunk,
ThinkChunk,
- ToolContentChunk,
- UserContentChunk,
)
from mistral_common.protocol.instruct.messages import (
UATS,
@@ -300,7 +298,7 @@ def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list
def encode_user_content(
self,
- content: str | list[UserContentChunk],
+ content: str | list[ContentChunk],
is_last: bool,
system_prompt: str | None = None,
force_img_first: bool = False,
@@ -488,9 +486,15 @@ def _parse_json_content(self, content: str) -> Any:
except json.JSONDecodeError:
return content
- def _parse_tool_content(self, content: str | list[ToolContentChunk]) -> Any:
+ def _parse_tool_content(self, content: str | list[ContentChunk]) -> Any:
if isinstance(content, list):
- content = "".join(chunk.text for chunk in content if isinstance(chunk, TextChunk))
+ text_parts: list[str] = []
+ for chunk in content:
+ assert isinstance(chunk, TextChunk), (
+ f"Tool content only supports text chunks, got {type(chunk).__name__}."
+ )
+ text_parts.append(chunk.text)
+ content = "".join(text_parts)
return self._parse_json_content(content)
def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]:
@@ -753,7 +757,7 @@ def _encode_content_chunks(
def encode_user_content(
self,
- content: str | list[UserContentChunk],
+ content: str | list[ContentChunk],
is_last: bool,
system_prompt: str | None = None,
force_img_first: bool = False,
@@ -793,13 +797,18 @@ def encode_user_content(
assert not content_str, (
f"It is not possible that `content` is non-empty when chunk is of type {type(chunk)}."
)
- chunk_tokens, _, chunk_audio = self._encode_content_chunk(chunk)
+ chunk_tokens, maybe_image, chunk_audio = self._encode_content_chunk(chunk)
+ assert maybe_image is None, f"Unexpected image for audio chunk {type(chunk).__name__}."
audio.append(chunk_audio)
elif isinstance(chunk, (ImageChunk, ImageURLChunk)):
- chunk_tokens, chunk_image, _ = self._encode_content_chunk(chunk)
+ chunk_tokens, chunk_image, maybe_audio = self._encode_content_chunk(chunk)
+ assert maybe_audio is None, f"Unexpected audio for image chunk {type(chunk).__name__}."
images.append(chunk_image)
else:
- chunk_tokens = self._encode_content_chunk(chunk)[0]
+ chunk_tokens, maybe_image, maybe_audio = self._encode_content_chunk(chunk)
+ assert maybe_image is None and maybe_audio is None, (
+ f"Unexpected image/audio for chunk {type(chunk).__name__}."
+ )
tokens.extend(chunk_tokens)
return tokens, images, audio
@@ -891,14 +900,15 @@ def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list
tokens = [self.BEGIN_SYSTEM]
if isinstance(content := message.content, str):
content = [TextChunk(text=content)]
- content_tokens, _images, audios = self._encode_content_chunks(content)
+ content_tokens, images, audios = self._encode_content_chunks(content)
+ assert not images, f"System messages cannot contain images, got {len(images)}."
tokens += content_tokens
tokens.append(self.END_SYSTEM)
return tokens, audios
def encode_user_content(
self,
- content: str | list[UserContentChunk],
+ content: str | list[ContentChunk],
is_last: bool,
system_prompt: str | None = None,
force_img_first: bool = False,
@@ -995,7 +1005,7 @@ def _encode_instruct_transcription(self, request: TranscriptionRequest) -> Token
)
assert self.TRANSCRIBE is not None, f"{self.__class__.__name__} needs to have a TRANSCRIBE token"
prefix = self.start()
- tokens, _, audio = self.encode_user_message(
+ tokens, images, audio = self.encode_user_message(
UserMessage(content=[AudioChunk(input_audio=request.audio)]),
available_tools=[],
is_last=True,
@@ -1003,6 +1013,7 @@ def _encode_instruct_transcription(self, request: TranscriptionRequest) -> Token
system_prompt=None,
settings=ModelSettings.none(),
)
+ assert not images, f"Transcription input cannot contain images, got {len(images)}."
tokens = [*prefix, *tokens]
if request.language is not None:
@@ -1152,7 +1163,12 @@ def encode_assistant_message(
if isinstance(message.content, str):
curr_tokens = self._encode_normal_content_assistant_message(message)
elif isinstance(message.content, list):
- curr_tokens += self._encode_content_chunks(message.content)[0]
+ content_tokens, images, audios = self._encode_content_chunks(message.content)
+ assert not images and not audios, (
+ f"Assistant messages cannot contain images or audios, got {len(images)} images "
+ f"and {len(audios)} audios."
+ )
+ curr_tokens += content_tokens
if message.tool_calls:
curr_tokens += self._encode_tool_calls_in_assistant_message(message)
if not message.prefix and not continue_message:
diff --git a/tests/test_tokenizer_v15.py b/tests/test_tokenizer_v15.py
index 2c882cac..c8abe52c 100644
--- a/tests/test_tokenizer_v15.py
+++ b/tests/test_tokenizer_v15.py
@@ -488,7 +488,7 @@ def test_encode_chat_completion_with_multimodal_tool(
expected_text: str,
) -> None:
mistral_tokenizer = tokenizer_factory()
- chat_request = ChatCompletionRequest( # type: ignore[type-var]
+ chat_request: ChatCompletionRequest = ChatCompletionRequest(
messages=[
UserMessage(content="Use the tool"),
AssistantMessage(tool_calls=[ToolCall(id="test12345", function=FunctionCall(name="fn", arguments="{}"))]),
@@ -517,7 +517,7 @@ def test_encode_chat_completion_with_multimodal_system(
expected_text: str,
) -> None:
mistral_tokenizer = tokenizer_factory()
- chat_request = ChatCompletionRequest( # type: ignore[type-var]
+ chat_request: ChatCompletionRequest = ChatCompletionRequest(
messages=[
SystemMessage(content=[TextChunk(text="System with content"), content_chunk]),
UserMessage(content="Hello"),
diff --git a/tests/test_tokenizer_v7_audio.py b/tests/test_tokenizer_v7_audio.py
index 10995957..4ef3d47c 100644
--- a/tests/test_tokenizer_v7_audio.py
+++ b/tests/test_tokenizer_v7_audio.py
@@ -11,8 +11,8 @@
AudioChunk,
AudioURL,
AudioURLChunk,
+ ContentChunk,
TextChunk,
- UserContentChunk,
)
from mistral_common.protocol.instruct.messages import (
UATS,
@@ -215,7 +215,7 @@ def test_tokenize_user_message(tekkenizer: InstructTokenizerV7, audio_first: boo
text_chunk = TextChunk(text="a")
num_expected_frames = int(np.ceil(duration * frame_rate))
- chunks: list[UserContentChunk] = [audio_chunk, text_chunk] if audio_first else [text_chunk, audio_chunk]
+ chunks: list[ContentChunk] = [audio_chunk, text_chunk] if audio_first else [text_chunk, audio_chunk]
tokenized = tekkenizer.encode_instruct(
InstructRequest(
@@ -261,7 +261,7 @@ def test_tokenize_multi_turn(tekkenizer: InstructTokenizerV7) -> None:
text_chunk = TextChunk(text="a")
num_expected_frames = int(np.ceil(duration * frame_rate))
- chunks: list[UserContentChunk] = [audio_chunk, text_chunk]
+ chunks: list[ContentChunk] = [audio_chunk, text_chunk]
tokenized = tekkenizer.encode_instruct(
InstructRequest(
From b57956eb3c53523793c32d67032f8e180992cfcd Mon Sep 17 00:00:00 2001
From: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Date: Mon, 15 Jun 2026 14:23:57 +0200
Subject: [PATCH 47/47] Fix mypy.
---
tests/test_converters.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/tests/test_converters.py b/tests/test_converters.py
index d093d945..19b69a22 100644
--- a/tests/test_converters.py
+++ b/tests/test_converters.py
@@ -37,6 +37,7 @@
AudioChunk,
AudioURL,
AudioURLChunk,
+ ContentChunk,
ImageChunk,
ImageURL,
ImageURLChunk,
@@ -586,7 +587,7 @@ def test_non_leading_think_chunks_construction_ok() -> None:
[TextChunk(text="A"), TextChunk(text="B"), ThinkChunk(thinking="End", closed=True)],
],
)
-def test_non_leading_think_chunks_to_openai_raises(content: list[TextChunk | ThinkChunk]) -> None:
+def test_non_leading_think_chunks_to_openai_raises(content: list[ContentChunk]) -> None:
"""to_openai raises when ThinkChunks are not leading."""
msg = AssistantMessage(content=content)
with pytest.raises(InvalidAssistantMessageException, match="ThinkChunks must be leading"):