diff --git a/src/mistral_common/experimental/app/routers.py b/src/mistral_common/experimental/app/routers.py
index 73c9c6e3..40103e70 100644
--- a/src/mistral_common/experimental/app/routers.py
+++ b/src/mistral_common/experimental/app/routers.py
@@ -14,7 +14,7 @@
)
from mistral_common.experimental.think import _split_content_and_think_chunks
from mistral_common.experimental.tools import _decode_tool_calls, _split_content_and_tool_calls
-from mistral_common.protocol.instruct.chunk import TextChunk, ThinkChunk
+from mistral_common.protocol.instruct.chunk import ContentChunk, TextChunk, ThinkChunk
from mistral_common.protocol.instruct.messages import AssistantMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, Tokenized, TokenizerVersion
@@ -100,7 +100,7 @@ async def detokenize_to_assistant_message(
else:
content_tokens, tool_calls_tokens = tokens, ()
- content: str | list[TextChunk | ThinkChunk] | None = None
+ content: str | list[ContentChunk] | None = None
if settings.tokenizer.instruct_tokenizer.tokenizer.version >= TokenizerVersion.v13:
assert isinstance(settings.tokenizer.instruct_tokenizer, InstructTokenizerV13)
diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py
index 807ed528..b9dc662a 100644
--- a/src/mistral_common/integrations/chat_templates/template_generator.py
+++ b/src/mistral_common/integrations/chat_templates/template_generator.py
@@ -216,6 +216,30 @@ def validates_assistant_non_empty(self) -> bool:
r"""Whether to validate that assistant messages have non-empty content or tool calls."""
return self.version >= TokenizerVersion.v7 or (self.version >= TokenizerVersion.v3 and not self.spm)
+ @property
+ def tool_supports_multimodal(self) -> bool:
+ r"""Whether tool messages can contain non-text content chunks. V15+."""
+ return self.version >= TokenizerVersion.v15
+
+ @property
+ def system_supports_audio(self) -> bool:
+ r"""Whether system messages can contain audio. V15+ with audio support."""
+ return self.audio_support and self.version >= TokenizerVersion.v15
+
+
+def _join_types_desc(parts: list[str]) -> str:
+ r"""Join type names into a human-readable description string.
+
+ Args:
+ parts: List of type names (e.g. ["text", "thinking", "image"]).
+
+ Returns:
+ Formatted string like "text", "text and thinking", or "text, thinking and image".
+ """
+ if len(parts) == 1:
+ return parts[0]
+ return ", ".join(parts[:-1]) + " and " + parts[-1]
+
def _generate_header(config: TemplateConfig) -> str:
r"""Generate template header with default system message.
@@ -872,6 +896,8 @@ def _generate_system_message_handling(config: TemplateConfig) -> str:
if has_extra_types:
if config.system_supports_thinking:
rc_args += ", supported_types_desc='text and thinking'"
+ elif config.system_supports_audio:
+ rc_args += ", supported_types_desc='text and audio'"
else:
rc_args += ", supported_types_desc='text'"
if config.any_thinking_support:
@@ -882,7 +908,7 @@ def _generate_system_message_handling(config: TemplateConfig) -> str:
if config.image_support:
rc_args += ", support_images=false"
if config.audio_support:
- rc_args += ", support_audio=false"
+ rc_args += f", support_audio={'true' if config.system_supports_audio else 'false'}"
lines.append(" {{- render_content(" + rc_args + ") -}}")
lines.append(" {{- '" + _END_SYSTEM + "' -}}")
@@ -1204,10 +1230,10 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str:
"""
lines = []
+ comment_parts = ["text"]
if config.any_thinking_support:
- chunk_types = "text and thinking"
- else:
- chunk_types = "text"
+ comment_parts.append("thinking")
+ chunk_types = _join_types_desc(comment_parts)
comment = f"{{#- Assistant messages supports {chunk_types} content. #}}"
lines.append("")
@@ -1235,10 +1261,10 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str:
has_extra_types = config.any_thinking_support or config.image_support or config.audio_support
rc_call_args = "message['content'], 'assistant message contents'"
if has_extra_types:
+ desc_parts = ["text"]
if config.any_thinking_support:
- rc_call_args += ", supported_types_desc='text and thinking'"
- else:
- rc_call_args += ", supported_types_desc='text'"
+ desc_parts.append("thinking")
+ rc_call_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'"
if config.any_thinking_support:
rc_call_args += ", support_thinking=true"
if config.image_support:
@@ -1423,7 +1449,10 @@ def _generate_tool_message_handling(config: TemplateConfig) -> str:
lines.append(" {#- Tool messages supports int, float or text content. #}")
lines.append(" {%- elif message['role'] == 'tool' and ns.index > ns.max_idx_user %}")
else:
- lines.append(" {#- Tool messages only supports text content. #}")
+ if config.tool_supports_multimodal:
+ lines.append(" {#- Tool messages (multimodal). #}")
+ else:
+ lines.append(" {#- Tool messages only supports text content. #}")
lines.append(" {%- elif message['role'] == 'tool' %}")
if config.uses_spm_space_tracking:
@@ -1484,9 +1513,26 @@ def _generate_tool_message_handling(config: TemplateConfig) -> str:
+ "' }}" # noqa: E501
)
elif config.uses_simple_tool_results:
- lines.append(
- " {{- '" + _BEGIN_TOOL_RESULTS + "' + message['content']|string + '" + _END_TOOL_RESULTS + "' }}"
- ) # noqa: E501
+ if config.tool_supports_multimodal:
+ tool_rc_args = "message['content'], 'tool message contents'"
+ if config.image_support or config.audio_support:
+ desc_parts = ["text"]
+ if config.image_support:
+ desc_parts.append("image")
+ if config.audio_support:
+ desc_parts.append("audio")
+ tool_rc_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'"
+ if config.image_support:
+ tool_rc_args += ", support_images=true"
+ if config.audio_support:
+ tool_rc_args += ", support_audio=true"
+ lines.append(" {{- '" + _BEGIN_TOOL_RESULTS + "' -}}")
+ lines.append(" {{- render_content(" + tool_rc_args + ") -}}")
+ lines.append(" {{- '" + _END_TOOL_RESULTS + "' }}")
+ else:
+ lines.append(
+ " {{- '" + _BEGIN_TOOL_RESULTS + "' + message['content']|string + '" + _END_TOOL_RESULTS + "' }}"
+ ) # noqa: E501
else:
# v3 non-spm style
lines.extend(_emit_int_float_parsing(" "))
diff --git a/src/mistral_common/protocol/instruct/chunk.py b/src/mistral_common/protocol/instruct/chunk.py
index 2ed61528..91fcb679 100644
--- a/src/mistral_common/protocol/instruct/chunk.py
+++ b/src/mistral_common/protocol/instruct/chunk.py
@@ -451,9 +451,6 @@ def from_openai(cls, openai_chunk: dict[str, Any]) -> "ThinkChunk":
ContentChunk = Annotated[
TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk | ThinkChunk, Field(discriminator="type")
]
-UserContentChunk = Annotated[
- TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk, Field(discriminator="type")
-]
def _convert_openai_content_chunks(openai_content_chunks: dict[str, Any]) -> ContentChunk:
diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py
index e82379cb..f96b78af 100644
--- a/src/mistral_common/protocol/instruct/messages.py
+++ b/src/mistral_common/protocol/instruct/messages.py
@@ -1,17 +1,22 @@
import warnings
+from collections.abc import Sequence
from enum import Enum
-from typing import Any, Literal, TypeGuard, TypeVar
+from typing import Any, ClassVar, Literal, TypeVar
-from pydantic import Field
-from typing_extensions import Annotated, TypeAlias
+from pydantic import Field, model_validator
+from typing_extensions import Annotated, TypeAlias, TypeGuard
from mistral_common.base import MistralBase
from mistral_common.exceptions import InvalidAssistantMessageException
from mistral_common.protocol.instruct.chunk import (
+ AudioChunk,
+ AudioURLChunk,
+ BaseContentChunk,
ContentChunk,
+ ImageChunk,
+ ImageURLChunk,
TextChunk,
ThinkChunk,
- UserContentChunk,
_convert_openai_content_chunks,
)
from mistral_common.protocol.instruct.tool_calls import ToolCall
@@ -23,12 +28,14 @@
)
-def _are_think_chunks(think_chunks: list[ThinkChunk | TextChunk]) -> TypeGuard[list[ThinkChunk]]:
- return all(isinstance(c, ThinkChunk) for c in think_chunks)
+def _are_think_chunks(chunks: Sequence[ContentChunk]) -> TypeGuard[list[ThinkChunk]]:
+ r"""Narrow a chunk list to ThinkChunk list."""
+ return all(isinstance(c, ThinkChunk) for c in chunks)
-def _are_text_chunks(think_chunks: list[ThinkChunk | TextChunk]) -> TypeGuard[list[TextChunk]]:
- return all(isinstance(c, TextChunk) for c in think_chunks)
+def _are_text_chunks(chunks: Sequence[ContentChunk]) -> TypeGuard[list[TextChunk]]:
+ r"""Narrow a chunk list to TextChunk list."""
+ return all(isinstance(c, TextChunk) for c in chunks)
class ReasoningFieldFormat(str, Enum):
@@ -73,6 +80,57 @@ class BaseMessage(MistralBase):
role: Literal[Roles.system, Roles.user, Roles.assistant, Roles.tool]
+ # Allow-list of content chunk types accepted by this message. Must be set by each subclass.
+ _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]]
+
+ @model_validator(mode="after")
+ def _validate_allowed_content_chunks(self) -> "BaseMessage":
+ r"""Enforce the per-message content chunk allow-list."""
+ content = getattr(self, "content", None)
+ if isinstance(content, list):
+ for chunk in content:
+ if not isinstance(chunk, self._allowed_content_chunks):
+ raise ValueError(f"{type(chunk).__name__} cannot be used in {self.role} message.")
+ return self
+
+ @staticmethod
+ def _content_to_openai(
+ content: str | Sequence[ContentChunk] | None,
+ ) -> str | list[dict[str, Any]] | None:
+ r"""Serialize message content to OpenAI format.
+
+ Args:
+ content: String, list of content chunks, or None.
+
+ Returns:
+ String content as-is, list of chunks serialized via each chunk's
+ to_openai(), or None.
+ """
+ if content is None or isinstance(content, str):
+ return content
+ return [chunk.to_openai() for chunk in content]
+
+ @staticmethod
+ def _content_from_openai(
+ raw: str | list[dict[str, Any]] | None,
+ ) -> str | list[ContentChunk] | None:
+ r"""Deserialize content from OpenAI format.
+
+ Args:
+ raw: Raw content from OpenAI message dict.
+
+ Returns:
+ String content as-is, list of deserialized content chunks, or None.
+
+ Raises:
+ ValueError: If content type is unrecognized.
+ """
+ if raw is None or isinstance(raw, str):
+ return raw
+ if isinstance(raw, list):
+ return [_convert_openai_content_chunks(chunk) for chunk in raw]
+ raise ValueError(f"Unknown content type: {type(raw)}")
+
def to_openai(self) -> dict[str, Any]:
r"""Converts the message to the OpenAI format.
@@ -100,24 +158,24 @@ class UserMessage(BaseMessage):
"""
role: Literal[Roles.user] = Roles.user
- content: str | list[UserContentChunk]
+ content: str | list[ContentChunk]
+ _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = (
+ TextChunk,
+ ImageChunk,
+ ImageURLChunk,
+ AudioChunk,
+ AudioURLChunk,
+ )
def to_openai(self) -> dict[str, Any]:
r"""Converts the message to the OpenAI format."""
- if isinstance(self.content, str):
- return {"role": self.role, "content": self.content}
- return {"role": self.role, "content": [chunk.to_openai() for chunk in self.content]}
+ return {"role": self.role, "content": self._content_to_openai(self.content)}
@classmethod
def from_openai(cls, openai_message: dict[str, Any]) -> "UserMessage":
r"""Converts the OpenAI message to the Mistral format."""
- if isinstance(openai_message["content"], str):
- return cls.model_validate_ignore_extra(openai_message)
return cls.model_validate(
- {
- "role": openai_message["role"],
- "content": [_convert_openai_content_chunks(chunk) for chunk in openai_message["content"]],
- },
+ {"role": openai_message["role"], "content": cls._content_from_openai(openai_message["content"])}
)
@@ -132,16 +190,19 @@ class SystemMessage(BaseMessage):
"""
role: Literal[Roles.system] = Roles.system
- content: str | list[TextChunk | ThinkChunk]
+ content: str | list[ContentChunk]
+ _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = (TextChunk, AudioChunk, ThinkChunk)
def to_openai(self) -> dict[str, Any]:
r"""Converts the message to the OpenAI format."""
- return self.model_dump()
+ return {"role": self.role, "content": self._content_to_openai(self.content)}
@classmethod
def from_openai(cls, openai_message: dict[str, Any]) -> "SystemMessage":
r"""Converts the OpenAI message to the Mistral format."""
- return cls.model_validate_ignore_extra(openai_message)
+ return cls.model_validate(
+ {"role": openai_message["role"], "content": cls._content_from_openai(openai_message["content"])}
+ )
class AssistantMessage(BaseMessage):
@@ -158,7 +219,8 @@ class AssistantMessage(BaseMessage):
"""
role: Literal[Roles.assistant] = Roles.assistant
- content: str | list[TextChunk | ThinkChunk] | None = None
+ content: str | list[ContentChunk] | None = None
+ _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = (TextChunk, ThinkChunk)
tool_calls: list[ToolCall] | None = None
prefix: bool = False
@@ -205,7 +267,7 @@ def to_openai(
match reasoning_field_format:
case None | ReasoningFieldFormat.thinking_chunks:
- out_dict["content"] = [chunk.to_openai() for chunk in self.content]
+ out_dict["content"] = self._content_to_openai(self.content)
case ReasoningFieldFormat.reasoning | ReasoningFieldFormat.reasoning_content:
think_chunks, content_chunks = self.content[: last_think_idx + 1], self.content[last_think_idx + 1 :]
if not _are_think_chunks(think_chunks) or not _are_text_chunks(content_chunks):
@@ -216,7 +278,7 @@ def to_openai(
if len(content_chunks) == 1:
out_dict["content"] = content_chunks[0].text
elif content_chunks:
- out_dict["content"] = [chunk.to_openai() for chunk in content_chunks]
+ out_dict["content"] = self._content_to_openai(content_chunks)
case _:
raise ValueError(f"{reasoning_field_format=} is not supported.")
@@ -234,14 +296,7 @@ def from_openai(cls, openai_message: dict[str, Any]) -> "AssistantMessage":
tools_calls.append(ToolCall.from_openai(openai_tool_call))
else:
raise ValueError(f"tool_calls must be a list, got {type(openai_tool_calls)}")
- openai_content = openai_message.get("content", None)
- content: str | list[ContentChunk] | None = None
- if openai_content is None or isinstance(openai_content, str):
- content = openai_content
- elif isinstance(openai_content, list):
- content = [_convert_openai_content_chunks(chunk) for chunk in openai_content]
- else:
- raise ValueError(f"Unknown content type: {type(openai_content)}")
+ content = cls._content_from_openai(openai_message.get("content"))
reasoning_content: str | None = openai_message.get("reasoning_content")
reasoning: str | None = openai_message.get("reasoning")
@@ -308,23 +363,44 @@ class ToolMessage(BaseMessage):
>>> message = ToolMessage(content="Hello, how can I help you?", tool_call_id="123")
"""
- content: str | list[TextChunk]
+ content: str | list[ContentChunk]
role: Literal[Roles.tool] = Roles.tool
tool_call_id: str | None = None
# Deprecated in V3 tokenization
name: str | None = None
+ # Tool messages accept all content chunk types.
+ _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = (
+ TextChunk,
+ ImageChunk,
+ ImageURLChunk,
+ AudioChunk,
+ AudioURLChunk,
+ ThinkChunk,
+ )
+
def to_openai(self) -> dict[str, Any]:
r"""Converts the message to the OpenAI format."""
assert self.tool_call_id is not None, "tool_call_id must be provided for tool messages."
- return self.model_dump(exclude={"name"})
+ return {
+ "role": self.role,
+ "tool_call_id": self.tool_call_id,
+ "content": self._content_to_openai(self.content),
+ }
@classmethod
- def from_openai(cls, messages: dict[str, str | list[dict[str, str | dict[str, Any]]]]) -> "ToolMessage":
+ def from_openai(cls, openai_message: dict[str, Any]) -> "ToolMessage":
r"""Converts the OpenAI message to the Mistral format."""
- tool_message = cls.model_validate_ignore_extra(messages)
- assert tool_message.tool_call_id is not None, "tool_call_id must be provided for tool messages."
+ content = cls._content_from_openai(openai_message["content"])
+ tool_message = cls.model_validate(
+ {
+ "role": openai_message["role"],
+ "tool_call_id": openai_message["tool_call_id"],
+ "content": content,
+ "name": openai_message.get("name"),
+ }
+ )
return tool_message
diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py
index c96a1940..c423cda0 100644
--- a/src/mistral_common/protocol/instruct/normalize.py
+++ b/src/mistral_common/protocol/instruct/normalize.py
@@ -1,18 +1,13 @@
import json
import warnings
-from typing import Generic, Sequence, TypeGuard
+from typing import Generic, Sequence
from typing_extensions import assert_never
from mistral_common.exceptions import InvalidRequestException
from mistral_common.protocol.instruct.chunk import (
- AudioChunk,
- AudioURLChunk,
ContentChunk,
- ImageChunk,
- ImageURLChunk,
TextChunk,
- ThinkChunk,
)
from mistral_common.protocol.instruct.messages import (
UATS,
@@ -98,18 +93,6 @@ def _flush_text() -> None:
return all_content
-def _is_assistant_content(chunks: list[ContentChunk]) -> TypeGuard[list[TextChunk | ThinkChunk]]:
- """Narrow ContentChunk list to assistant-compatible types."""
- return all(isinstance(c, (TextChunk, ThinkChunk)) for c in chunks)
-
-
-def _is_user_content(
- chunks: list[ContentChunk],
-) -> TypeGuard[list[TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk]]:
- """Narrow ContentChunk list to user-compatible types."""
- return all(not isinstance(c, ThinkChunk) for c in chunks)
-
-
class InstructRequestNormalizer(
Generic[UserMessageType, AssistantMessageType, ToolMessageType, SystemMessageType, InstructRequestType]
):
@@ -244,13 +227,17 @@ def _aggregate_system_prompts(self, messages: list[UATS]) -> str | None:
def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
"""Normalize tool messages without aggregation across messages.
- Each tool message's content chunks are aggregated individually and JSON is normalized.
+ Each tool message's content is JSON-normalized; chunk types are guaranteed by the validator.
"""
tool_messages: list[ToolMessageType] = []
for message in messages:
assert isinstance(message, self._tool_message_class), "Expected tool message"
- content = self._aggregate_content_chunks_to_str_same_message(message)
+ content = self._aggregate_content_chunks([message])
+ assert isinstance(content, str), (
+ f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}"
+ )
normalized_content = self._normalize_json_content(content)
+
tool_messages.append(
self._tool_message_class(
content=normalized_content, tool_call_id=message.tool_call_id, name=message.name
@@ -296,15 +283,8 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
)
weight = message.weight
- if isinstance(content, str) or _is_assistant_content(content):
- narrowed_content: str | list[TextChunk | ThinkChunk] = content
- else:
- raise InvalidRequestException(
- f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}"
- )
-
aggregated_message = self._assistant_message_class(
- content=narrowed_content,
+ content=content,
tool_calls=tool_calls or None,
prefix=prefix,
)
@@ -316,12 +296,7 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag
def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType:
"""Coalesce neighboring blocks of ContentChunks in user messages."""
content = self._aggregate_content_chunks(messages)
- if isinstance(content, str) or _is_user_content(content):
- return self._user_message_class(content=content)
- else:
- raise InvalidRequestException(
- f"Unexpected content chunk types in user message: {[type(c).__name__ for c in content]}"
- )
+ return self._user_message_class(content=content)
def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]:
if role == Roles.tool:
@@ -444,12 +419,31 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I
UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest[UATS, Tool], None
)
+ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
+ """Normalize tool messages without JSON normalization.
+
+ V7+ normalizers skip JSON content normalization for tool messages (chunk-type validation is
+ handled by the validator).
+ """
+ tool_messages: list[ToolMessageType] = []
+ for message in messages:
+ assert isinstance(message, self._tool_message_class), "Expected tool message"
+ content = self._aggregate_content_chunks([message])
+ assert isinstance(content, str), (
+ f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}"
+ )
+ tool_messages.append(
+ self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name)
+ )
+ return tool_messages
+
def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]:
- return [
- self._system_message_class(content=self._aggregate_content_chunks([message]))
- for message in messages
- if isinstance(message, self._system_message_class)
- ]
+ aggregated: list[SystemMessageType] = []
+ for message in messages:
+ if isinstance(message, self._system_message_class):
+ content = self._aggregate_content_chunks([message])
+ aggregated.append(self._system_message_class(content=content))
+ return aggregated
def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]:
if role == Roles.tool:
@@ -530,8 +524,8 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I
UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest[UATS, Tool], None
)
- def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
- tool_messages: list[ToolMessageType] = super()._aggregate_tool_messages(messages, latest_call_ids)
+ @staticmethod
+ def _inplace_sort_tool_messages(tool_messages: list[ToolMessageType], latest_call_ids: list[str]) -> None:
id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)}
id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)}
# First order by tool call idx and then by tool result idx
@@ -541,6 +535,10 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s
id_to_tool_result_idx[msg.tool_call_id],
),
)
+
+ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
+ tool_messages: list[ToolMessageType] = super()._aggregate_tool_messages(messages, latest_call_ids)
+ self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids)
return tool_messages
@@ -555,6 +553,18 @@ class InstructRequestNormalizerV15(InstructRequestNormalizerV13):
_chunk_join_str: str = ""
+ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]:
+ r"""V15 keeps all aggregated tool content (validation handled by the validator)."""
+ tool_messages: list[ToolMessageType] = []
+ for message in messages:
+ assert isinstance(message, self._tool_message_class), "Expected tool message"
+ content = self._aggregate_content_chunks([message])
+ tool_messages.append(
+ self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name)
+ )
+ self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids)
+ return tool_messages
+
@staticmethod
def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15":
r"""Returns a normalizer for the V15 instruct request.
diff --git a/src/mistral_common/protocol/instruct/validator.py b/src/mistral_common/protocol/instruct/validator.py
index f121e49e..e15d0935 100644
--- a/src/mistral_common/protocol/instruct/validator.py
+++ b/src/mistral_common/protocol/instruct/validator.py
@@ -1,4 +1,5 @@
import re
+from collections.abc import Sequence
from enum import Enum
from typing import Generic
@@ -14,8 +15,18 @@
InvalidToolException,
InvalidToolMessageException,
InvalidToolSchemaException,
+ InvalidUserMessageException,
+ MistralCommonException,
+)
+from mistral_common.protocol.instruct.chunk import (
+ AudioChunk,
+ AudioURLChunk,
+ ContentChunk,
+ ImageChunk,
+ ImageURLChunk,
+ TextChunk,
+ ThinkChunk,
)
-from mistral_common.protocol.instruct.chunk import AudioChunk, AudioURLChunk
from mistral_common.protocol.instruct.messages import (
UATS,
AssistantMessage,
@@ -38,6 +49,23 @@
from mistral_common.tokens.tokenizers.base import TokenizerVersion
+def _validate_content_chunk_types(
+ content: str | Sequence[ContentChunk] | None,
+ allowed: tuple[type[ContentChunk], ...],
+ role: str,
+ exception_cls: type[MistralCommonException],
+) -> None:
+ r"""Raise if any chunk in a list content is not an instance of `allowed`.
+
+ String or None content is always accepted (covered elsewhere).
+ """
+ if not isinstance(content, list):
+ return
+ invalid = sorted({type(chunk).__name__ for chunk in content if not isinstance(chunk, allowed)})
+ if invalid:
+ raise exception_cls(f"Unexpected content chunk types in {role} message: {invalid}")
+
+
class ValidationMode(str, Enum):
r"""Enum for the validation mode.
@@ -153,13 +181,30 @@ def _validate_tools(self, tools: list[Tool]) -> None:
self._validate_function(tool.function)
def _validate_user_message(self, message: UserMessageType) -> None:
- pass
+ self._validate_user_content_chunks(message.content)
+
+ def _validate_user_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
+ r"""v1/v2 user messages accept text content only (image >= v3, audio >= v7)."""
+ _validate_content_chunk_types(content, (TextChunk,), "user", InvalidUserMessageException)
+
+ def _validate_assistant_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
+ r"""Pre-v11 assistant messages accept text content only."""
+ _validate_content_chunk_types(content, (TextChunk,), "assistant", InvalidAssistantMessageException)
+
+ def _validate_system_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
+ r"""v1-v3 system messages accept text content only."""
+ _validate_content_chunk_types(content, (TextChunk,), "system", InvalidSystemPromptException)
+
+ def _validate_tool_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
+ r"""Pre-v15 tool messages accept text content only."""
+ _validate_content_chunk_types(content, (TextChunk,), "tool", InvalidToolMessageException)
def _validate_tool_message(self, message: ToolMessageType) -> None:
"""
Checks:
- The tool name is valid
"""
+ self._validate_tool_content_chunks(message.content)
if message.name is not None:
if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", message.name):
raise InvalidToolMessageException(
@@ -174,6 +219,7 @@ def _validate_system_message(self, message: SystemMessageType) -> None:
"""
if message.content is None:
raise InvalidSystemPromptException("System prompt must have content")
+ self._validate_system_content_chunks(message.content)
def _validate_function_call(self, function_call: FunctionCall) -> None:
"""
@@ -201,6 +247,8 @@ def _validate_assistant_message(self, message: AssistantMessageType, is_last_mes
- That the tool calls are valid
"""
+ self._validate_assistant_content_chunks(message.content)
+
# Validate that the message has either text or tool_calls
# but not both and not neither.
if (not self._allow_tool_call_and_content) and (bool(message.content) == bool(message.tool_calls)):
@@ -359,18 +407,19 @@ class MistralRequestValidatorV3(MistralRequestValidator):
>>> validator = MistralRequestValidatorV3()
"""
+ def _validate_user_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
+ r"""v3 user messages accept text and image chunks (audio >= v7)."""
+ _validate_content_chunk_types(
+ content, (TextChunk, ImageChunk, ImageURLChunk), "user", InvalidUserMessageException
+ )
+
def _validate_tool_message(self, message: ToolMessageType) -> None:
"""
Checks:
- The tool name is valid
- Tool call id is valid
"""
- if message.name is not None:
- if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", message.name):
- raise InvalidToolMessageException(
- f"Function name was {message.name} but must be a-z, A-Z, 0-9, "
- "or contain underscores and dashes, with a maximum length of 64."
- )
+ super()._validate_tool_message(message)
if message.tool_call_id is None:
raise InvalidRequestException("Tool call id has to be defined.")
@@ -423,6 +472,21 @@ class MistralRequestValidatorV5(MistralRequestValidatorV3):
_allow_tool_call_and_content: bool = True
+ def _validate_user_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
+ r"""v7+ user messages accept text, image and audio chunks."""
+ _validate_content_chunk_types(
+ content,
+ (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk),
+ "user",
+ InvalidUserMessageException,
+ )
+
+ def _validate_system_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
+ r"""v7+ system messages accept text, audio and thinking chunks."""
+ _validate_content_chunk_types(
+ content, (TextChunk, AudioChunk, ThinkChunk), "system", InvalidSystemPromptException
+ )
+
def _validate_system_prompt_and_audio(self, messages: list[UATS]) -> None:
r"""Validates that system prompts and audio chunks are not used together in v5."""
@@ -507,12 +571,24 @@ def _validate_tool_calls_followed_by_tool_messages(self, messages: list[UATS]) -
elif len(expected_tool_ids) < len(observed_tool_ids) and self._mode == ValidationMode.finetuning:
raise InvalidMessageStructureException("More tool responses than tool calls")
+ def _validate_assistant_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
+ r"""v11+ assistant messages accept text and thinking chunks."""
+ _validate_content_chunk_types(content, (TextChunk, ThinkChunk), "assistant", InvalidAssistantMessageException)
+
def _validate_system_prompt_and_audio(self, messages: list[UATS]) -> None:
r"""Allows system prompts and audio chunks to coexist in v13."""
return
class MistralRequestValidatorV15(MistralRequestValidatorV13):
+ def _validate_system_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
+ r"""v15 system messages accept text and audio but reject thinking chunks."""
+ _validate_content_chunk_types(content, (TextChunk, AudioChunk), "system", InvalidSystemPromptException)
+
+ def _validate_tool_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None:
+ r"""v15 tool messages accept all content chunk types."""
+ return
+
def _validate_model_settings(self, request: ChatCompletionRequest) -> None:
pass
diff --git a/src/mistral_common/tokens/tokenizers/base.py b/src/mistral_common/tokens/tokenizers/base.py
index 1fe580fa..06dff715 100644
--- a/src/mistral_common/tokens/tokenizers/base.py
+++ b/src/mistral_common/tokens/tokenizers/base.py
@@ -9,7 +9,7 @@
from mistral_common.base import MistralBase
from mistral_common.protocol.fim.request import FIMRequest
-from mistral_common.protocol.instruct.chunk import UserContentChunk
+from mistral_common.protocol.instruct.chunk import ContentChunk
from mistral_common.protocol.instruct.messages import (
AssistantMessageType,
UserMessage,
@@ -428,7 +428,7 @@ def encode_user_message(
@abstractmethod
def encode_user_content(
self,
- content: str | list[UserContentChunk],
+ content: str | list[ContentChunk],
is_last: bool,
system_prompt: str | None = None,
force_img_first: bool = False,
diff --git a/src/mistral_common/tokens/tokenizers/instruct.py b/src/mistral_common/tokens/tokenizers/instruct.py
index 7992df1c..7d0e5ab8 100644
--- a/src/mistral_common/tokens/tokenizers/instruct.py
+++ b/src/mistral_common/tokens/tokenizers/instruct.py
@@ -20,7 +20,6 @@
ImageURLChunk,
TextChunk,
ThinkChunk,
- UserContentChunk,
)
from mistral_common.protocol.instruct.messages import (
UATS,
@@ -104,7 +103,9 @@ def find_first_last_user(request: InstructRequest) -> tuple[int, int]:
return first_user_idx, last_user_idx
@abstractmethod
- def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]:
+ def encode_tool_message(
+ self, message: ToolMessage, is_before_last_user_message: bool
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a tool message.
Raises:
@@ -192,7 +193,9 @@ def encode_instruct(
images.extend(new_images)
audios.extend(new_audios)
elif isinstance(msg, ToolMessage):
- new_tokens = self.encode_tool_message(msg, msg_idx < last_user_idx)
+ new_tokens, new_images, new_audios = self.encode_tool_message(msg, msg_idx < last_user_idx)
+ images.extend(new_images)
+ audios.extend(new_audios)
elif isinstance(msg, AssistantMessage):
continue_message = request.continue_final_message and (msg_idx == len(request.messages) - 1)
@@ -202,7 +205,8 @@ def encode_instruct(
if msg_idx == len(request.messages) - 1:
prefix_ids = new_tokens
elif isinstance(msg, SystemMessage):
- new_tokens = self.encode_system_message(msg)
+ new_tokens, new_audios = self.encode_system_message(msg)
+ audios.extend(new_audios)
else:
raise TokenizerException(f"Unknown message type {type(msg)}")
@@ -289,12 +293,12 @@ def encode_user_message(
curr_tokens, image, audio = self.encode_user_content(content=message_txt, is_last=False, system_prompt=None)
return curr_tokens, image, audio
- def encode_system_message(self, message: SystemMessage) -> list[int]:
+ def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[Audio]]:
raise NotImplementedError(f"System message encoding not implemented for {self.__class__.__name__}")
def encode_user_content(
self,
- content: str | list[UserContentChunk],
+ content: str | list[ContentChunk],
is_last: bool,
system_prompt: str | None = None,
force_img_first: bool = False,
@@ -318,7 +322,9 @@ def encode_user_content(
tokens = self.tokenizer.encode(content, bos=False, eos=False)
return tokens, [], []
- def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]:
+ def encode_tool_message(
+ self, message: ToolMessage, is_before_last_user_message: bool
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a tool message.
Raises:
@@ -480,9 +486,15 @@ def _parse_json_content(self, content: str) -> Any:
except json.JSONDecodeError:
return content
- def _parse_tool_content(self, content: str | list[TextChunk]) -> Any:
+ def _parse_tool_content(self, content: str | list[ContentChunk]) -> Any:
if isinstance(content, list):
- content = "".join(chunk.text for chunk in content)
+ text_parts: list[str] = []
+ for chunk in content:
+ assert isinstance(chunk, TextChunk), (
+ f"Tool content only supports text chunks, got {type(chunk).__name__}."
+ )
+ text_parts.append(chunk.text)
+ content = "".join(text_parts)
return self._parse_json_content(content)
def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]:
@@ -492,7 +504,9 @@ def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]:
"content": self._parse_tool_content(tool_message.content),
}
- def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]:
+ def encode_tool_message(
+ self, message: ToolMessage, is_before_last_user_message: bool
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a tool message.
Args:
@@ -501,11 +515,11 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
not encoded.
Returns:
- The encoded tokens.
+ The encoded tokens, images, and audios.
"""
if is_before_last_user_message:
# don't tokenize last tool response before last user msg
- return []
+ return [], [], []
# Currently only supports single tool results
tool_result_str = json.dumps([self._prepare_tool_result(message)], ensure_ascii=False)
@@ -514,7 +528,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
*self.tokenizer.encode(tool_result_str, bos=False, eos=False),
self.END_TOOL_RESULTS,
]
- return curr_tokens
+ return curr_tokens, [], []
def _prepare_function_call(self, tool_call: ToolCall) -> dict[str, Any]:
r"""Bit of a hack due to the way function calls are tokenized."""
@@ -652,7 +666,9 @@ def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]:
"call_id": tool_message.tool_call_id,
}
- def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]:
+ def encode_tool_message(
+ self, message: ToolMessage, is_before_last_user_message: bool
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a tool message.
Note:
@@ -665,7 +681,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
not encoded.
Returns:
- The encoded tokens.
+ The encoded tokens, images, and audios.
"""
tool_result_str = json.dumps(self._prepare_tool_result(message), ensure_ascii=False)
curr_tokens = [
@@ -673,7 +689,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
*self.tokenizer.encode(tool_result_str, bos=False, eos=False),
self.END_TOOL_RESULTS,
]
- return curr_tokens
+ return curr_tokens, [], []
def encode_assistant_message(
self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool
@@ -741,7 +757,7 @@ def _encode_content_chunks(
def encode_user_content(
self,
- content: str | list[UserContentChunk],
+ content: str | list[ContentChunk],
is_last: bool,
system_prompt: str | None = None,
force_img_first: bool = False,
@@ -781,13 +797,18 @@ def encode_user_content(
assert not content_str, (
f"It is not possible that `content` is non-empty when chunk is of type {type(chunk)}."
)
- chunk_tokens, _, chunk_audio = self._encode_content_chunk(chunk)
+ chunk_tokens, maybe_image, chunk_audio = self._encode_content_chunk(chunk)
+ assert maybe_image is None, f"Unexpected image for audio chunk {type(chunk).__name__}."
audio.append(chunk_audio)
elif isinstance(chunk, (ImageChunk, ImageURLChunk)):
- chunk_tokens, chunk_image, _ = self._encode_content_chunk(chunk)
+ chunk_tokens, chunk_image, maybe_audio = self._encode_content_chunk(chunk)
+ assert maybe_audio is None, f"Unexpected audio for image chunk {type(chunk).__name__}."
images.append(chunk_image)
else:
- chunk_tokens = self._encode_content_chunk(chunk)[0]
+ chunk_tokens, maybe_image, maybe_audio = self._encode_content_chunk(chunk)
+ assert maybe_image is None and maybe_audio is None, (
+ f"Unexpected image/audio for chunk {type(chunk).__name__}."
+ )
tokens.extend(chunk_tokens)
return tokens, images, audio
@@ -867,26 +888,27 @@ def drop(idx: int) -> None:
if to_drop > 0:
raise TokenizerException("Input couldn't fit in truncate_at_max_token")
- def encode_system_message(self, message: SystemMessage) -> list[int]:
+ def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[Audio]]:
r"""Encode a system message.
Args:
message: The message to encode.
Returns:
- The encoded tokens.
+ The encoded tokens and audios.
"""
-
tokens = [self.BEGIN_SYSTEM]
if isinstance(content := message.content, str):
content = [TextChunk(text=content)]
- tokens += self._encode_content_chunks(content)[0]
+ content_tokens, images, audios = self._encode_content_chunks(content)
+ assert not images, f"System messages cannot contain images, got {len(images)}."
+ tokens += content_tokens
tokens.append(self.END_SYSTEM)
- return tokens
+ return tokens, audios
def encode_user_content(
self,
- content: str | list[UserContentChunk],
+ content: str | list[ContentChunk],
is_last: bool,
system_prompt: str | None = None,
force_img_first: bool = False,
@@ -983,7 +1005,7 @@ def _encode_instruct_transcription(self, request: TranscriptionRequest) -> Token
)
assert self.TRANSCRIBE is not None, f"{self.__class__.__name__} needs to have a TRANSCRIBE token"
prefix = self.start()
- tokens, _, audio = self.encode_user_message(
+ tokens, images, audio = self.encode_user_message(
UserMessage(content=[AudioChunk(input_audio=request.audio)]),
available_tools=[],
is_last=True,
@@ -991,6 +1013,7 @@ def _encode_instruct_transcription(self, request: TranscriptionRequest) -> Token
system_prompt=None,
settings=ModelSettings.none(),
)
+ assert not images, f"Transcription input cannot contain images, got {len(images)}."
tokens = [*prefix, *tokens]
if request.language is not None:
@@ -1081,7 +1104,9 @@ def _has_audio(messages: list[UATS]) -> bool:
for message in messages
)
- def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]:
+ def encode_tool_message(
+ self, message: ToolMessage, is_before_last_user_message: bool
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a tool message.
Note:
@@ -1093,7 +1118,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
is_before_last_user_message: Not used.
Returns:
- The encoded tokens.
+ The encoded tokens, images, and audios.
"""
assert message.tool_call_id is not None
assert isinstance(message.content, str), "Message content must be normalized"
@@ -1110,7 +1135,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message:
*tokens,
self.END_TOOL_RESULTS,
]
- return curr_tokens
+ return curr_tokens, [], []
def encode_assistant_message(
self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool
@@ -1138,7 +1163,12 @@ def encode_assistant_message(
if isinstance(message.content, str):
curr_tokens = self._encode_normal_content_assistant_message(message)
elif isinstance(message.content, list):
- curr_tokens += self._encode_content_chunks(message.content)[0]
+ content_tokens, images, audios = self._encode_content_chunks(message.content)
+ assert not images and not audios, (
+ f"Assistant messages cannot contain images or audios, got {len(images)} images "
+ f"and {len(audios)} audios."
+ )
+ curr_tokens += content_tokens
if message.tool_calls:
curr_tokens += self._encode_tool_calls_in_assistant_message(message)
if not message.prefix and not continue_message:
@@ -1282,28 +1312,32 @@ def _encode_tool_calls_in_assistant_message(self, message: AssistantMessageType)
]
return curr_tokens
- def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]:
+ def encode_tool_message(
+ self, message: ToolMessage, is_before_last_user_message: bool
+ ) -> tuple[list[int], list[np.ndarray], list[Audio]]:
r"""Encode a tool message.
Args:
message: The message to encode.
is_before_last_user_message: Not used.
Returns:
- The encoded tokens.
+ The encoded tokens, images, and audios.
"""
assert message.tool_call_id is not None, "Tool call id must be provided for tokenizer >= v13"
- content = message.content
- if not isinstance(content, str):
- content = "".join(chunk.text for chunk in content)
+ if isinstance(message.content, str):
+ content_tokens = self.tokenizer.encode(message.content, bos=False, eos=False)
+ images: list[np.ndarray] = []
+ audios: list[Audio] = []
+ else:
+ content_tokens, images, audios = self._encode_content_chunks(message.content)
- tokens = self.tokenizer.encode(content, bos=False, eos=False)
curr_tokens = [
self.BEGIN_TOOL_RESULTS,
- *tokens,
+ *content_tokens,
self.END_TOOL_RESULTS,
]
- return curr_tokens
+ return curr_tokens, images, audios
def encode_think(self, chunk: ThinkChunk) -> list[int]:
r"""Encode a thinking chunk.
@@ -1368,7 +1402,7 @@ def _encode_settings(
]
return settings_tokens
- def encode_system_message(self, message: SystemMessage) -> list[int]:
+ def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[Audio]]:
r"""Encode a system message, rejecting ThinkChunk content."""
if isinstance(message.content, list):
if any(isinstance(chunk, ThinkChunk) for chunk in message.content):
diff --git a/tests/data/chat_templates/v15.jinja b/tests/data/chat_templates/v15.jinja
index 26b4f9d2..e4df96ef 100644
--- a/tests/data/chat_templates/v15.jinja
+++ b/tests/data/chat_templates/v15.jinja
@@ -180,9 +180,11 @@
{{- eos_token }}
- {#- Tool messages only supports text content. #}
+ {#- Tool messages (multimodal). #}
{%- elif message['role'] == 'tool' %}
- {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}
+ {{- '[TOOL_RESULTS]' -}}
+ {{- render_content(message['content'], 'tool message contents') -}}
+ {{- '[/TOOL_RESULTS]' }}
{#- System messages. #}
{%- elif message['role'] == 'system' %}
diff --git a/tests/data/chat_templates/v15_audio.jinja b/tests/data/chat_templates/v15_audio.jinja
index 68ce5346..f3c90f1c 100644
--- a/tests/data/chat_templates/v15_audio.jinja
+++ b/tests/data/chat_templates/v15_audio.jinja
@@ -177,14 +177,16 @@
{{- eos_token }}
- {#- Tool messages only supports text content. #}
+ {#- Tool messages (multimodal). #}
{%- elif message['role'] == 'tool' %}
- {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}
+ {{- '[TOOL_RESULTS]' -}}
+ {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and audio', support_audio=true) -}}
+ {{- '[/TOOL_RESULTS]' }}
{#- System messages. #}
{%- elif message['role'] == 'system' %}
{{- '[SYSTEM_PROMPT]' -}}
- {{- render_content(message['content'], 'system message contents', supported_types_desc='text', support_audio=false) -}}
+ {{- render_content(message['content'], 'system message contents', supported_types_desc='text and audio', support_audio=true) -}}
{{- '[/SYSTEM_PROMPT]' -}}
{#- Raise exception for unsupported roles. #}
diff --git a/tests/data/chat_templates/v15_image.jinja b/tests/data/chat_templates/v15_image.jinja
index bb338233..bfbfdd17 100644
--- a/tests/data/chat_templates/v15_image.jinja
+++ b/tests/data/chat_templates/v15_image.jinja
@@ -188,9 +188,11 @@
{{- eos_token }}
- {#- Tool messages only supports text content. #}
+ {#- Tool messages (multimodal). #}
{%- elif message['role'] == 'tool' %}
- {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}
+ {{- '[TOOL_RESULTS]' -}}
+ {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and image', support_images=true) -}}
+ {{- '[/TOOL_RESULTS]' }}
{#- System messages. #}
{%- elif message['role'] == 'system' %}
diff --git a/tests/data/chat_templates/v15_image_think.jinja b/tests/data/chat_templates/v15_image_think.jinja
index 63c61f6b..e1f26e8f 100644
--- a/tests/data/chat_templates/v15_image_think.jinja
+++ b/tests/data/chat_templates/v15_image_think.jinja
@@ -215,9 +215,11 @@
{{- eos_token }}
- {#- Tool messages only supports text content. #}
+ {#- Tool messages (multimodal). #}
{%- elif message['role'] == 'tool' %}
- {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}
+ {{- '[TOOL_RESULTS]' -}}
+ {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and image', support_images=true) -}}
+ {{- '[/TOOL_RESULTS]' }}
{#- System messages. #}
{%- elif message['role'] == 'system' %}
diff --git a/tests/data/chat_templates/v15_think.jinja b/tests/data/chat_templates/v15_think.jinja
index 7df9072a..40b8a7fc 100644
--- a/tests/data/chat_templates/v15_think.jinja
+++ b/tests/data/chat_templates/v15_think.jinja
@@ -207,9 +207,11 @@
{{- eos_token }}
- {#- Tool messages only supports text content. #}
+ {#- Tool messages (multimodal). #}
{%- elif message['role'] == 'tool' %}
- {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}
+ {{- '[TOOL_RESULTS]' -}}
+ {{- render_content(message['content'], 'tool message contents') -}}
+ {{- '[/TOOL_RESULTS]' }}
{#- System messages. #}
{%- elif message['role'] == 'system' %}
diff --git a/tests/fixtures/chunks.py b/tests/fixtures/chunks.py
new file mode 100644
index 00000000..289765f5
--- /dev/null
+++ b/tests/fixtures/chunks.py
@@ -0,0 +1,42 @@
+from typing import Any
+
+from PIL import Image
+
+from mistral_common.protocol.instruct.chunk import (
+ ContentChunk,
+ ImageChunk,
+ ImageURLChunk,
+ TextChunk,
+ ThinkChunk,
+)
+from tests.fixtures.audio import get_dummy_audio_chunk, get_dummy_audio_url_chunk
+
+
+def get_content_chunk(name: str) -> ContentChunk:
+ r"""Return a single instance of the requested content chunk type.
+
+ Args:
+ name: One of "text", "image", "image_url", "audio", "audio_url" or "think".
+
+ Returns:
+ A content chunk of the requested type.
+ """
+ chunks: dict[str, ContentChunk] = {
+ "text": TextChunk(text="hello"),
+ "image": ImageChunk(image=Image.new("RGB", (4, 4), "red")),
+ "image_url": ImageURLChunk(image_url="data:image/png;base64,iVBORw0"),
+ "audio": get_dummy_audio_chunk(),
+ "audio_url": get_dummy_audio_url_chunk(),
+ "think": ThinkChunk(thinking="reasoning"),
+ }
+ return chunks[name]
+
+
+def get_content_chunks(names: tuple[str, ...]) -> list[Any]:
+ r"""Return a list of content chunks for the requested type names.
+
+ The element type is intentionally `Any` so the chunks can be fed to any message role
+ constructor (including deliberately invalid combinations) in tests, leaving the runtime
+ Pydantic validation to enforce the actual rules.
+ """
+ return [get_content_chunk(name) for name in names]
diff --git a/tests/integrations/chat_templates/fixtures_data.py b/tests/integrations/chat_templates/fixtures_data.py
index 6f0e6e55..b396f457 100644
--- a/tests/integrations/chat_templates/fixtures_data.py
+++ b/tests/integrations/chat_templates/fixtures_data.py
@@ -952,6 +952,75 @@
]
)
+# -- Multimodal content in non-user messages (v15+) --
+
+REQUEST_TOOL_IMAGE_TRAIN = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ UserMessage(content="What is in this image?"),
+ AssistantMessage(
+ content=None,
+ tool_calls=[
+ ToolCall(
+ id="tl1mg2345",
+ function=FunctionCall(
+ name="tool1",
+ arguments={"location": "San Francisco, CA"}, # type: ignore[arg-type]
+ ),
+ ),
+ ],
+ ),
+ ToolMessage(
+ content=[
+ TextChunk(text="Here is the result."),
+ ImageURLChunk(image_url=_IMAGE_URL),
+ ],
+ tool_call_id="tl1mg2345",
+ ),
+ AssistantMessage(content="The tool returned an image of a red square."),
+ ],
+ tools=_TOOLS,
+)
+
+REQUEST_TOOL_AUDIO_TRAIN = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ UserMessage(content="What does this sound like?"),
+ AssistantMessage(
+ content=None,
+ tool_calls=[
+ ToolCall(
+ id="tl1mg2345",
+ function=FunctionCall(
+ name="tool1",
+ arguments={"location": "San Francisco, CA"}, # type: ignore[arg-type]
+ ),
+ ),
+ ],
+ ),
+ ToolMessage(
+ content=[
+ TextChunk(text="Here is the audio result."),
+ AudioChunk(input_audio=_AUDIO),
+ ],
+ tool_call_id="tl1mg2345",
+ ),
+ AssistantMessage(content="The tool returned audio of a sine wave."),
+ ],
+ tools=_TOOLS,
+)
+
+REQUEST_SYSTEM_AUDIO_TRAIN = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ SystemMessage(
+ content=[
+ TextChunk(text="You are an audio assistant."),
+ AudioChunk(input_audio=_AUDIO),
+ ],
+ ),
+ UserMessage(content="What was that sound?"),
+ AssistantMessage(content="That was a sine wave tone."),
+ ],
+)
+
def _get_conversations(
tokenizer_version: TokenizerVersion,
@@ -1102,6 +1171,18 @@ def _get_conversations(
]
)
+ # v15+ only: multimodal content in non-user messages (finetuning only)
+ if tokenizer_version >= TokenizerVersion.v15 and validation_mode == ValidationMode.finetuning:
+ if image:
+ conversations.extend(
+ [
+ REQUEST_TOOL_IMAGE_TRAIN,
+ ]
+ )
+ if audio:
+ conversations.append(REQUEST_SYSTEM_AUDIO_TRAIN)
+ conversations.append(REQUEST_TOOL_AUDIO_TRAIN)
+
conversations = [c.model_copy(deep=True) for c in conversations]
if think and tokenizer_version >= TokenizerVersion.v15:
diff --git a/tests/integrations/chat_templates/hf_utils.py b/tests/integrations/chat_templates/hf_utils.py
index 9e8133d2..6d071b3d 100644
--- a/tests/integrations/chat_templates/hf_utils.py
+++ b/tests/integrations/chat_templates/hf_utils.py
@@ -1,7 +1,7 @@
r"""HuggingFace transformers/tokenizers utilities for chat template tests.
-This module contains helpers that depend on ``transformers`` and
-``tokenizers``. General-purpose helpers live in ``helpers.py``.
+This module contains helpers that depend on `transformers` and
+`tokenizers`. General-purpose helpers live in `helpers.py`.
"""
import json
diff --git a/tests/integrations/chat_templates/test_parity.py b/tests/integrations/chat_templates/test_parity.py
index 5394d5cb..56d79c5e 100644
--- a/tests/integrations/chat_templates/test_parity.py
+++ b/tests/integrations/chat_templates/test_parity.py
@@ -3,8 +3,11 @@
import pytest
from mistral_common.integrations.chat_templates.template_generator import build_chat_template
+from mistral_common.protocol.instruct.request import ChatCompletionRequest
+from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import TokenizerVersion
from tests.integrations.chat_templates.conftest import ALL_CONFIGS, _config_id
+from tests.integrations.chat_templates.fixtures_data import _get_conversations
from tests.integrations.chat_templates.helpers import (
TestConfig,
_load_golden_template,
@@ -13,6 +16,42 @@
)
+def _request_to_render_args(request: ChatCompletionRequest) -> dict[str, Any]:
+ r"""Convert a ChatCompletionRequest to render_template kwargs.
+
+ Enriches tool messages with the `name` field resolved from the preceding
+ assistant's `tool_calls`, which the v2 Jinja template requires but
+ `ToolMessage.to_openai()` omits.
+ """
+ openai = request.to_openai()
+ messages = openai["messages"]
+
+ # Build a mapping from tool_call_id -> function name so tool messages
+ # can carry the `name` field expected by v2 templates.
+ tool_call_names: dict[str, str] = {}
+ for msg in messages:
+ for tc in msg.get("tool_calls", []):
+ tc_id = tc.get("id")
+ fn_name = tc.get("function", {}).get("name")
+ if tc_id and fn_name:
+ tool_call_names[tc_id] = fn_name
+
+ for msg in messages:
+ if msg.get("role") == "tool" and "name" not in msg:
+ tc_id = msg.get("tool_call_id")
+ if tc_id and tc_id in tool_call_names:
+ msg["name"] = tool_call_names[tc_id]
+
+ kwargs: dict[str, Any] = {
+ "messages": messages,
+ }
+ if "tools" in openai and openai["tools"]:
+ kwargs["tools"] = openai["tools"]
+ if (reasoning := openai.get("reasoning_effort")) is not None:
+ kwargs["reasoning_effort"] = reasoning
+ return kwargs
+
+
@pytest.mark.parametrize(
"config",
[c for c in ALL_CONFIGS if not c.plain_think],
@@ -64,399 +103,14 @@ def test_dynamic_template_comprehensive(config: TestConfig) -> None:
static_template = _load_golden_template(template_config)
dynamic_template = build_chat_template(template_config)
- # Test cases
- test_cases = [
- # Simple one-turn
- {
- "name": "one_turn",
- "messages": [
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi there!"},
- ],
- },
- # With system message
- {
- "name": "with_system",
- "messages": [
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi there!"},
- ],
- },
- # Multi-turn
- {
- "name": "multi_turn",
- "messages": [
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi there!"},
- {"role": "user", "content": "How are you?"},
- {"role": "assistant", "content": "I'm doing well!"},
- ],
- },
- # Content as list of chunks
- {
- "name": "content_chunks",
- "messages": [
- {"role": "user", "content": [{"type": "text", "text": "Hello"}]},
- {"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]},
- ],
- },
- ]
-
- # Tool call scenarios (v2+ only)
- if config.version > TokenizerVersion.v1:
- test_cases.extend(
- [
- {
- "name": "with_tools_definition",
- "messages": [
- {"role": "user", "content": "What's the weather?"},
- {"role": "assistant", "content": "It's sunny."},
- ],
- "tools": [
- {
- "type": "function",
- "function": {
- "name": "get_weather",
- "description": "Get weather",
- "parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
- },
- }
- ],
- },
- {
- "name": "with_tool_call_and_result",
- "messages": [
- {"role": "user", "content": "What's the weather in Paris?"},
- {
- "role": "assistant",
- "content": "",
- "tool_calls": [
- {
- "id": "abc123def",
- "type": "function",
- "function": {"name": "get_weather", "arguments": '{"city": "Paris"}'},
- }
- ],
- },
- {"role": "tool", "content": '{"temp": 20}', "tool_call_id": "abc123def", "name": "get_weather"},
- {"role": "assistant", "content": "It's 20 degrees in Paris."},
- ],
- "tools": [
- {
- "type": "function",
- "function": {
- "name": "get_weather",
- "description": "Get weather",
- "parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
- },
- }
- ],
- },
- ]
- )
-
- # Add image test case if image support
- if config.image:
- test_cases.append(
- {
- "name": "with_image",
- "messages": [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "What is this?"},
- {"type": "image_url", "image_url": "http://example.com/image.png"},
- ],
- },
- {"role": "assistant", "content": "It's an image."},
- ],
- }
- )
-
- # Add audio test case if audio support
- if config.audio:
- test_cases.append(
- {
- "name": "with_audio",
- "messages": [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "What is this?"},
- {"type": "audio_url", "audio_url": "http://example.com/audio.mp3"},
- ],
- },
- {"role": "assistant", "content": "It's an audio file."},
- ],
- }
- )
-
- # Add thinking test case if thinking support
- if config.think:
- test_cases.append(
- {
- "name": "with_thinking",
- "messages": [
- {"role": "user", "content": "Solve this problem"},
- {
- "role": "assistant",
- "content": [
- {"type": "thinking", "thinking": "Let me think..."},
- {"type": "text", "text": "The answer is 42."},
- ],
- },
- ],
- }
- )
-
- # Add message aggregation test cases
- test_cases.extend(
- [
- {
- "name": "consecutive_users",
- "messages": [
- {"role": "user", "content": "Hello"},
- {"role": "user", "content": "World"},
- {"role": "assistant", "content": "Hi there"},
- ],
- },
- {
- "name": "consecutive_users_with_system",
- "messages": [
- {"role": "system", "content": "You are helpful."},
- {"role": "user", "content": "Hello"},
- {"role": "user", "content": "World"},
- {"role": "assistant", "content": "Hi there"},
- ],
- },
- {
- "name": "consecutive_assistants",
- "messages": [
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi"},
- {"role": "assistant", "content": "How can I help?"},
- {"role": "user", "content": "Thanks"},
- {"role": "assistant", "content": "Welcome"},
- ],
- },
- {
- "name": "multiple_systems",
- "messages": [
- {"role": "system", "content": "System 1."},
- {"role": "system", "content": "System 2."},
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi"},
- ],
- },
- {
- "name": "mid_conv_system",
- "messages": [
- {"role": "user", "content": "Hello"},
- {"role": "system", "content": "New instruction."},
- {"role": "assistant", "content": "Got it"},
- ],
- },
- {
- "name": "mid_conv_system_with_consecutive_users",
- "messages": [
- {"role": "system", "content": "Be helpful."},
- {"role": "user", "content": "Hello"},
- {"role": "user", "content": "World"},
- {"role": "system", "content": "Now be concise."},
- {"role": "assistant", "content": "Got it"},
- ],
- },
- ]
- )
-
- # Multi-chunk aggregation test cases
- test_cases.extend(
- [
- {
- "name": "consecutive_users_text_chunks",
- "messages": [
- {"role": "user", "content": "First as string"},
- {"role": "user", "content": [{"type": "text", "text": "Second as chunk"}]},
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Third part A"},
- {"type": "text", "text": "Third part B"},
- ],
- },
- {"role": "assistant", "content": "Response"},
- ],
- },
- {
- "name": "system_text_chunks",
- "messages": [
- {
- "role": "system",
- "content": [
- {"type": "text", "text": "You are helpful."},
- {"type": "text", "text": "Be concise."},
- ],
- },
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi"},
- ],
- },
- ]
- )
-
- if config.image:
- test_cases.extend(
- [
- {
- "name": "consecutive_users_with_image",
- "messages": [
- {"role": "user", "content": "What is this?"},
- {
- "role": "user",
- "content": [
- {"type": "image_url", "image_url": "http://example.com/image.png"},
- {"type": "text", "text": "Describe it"},
- ],
- },
- {"role": "assistant", "content": "It's an image."},
- ],
- },
- {
- "name": "consecutive_users_multi_image",
- "messages": [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Describe this"},
- {"type": "image_url", "image_url": "http://example.com/a.png"},
- {"type": "text", "text": "What color?"},
- ],
- },
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Also this"},
- {"type": "image_url", "image_url": "http://example.com/b.png"},
- {"type": "text", "text": "What shape?"},
- ],
- },
- {"role": "assistant", "content": "Both are red squares."},
- ],
- },
- ]
- )
-
- if config.audio:
- test_cases.append(
- {
- "name": "consecutive_users_multi_audio",
- "messages": [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Listen"},
- {"type": "audio_url", "audio_url": "http://example.com/a.wav"},
- {"type": "text", "text": "What language?"},
- ],
- },
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "And this"},
- {"type": "audio_url", "audio_url": "http://example.com/b.wav"},
- {"type": "text", "text": "Transcribe it"},
- ],
- },
- {"role": "assistant", "content": "Both are in English."},
- ],
- }
- )
-
- if config.think:
- test_cases.extend(
- [
- {
- "name": "consecutive_assistants_think",
- "messages": [
- {"role": "user", "content": "Solve this"},
- {
- "role": "assistant",
- "content": [
- {"type": "text", "text": "Hmm."},
- {"type": "thinking", "thinking": "Let me think..."},
- {"type": "text", "text": "I need more context."},
- ],
- },
- {
- "role": "assistant",
- "content": [
- {"type": "text", "text": "OK."},
- {"type": "thinking", "thinking": "Now I understand."},
- {"type": "text", "text": "The answer is 42."},
- ],
- },
- {"role": "user", "content": "Thanks"},
- {"role": "assistant", "content": "You're welcome"},
- ],
- },
- {
- "name": "consecutive_systems_think",
- "messages": [
- {
- "role": "system",
- "content": [
- {"type": "text", "text": "Rule A"},
- {"type": "text", "text": "Rule B"},
- {"type": "thinking", "thinking": "Think 1"},
- ],
- },
- {
- "role": "system",
- "content": [
- {"type": "thinking", "thinking": "Think 2"},
- {"type": "text", "text": "Rule C"},
- {"type": "text", "text": "Rule D"},
- ],
- },
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hi"},
- ],
- },
- ]
- )
-
- skip_names_image = {"with_image", "consecutive_users_with_image", "consecutive_users_multi_image"}
- skip_names_audio = {"with_audio", "consecutive_users_multi_audio"}
- skip_names_think = {"with_thinking", "consecutive_assistants_think", "consecutive_systems_think"}
- # ThinkChunks in system messages are only supported in v13 (not v15+)
- skip_names_think_system = {"consecutive_systems_think"}
- skip_names_tools = {"with_tools_definition", "with_tool_call_and_result"}
-
- # Not using parametrize here because test_cases are built dynamically based on the
- # config (version/image/audio/think) from the outer parametrize. Each sub-case is
- # identifiable via test_name in the assertion failure message.
- for test_case in test_cases:
- test_name = test_case["name"]
- messages = test_case["messages"]
- tools = test_case.get("tools")
-
- if test_name in skip_names_image and not config.image:
- continue
- if test_name in skip_names_audio and not config.audio:
- continue
- if test_name in skip_names_think and not config.think:
- continue
- if test_name in skip_names_think_system and config.version >= TokenizerVersion.v15:
- continue
- if test_name in skip_names_tools and config.version <= TokenizerVersion.v1:
- continue
-
- static_output = render_template(static_template, messages, tools=tools) # type: ignore
- dynamic_output = render_template(dynamic_template, messages, tools=tools) # type: ignore
-
- assert static_output == dynamic_output, (
- f"Output mismatch for {config}, case={test_name}\n\n"
- f"Static output: {static_output}\n"
- f"Dynamic output: {dynamic_output}"
- )
+ for mode in (ValidationMode.finetuning, ValidationMode.test):
+ conversations = _get_conversations(config.version, mode, config.image, config.audio, config.think)
+ for idx, request in enumerate(conversations):
+ render_args = _request_to_render_args(request)
+ static_output = render_template(static_template, **render_args)
+ dynamic_output = render_template(dynamic_template, **render_args)
+ assert static_output == dynamic_output, (
+ f"Output mismatch for {config}, mode={mode}, conversation={idx}\n\n"
+ f"Static output: {static_output}\n"
+ f"Dynamic output: {dynamic_output}"
+ )
diff --git a/tests/integrations/chat_templates/transformers/test_core_parity.py b/tests/integrations/chat_templates/transformers/test_core_parity.py
index bf5eb23d..c579e53c 100644
--- a/tests/integrations/chat_templates/transformers/test_core_parity.py
+++ b/tests/integrations/chat_templates/transformers/test_core_parity.py
@@ -205,10 +205,16 @@ def test_invalid_chunks(
# Not using parametrize here because invalid_convs depends on the version/image/audio/think
# parameters from the outer parametrize. Each sub-case is identifiable via the TemplateError
# match string which includes the role and allowed chunks.
+ is_v15_plus = config.version >= TokenizerVersion.v15
for conv in invalid_convs:
msg_template = "Only {chunks} chunks are supported in {role} message content."
if conv in sp_invalids:
- chunks = "text and thinking" if config.think and config.version < TokenizerVersion.v15 else "text"
+ if config.think and not is_v15_plus:
+ chunks = "text and thinking"
+ elif config.audio and is_v15_plus:
+ chunks = "text and audio"
+ else:
+ chunks = "text"
role = "system"
elif conv in user_invalids:
chunks = "text"
@@ -218,7 +224,13 @@ def test_invalid_chunks(
chunks += ", input_audio and audio_url"
role = "user"
elif conv in assistant_invalids:
- chunks = "text and thinking" if config.think else "text"
+ desc_parts = ["text"]
+ if config.think:
+ desc_parts.append("thinking")
+ if len(desc_parts) == 1:
+ chunks = "text"
+ else:
+ chunks = ", ".join(desc_parts[:-1]) + " and " + desc_parts[-1]
role = "assistant"
err_msg = msg_template.format(chunks=chunks, role=role)
diff --git a/tests/integrations/chat_templates/unit/test_v15.py b/tests/integrations/chat_templates/unit/test_v15.py
index 1364e5f1..39c4705a 100644
--- a/tests/integrations/chat_templates/unit/test_v15.py
+++ b/tests/integrations/chat_templates/unit/test_v15.py
@@ -246,3 +246,228 @@ def test_v15_think_template_rejects_think_in_system(self, image: bool) -> None:
with pytest.raises(ValueError, match="Only text chunks are supported in system message contents"):
render_template(template, messages)
+
+
+class TestV15MultimodalContent:
+ def test_v15_tool_message_with_image_content(self) -> None:
+ r"""V15 image template renders tool message with image content using render_content."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v15,
+ image_support=True,
+ audio_support=False,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {"role": "user", "content": "Use tool"},
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [{"id": "test12345", "function": {"name": "fn", "arguments": "{}"}}],
+ },
+ {
+ "role": "tool",
+ "content": [
+ {"type": "text", "text": "result"},
+ {"type": "image_url", "image_url": "http://example.com/img.png"},
+ ],
+ "tool_call_id": "test12345",
+ },
+ {"role": "assistant", "content": "Done"},
+ ]
+
+ tools = [{"type": "function", "function": {"name": "fn", "description": "test", "parameters": {}}}]
+ output = render_template(template, messages, tools=tools, reasoning_effort="none")
+ assert "[TOOL_RESULTS]" in output
+ assert "[IMG]" in output
+ assert "result" in output
+
+ def test_v15_tool_message_with_audio_content(self) -> None:
+ r"""V15 audio template renders tool message with audio content."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v15,
+ image_support=False,
+ audio_support=True,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {"role": "user", "content": "Use tool"},
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [{"id": "test12345", "function": {"name": "fn", "arguments": "{}"}}],
+ },
+ {
+ "role": "tool",
+ "content": [
+ {"type": "text", "text": "result"},
+ {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
+ ],
+ "tool_call_id": "test12345",
+ },
+ {"role": "assistant", "content": "Done"},
+ ]
+
+ tools = [{"type": "function", "function": {"name": "fn", "description": "test", "parameters": {}}}]
+ output = render_template(template, messages, tools=tools, reasoning_effort="none")
+ assert "[TOOL_RESULTS]" in output
+ assert "[AUDIO]" in output
+ assert "result" in output
+
+ def test_v15_system_message_with_audio_content(self) -> None:
+ r"""V15 audio template renders system message with audio content."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v15,
+ image_support=False,
+ audio_support=True,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "Listen to context"},
+ {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
+ ],
+ },
+ {"role": "user", "content": "Summarize"},
+ {"role": "assistant", "content": "Done"},
+ ]
+
+ output = render_template(template, messages, reasoning_effort="none")
+ assert "[SYSTEM_PROMPT]" in output
+ assert "Listen to context" in output
+ assert "[AUDIO]" in output
+
+ def test_pre_v15_image_template_rejects_image_in_assistant(self) -> None:
+ r"""Pre-V15 image template rejects image chunks in assistant message."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v13,
+ image_support=True,
+ audio_support=False,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {"role": "user", "content": "Show me"},
+ {
+ "role": "assistant",
+ "content": [
+ {"type": "text", "text": "Here"},
+ {"type": "image_url", "image_url": "http://example.com/img.png"},
+ ],
+ },
+ ]
+
+ with pytest.raises(ValueError, match="Only text chunks are supported in assistant message contents"):
+ render_template(template, messages)
+
+ def test_pre_v15_audio_template_rejects_audio_in_assistant(self) -> None:
+ r"""Pre-V15 audio template rejects audio chunks in assistant message."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v13,
+ image_support=False,
+ audio_support=True,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {"role": "user", "content": "Listen"},
+ {
+ "role": "assistant",
+ "content": [
+ {"type": "text", "text": "Here"},
+ {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
+ ],
+ },
+ ]
+
+ with pytest.raises(ValueError, match="Only text chunks are supported in assistant message contents"):
+ render_template(template, messages)
+
+ def test_pre_v15_audio_template_rejects_audio_in_system(self) -> None:
+ r"""Pre-V15 audio template rejects audio chunks in system message."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v13,
+ image_support=False,
+ audio_support=True,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "Context"},
+ {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}},
+ ],
+ },
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi"},
+ ]
+
+ with pytest.raises(ValueError, match="Only text chunks are supported in system message contents"):
+ render_template(template, messages)
+
+ def test_pre_v15_image_template_rejects_image_in_tool(self) -> None:
+ r"""Pre-V15 image template rejects image chunks in tool message content."""
+ template = generate_chat_template(
+ spm=False,
+ tokenizer_version=TokenizerVersion.v13,
+ image_support=True,
+ audio_support=False,
+ thinking_support=False,
+ default_system_prompt=None,
+ plain_thinking_support=False,
+ use_special_token_variables=True,
+ )
+
+ messages: list[dict[str, Any]] = [
+ {"role": "user", "content": "Use tool"},
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [{"id": "test12345", "function": {"name": "fn", "arguments": "{}"}}],
+ },
+ {
+ "role": "tool",
+ "content": [
+ {"type": "text", "text": "result"},
+ {"type": "image_url", "image_url": "http://example.com/img.png"},
+ ],
+ "tool_call_id": "test12345",
+ },
+ {"role": "assistant", "content": "Done"},
+ ]
+
+ tools = [{"type": "function", "function": {"name": "fn", "description": "test", "parameters": {}}}]
+ # Pre-V15 uses message['content']|string which coerces list to string representation
+ # rather than rendering through render_content — no [IMG] token produced
+ output = render_template(template, messages, tools=tools)
+ assert "[IMG]" not in output
diff --git a/tests/test_converters.py b/tests/test_converters.py
index d093d945..19b69a22 100644
--- a/tests/test_converters.py
+++ b/tests/test_converters.py
@@ -37,6 +37,7 @@
AudioChunk,
AudioURL,
AudioURLChunk,
+ ContentChunk,
ImageChunk,
ImageURL,
ImageURLChunk,
@@ -586,7 +587,7 @@ def test_non_leading_think_chunks_construction_ok() -> None:
[TextChunk(text="A"), TextChunk(text="B"), ThinkChunk(thinking="End", closed=True)],
],
)
-def test_non_leading_think_chunks_to_openai_raises(content: list[TextChunk | ThinkChunk]) -> None:
+def test_non_leading_think_chunks_to_openai_raises(content: list[ContentChunk]) -> None:
"""to_openai raises when ThinkChunks are not leading."""
msg = AssistantMessage(content=content)
with pytest.raises(InvalidAssistantMessageException, match="ThinkChunks must be leading"):
diff --git a/tests/test_messages.py b/tests/test_messages.py
new file mode 100644
index 00000000..44aba528
--- /dev/null
+++ b/tests/test_messages.py
@@ -0,0 +1,44 @@
+import pytest
+from pydantic import ValidationError
+
+from mistral_common.protocol.instruct.messages import (
+ AssistantMessage,
+ SystemMessage,
+ ToolMessage,
+ UserMessage,
+)
+from tests.fixtures.chunks import get_content_chunks as _chunks
+
+
+class TestMessageContentChunkUnions:
+ r"""Pydantic-level (version-independent) content-chunk unions for each message role."""
+
+ def test_user_allows_text_image_audio(self) -> None:
+ UserMessage(content=_chunks(("text", "image", "image_url", "audio", "audio_url")))
+
+ def test_user_rejects_think(self) -> None:
+ for name in ("think",):
+ with pytest.raises(ValidationError):
+ UserMessage(content=_chunks((name,)))
+
+ def test_assistant_allows_text_and_think(self) -> None:
+ AssistantMessage(content=_chunks(("text", "think")))
+
+ def test_assistant_rejects_image_and_audio(self) -> None:
+ for name in ("image", "image_url", "audio", "audio_url"):
+ with pytest.raises(ValidationError):
+ AssistantMessage(content=_chunks((name,)))
+
+ def test_system_allows_text_audio_think(self) -> None:
+ SystemMessage(content=_chunks(("text", "audio", "think")))
+
+ def test_system_rejects_image_and_audio_url(self) -> None:
+ for name in ("image", "image_url", "audio_url"):
+ with pytest.raises(ValidationError):
+ SystemMessage(content=_chunks((name,)))
+
+ def test_tool_allows_all_chunk_types(self) -> None:
+ ToolMessage(
+ content=_chunks(("text", "image", "image_url", "audio", "audio_url", "think")),
+ tool_call_id="c1",
+ )
diff --git a/tests/test_normalization.py b/tests/test_normalization.py
index 1ff09546..70145f16 100644
--- a/tests/test_normalization.py
+++ b/tests/test_normalization.py
@@ -3,8 +3,8 @@
import pytest
from mistral_common.protocol.instruct.chunk import (
+ AudioChunk,
ChunkTypes,
- ContentChunk,
ImageURLChunk,
TextChunk,
ThinkChunk,
@@ -62,7 +62,10 @@ def test_user_system_user(self, normalizer: InstructRequestNormalizer) -> None:
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- assert parsed_request.system_prompt == "S"
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="U"), UserMessage(content="U")],
+ system_prompt="S",
+ )
def test_multiple_system(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -74,7 +77,10 @@ def test_multiple_system(self, normalizer: InstructRequestNormalizer) -> None:
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- assert parsed_request.system_prompt == "S\n\nS\n\nS"
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="")],
+ system_prompt="S\n\nS\n\nS",
+ )
def test_single_system(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -84,8 +90,10 @@ def test_single_system(self, normalizer: InstructRequestNormalizer) -> None:
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
-
- assert parsed_request.system_prompt == "S"
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="")],
+ system_prompt="S",
+ )
def test_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -97,11 +105,10 @@ def test_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> N
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
-
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == ""
- assert parsed_request.system_prompt == "S"
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content=""), AssistantMessage(content="A"), UserMessage(content="U")],
+ system_prompt="S",
+ )
def test_assistant_content_with_tool_calls(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -125,110 +132,89 @@ def test_assistant_system_user_adds_user(self, normalizer: InstructRequestNormal
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content=""), AssistantMessage(content="A"), UserMessage(content="U")],
+ system_prompt="S",
+ )
- assert parsed_request.system_prompt == "S"
-
- assert len(parsed_request.messages) == 3 # 1 user message added, system message removed
-
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == ""
- assert parsed_request.system_prompt == "S"
+ def test_message_aggregation_system_then_user(self, normalizer: InstructRequestNormalizer) -> None:
+ parsed = normalizer.from_chat_completion_request(
+ mock_chat_completion(
+ [
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ UserMessage(content="u"),
+ ]
+ )
+ )
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="u")], system_prompt="s\n\ns\n\ns"
+ )
- def check_merge(
- self,
- roles: list[str],
- expected_roles: list[str],
- expected_content: list[list[ContentChunk] | str],
- normalizer: InstructRequestNormalizer,
- ) -> None:
- letter_to_cls: dict[str, ChatMessage] = {
- "s": SystemMessage(content="s"),
- "u": UserMessage(content="u"),
- "a": AssistantMessage(content="a"),
- "u2": UserMessage(content="u2"),
- }
+ def test_message_aggregation_system_then_users(self, normalizer: InstructRequestNormalizer) -> None:
+ parsed = normalizer.from_chat_completion_request(
+ mock_chat_completion(
+ [
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ UserMessage(content="u"),
+ UserMessage(content="u"),
+ ]
+ )
+ )
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="u\n\nu")], system_prompt="s\n\ns\n\ns"
+ )
- chat_completion_request = mock_chat_completion(
- messages=[letter_to_cls[r] for r in roles],
+ def test_message_aggregation_mixed_with_middle_system(self, normalizer: InstructRequestNormalizer) -> None:
+ parsed = normalizer.from_chat_completion_request(
+ mock_chat_completion(
+ [
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ UserMessage(content="u"),
+ UserMessage(content="u"),
+ SystemMessage(content="s"),
+ AssistantMessage(content="a"),
+ UserMessage(content="u"),
+ ]
+ )
+ )
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="u\n\nu"), AssistantMessage(content="a"), UserMessage(content="u")],
+ system_prompt="s\n\ns\n\ns\n\ns",
)
- parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- assert len(parsed_request.messages) == len(expected_roles)
- assert [message.role for message in parsed_request.messages] == [
- letter_to_cls[role].role for role in expected_roles
- ]
- assert len(expected_content) == len(parsed_request.messages)
- for x, expected in zip(parsed_request.messages, expected_content):
- assert isinstance(x, (UserMessage, AssistantMessage))
- assert x.content == expected
- def check_merge_chunks(
- self,
- roles: list[str],
- expected_roles: list[str],
- expected_content: list[list[ContentChunk] | str],
- normalizer: InstructRequestNormalizer,
- ) -> None:
- letter_to_cls: dict[str, ChatMessage] = {
- "s": SystemMessage(content="s"),
- "u": UserMessage(content="u"),
- "a": AssistantMessage(content="a"),
- "a2": AssistantMessage(
- content=[
- ThinkChunk(thinking="t1"),
- ThinkChunk(thinking="t2"),
- TextChunk(text="a1"),
- TextChunk(text="a2"),
- TextChunk(text="a3"),
+ def test_message_aggregation_consecutive_assistants(self, normalizer: InstructRequestNormalizer) -> None:
+ parsed = normalizer.from_chat_completion_request(
+ mock_chat_completion(
+ [
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ SystemMessage(content="s"),
+ UserMessage(content="u"),
+ UserMessage(content="u"),
+ AssistantMessage(content="a"),
+ AssistantMessage(content="a"),
+ UserMessage(content="u"),
]
- ),
- }
+ )
+ )
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="u\n\nu"), AssistantMessage(content="a\n\na"), UserMessage(content="u")],
+ system_prompt="s\n\ns\n\ns",
+ )
- chat_completion_request = mock_chat_completion(
- messages=[letter_to_cls[r] for r in roles],
+ def test_message_aggregation_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> None:
+ parsed = normalizer.from_chat_completion_request(
+ mock_chat_completion([SystemMessage(content="s"), AssistantMessage(content="a"), UserMessage(content="u")])
)
- parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- assert len(parsed_request.messages) == len(expected_roles)
- assert [message.role for message in parsed_request.messages] == [
- letter_to_cls[role].role for role in expected_roles
- ]
- assert len(expected_content) == len(parsed_request.messages)
- for x, expected in zip(parsed_request.messages, expected_content):
- assert isinstance(x, (UserMessage, AssistantMessage))
- assert x.content == expected
-
- def test_message_aggregation(self, normalizer: InstructRequestNormalizer) -> None:
- self.check_merge(["s", "s", "s", "u"], ["u"], ["u"], normalizer)
- self.check_merge(["s", "s", "s", "u", "u"], ["u"], ["u\n\nu"], normalizer)
- self.check_merge(["s", "s", "s", "u", "u", "s", "a", "u"], ["u", "a", "u"], ["u\n\nu", "a", "u"], normalizer)
-
- self.check_merge(
- ["s", "s", "s", "u", "u", "a", "a", "u"],
- ["u", "a", "u"],
- ["u\n\nu", "a\n\na", "u"],
- normalizer,
- )
-
- self.check_merge(
- ["s", "a", "u"],
- ["u", "a", "u"],
- ["", "a", "u"],
- normalizer,
- )
-
- self.check_merge_chunks(
- ["u", "a2", "u"],
- ["u", "a", "u"],
- [
- "u",
- [
- ThinkChunk(thinking="t1"),
- ThinkChunk(thinking="t2"),
- TextChunk(text="a1\n\na2\n\na3"),
- ],
- "u",
- ],
- normalizer,
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content=""), AssistantMessage(content="a"), UserMessage(content="u")],
+ system_prompt="s",
)
def test_tool_chunk_aggregation(self, normalizer: InstructRequestNormalizer) -> None:
@@ -266,10 +252,9 @@ def test_normalize_chunks(self, normalizer: InstructRequestNormalizer) -> None:
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
-
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == "foo\n\nchunk\n\nfoo\n\nchunk"
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="foo\n\nchunk\n\nfoo\n\nchunk")]
+ )
def test_many_chunks_in_user_message(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -285,9 +270,9 @@ def test_many_chunks_in_user_message(self, normalizer: InstructRequestNormalizer
],
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == "foo\n\nchunk1\n\nchunk2\n\nchunk3"
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="foo\n\nchunk1\n\nchunk2\n\nchunk3")]
+ )
def test_ignore_middle_empty_text_chunks(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = mock_chat_completion(
@@ -309,12 +294,9 @@ def test_ignore_middle_empty_text_chunks(self, normalizer: InstructRequestNormal
]
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == "U\n\nV"
- second_message = parsed_request.messages[1]
- assert isinstance(second_message, AssistantMessage)
- assert second_message.content == "A\n\nB"
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="U\n\nV"), AssistantMessage(content="A\n\nB")],
+ )
def test_safety_prompt_aggregated(self, normalizer: InstructRequestNormalizer) -> None:
chat_completion_request = ChatCompletionRequest[ChatMessage](
@@ -328,10 +310,10 @@ def test_safety_prompt_aggregated(self, normalizer: InstructRequestNormalizer) -
)
parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == "user"
- assert parsed_request.system_prompt == "system"
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="user")],
+ system_prompt="system",
+ )
def test_normalize_tools(self, normalizer: InstructRequestNormalizer) -> None:
"""
@@ -410,18 +392,41 @@ def test_assert_parsed_settings(
normalizer: InstructRequestNormalizer,
) -> None:
chat_completion_request = ChatCompletionRequest(messages=[UserMessage(content="B")])
- parsed_request: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(
- chat_completion_request
- )
- assert parsed_request.settings == ModelSettings.none()
+ parsed_request = normalizer.from_chat_completion_request(chat_completion_request)
+ assert parsed_request == InstructRequest[ChatMessage, Tool](messages=[UserMessage(content="B")])
def test_continue_final_message_forwarded(self, normalizer: InstructRequestNormalizer) -> None:
request = ChatCompletionRequest[ChatMessage](
messages=[UserMessage(content="a"), AssistantMessage(content="b")],
continue_final_message=True,
)
- result: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request)
- assert result.continue_final_message is True
+ result = normalizer.from_chat_completion_request(request)
+ assert result == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="a"), AssistantMessage(content="b")],
+ continue_final_message=True,
+ )
+
+ def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalizer) -> None:
+ r"""Base normalizer (v1-v3) JSON-normalizes tool message content."""
+ messy_json = '{"key" : "value" , "num": 1}'
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
+ ToolMessage(content=messy_json, tool_call_id="c1"),
+ ],
+ )
+ parsed = normalizer.from_chat_completion_request(request)
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ content="",
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")],
+ ),
+ ToolMessage(content='{"key": "value", "num": 1}', tool_call_id="c1"),
+ ],
+ )
class TestChatCompletionRequestNormalizationV7:
@@ -440,19 +445,14 @@ def test_system_assistant_user_v7(self, normalizer_v7: InstructRequestNormalizer
]
)
- parsed_request: InstructRequest = normalizer_v7.from_chat_completion_request(chat_completion_request)
-
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, SystemMessage)
- assert first_message.content == "S"
-
- second_message = parsed_request.messages[1]
- assert isinstance(second_message, AssistantMessage)
- assert second_message.content == "A"
-
- assert parsed_request.system_prompt is None
+ parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
+ chat_completion_request
+ )
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[SystemMessage(content="S"), AssistantMessage(content="A"), UserMessage(content="U")],
+ )
- def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNormalizer) -> None:
+ def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
chat_completion_request = mock_chat_completion(
messages=[
AssistantMessage(content="A"),
@@ -460,21 +460,14 @@ def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNorma
]
)
- parsed_request = normalizer_v7.from_chat_completion_request(chat_completion_request)
-
- assert parsed_request.system_prompt is None
-
- assert len(parsed_request.messages) == 2
-
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, AssistantMessage)
- assert first_message.content == "A"
-
- second_message = parsed_request.messages[1]
- assert isinstance(second_message, SystemMessage)
- assert second_message.content == "S"
+ parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
+ chat_completion_request
+ )
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[AssistantMessage(content="A"), SystemMessage(content="S")],
+ )
- def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestNormalizer) -> None:
+ def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
chat_completion_request = mock_chat_completion(
messages=[
AssistantMessage(
@@ -483,28 +476,28 @@ def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestN
)
]
)
- normalized_chat_req = normalizer_v7.from_chat_completion_request(chat_completion_request)
-
- assert normalized_chat_req.messages[0].content == "A", normalized_chat_req.messages[0].content
- assert len(normalized_chat_req.messages[0].tool_calls) == 1, normalized_chat_req.messages[0].tool_calls
- assert normalized_chat_req.messages[0].tool_calls[0].function.name == "tool1", (
- normalized_chat_req.messages[0].tool_calls[0].function.name
+ normalized_chat_req: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
+ chat_completion_request
+ )
+ assert normalized_chat_req == InstructRequest[ChatMessage, Tool](
+ messages=[
+ AssistantMessage(
+ content="A",
+ tool_calls=[ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "11"}'))],
+ ),
+ ],
)
- def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructRequestNormalizer) -> None:
+ def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
chat_completion_request = mock_chat_completion(
messages=[
UserMessage(content="A1"),
- AssistantMessage(
- content="B1",
- ),
+ AssistantMessage(content="B1"),
AssistantMessage(
content="B2",
tool_calls=[ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "1"}'))],
),
- AssistantMessage(
- content="B3",
- ),
+ AssistantMessage(content="B3"),
AssistantMessage(
content="B4",
tool_calls=[
@@ -512,24 +505,27 @@ def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructReq
ToolCall(function=FunctionCall(name="tool22", arguments='{"input": "22"}')),
],
),
+ AssistantMessage(content="B5"),
+ UserMessage(content="C1"),
+ ]
+ )
+ normalized_chat_req: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
+ chat_completion_request
+ )
+ assert normalized_chat_req == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="A1"),
AssistantMessage(
- content="B5",
+ content="B1\n\nB2\n\nB3\n\nB4\n\nB5",
+ tool_calls=[
+ ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "1"}')),
+ ToolCall(function=FunctionCall(name="tool21", arguments='{"input": "21"}')),
+ ToolCall(function=FunctionCall(name="tool22", arguments='{"input": "22"}')),
+ ],
),
UserMessage(content="C1"),
- ]
+ ],
)
- normalized_chat_req = normalizer_v7.from_chat_completion_request(chat_completion_request)
-
- assert normalized_chat_req.messages[0].content == "A1"
- assert normalized_chat_req.messages[1].content.split("\n\n") == [f"B{i}" for i in range(1, 6)]
-
- tool_calls = normalized_chat_req.messages[1].tool_calls
-
- assert len(tool_calls) == 3
-
- tool_key = ["1", "21", "22"]
- assert all([t.function.name == f"tool{tool_key[i]}" for i, t in enumerate(tool_calls)])
- assert all([json.loads(t.function.arguments)["input"] == tool_key[i] for i, t in enumerate(tool_calls)])
def test_assert_parsed_settings(
self,
@@ -539,7 +535,7 @@ def test_assert_parsed_settings(
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.settings == ModelSettings.none()
+ assert parsed_request == InstructRequest[ChatMessage, Tool](messages=[UserMessage(content="B")])
def test_continue_final_message_forwarded(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
request = ChatCompletionRequest[ChatMessage](
@@ -547,7 +543,10 @@ def test_continue_final_message_forwarded(self, normalizer_v7: InstructRequestNo
continue_final_message=True,
)
result: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
- assert result.continue_final_message is True
+ assert result == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="a"), AssistantMessage(content="b")],
+ continue_final_message=True,
+ )
@pytest.mark.parametrize("num_empty", [0, 1, 2])
def test_only_empty_text_chunks(self, normalizer_v7: InstructRequestNormalizerV7, num_empty: int) -> None:
@@ -561,13 +560,9 @@ def test_only_empty_text_chunks(self, normalizer_v7: InstructRequestNormalizerV7
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
chat_completion_request
)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == ""
- second_message = parsed_request.messages[1]
- assert isinstance(second_message, AssistantMessage)
- # Empty string content is passed through directly
- assert second_message.content == ""
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content=""), AssistantMessage(content="")],
+ )
def test_complex_user_aggregation(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
"""Complex multi-user-message aggregation with mixed str, chunks, empty, and non-text chunks."""
@@ -594,13 +589,17 @@ def test_complex_user_aggregation(self, normalizer_v7: InstructRequestNormalizer
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
chat_completion_request
)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, UserMessage)
- assert first_message.content == [
- TextChunk(text="A\n\nB\n\nC\n\nD"),
- ImageURLChunk(image_url="E"),
- TextChunk(text="G\n\nH"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(
+ content=[
+ TextChunk(text="A\n\nB\n\nC\n\nD"),
+ ImageURLChunk(image_url="E"),
+ TextChunk(text="G\n\nH"),
+ ]
+ ),
+ ],
+ )
def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
"""Complex multi-assistant-message aggregation with mixed str, chunks, and empty content."""
@@ -613,7 +612,6 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma
AssistantMessage(content=[TextChunk(text="")]),
AssistantMessage(
content=[
- ThinkChunk(thinking="T"),
TextChunk(text="C"),
TextChunk(text="D"),
]
@@ -623,13 +621,62 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(
chat_completion_request
)
- first_message = parsed_request.messages[0]
- assert isinstance(first_message, AssistantMessage)
- assert first_message.content == [
- TextChunk(text="A\n\nB"),
- ThinkChunk(thinking="T"),
- TextChunk(text="C\n\nD"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[AssistantMessage(content="A\n\nB\n\nC\n\nD")],
+ )
+
+ def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
+ r"""V7 normalizer accepts string content in assistant messages."""
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content="plain text"),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="query"), AssistantMessage(content="plain text")],
+ )
+
+ def test_skips_json_normalization_on_tool_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
+ r"""V7+ normalizers do not JSON-normalize tool message content."""
+ messy_json = '{"key" : "value" , "num": 1}'
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
+ ToolMessage(content=messy_json, tool_call_id="c1"),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ content="",
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")],
+ ),
+ ToolMessage(content=messy_json, tool_call_id="c1"),
+ ],
+ )
+
+ def test_preserves_audio_in_system_message(self, normalizer_v7: InstructRequestNormalizerV7) -> None:
+ r"""V7 normalizer preserves AudioChunk in system messages."""
+ request = mock_chat_completion(
+ messages=[
+ SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]),
+ UserMessage(content="test"),
+ ]
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request)
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ SystemMessage(
+ content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")],
+ ),
+ UserMessage(content="test"),
+ ],
+ )
class TestFineTuningNormalizer:
@@ -720,18 +767,20 @@ def test_no_reorder_tool_messages(self, normalizer_v13: InstructRequestNormalize
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- UserMessage(content="A"),
- AssistantMessage(
- content="B",
- tool_calls=[
- ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
- ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
- ],
- ),
- ToolMessage(content="C", tool_call_id="1"),
- ToolMessage(content="D", tool_call_id="2"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="A"),
+ AssistantMessage(
+ content="B",
+ tool_calls=[
+ ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
+ ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
+ ],
+ ),
+ ToolMessage(content="C", tool_call_id="1"),
+ ToolMessage(content="D", tool_call_id="2"),
+ ],
+ )
def test_reorder_last_tool_messages(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
chat_completion_request: ChatCompletionRequest = self._mock_chat_completion(
@@ -751,18 +800,20 @@ def test_reorder_last_tool_messages(self, normalizer_v13: InstructRequestNormali
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- UserMessage(content="A"),
- AssistantMessage(
- content="B",
- tool_calls=[
- ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
- ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
- ],
- ),
- ToolMessage(content="C", tool_call_id="1"),
- ToolMessage(content="D", tool_call_id="2"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="A"),
+ AssistantMessage(
+ content="B",
+ tool_calls=[
+ ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
+ ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
+ ],
+ ),
+ ToolMessage(content="C", tool_call_id="1"),
+ ToolMessage(content="D", tool_call_id="2"),
+ ],
+ )
def test_reorder_internal_tool_messages(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
chat_completion_request: ChatCompletionRequest = self._mock_chat_completion(
@@ -783,19 +834,21 @@ def test_reorder_internal_tool_messages(self, normalizer_v13: InstructRequestNor
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- UserMessage(content="A"),
- AssistantMessage(
- content="B",
- tool_calls=[
- ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
- ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
- ],
- ),
- ToolMessage(content="C", tool_call_id="1"),
- ToolMessage(content="D", tool_call_id="2"),
- AssistantMessage(content="E"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="A"),
+ AssistantMessage(
+ content="B",
+ tool_calls=[
+ ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
+ ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
+ ],
+ ),
+ ToolMessage(content="C", tool_call_id="1"),
+ ToolMessage(content="D", tool_call_id="2"),
+ AssistantMessage(content="E"),
+ ],
+ )
def test_reorder_extra_tool_messages(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
chat_completion_request: ChatCompletionRequest = self._mock_chat_completion(
@@ -814,17 +867,19 @@ def test_reorder_extra_tool_messages(self, normalizer_v13: InstructRequestNormal
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- UserMessage(content="A"),
- AssistantMessage(
- content="B",
- tool_calls=[
- ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
- ],
- ),
- ToolMessage(content="C", tool_call_id="1"),
- ToolMessage(content="D", tool_call_id="2"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="A"),
+ AssistantMessage(
+ content="B",
+ tool_calls=[
+ ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
+ ],
+ ),
+ ToolMessage(content="C", tool_call_id="1"),
+ ToolMessage(content="D", tool_call_id="2"),
+ ],
+ )
def test_reorder_only_from_latest_assistant_message(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
chat_completion_request: ChatCompletionRequest = self._mock_chat_completion(
@@ -853,27 +908,29 @@ def test_reorder_only_from_latest_assistant_message(self, normalizer_v13: Instru
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- UserMessage(content="A"),
- AssistantMessage(
- content="B",
- tool_calls=[
- ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
- ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
- ],
- ),
- ToolMessage(content="C", tool_call_id="1"),
- ToolMessage(content="D", tool_call_id="2"),
- AssistantMessage(
- content="E",
- tool_calls=[
- ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
- ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
- ],
- ),
- ToolMessage(content="D", tool_call_id="2"),
- ToolMessage(content="C", tool_call_id="1"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="A"),
+ AssistantMessage(
+ content="B",
+ tool_calls=[
+ ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
+ ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
+ ],
+ ),
+ ToolMessage(content="C", tool_call_id="1"),
+ ToolMessage(content="D", tool_call_id="2"),
+ AssistantMessage(
+ content="E",
+ tool_calls=[
+ ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")),
+ ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")),
+ ],
+ ),
+ ToolMessage(content="D", tool_call_id="2"),
+ ToolMessage(content="C", tool_call_id="1"),
+ ],
+ )
@pytest.mark.parametrize(
["system_message", "expected_system_message"],
@@ -919,7 +976,9 @@ def test_aggregate_system_prompt_content(
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [expected_system_message, UserMessage(content="B")]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[expected_system_message, UserMessage(content="B")]
+ )
def test_system_messages_no_aggregation(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
"""Consecutive system messages are NOT aggregated into one in V7+."""
@@ -934,12 +993,14 @@ def test_system_messages_no_aggregation(self, normalizer_v13: InstructRequestNor
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- SystemMessage(content="A"),
- SystemMessage(content="B"),
- UserMessage(content="C"),
- SystemMessage(content="D"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ SystemMessage(content="A"),
+ SystemMessage(content="B"),
+ UserMessage(content="C"),
+ SystemMessage(content="D"),
+ ],
+ )
def test_system_messages_normalization(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
"""System message chunks within the same message are aggregated with no separator."""
@@ -952,10 +1013,12 @@ def test_system_messages_normalization(self, normalizer_v13: InstructRequestNorm
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.messages == [
- SystemMessage(content="A\n\nB"),
- SystemMessage(content="C"),
- ]
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[
+ SystemMessage(content="A\n\nB"),
+ SystemMessage(content="C"),
+ ],
+ )
def test_assert_parsed_settings(
self,
@@ -965,7 +1028,7 @@ def test_assert_parsed_settings(
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.settings == ModelSettings.none()
+ assert parsed_request == InstructRequest[ChatMessage, Tool](messages=[UserMessage(content="B")])
def test_continue_final_message_forwarded(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
request = ChatCompletionRequest[ChatMessage](
@@ -973,7 +1036,128 @@ def test_continue_final_message_forwarded(self, normalizer_v13: InstructRequestN
continue_final_message=True,
)
result: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
- assert result.continue_final_message is True
+ assert result == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="a"), AssistantMessage(content="b")],
+ continue_final_message=True,
+ )
+
+ def test_accepts_text_and_think_chunks(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
+ r"""V13 normalizer accepts TextChunk and ThinkChunk in assistant messages."""
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
+ ],
+ )
+
+ def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
+ r"""V13 normalizer accepts string content in assistant messages."""
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content="plain text"),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="query"), AssistantMessage(content="plain text")],
+ )
+
+ def test_assistant_think_chunk_inter_message_aggregation(
+ self, normalizer_v13: InstructRequestNormalizerV13
+ ) -> None:
+ r"""V13 normalizer preserves ThinkChunks across multiple assistant messages."""
+ chat_completion_request = mock_chat_completion(
+ messages=[
+ AssistantMessage(content="A"),
+ AssistantMessage(content=[TextChunk(text="B")]),
+ AssistantMessage(
+ content=[
+ ThinkChunk(thinking="T"),
+ TextChunk(text="C"),
+ TextChunk(text="D"),
+ ]
+ ),
+ ]
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
+ chat_completion_request
+ )
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ AssistantMessage(
+ content=[
+ TextChunk(text="A\n\nB"),
+ ThinkChunk(thinking="T"),
+ TextChunk(text="C\n\nD"),
+ ]
+ ),
+ ],
+ )
+
+ def test_assistant_think_chunk_intra_message_aggregation(
+ self, normalizer_v13: InstructRequestNormalizerV13
+ ) -> None:
+ r"""V13 normalizer coalesces TextChunks and preserves multiple ThinkChunks within a single message."""
+ chat_completion_request = mock_chat_completion(
+ messages=[
+ UserMessage(content="u"),
+ AssistantMessage(
+ content=[
+ ThinkChunk(thinking="t1"),
+ ThinkChunk(thinking="t2"),
+ TextChunk(text="a1"),
+ TextChunk(text="a2"),
+ TextChunk(text="a3"),
+ ]
+ ),
+ UserMessage(content="u"),
+ ]
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(
+ chat_completion_request
+ )
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="u"),
+ AssistantMessage(
+ content=[
+ ThinkChunk(thinking="t1"),
+ ThinkChunk(thinking="t2"),
+ TextChunk(text="a1\n\na2\n\na3"),
+ ]
+ ),
+ UserMessage(content="u"),
+ ],
+ )
+
+ def test_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None:
+ r"""V13 normalizer aggregates TextChunks in tool messages to a string."""
+ request = mock_chat_completion(
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
+ ToolMessage(content=[TextChunk(text="hello"), TextChunk(text="world")], tool_call_id="c1"),
+ ],
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request)
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ content="",
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")],
+ ),
+ ToolMessage(content="hello\n\nworld", tool_call_id="c1"),
+ ],
+ )
class TestChatCompletionRequestNormalizationV15:
@@ -1002,7 +1186,10 @@ def test_assert_parsed_settings(
parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(
chat_completion_request
)
- assert parsed_request.settings == ModelSettings(reasoning_effort=reasoning_effort)
+ assert parsed_request == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="B")],
+ settings=ModelSettings(reasoning_effort=reasoning_effort),
+ )
def test_continue_final_message_forwarded(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
request = ChatCompletionRequest[ChatMessage](
@@ -1011,7 +1198,11 @@ def test_continue_final_message_forwarded(self, normalizer_v15: InstructRequestN
reasoning_effort=ReasoningEffort.high,
)
result: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assert result.continue_final_message is True
+ assert result == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="a"), AssistantMessage(content="b")],
+ continue_final_message=True,
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_v15_intra_message_chunks_joined_without_separator(
self, normalizer_v15: InstructRequestNormalizerV15
@@ -1025,12 +1216,10 @@ def test_v15_intra_message_chunks_joined_without_separator(
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- user_msg = parsed.messages[0]
- assert isinstance(user_msg, UserMessage)
- assert user_msg.content == "AB"
- assistant_msg = parsed.messages[1]
- assert isinstance(assistant_msg, AssistantMessage)
- assert assistant_msg.content == "CD"
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="AB"), AssistantMessage(content="CD")],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_v15_inter_message_join_still_uses_separator(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 still joins text across different messages with '\n\n'."""
@@ -1043,9 +1232,10 @@ def test_v15_inter_message_join_still_uses_separator(self, normalizer_v15: Instr
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- user_msg = parsed.messages[0]
- assert isinstance(user_msg, UserMessage)
- assert user_msg.content == "First\n\nSecond"
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="First\n\nSecond"), AssistantMessage(content="Reply")],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_v15_mixed_intra_and_inter_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 combines intra-message ('') and inter-message ('\n\n') joining."""
@@ -1058,9 +1248,10 @@ def test_v15_mixed_intra_and_inter_message(self, normalizer_v15: InstructRequest
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- user_msg = parsed.messages[0]
- assert isinstance(user_msg, UserMessage)
- assert user_msg.content == "AB\n\nCD"
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="AB\n\nCD"), AssistantMessage(content="Reply")],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_v15_mixed_intra_and_inter_assistant_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
r"""V15 combines intra-message ('') and inter-message ('\n\n') joining for assistant messages."""
@@ -1073,9 +1264,10 @@ def test_v15_mixed_intra_and_inter_assistant_messages(self, normalizer_v15: Inst
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- assistant_msg = parsed.messages[1]
- assert isinstance(assistant_msg, AssistantMessage)
- assert assistant_msg.content == "AB\n\nCD"
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="Hello"), AssistantMessage(content="AB\n\nCD")],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
def test_v15_tool_message_text_chunks_joined_without_separator(
self, normalizer_v15: InstructRequestNormalizerV15
@@ -1090,9 +1282,129 @@ def test_v15_tool_message_text_chunks_joined_without_separator(
reasoning_effort=ReasoningEffort.high,
)
parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
- tool_msg = parsed.messages[2]
- assert isinstance(tool_msg, ToolMessage)
- assert tool_msg.content == "XY"
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ content="",
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")],
+ ),
+ ToolMessage(content="XY", tool_call_id="c1"),
+ ],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
+
+ def test_accepts_text_and_think_chunks(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ r"""V15 normalizer accepts TextChunk and ThinkChunk in assistant messages."""
+ request = ChatCompletionRequest[ChatMessage](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
+ ],
+ reasoning_effort=ReasoningEffort.high,
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]),
+ ],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
+
+ def test_accepts_string_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ r"""V15 normalizer accepts string content in assistant messages."""
+ request = ChatCompletionRequest[ChatMessage](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(content="plain text"),
+ ],
+ reasoning_effort=ReasoningEffort.high,
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[UserMessage(content="query"), AssistantMessage(content="plain text")],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
+
+ def test_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ r"""V15 normalizer preserves non-text chunks in tool messages."""
+ image_chunk = ImageURLChunk(image_url="https://example.com/image.png")
+ request = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]),
+ ToolMessage(content=[image_chunk], tool_call_id="c1"),
+ ],
+ reasoning_effort=ReasoningEffort.high,
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ content="",
+ tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")],
+ ),
+ ToolMessage(content=[image_chunk], tool_call_id="c1"),
+ ],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
+
+ def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ r"""V15 normalizer sorts multimodal tool messages by tool call order."""
+ image_chunk_1 = ImageURLChunk(image_url="https://example.com/img1.png")
+ image_chunk_2 = ImageURLChunk(image_url="https://example.com/img2.png")
+ request = ChatCompletionRequest( # type: ignore[type-var]
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ tool_calls=[
+ ToolCall(function=FunctionCall(name="fn1", arguments="{}"), id="c1"),
+ ToolCall(function=FunctionCall(name="fn2", arguments="{}"), id="c2"),
+ ]
+ ),
+ ToolMessage(content=[image_chunk_2], tool_call_id="c2"),
+ ToolMessage(content=[image_chunk_1], tool_call_id="c1"),
+ ],
+ reasoning_effort=ReasoningEffort.high,
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ UserMessage(content="query"),
+ AssistantMessage(
+ content="",
+ tool_calls=[
+ ToolCall(function=FunctionCall(name="fn1", arguments="{}"), id="c1"),
+ ToolCall(function=FunctionCall(name="fn2", arguments="{}"), id="c2"),
+ ],
+ ),
+ ToolMessage(content=[image_chunk_1], tool_call_id="c1"),
+ ToolMessage(content=[image_chunk_2], tool_call_id="c2"),
+ ],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
+
+ def test_preserves_audio_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None:
+ r"""V15 normalizer preserves AudioChunk in system messages."""
+ request = ChatCompletionRequest[ChatMessage](
+ messages=[
+ SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]),
+ UserMessage(content="test"),
+ ],
+ reasoning_effort=ReasoningEffort.high,
+ )
+ parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request)
+ assert parsed == InstructRequest[ChatMessage, Tool](
+ messages=[
+ SystemMessage(
+ content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")],
+ ),
+ UserMessage(content="test"),
+ ],
+ settings=ModelSettings(reasoning_effort=ReasoningEffort.high),
+ )
@pytest.mark.parametrize(
diff --git a/tests/test_tokenizer_v13.py b/tests/test_tokenizer_v13.py
index 15bb6f3a..64380ac3 100644
--- a/tests/test_tokenizer_v13.py
+++ b/tests/test_tokenizer_v13.py
@@ -237,13 +237,17 @@ def test_end_to_end_v13_wrong_order(
def test_encode_tool_message(v13_tekkenizer: InstructTokenizerV13) -> None:
tool_message = ToolMessage(content="R1", tool_call_id="123456789")
assert isinstance(v13_tekkenizer, InstructTokenizerV13)
- encoded = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False)
+ encoded, images, audios = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False)
assert encoded == [7, 182, 149, 8]
+ assert images == []
+ assert audios == []
tool_message = ToolMessage(content=[TextChunk(text="R1"), TextChunk(text="R2")], tool_call_id="123456789")
assert isinstance(v13_tekkenizer, InstructTokenizerV13)
- encoded = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False)
+ encoded, images, audios = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False)
assert encoded == [7, 182, 149, 182, 150, 8]
+ assert images == []
+ assert audios == []
def test_encode_think_chunk(v13_tekkenizer_think: InstructTokenizerV13) -> None:
@@ -371,8 +375,9 @@ def test_tokenize_assistant_message_error(v13_tekkenizer: InstructTokenizerV13)
def test_encode_system_message(
v13_tekkenizer_think: InstructTokenizerV13, message: SystemMessage, expected: str
) -> None:
- encoded = v13_tekkenizer_think.encode_system_message(message)
+ encoded, audios = v13_tekkenizer_think.encode_system_message(message)
assert v13_tekkenizer_think.decode(encoded, special_token_policy=SpecialTokenPolicy.KEEP) == expected
+ assert audios == []
@pytest.mark.parametrize("audio_fixture", ["audio_chunk", "audio_url_chunk"])
diff --git a/tests/test_tokenizer_v15.py b/tests/test_tokenizer_v15.py
index fdd84c33..c8abe52c 100644
--- a/tests/test_tokenizer_v15.py
+++ b/tests/test_tokenizer_v15.py
@@ -1,7 +1,18 @@
+import base64
+from collections.abc import Callable
+from io import BytesIO
+
import pytest
+from PIL import Image
from mistral_common.exceptions import InvalidRequestException, TokenizerException
-from mistral_common.protocol.instruct.chunk import ThinkChunk
+from mistral_common.protocol.instruct.chunk import (
+ AudioChunk,
+ AudioURLChunk,
+ ImageURLChunk,
+ TextChunk,
+ ThinkChunk,
+)
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
ChatMessage,
@@ -18,11 +29,14 @@
)
from mistral_common.protocol.instruct.tool_calls import Function, FunctionCall, Tool, ToolCall
from mistral_common.protocol.instruct.validator import ValidationMode, get_validator
-from mistral_common.tokens.tokenizers.base import TokenizerVersion
+from mistral_common.tokens.tokenizers.audio import AudioConfig, AudioEncoder, AudioSpectrogramConfig, SpecialAudioIDs
+from mistral_common.tokens.tokenizers.base import SpecialTokens, TokenizerVersion
+from mistral_common.tokens.tokenizers.image import ImageConfig, ImageEncoder, SpecialImageIDs
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV15
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.model_settings_builder import EnumBuilder, ModelSettingsBuilder
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
+from tests.fixtures.audio import get_dummy_audio_chunk, get_dummy_audio_url_chunk
from tests.test_tekken import get_special_tokens, quick_vocab
EXPECTED_TEXT_V15: str = (
@@ -49,6 +63,44 @@
r"[INST]U2[/INST]"
)
+EXPECTED_TEXT_TOOL_AUDIO: str = (
+ r""
+ r'[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "fn",'
+ r' "description": "test", "parameters": {}}}]'
+ r'[/AVAILABLE_TOOLS][MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]'
+ r"[INST]Use the tool[/INST]"
+ r"[TOOL_CALLS]fn[ARGS]{}"
+ r"[TOOL_RESULTS]result[BEGIN_AUDIO][AUDIO][AUDIO][/TOOL_RESULTS]"
+)
+
+EXPECTED_TEXT_TOOL_IMAGE: str = (
+ r""
+ r'[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "fn",'
+ r' "description": "test", "parameters": {}}}]'
+ r'[/AVAILABLE_TOOLS][MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]'
+ r"[INST]Use the tool[/INST]"
+ r"[TOOL_CALLS]fn[ARGS]{}"
+ r"[TOOL_RESULTS]result[IMG][IMG_END][/TOOL_RESULTS]"
+)
+
+EXPECTED_TEXT_SYSTEM_AUDIO: str = (
+ r"[SYSTEM_PROMPT]System with content[BEGIN_AUDIO][AUDIO][AUDIO][/SYSTEM_PROMPT]"
+ r'[MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]'
+ r"[INST]Hello[/INST]"
+)
+
+EXPECTED_TEXT_USER_AUDIO: str = (
+ r""
+ r'[MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]'
+ r"[INST]Here is content[BEGIN_AUDIO][AUDIO][AUDIO][/INST]"
+)
+
+EXPECTED_TEXT_USER_IMAGE: str = (
+ r""
+ r'[MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]'
+ r"[INST][IMG][IMG_END]Here is content[/INST]"
+)
+
def _build_v15_tekkenizer(model_settings_builder: ModelSettingsBuilder | None) -> Tekkenizer:
r"""Build a v15 Tekkenizer with the given model settings builder.
@@ -89,6 +141,103 @@ def get_v15_mistral_tokenizer(
)
+def _build_model_settings_builder(
+ allowed_reasoning_effort: tuple[str, ...] | None,
+) -> ModelSettingsBuilder:
+ """Build a ModelSettingsBuilder from allowed reasoning effort values.
+
+ When `allowed_reasoning_effort` is `None`, returns `ModelSettingsBuilder.none()`
+ (all fields ignored). This matches the behavior of `Tekkenizer.from_file` when no
+ `model_settings_builder` key is present in the JSON.
+ """
+ if allowed_reasoning_effort is None:
+ return ModelSettingsBuilder.none()
+ return ModelSettingsBuilder(
+ reasoning_effort=EnumBuilder[ReasoningEffort](
+ values=[ReasoningEffort(v) for v in allowed_reasoning_effort],
+ accepts_none=True,
+ default=ReasoningEffort(allowed_reasoning_effort[0]) if allowed_reasoning_effort else None,
+ )
+ )
+
+
+def _get_dummy_image_url_chunk() -> ImageURLChunk:
+ r"""Build a small base64-encoded image URL chunk for testing."""
+ img = Image.new("RGB", (4, 4), "red")
+ buf = BytesIO()
+ img.save(buf, "PNG")
+ data_url = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
+ return ImageURLChunk(image_url=data_url)
+
+
+def get_v15_mistral_tokenizer_with_audio() -> MistralTokenizer:
+ r"""Build a V15 MistralTokenizer with audio encoder."""
+ builder = _build_model_settings_builder(tuple(ReasoningEffort))
+ tekkenizer = Tekkenizer(
+ quick_vocab([b"a", b"b", b"c", b"f", b"de"]),
+ special_tokens=get_special_tokens(TokenizerVersion.v15, add_audio=True),
+ pattern=r".+",
+ vocab_size=256 + 100,
+ num_special_tokens=100,
+ version=TokenizerVersion.v15,
+ model_settings_builder=builder,
+ )
+ audio_config = AudioConfig(
+ sampling_rate=24_000,
+ frame_rate=12.5,
+ encoding_config=AudioSpectrogramConfig(
+ num_mel_bins=128,
+ hop_length=160,
+ window_size=400,
+ ),
+ )
+ special_audio_ids = SpecialAudioIDs(
+ audio=tekkenizer.get_special_token(SpecialTokens.audio.value),
+ begin_audio=tekkenizer.get_special_token(SpecialTokens.begin_audio.value),
+ streaming_pad=None,
+ text_to_audio=None,
+ audio_to_text=None,
+ )
+ audio_encoder = AudioEncoder(audio_config, special_audio_ids)
+ instruct_tokenizer = InstructTokenizerV15(tekkenizer, audio_encoder=audio_encoder)
+ request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder)
+ validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test)
+ return MistralTokenizer(
+ instruct_tokenizer=instruct_tokenizer,
+ validator=validator,
+ request_normalizer=request_normalizer,
+ )
+
+
+def get_v15_mistral_tokenizer_with_image() -> MistralTokenizer:
+ r"""Build a V15 MistralTokenizer with image encoder."""
+ builder = _build_model_settings_builder(tuple(ReasoningEffort))
+ tekkenizer = Tekkenizer(
+ quick_vocab([b"a", b"b", b"c", b"f", b"de"]),
+ special_tokens=get_special_tokens(TokenizerVersion.v15, add_think=True),
+ pattern=r".+",
+ vocab_size=256 + 100,
+ num_special_tokens=100,
+ version=TokenizerVersion.v15,
+ model_settings_builder=builder,
+ )
+ image_config = ImageConfig(image_patch_size=16, max_image_size=1024)
+ special_image_ids = SpecialImageIDs(
+ img=tekkenizer.get_special_token(SpecialTokens.img.value),
+ img_break=tekkenizer.get_special_token(SpecialTokens.img_break.value),
+ img_end=tekkenizer.get_special_token(SpecialTokens.img_end.value),
+ )
+ image_encoder = ImageEncoder(image_config, special_image_ids)
+ instruct_tokenizer = InstructTokenizerV15(tekkenizer, image_encoder=image_encoder)
+ request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder)
+ validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test)
+ return MistralTokenizer(
+ instruct_tokenizer=instruct_tokenizer,
+ validator=validator,
+ request_normalizer=request_normalizer,
+ )
+
+
@pytest.fixture(scope="session")
def v15_tekkenizer() -> InstructTokenizerV15:
return get_v15_tekkenizer(_build_model_settings_builder(tuple(ReasoningEffort)))
@@ -139,6 +288,61 @@ def messages() -> list[ChatMessage]:
]
+@pytest.fixture(scope="session")
+def audio_chunk() -> AudioChunk:
+ return get_dummy_audio_chunk()
+
+
+# Multimodal content chunks and their corresponding tokenizer factories for parametrized tests.
+_TOOL_MULTIMODAL_PARAMS = [
+ pytest.param(
+ get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, EXPECTED_TEXT_TOOL_AUDIO, id="audio"
+ ),
+ pytest.param(
+ get_dummy_audio_url_chunk(),
+ 1,
+ 0,
+ get_v15_mistral_tokenizer_with_audio,
+ EXPECTED_TEXT_TOOL_AUDIO,
+ id="audio_url",
+ ),
+ pytest.param(
+ _get_dummy_image_url_chunk(),
+ 0,
+ 1,
+ get_v15_mistral_tokenizer_with_image,
+ EXPECTED_TEXT_TOOL_IMAGE,
+ id="image_url",
+ ),
+]
+_SYSTEM_MULTIMODAL_PARAMS = [
+ pytest.param(
+ get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, EXPECTED_TEXT_SYSTEM_AUDIO, id="audio"
+ ),
+]
+_USER_MULTIMODAL_PARAMS = [
+ pytest.param(
+ get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, EXPECTED_TEXT_USER_AUDIO, id="audio"
+ ),
+ pytest.param(
+ get_dummy_audio_url_chunk(),
+ 1,
+ 0,
+ get_v15_mistral_tokenizer_with_audio,
+ EXPECTED_TEXT_USER_AUDIO,
+ id="audio_url",
+ ),
+ pytest.param(
+ _get_dummy_image_url_chunk(),
+ 0,
+ 1,
+ get_v15_mistral_tokenizer_with_image,
+ EXPECTED_TEXT_USER_IMAGE,
+ id="image_url",
+ ),
+]
+
+
def test_tools_and_reasoning_effort(
v15_tekkenizer: InstructTokenizerV15, available_tools: list[Tool], messages: list[ChatMessage]
) -> None:
@@ -177,30 +381,6 @@ def test_system_think_chunk_raises_v15(v15_tekkenizer: InstructTokenizerV15) ->
v15_tekkenizer.encode_instruct(request)
-def _build_model_settings_builder(
- allowed_reasoning_effort: tuple[str, ...] | None,
-) -> ModelSettingsBuilder:
- """Build a ModelSettingsBuilder from allowed reasoning effort values.
-
- When `allowed_reasoning_effort` is `None`, returns `ModelSettingsBuilder.none()`
- (all fields ignored). This matches the behavior of `Tekkenizer.from_file` when no
- `model_settings_builder` key is present in the JSON.
- """
- if allowed_reasoning_effort is None:
- return ModelSettingsBuilder.none()
- if not allowed_reasoning_effort:
- return ModelSettingsBuilder(
- reasoning_effort=EnumBuilder[ReasoningEffort](values=[], accepts_none=True, default=None)
- )
- return ModelSettingsBuilder(
- reasoning_effort=EnumBuilder[ReasoningEffort](
- values=[ReasoningEffort(v) for v in allowed_reasoning_effort],
- accepts_none=True,
- default=ReasoningEffort(allowed_reasoning_effort[0]) if allowed_reasoning_effort else None,
- )
- )
-
-
@pytest.mark.parametrize(
("reasoning_effort", "allowed_reasoning_effort", "raises", "match"),
[
@@ -294,3 +474,79 @@ def test_encode_chat_completion_continue_final_message() -> None:
eos_id = tokenizer_v15.instruct_tokenizer.tokenizer.eos_id
assert encoded.tokens[-1] != eos_id
+
+
+@pytest.mark.parametrize(
+ ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory", "expected_text"),
+ _TOOL_MULTIMODAL_PARAMS,
+)
+def test_encode_chat_completion_with_multimodal_tool(
+ content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk,
+ expected_audios: int,
+ expected_images: int,
+ tokenizer_factory: Callable[[], MistralTokenizer],
+ expected_text: str,
+) -> None:
+ mistral_tokenizer = tokenizer_factory()
+ chat_request: ChatCompletionRequest = ChatCompletionRequest(
+ messages=[
+ UserMessage(content="Use the tool"),
+ AssistantMessage(tool_calls=[ToolCall(id="test12345", function=FunctionCall(name="fn", arguments="{}"))]),
+ ToolMessage(
+ content=[TextChunk(text="result"), content_chunk],
+ tool_call_id="test12345",
+ ),
+ ],
+ tools=[Tool(function=Function(name="fn", description="test", parameters={}))],
+ )
+ encoded = mistral_tokenizer.encode_chat_completion(chat_request)
+ assert encoded.text == expected_text, encoded.text
+ assert len(encoded.audios) == expected_audios
+ assert len(encoded.images) == expected_images
+
+
+@pytest.mark.parametrize(
+ ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory", "expected_text"),
+ _SYSTEM_MULTIMODAL_PARAMS,
+)
+def test_encode_chat_completion_with_multimodal_system(
+ content_chunk: AudioChunk,
+ expected_audios: int,
+ expected_images: int,
+ tokenizer_factory: Callable[[], MistralTokenizer],
+ expected_text: str,
+) -> None:
+ mistral_tokenizer = tokenizer_factory()
+ chat_request: ChatCompletionRequest = ChatCompletionRequest(
+ messages=[
+ SystemMessage(content=[TextChunk(text="System with content"), content_chunk]),
+ UserMessage(content="Hello"),
+ ],
+ )
+ encoded = mistral_tokenizer.encode_chat_completion(chat_request)
+ assert encoded.text == expected_text, encoded.text
+ assert len(encoded.audios) == expected_audios
+ assert len(encoded.images) == expected_images
+
+
+@pytest.mark.parametrize(
+ ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory", "expected_text"),
+ _USER_MULTIMODAL_PARAMS,
+)
+def test_encode_chat_completion_with_multimodal_user(
+ content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk,
+ expected_audios: int,
+ expected_images: int,
+ tokenizer_factory: Callable[[], MistralTokenizer],
+ expected_text: str,
+) -> None:
+ mistral_tokenizer = tokenizer_factory()
+ chat_request = ChatCompletionRequest(
+ messages=[
+ UserMessage(content=[TextChunk(text="Here is content"), content_chunk]),
+ ],
+ )
+ encoded = mistral_tokenizer.encode_chat_completion(chat_request)
+ assert encoded.text == expected_text, encoded.text
+ assert len(encoded.audios) == expected_audios
+ assert len(encoded.images) == expected_images
diff --git a/tests/test_tokenizer_v7_audio.py b/tests/test_tokenizer_v7_audio.py
index 10995957..4ef3d47c 100644
--- a/tests/test_tokenizer_v7_audio.py
+++ b/tests/test_tokenizer_v7_audio.py
@@ -11,8 +11,8 @@
AudioChunk,
AudioURL,
AudioURLChunk,
+ ContentChunk,
TextChunk,
- UserContentChunk,
)
from mistral_common.protocol.instruct.messages import (
UATS,
@@ -215,7 +215,7 @@ def test_tokenize_user_message(tekkenizer: InstructTokenizerV7, audio_first: boo
text_chunk = TextChunk(text="a")
num_expected_frames = int(np.ceil(duration * frame_rate))
- chunks: list[UserContentChunk] = [audio_chunk, text_chunk] if audio_first else [text_chunk, audio_chunk]
+ chunks: list[ContentChunk] = [audio_chunk, text_chunk] if audio_first else [text_chunk, audio_chunk]
tokenized = tekkenizer.encode_instruct(
InstructRequest(
@@ -261,7 +261,7 @@ def test_tokenize_multi_turn(tekkenizer: InstructTokenizerV7) -> None:
text_chunk = TextChunk(text="a")
num_expected_frames = int(np.ceil(duration * frame_rate))
- chunks: list[UserContentChunk] = [audio_chunk, text_chunk]
+ chunks: list[ContentChunk] = [audio_chunk, text_chunk]
tokenized = tekkenizer.encode_instruct(
InstructRequest(
diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py
index f4c8653c..c036e2a8 100644
--- a/tests/validation/test_chat_validation.py
+++ b/tests/validation/test_chat_validation.py
@@ -1,11 +1,19 @@
+from typing import Any
+
import pytest
from mistral_common.exceptions import (
InvalidAssistantMessageException,
InvalidMessageStructureException,
InvalidRequestException,
+ InvalidSystemPromptException,
+ InvalidToolMessageException,
+ InvalidUserMessageException,
+)
+from mistral_common.protocol.instruct.chunk import (
+ AudioChunk,
+ AudioURLChunk,
)
-from mistral_common.protocol.instruct.chunk import AudioChunk, AudioURLChunk
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
SystemMessage,
@@ -26,6 +34,29 @@
ValidationMode,
)
from tests.fixtures.audio import get_dummy_audio_chunk, get_dummy_audio_url_chunk
+from tests.fixtures.chunks import get_content_chunks
+
+_Messages = list[UserMessage | AssistantMessage | SystemMessage | ToolMessage]
+
+
+def _user_convo(content: "str | list[Any]") -> _Messages:
+ return [UserMessage(content=content)]
+
+
+def _assistant_convo(content: "str | list[Any]") -> _Messages:
+ return [UserMessage(content="hi"), AssistantMessage(content=content), UserMessage(content="next")]
+
+
+def _system_convo(content: "str | list[Any]") -> _Messages:
+ return [SystemMessage(content=content), UserMessage(content="hi")]
+
+
+def _tool_convo(content: "str | list[Any]") -> _Messages:
+ return [
+ UserMessage(content="hi"),
+ AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]),
+ ToolMessage(content=content, tool_call_id="test12345"),
+ ]
@pytest.fixture(scope="module")
@@ -49,6 +80,16 @@ def validator(request: pytest.FixtureRequest) -> MistralRequestValidator:
return request.param # type: ignore
+@pytest.fixture
+def validator_base() -> MistralRequestValidator:
+ return MistralRequestValidator(ValidationMode.serving)
+
+
+@pytest.fixture
+def validator_v3() -> MistralRequestValidatorV3:
+ return MistralRequestValidatorV3(ValidationMode.serving)
+
+
@pytest.fixture
def validator_v5() -> MistralRequestValidatorV5:
return MistralRequestValidatorV5(ValidationMode.serving)
@@ -353,6 +394,57 @@ def test_build_settings_raises_error(self, validator: MistralRequestValidator) -
with pytest.raises(InvalidRequestException, match="reasoning_effort='none' is not supported for this model"):
validator._validate_model_settings(request)
+ def test_allows_text_content_chunks(self, validator_base: MistralRequestValidator) -> None:
+ validator_base.validate_messages(_user_convo(get_content_chunks(("text",))), continue_final_message=False)
+ validator_base.validate_messages(_assistant_convo(get_content_chunks(("text",))), continue_final_message=False)
+ validator_base.validate_messages(_system_convo(get_content_chunks(("text",))), continue_final_message=False)
+ validator_base.validate_messages(_tool_convo(get_content_chunks(("text",))), continue_final_message=False)
+
+ def test_rejects_non_text_in_user(self, validator_base: MistralRequestValidator) -> None:
+ for name in ("image", "image_url", "audio", "audio_url"):
+ with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
+ validator_base.validate_messages(_user_convo(get_content_chunks((name,))), continue_final_message=False)
+
+ def test_rejects_non_text_in_assistant(self, validator_base: MistralRequestValidator) -> None:
+ for name in ("think",):
+ with pytest.raises(
+ InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message"
+ ):
+ validator_base.validate_messages(
+ _assistant_convo(get_content_chunks((name,))), continue_final_message=False
+ )
+
+ def test_rejects_non_text_in_system(self, validator_base: MistralRequestValidator) -> None:
+ for name in ("audio", "think"):
+ with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"):
+ validator_base.validate_messages(
+ _system_convo(get_content_chunks((name,))), continue_final_message=False
+ )
+
+ def test_rejects_non_text_in_tool(self, validator_base: MistralRequestValidator) -> None:
+ for name in ("image", "image_url", "audio", "audio_url", "think"):
+ with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"):
+ validator_base.validate_messages(_tool_convo(get_content_chunks((name,))), continue_final_message=False)
+
+ def test_reports_sorted_unique_invalid_chunk_types(self, validator_base: MistralRequestValidator) -> None:
+ content = get_content_chunks(("audio", "image_url"))
+ with pytest.raises(
+ InvalidUserMessageException,
+ match=r"Unexpected content chunk types in user message: \['AudioChunk', 'ImageURLChunk'\]",
+ ):
+ validator_base.validate_messages(_user_convo(content), continue_final_message=False)
+
+
+class TestChatValidationV3:
+ def test_allows_text_and_image_in_user(self, validator_v3: MistralRequestValidatorV3) -> None:
+ content = get_content_chunks(("text", "image", "image_url"))
+ validator_v3.validate_messages(_user_convo(content), continue_final_message=False)
+
+ def test_rejects_audio_in_user(self, validator_v3: MistralRequestValidatorV3) -> None:
+ for name in ("audio", "audio_url"):
+ with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"):
+ validator_v3.validate_messages(_user_convo(get_content_chunks((name,))), continue_final_message=False)
+
class TestChatValidationV5:
@pytest.mark.parametrize("audio_fixture", ["audio_chunk", "audio_url_chunk"])
@@ -403,6 +495,23 @@ def test_build_settings_raises_error(self, validator: MistralRequestValidator) -
with pytest.raises(InvalidRequestException, match="reasoning_effort='none' is not supported for this model"):
validator._validate_model_settings(request)
+ def test_allows_text_image_audio_in_user(self, validator_v5: MistralRequestValidatorV5) -> None:
+ content = get_content_chunks(("text", "image", "image_url", "audio", "audio_url"))
+ validator_v5.validate_messages(_user_convo(content), continue_final_message=False)
+
+ def test_allows_text_audio_think_in_system(self, validator_v5: MistralRequestValidatorV5) -> None:
+ content = get_content_chunks(("text", "audio", "think"))
+ validator_v5.validate_messages(_system_convo(content), continue_final_message=False)
+
+ def test_rejects_think_in_assistant(self, validator_v5: MistralRequestValidatorV5) -> None:
+ for name in ("think",):
+ with pytest.raises(
+ InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message"
+ ):
+ validator_v5.validate_messages(
+ _assistant_convo(get_content_chunks((name,))), continue_final_message=False
+ )
+
class TestChatValidationV13:
def test_right_number_results_invalid_id(self, validator_v13: MistralRequestValidatorV13) -> None:
@@ -613,6 +722,15 @@ def test_audio_with_system_prompt_raises_ok(
continue_final_message=False,
)
+ def test_allows_text_and_think_in_assistant(self, validator_v13: MistralRequestValidatorV13) -> None:
+ content = get_content_chunks(("text", "think"))
+ validator_v13.validate_messages(_assistant_convo(content), continue_final_message=False)
+
+ def test_rejects_non_text_in_tool(self, validator_v13: MistralRequestValidatorV13) -> None:
+ for name in ("image", "image_url", "audio", "audio_url", "think"):
+ with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"):
+ validator_v13.validate_messages(_tool_convo(get_content_chunks((name,))), continue_final_message=False)
+
class TestChatValidationV15:
@pytest.mark.parametrize("reasoning_effort", [*list(ReasoningEffort), None])
@@ -621,3 +739,18 @@ def test_build_settings_v15_reasoning_effort(
) -> None:
request = ChatCompletionRequest(messages=[UserMessage(content="Hello")], reasoning_effort=reasoning_effort)
validator_v15._validate_model_settings(request)
+
+ def test_allows_text_and_audio_in_system(self, validator_v15: MistralRequestValidatorV15) -> None:
+ content = get_content_chunks(("text", "audio"))
+ validator_v15.validate_messages(_system_convo(content), continue_final_message=False)
+
+ def test_rejects_think_in_system(self, validator_v15: MistralRequestValidatorV15) -> None:
+ for name in ("think",):
+ with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"):
+ validator_v15.validate_messages(
+ _system_convo(get_content_chunks((name,))), continue_final_message=False
+ )
+
+ def test_allows_all_chunk_types_in_tool(self, validator_v15: MistralRequestValidatorV15) -> None:
+ content = get_content_chunks(("text", "image", "image_url", "audio", "audio_url", "think"))
+ validator_v15.validate_messages(_tool_convo(content), continue_final_message=False)