diff --git a/src/mistral_common/experimental/app/routers.py b/src/mistral_common/experimental/app/routers.py index 73c9c6e3..40103e70 100644 --- a/src/mistral_common/experimental/app/routers.py +++ b/src/mistral_common/experimental/app/routers.py @@ -14,7 +14,7 @@ ) from mistral_common.experimental.think import _split_content_and_think_chunks from mistral_common.experimental.tools import _decode_tool_calls, _split_content_and_tool_calls -from mistral_common.protocol.instruct.chunk import TextChunk, ThinkChunk +from mistral_common.protocol.instruct.chunk import ContentChunk, TextChunk, ThinkChunk from mistral_common.protocol.instruct.messages import AssistantMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, Tokenized, TokenizerVersion @@ -100,7 +100,7 @@ async def detokenize_to_assistant_message( else: content_tokens, tool_calls_tokens = tokens, () - content: str | list[TextChunk | ThinkChunk] | None = None + content: str | list[ContentChunk] | None = None if settings.tokenizer.instruct_tokenizer.tokenizer.version >= TokenizerVersion.v13: assert isinstance(settings.tokenizer.instruct_tokenizer, InstructTokenizerV13) diff --git a/src/mistral_common/integrations/chat_templates/template_generator.py b/src/mistral_common/integrations/chat_templates/template_generator.py index 807ed528..b9dc662a 100644 --- a/src/mistral_common/integrations/chat_templates/template_generator.py +++ b/src/mistral_common/integrations/chat_templates/template_generator.py @@ -216,6 +216,30 @@ def validates_assistant_non_empty(self) -> bool: r"""Whether to validate that assistant messages have non-empty content or tool calls.""" return self.version >= TokenizerVersion.v7 or (self.version >= TokenizerVersion.v3 and not self.spm) + @property + def tool_supports_multimodal(self) -> bool: + r"""Whether tool messages can contain non-text content chunks. V15+.""" + return self.version >= TokenizerVersion.v15 + + @property + def system_supports_audio(self) -> bool: + r"""Whether system messages can contain audio. V15+ with audio support.""" + return self.audio_support and self.version >= TokenizerVersion.v15 + + +def _join_types_desc(parts: list[str]) -> str: + r"""Join type names into a human-readable description string. + + Args: + parts: List of type names (e.g. ["text", "thinking", "image"]). + + Returns: + Formatted string like "text", "text and thinking", or "text, thinking and image". + """ + if len(parts) == 1: + return parts[0] + return ", ".join(parts[:-1]) + " and " + parts[-1] + def _generate_header(config: TemplateConfig) -> str: r"""Generate template header with default system message. @@ -872,6 +896,8 @@ def _generate_system_message_handling(config: TemplateConfig) -> str: if has_extra_types: if config.system_supports_thinking: rc_args += ", supported_types_desc='text and thinking'" + elif config.system_supports_audio: + rc_args += ", supported_types_desc='text and audio'" else: rc_args += ", supported_types_desc='text'" if config.any_thinking_support: @@ -882,7 +908,7 @@ def _generate_system_message_handling(config: TemplateConfig) -> str: if config.image_support: rc_args += ", support_images=false" if config.audio_support: - rc_args += ", support_audio=false" + rc_args += f", support_audio={'true' if config.system_supports_audio else 'false'}" lines.append(" {{- render_content(" + rc_args + ") -}}") lines.append(" {{- '" + _END_SYSTEM + "' -}}") @@ -1204,10 +1230,10 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str: """ lines = [] + comment_parts = ["text"] if config.any_thinking_support: - chunk_types = "text and thinking" - else: - chunk_types = "text" + comment_parts.append("thinking") + chunk_types = _join_types_desc(comment_parts) comment = f"{{#- Assistant messages supports {chunk_types} content. #}}" lines.append("") @@ -1235,10 +1261,10 @@ def _generate_assistant_message_handling(config: TemplateConfig) -> str: has_extra_types = config.any_thinking_support or config.image_support or config.audio_support rc_call_args = "message['content'], 'assistant message contents'" if has_extra_types: + desc_parts = ["text"] if config.any_thinking_support: - rc_call_args += ", supported_types_desc='text and thinking'" - else: - rc_call_args += ", supported_types_desc='text'" + desc_parts.append("thinking") + rc_call_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'" if config.any_thinking_support: rc_call_args += ", support_thinking=true" if config.image_support: @@ -1423,7 +1449,10 @@ def _generate_tool_message_handling(config: TemplateConfig) -> str: lines.append(" {#- Tool messages supports int, float or text content. #}") lines.append(" {%- elif message['role'] == 'tool' and ns.index > ns.max_idx_user %}") else: - lines.append(" {#- Tool messages only supports text content. #}") + if config.tool_supports_multimodal: + lines.append(" {#- Tool messages (multimodal). #}") + else: + lines.append(" {#- Tool messages only supports text content. #}") lines.append(" {%- elif message['role'] == 'tool' %}") if config.uses_spm_space_tracking: @@ -1484,9 +1513,26 @@ def _generate_tool_message_handling(config: TemplateConfig) -> str: + "' }}" # noqa: E501 ) elif config.uses_simple_tool_results: - lines.append( - " {{- '" + _BEGIN_TOOL_RESULTS + "' + message['content']|string + '" + _END_TOOL_RESULTS + "' }}" - ) # noqa: E501 + if config.tool_supports_multimodal: + tool_rc_args = "message['content'], 'tool message contents'" + if config.image_support or config.audio_support: + desc_parts = ["text"] + if config.image_support: + desc_parts.append("image") + if config.audio_support: + desc_parts.append("audio") + tool_rc_args += f", supported_types_desc='{_join_types_desc(desc_parts)}'" + if config.image_support: + tool_rc_args += ", support_images=true" + if config.audio_support: + tool_rc_args += ", support_audio=true" + lines.append(" {{- '" + _BEGIN_TOOL_RESULTS + "' -}}") + lines.append(" {{- render_content(" + tool_rc_args + ") -}}") + lines.append(" {{- '" + _END_TOOL_RESULTS + "' }}") + else: + lines.append( + " {{- '" + _BEGIN_TOOL_RESULTS + "' + message['content']|string + '" + _END_TOOL_RESULTS + "' }}" + ) # noqa: E501 else: # v3 non-spm style lines.extend(_emit_int_float_parsing(" ")) diff --git a/src/mistral_common/protocol/instruct/chunk.py b/src/mistral_common/protocol/instruct/chunk.py index 2ed61528..91fcb679 100644 --- a/src/mistral_common/protocol/instruct/chunk.py +++ b/src/mistral_common/protocol/instruct/chunk.py @@ -451,9 +451,6 @@ def from_openai(cls, openai_chunk: dict[str, Any]) -> "ThinkChunk": ContentChunk = Annotated[ TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk | ThinkChunk, Field(discriminator="type") ] -UserContentChunk = Annotated[ - TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk, Field(discriminator="type") -] def _convert_openai_content_chunks(openai_content_chunks: dict[str, Any]) -> ContentChunk: diff --git a/src/mistral_common/protocol/instruct/messages.py b/src/mistral_common/protocol/instruct/messages.py index e82379cb..f96b78af 100644 --- a/src/mistral_common/protocol/instruct/messages.py +++ b/src/mistral_common/protocol/instruct/messages.py @@ -1,17 +1,22 @@ import warnings +from collections.abc import Sequence from enum import Enum -from typing import Any, Literal, TypeGuard, TypeVar +from typing import Any, ClassVar, Literal, TypeVar -from pydantic import Field -from typing_extensions import Annotated, TypeAlias +from pydantic import Field, model_validator +from typing_extensions import Annotated, TypeAlias, TypeGuard from mistral_common.base import MistralBase from mistral_common.exceptions import InvalidAssistantMessageException from mistral_common.protocol.instruct.chunk import ( + AudioChunk, + AudioURLChunk, + BaseContentChunk, ContentChunk, + ImageChunk, + ImageURLChunk, TextChunk, ThinkChunk, - UserContentChunk, _convert_openai_content_chunks, ) from mistral_common.protocol.instruct.tool_calls import ToolCall @@ -23,12 +28,14 @@ ) -def _are_think_chunks(think_chunks: list[ThinkChunk | TextChunk]) -> TypeGuard[list[ThinkChunk]]: - return all(isinstance(c, ThinkChunk) for c in think_chunks) +def _are_think_chunks(chunks: Sequence[ContentChunk]) -> TypeGuard[list[ThinkChunk]]: + r"""Narrow a chunk list to ThinkChunk list.""" + return all(isinstance(c, ThinkChunk) for c in chunks) -def _are_text_chunks(think_chunks: list[ThinkChunk | TextChunk]) -> TypeGuard[list[TextChunk]]: - return all(isinstance(c, TextChunk) for c in think_chunks) +def _are_text_chunks(chunks: Sequence[ContentChunk]) -> TypeGuard[list[TextChunk]]: + r"""Narrow a chunk list to TextChunk list.""" + return all(isinstance(c, TextChunk) for c in chunks) class ReasoningFieldFormat(str, Enum): @@ -73,6 +80,57 @@ class BaseMessage(MistralBase): role: Literal[Roles.system, Roles.user, Roles.assistant, Roles.tool] + # Allow-list of content chunk types accepted by this message. Must be set by each subclass. + _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] + + @model_validator(mode="after") + def _validate_allowed_content_chunks(self) -> "BaseMessage": + r"""Enforce the per-message content chunk allow-list.""" + content = getattr(self, "content", None) + if isinstance(content, list): + for chunk in content: + if not isinstance(chunk, self._allowed_content_chunks): + raise ValueError(f"{type(chunk).__name__} cannot be used in {self.role} message.") + return self + + @staticmethod + def _content_to_openai( + content: str | Sequence[ContentChunk] | None, + ) -> str | list[dict[str, Any]] | None: + r"""Serialize message content to OpenAI format. + + Args: + content: String, list of content chunks, or None. + + Returns: + String content as-is, list of chunks serialized via each chunk's + to_openai(), or None. + """ + if content is None or isinstance(content, str): + return content + return [chunk.to_openai() for chunk in content] + + @staticmethod + def _content_from_openai( + raw: str | list[dict[str, Any]] | None, + ) -> str | list[ContentChunk] | None: + r"""Deserialize content from OpenAI format. + + Args: + raw: Raw content from OpenAI message dict. + + Returns: + String content as-is, list of deserialized content chunks, or None. + + Raises: + ValueError: If content type is unrecognized. + """ + if raw is None or isinstance(raw, str): + return raw + if isinstance(raw, list): + return [_convert_openai_content_chunks(chunk) for chunk in raw] + raise ValueError(f"Unknown content type: {type(raw)}") + def to_openai(self) -> dict[str, Any]: r"""Converts the message to the OpenAI format. @@ -100,24 +158,24 @@ class UserMessage(BaseMessage): """ role: Literal[Roles.user] = Roles.user - content: str | list[UserContentChunk] + content: str | list[ContentChunk] + _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = ( + TextChunk, + ImageChunk, + ImageURLChunk, + AudioChunk, + AudioURLChunk, + ) def to_openai(self) -> dict[str, Any]: r"""Converts the message to the OpenAI format.""" - if isinstance(self.content, str): - return {"role": self.role, "content": self.content} - return {"role": self.role, "content": [chunk.to_openai() for chunk in self.content]} + return {"role": self.role, "content": self._content_to_openai(self.content)} @classmethod def from_openai(cls, openai_message: dict[str, Any]) -> "UserMessage": r"""Converts the OpenAI message to the Mistral format.""" - if isinstance(openai_message["content"], str): - return cls.model_validate_ignore_extra(openai_message) return cls.model_validate( - { - "role": openai_message["role"], - "content": [_convert_openai_content_chunks(chunk) for chunk in openai_message["content"]], - }, + {"role": openai_message["role"], "content": cls._content_from_openai(openai_message["content"])} ) @@ -132,16 +190,19 @@ class SystemMessage(BaseMessage): """ role: Literal[Roles.system] = Roles.system - content: str | list[TextChunk | ThinkChunk] + content: str | list[ContentChunk] + _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = (TextChunk, AudioChunk, ThinkChunk) def to_openai(self) -> dict[str, Any]: r"""Converts the message to the OpenAI format.""" - return self.model_dump() + return {"role": self.role, "content": self._content_to_openai(self.content)} @classmethod def from_openai(cls, openai_message: dict[str, Any]) -> "SystemMessage": r"""Converts the OpenAI message to the Mistral format.""" - return cls.model_validate_ignore_extra(openai_message) + return cls.model_validate( + {"role": openai_message["role"], "content": cls._content_from_openai(openai_message["content"])} + ) class AssistantMessage(BaseMessage): @@ -158,7 +219,8 @@ class AssistantMessage(BaseMessage): """ role: Literal[Roles.assistant] = Roles.assistant - content: str | list[TextChunk | ThinkChunk] | None = None + content: str | list[ContentChunk] | None = None + _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = (TextChunk, ThinkChunk) tool_calls: list[ToolCall] | None = None prefix: bool = False @@ -205,7 +267,7 @@ def to_openai( match reasoning_field_format: case None | ReasoningFieldFormat.thinking_chunks: - out_dict["content"] = [chunk.to_openai() for chunk in self.content] + out_dict["content"] = self._content_to_openai(self.content) case ReasoningFieldFormat.reasoning | ReasoningFieldFormat.reasoning_content: think_chunks, content_chunks = self.content[: last_think_idx + 1], self.content[last_think_idx + 1 :] if not _are_think_chunks(think_chunks) or not _are_text_chunks(content_chunks): @@ -216,7 +278,7 @@ def to_openai( if len(content_chunks) == 1: out_dict["content"] = content_chunks[0].text elif content_chunks: - out_dict["content"] = [chunk.to_openai() for chunk in content_chunks] + out_dict["content"] = self._content_to_openai(content_chunks) case _: raise ValueError(f"{reasoning_field_format=} is not supported.") @@ -234,14 +296,7 @@ def from_openai(cls, openai_message: dict[str, Any]) -> "AssistantMessage": tools_calls.append(ToolCall.from_openai(openai_tool_call)) else: raise ValueError(f"tool_calls must be a list, got {type(openai_tool_calls)}") - openai_content = openai_message.get("content", None) - content: str | list[ContentChunk] | None = None - if openai_content is None or isinstance(openai_content, str): - content = openai_content - elif isinstance(openai_content, list): - content = [_convert_openai_content_chunks(chunk) for chunk in openai_content] - else: - raise ValueError(f"Unknown content type: {type(openai_content)}") + content = cls._content_from_openai(openai_message.get("content")) reasoning_content: str | None = openai_message.get("reasoning_content") reasoning: str | None = openai_message.get("reasoning") @@ -308,23 +363,44 @@ class ToolMessage(BaseMessage): >>> message = ToolMessage(content="Hello, how can I help you?", tool_call_id="123") """ - content: str | list[TextChunk] + content: str | list[ContentChunk] role: Literal[Roles.tool] = Roles.tool tool_call_id: str | None = None # Deprecated in V3 tokenization name: str | None = None + # Tool messages accept all content chunk types. + _allowed_content_chunks: ClassVar[tuple[type[BaseContentChunk], ...]] = ( + TextChunk, + ImageChunk, + ImageURLChunk, + AudioChunk, + AudioURLChunk, + ThinkChunk, + ) + def to_openai(self) -> dict[str, Any]: r"""Converts the message to the OpenAI format.""" assert self.tool_call_id is not None, "tool_call_id must be provided for tool messages." - return self.model_dump(exclude={"name"}) + return { + "role": self.role, + "tool_call_id": self.tool_call_id, + "content": self._content_to_openai(self.content), + } @classmethod - def from_openai(cls, messages: dict[str, str | list[dict[str, str | dict[str, Any]]]]) -> "ToolMessage": + def from_openai(cls, openai_message: dict[str, Any]) -> "ToolMessage": r"""Converts the OpenAI message to the Mistral format.""" - tool_message = cls.model_validate_ignore_extra(messages) - assert tool_message.tool_call_id is not None, "tool_call_id must be provided for tool messages." + content = cls._content_from_openai(openai_message["content"]) + tool_message = cls.model_validate( + { + "role": openai_message["role"], + "tool_call_id": openai_message["tool_call_id"], + "content": content, + "name": openai_message.get("name"), + } + ) return tool_message diff --git a/src/mistral_common/protocol/instruct/normalize.py b/src/mistral_common/protocol/instruct/normalize.py index c96a1940..c423cda0 100644 --- a/src/mistral_common/protocol/instruct/normalize.py +++ b/src/mistral_common/protocol/instruct/normalize.py @@ -1,18 +1,13 @@ import json import warnings -from typing import Generic, Sequence, TypeGuard +from typing import Generic, Sequence from typing_extensions import assert_never from mistral_common.exceptions import InvalidRequestException from mistral_common.protocol.instruct.chunk import ( - AudioChunk, - AudioURLChunk, ContentChunk, - ImageChunk, - ImageURLChunk, TextChunk, - ThinkChunk, ) from mistral_common.protocol.instruct.messages import ( UATS, @@ -98,18 +93,6 @@ def _flush_text() -> None: return all_content -def _is_assistant_content(chunks: list[ContentChunk]) -> TypeGuard[list[TextChunk | ThinkChunk]]: - """Narrow ContentChunk list to assistant-compatible types.""" - return all(isinstance(c, (TextChunk, ThinkChunk)) for c in chunks) - - -def _is_user_content( - chunks: list[ContentChunk], -) -> TypeGuard[list[TextChunk | ImageChunk | ImageURLChunk | AudioChunk | AudioURLChunk]]: - """Narrow ContentChunk list to user-compatible types.""" - return all(not isinstance(c, ThinkChunk) for c in chunks) - - class InstructRequestNormalizer( Generic[UserMessageType, AssistantMessageType, ToolMessageType, SystemMessageType, InstructRequestType] ): @@ -244,13 +227,17 @@ def _aggregate_system_prompts(self, messages: list[UATS]) -> str | None: def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: """Normalize tool messages without aggregation across messages. - Each tool message's content chunks are aggregated individually and JSON is normalized. + Each tool message's content is JSON-normalized; chunk types are guaranteed by the validator. """ tool_messages: list[ToolMessageType] = [] for message in messages: assert isinstance(message, self._tool_message_class), "Expected tool message" - content = self._aggregate_content_chunks_to_str_same_message(message) + content = self._aggregate_content_chunks([message]) + assert isinstance(content, str), ( + f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}" + ) normalized_content = self._normalize_json_content(content) + tool_messages.append( self._tool_message_class( content=normalized_content, tool_call_id=message.tool_call_id, name=message.name @@ -296,15 +283,8 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag ) weight = message.weight - if isinstance(content, str) or _is_assistant_content(content): - narrowed_content: str | list[TextChunk | ThinkChunk] = content - else: - raise InvalidRequestException( - f"Unexpected content chunk types in assistant message: {[type(c).__name__ for c in content]}" - ) - aggregated_message = self._assistant_message_class( - content=narrowed_content, + content=content, tool_calls=tool_calls or None, prefix=prefix, ) @@ -316,12 +296,7 @@ def _aggregate_assistant_messages(self, messages: list[UATS]) -> AssistantMessag def _aggregate_user_messages(self, messages: list[UATS]) -> UserMessageType: """Coalesce neighboring blocks of ContentChunks in user messages.""" content = self._aggregate_content_chunks(messages) - if isinstance(content, str) or _is_user_content(content): - return self._user_message_class(content=content) - else: - raise InvalidRequestException( - f"Unexpected content chunk types in user message: {[type(c).__name__ for c in content]}" - ) + return self._user_message_class(content=content) def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]: if role == Roles.tool: @@ -444,12 +419,31 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest[UATS, Tool], None ) + def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: + """Normalize tool messages without JSON normalization. + + V7+ normalizers skip JSON content normalization for tool messages (chunk-type validation is + handled by the validator). + """ + tool_messages: list[ToolMessageType] = [] + for message in messages: + assert isinstance(message, self._tool_message_class), "Expected tool message" + content = self._aggregate_content_chunks([message]) + assert isinstance(content, str), ( + f"Unexpected content chunk types in tool message: {[type(c).__name__ for c in content]}" + ) + tool_messages.append( + self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name) + ) + return tool_messages + def _aggregate_system_messages(self, messages: list[UATS]) -> list[SystemMessageType]: - return [ - self._system_message_class(content=self._aggregate_content_chunks([message])) - for message in messages - if isinstance(message, self._system_message_class) - ] + aggregated: list[SystemMessageType] = [] + for message in messages: + if isinstance(message, self._system_message_class): + content = self._aggregate_content_chunks([message]) + aggregated.append(self._system_message_class(content=content)) + return aggregated def _aggregate_role(self, messages: list[UATS], role: Roles | None, latest_call_ids: list[str]) -> Sequence[UATS]: if role == Roles.tool: @@ -530,8 +524,8 @@ def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "I UserMessage, AssistantMessage, ToolMessage, SystemMessage, InstructRequest[UATS, Tool], None ) - def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: - tool_messages: list[ToolMessageType] = super()._aggregate_tool_messages(messages, latest_call_ids) + @staticmethod + def _inplace_sort_tool_messages(tool_messages: list[ToolMessageType], latest_call_ids: list[str]) -> None: id_to_tool_call_idx = {call_id: idx for idx, call_id in enumerate(latest_call_ids)} id_to_tool_result_idx = {message.tool_call_id: idx for idx, message in enumerate(tool_messages)} # First order by tool call idx and then by tool result idx @@ -541,6 +535,10 @@ def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[s id_to_tool_result_idx[msg.tool_call_id], ), ) + + def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: + tool_messages: list[ToolMessageType] = super()._aggregate_tool_messages(messages, latest_call_ids) + self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids) return tool_messages @@ -555,6 +553,18 @@ class InstructRequestNormalizerV15(InstructRequestNormalizerV13): _chunk_join_str: str = "" + def _aggregate_tool_messages(self, messages: list[UATS], latest_call_ids: list[str]) -> list[ToolMessageType]: + r"""V15 keeps all aggregated tool content (validation handled by the validator).""" + tool_messages: list[ToolMessageType] = [] + for message in messages: + assert isinstance(message, self._tool_message_class), "Expected tool message" + content = self._aggregate_content_chunks([message]) + tool_messages.append( + self._tool_message_class(content=content, tool_call_id=message.tool_call_id, name=message.name) + ) + self._inplace_sort_tool_messages(tool_messages=tool_messages, latest_call_ids=latest_call_ids) + return tool_messages + @staticmethod def normalizer(model_settings_builder: ModelSettingsBuilder | None = None) -> "InstructRequestNormalizerV15": r"""Returns a normalizer for the V15 instruct request. diff --git a/src/mistral_common/protocol/instruct/validator.py b/src/mistral_common/protocol/instruct/validator.py index f121e49e..e15d0935 100644 --- a/src/mistral_common/protocol/instruct/validator.py +++ b/src/mistral_common/protocol/instruct/validator.py @@ -1,4 +1,5 @@ import re +from collections.abc import Sequence from enum import Enum from typing import Generic @@ -14,8 +15,18 @@ InvalidToolException, InvalidToolMessageException, InvalidToolSchemaException, + InvalidUserMessageException, + MistralCommonException, +) +from mistral_common.protocol.instruct.chunk import ( + AudioChunk, + AudioURLChunk, + ContentChunk, + ImageChunk, + ImageURLChunk, + TextChunk, + ThinkChunk, ) -from mistral_common.protocol.instruct.chunk import AudioChunk, AudioURLChunk from mistral_common.protocol.instruct.messages import ( UATS, AssistantMessage, @@ -38,6 +49,23 @@ from mistral_common.tokens.tokenizers.base import TokenizerVersion +def _validate_content_chunk_types( + content: str | Sequence[ContentChunk] | None, + allowed: tuple[type[ContentChunk], ...], + role: str, + exception_cls: type[MistralCommonException], +) -> None: + r"""Raise if any chunk in a list content is not an instance of `allowed`. + + String or None content is always accepted (covered elsewhere). + """ + if not isinstance(content, list): + return + invalid = sorted({type(chunk).__name__ for chunk in content if not isinstance(chunk, allowed)}) + if invalid: + raise exception_cls(f"Unexpected content chunk types in {role} message: {invalid}") + + class ValidationMode(str, Enum): r"""Enum for the validation mode. @@ -153,13 +181,30 @@ def _validate_tools(self, tools: list[Tool]) -> None: self._validate_function(tool.function) def _validate_user_message(self, message: UserMessageType) -> None: - pass + self._validate_user_content_chunks(message.content) + + def _validate_user_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: + r"""v1/v2 user messages accept text content only (image >= v3, audio >= v7).""" + _validate_content_chunk_types(content, (TextChunk,), "user", InvalidUserMessageException) + + def _validate_assistant_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: + r"""Pre-v11 assistant messages accept text content only.""" + _validate_content_chunk_types(content, (TextChunk,), "assistant", InvalidAssistantMessageException) + + def _validate_system_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: + r"""v1-v3 system messages accept text content only.""" + _validate_content_chunk_types(content, (TextChunk,), "system", InvalidSystemPromptException) + + def _validate_tool_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: + r"""Pre-v15 tool messages accept text content only.""" + _validate_content_chunk_types(content, (TextChunk,), "tool", InvalidToolMessageException) def _validate_tool_message(self, message: ToolMessageType) -> None: """ Checks: - The tool name is valid """ + self._validate_tool_content_chunks(message.content) if message.name is not None: if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", message.name): raise InvalidToolMessageException( @@ -174,6 +219,7 @@ def _validate_system_message(self, message: SystemMessageType) -> None: """ if message.content is None: raise InvalidSystemPromptException("System prompt must have content") + self._validate_system_content_chunks(message.content) def _validate_function_call(self, function_call: FunctionCall) -> None: """ @@ -201,6 +247,8 @@ def _validate_assistant_message(self, message: AssistantMessageType, is_last_mes - That the tool calls are valid """ + self._validate_assistant_content_chunks(message.content) + # Validate that the message has either text or tool_calls # but not both and not neither. if (not self._allow_tool_call_and_content) and (bool(message.content) == bool(message.tool_calls)): @@ -359,18 +407,19 @@ class MistralRequestValidatorV3(MistralRequestValidator): >>> validator = MistralRequestValidatorV3() """ + def _validate_user_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: + r"""v3 user messages accept text and image chunks (audio >= v7).""" + _validate_content_chunk_types( + content, (TextChunk, ImageChunk, ImageURLChunk), "user", InvalidUserMessageException + ) + def _validate_tool_message(self, message: ToolMessageType) -> None: """ Checks: - The tool name is valid - Tool call id is valid """ - if message.name is not None: - if not re.match(r"^[a-zA-Z0-9_-]{1,64}$", message.name): - raise InvalidToolMessageException( - f"Function name was {message.name} but must be a-z, A-Z, 0-9, " - "or contain underscores and dashes, with a maximum length of 64." - ) + super()._validate_tool_message(message) if message.tool_call_id is None: raise InvalidRequestException("Tool call id has to be defined.") @@ -423,6 +472,21 @@ class MistralRequestValidatorV5(MistralRequestValidatorV3): _allow_tool_call_and_content: bool = True + def _validate_user_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: + r"""v7+ user messages accept text, image and audio chunks.""" + _validate_content_chunk_types( + content, + (TextChunk, ImageChunk, ImageURLChunk, AudioChunk, AudioURLChunk), + "user", + InvalidUserMessageException, + ) + + def _validate_system_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: + r"""v7+ system messages accept text, audio and thinking chunks.""" + _validate_content_chunk_types( + content, (TextChunk, AudioChunk, ThinkChunk), "system", InvalidSystemPromptException + ) + def _validate_system_prompt_and_audio(self, messages: list[UATS]) -> None: r"""Validates that system prompts and audio chunks are not used together in v5.""" @@ -507,12 +571,24 @@ def _validate_tool_calls_followed_by_tool_messages(self, messages: list[UATS]) - elif len(expected_tool_ids) < len(observed_tool_ids) and self._mode == ValidationMode.finetuning: raise InvalidMessageStructureException("More tool responses than tool calls") + def _validate_assistant_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: + r"""v11+ assistant messages accept text and thinking chunks.""" + _validate_content_chunk_types(content, (TextChunk, ThinkChunk), "assistant", InvalidAssistantMessageException) + def _validate_system_prompt_and_audio(self, messages: list[UATS]) -> None: r"""Allows system prompts and audio chunks to coexist in v13.""" return class MistralRequestValidatorV15(MistralRequestValidatorV13): + def _validate_system_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: + r"""v15 system messages accept text and audio but reject thinking chunks.""" + _validate_content_chunk_types(content, (TextChunk, AudioChunk), "system", InvalidSystemPromptException) + + def _validate_tool_content_chunks(self, content: str | Sequence[ContentChunk] | None) -> None: + r"""v15 tool messages accept all content chunk types.""" + return + def _validate_model_settings(self, request: ChatCompletionRequest) -> None: pass diff --git a/src/mistral_common/tokens/tokenizers/base.py b/src/mistral_common/tokens/tokenizers/base.py index 1fe580fa..06dff715 100644 --- a/src/mistral_common/tokens/tokenizers/base.py +++ b/src/mistral_common/tokens/tokenizers/base.py @@ -9,7 +9,7 @@ from mistral_common.base import MistralBase from mistral_common.protocol.fim.request import FIMRequest -from mistral_common.protocol.instruct.chunk import UserContentChunk +from mistral_common.protocol.instruct.chunk import ContentChunk from mistral_common.protocol.instruct.messages import ( AssistantMessageType, UserMessage, @@ -428,7 +428,7 @@ def encode_user_message( @abstractmethod def encode_user_content( self, - content: str | list[UserContentChunk], + content: str | list[ContentChunk], is_last: bool, system_prompt: str | None = None, force_img_first: bool = False, diff --git a/src/mistral_common/tokens/tokenizers/instruct.py b/src/mistral_common/tokens/tokenizers/instruct.py index 7992df1c..7d0e5ab8 100644 --- a/src/mistral_common/tokens/tokenizers/instruct.py +++ b/src/mistral_common/tokens/tokenizers/instruct.py @@ -20,7 +20,6 @@ ImageURLChunk, TextChunk, ThinkChunk, - UserContentChunk, ) from mistral_common.protocol.instruct.messages import ( UATS, @@ -104,7 +103,9 @@ def find_first_last_user(request: InstructRequest) -> tuple[int, int]: return first_user_idx, last_user_idx @abstractmethod - def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]: + def encode_tool_message( + self, message: ToolMessage, is_before_last_user_message: bool + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a tool message. Raises: @@ -192,7 +193,9 @@ def encode_instruct( images.extend(new_images) audios.extend(new_audios) elif isinstance(msg, ToolMessage): - new_tokens = self.encode_tool_message(msg, msg_idx < last_user_idx) + new_tokens, new_images, new_audios = self.encode_tool_message(msg, msg_idx < last_user_idx) + images.extend(new_images) + audios.extend(new_audios) elif isinstance(msg, AssistantMessage): continue_message = request.continue_final_message and (msg_idx == len(request.messages) - 1) @@ -202,7 +205,8 @@ def encode_instruct( if msg_idx == len(request.messages) - 1: prefix_ids = new_tokens elif isinstance(msg, SystemMessage): - new_tokens = self.encode_system_message(msg) + new_tokens, new_audios = self.encode_system_message(msg) + audios.extend(new_audios) else: raise TokenizerException(f"Unknown message type {type(msg)}") @@ -289,12 +293,12 @@ def encode_user_message( curr_tokens, image, audio = self.encode_user_content(content=message_txt, is_last=False, system_prompt=None) return curr_tokens, image, audio - def encode_system_message(self, message: SystemMessage) -> list[int]: + def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[Audio]]: raise NotImplementedError(f"System message encoding not implemented for {self.__class__.__name__}") def encode_user_content( self, - content: str | list[UserContentChunk], + content: str | list[ContentChunk], is_last: bool, system_prompt: str | None = None, force_img_first: bool = False, @@ -318,7 +322,9 @@ def encode_user_content( tokens = self.tokenizer.encode(content, bos=False, eos=False) return tokens, [], [] - def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]: + def encode_tool_message( + self, message: ToolMessage, is_before_last_user_message: bool + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a tool message. Raises: @@ -480,9 +486,15 @@ def _parse_json_content(self, content: str) -> Any: except json.JSONDecodeError: return content - def _parse_tool_content(self, content: str | list[TextChunk]) -> Any: + def _parse_tool_content(self, content: str | list[ContentChunk]) -> Any: if isinstance(content, list): - content = "".join(chunk.text for chunk in content) + text_parts: list[str] = [] + for chunk in content: + assert isinstance(chunk, TextChunk), ( + f"Tool content only supports text chunks, got {type(chunk).__name__}." + ) + text_parts.append(chunk.text) + content = "".join(text_parts) return self._parse_json_content(content) def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]: @@ -492,7 +504,9 @@ def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]: "content": self._parse_tool_content(tool_message.content), } - def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]: + def encode_tool_message( + self, message: ToolMessage, is_before_last_user_message: bool + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a tool message. Args: @@ -501,11 +515,11 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: not encoded. Returns: - The encoded tokens. + The encoded tokens, images, and audios. """ if is_before_last_user_message: # don't tokenize last tool response before last user msg - return [] + return [], [], [] # Currently only supports single tool results tool_result_str = json.dumps([self._prepare_tool_result(message)], ensure_ascii=False) @@ -514,7 +528,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: *self.tokenizer.encode(tool_result_str, bos=False, eos=False), self.END_TOOL_RESULTS, ] - return curr_tokens + return curr_tokens, [], [] def _prepare_function_call(self, tool_call: ToolCall) -> dict[str, Any]: r"""Bit of a hack due to the way function calls are tokenized.""" @@ -652,7 +666,9 @@ def _prepare_tool_result(self, tool_message: ToolMessage) -> dict[str, Any]: "call_id": tool_message.tool_call_id, } - def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]: + def encode_tool_message( + self, message: ToolMessage, is_before_last_user_message: bool + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a tool message. Note: @@ -665,7 +681,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: not encoded. Returns: - The encoded tokens. + The encoded tokens, images, and audios. """ tool_result_str = json.dumps(self._prepare_tool_result(message), ensure_ascii=False) curr_tokens = [ @@ -673,7 +689,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: *self.tokenizer.encode(tool_result_str, bos=False, eos=False), self.END_TOOL_RESULTS, ] - return curr_tokens + return curr_tokens, [], [] def encode_assistant_message( self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool @@ -741,7 +757,7 @@ def _encode_content_chunks( def encode_user_content( self, - content: str | list[UserContentChunk], + content: str | list[ContentChunk], is_last: bool, system_prompt: str | None = None, force_img_first: bool = False, @@ -781,13 +797,18 @@ def encode_user_content( assert not content_str, ( f"It is not possible that `content` is non-empty when chunk is of type {type(chunk)}." ) - chunk_tokens, _, chunk_audio = self._encode_content_chunk(chunk) + chunk_tokens, maybe_image, chunk_audio = self._encode_content_chunk(chunk) + assert maybe_image is None, f"Unexpected image for audio chunk {type(chunk).__name__}." audio.append(chunk_audio) elif isinstance(chunk, (ImageChunk, ImageURLChunk)): - chunk_tokens, chunk_image, _ = self._encode_content_chunk(chunk) + chunk_tokens, chunk_image, maybe_audio = self._encode_content_chunk(chunk) + assert maybe_audio is None, f"Unexpected audio for image chunk {type(chunk).__name__}." images.append(chunk_image) else: - chunk_tokens = self._encode_content_chunk(chunk)[0] + chunk_tokens, maybe_image, maybe_audio = self._encode_content_chunk(chunk) + assert maybe_image is None and maybe_audio is None, ( + f"Unexpected image/audio for chunk {type(chunk).__name__}." + ) tokens.extend(chunk_tokens) return tokens, images, audio @@ -867,26 +888,27 @@ def drop(idx: int) -> None: if to_drop > 0: raise TokenizerException("Input couldn't fit in truncate_at_max_token") - def encode_system_message(self, message: SystemMessage) -> list[int]: + def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[Audio]]: r"""Encode a system message. Args: message: The message to encode. Returns: - The encoded tokens. + The encoded tokens and audios. """ - tokens = [self.BEGIN_SYSTEM] if isinstance(content := message.content, str): content = [TextChunk(text=content)] - tokens += self._encode_content_chunks(content)[0] + content_tokens, images, audios = self._encode_content_chunks(content) + assert not images, f"System messages cannot contain images, got {len(images)}." + tokens += content_tokens tokens.append(self.END_SYSTEM) - return tokens + return tokens, audios def encode_user_content( self, - content: str | list[UserContentChunk], + content: str | list[ContentChunk], is_last: bool, system_prompt: str | None = None, force_img_first: bool = False, @@ -983,7 +1005,7 @@ def _encode_instruct_transcription(self, request: TranscriptionRequest) -> Token ) assert self.TRANSCRIBE is not None, f"{self.__class__.__name__} needs to have a TRANSCRIBE token" prefix = self.start() - tokens, _, audio = self.encode_user_message( + tokens, images, audio = self.encode_user_message( UserMessage(content=[AudioChunk(input_audio=request.audio)]), available_tools=[], is_last=True, @@ -991,6 +1013,7 @@ def _encode_instruct_transcription(self, request: TranscriptionRequest) -> Token system_prompt=None, settings=ModelSettings.none(), ) + assert not images, f"Transcription input cannot contain images, got {len(images)}." tokens = [*prefix, *tokens] if request.language is not None: @@ -1081,7 +1104,9 @@ def _has_audio(messages: list[UATS]) -> bool: for message in messages ) - def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]: + def encode_tool_message( + self, message: ToolMessage, is_before_last_user_message: bool + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a tool message. Note: @@ -1093,7 +1118,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: is_before_last_user_message: Not used. Returns: - The encoded tokens. + The encoded tokens, images, and audios. """ assert message.tool_call_id is not None assert isinstance(message.content, str), "Message content must be normalized" @@ -1110,7 +1135,7 @@ def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: *tokens, self.END_TOOL_RESULTS, ] - return curr_tokens + return curr_tokens, [], [] def encode_assistant_message( self, message: AssistantMessageType, is_before_last_user_message: bool, continue_message: bool @@ -1138,7 +1163,12 @@ def encode_assistant_message( if isinstance(message.content, str): curr_tokens = self._encode_normal_content_assistant_message(message) elif isinstance(message.content, list): - curr_tokens += self._encode_content_chunks(message.content)[0] + content_tokens, images, audios = self._encode_content_chunks(message.content) + assert not images and not audios, ( + f"Assistant messages cannot contain images or audios, got {len(images)} images " + f"and {len(audios)} audios." + ) + curr_tokens += content_tokens if message.tool_calls: curr_tokens += self._encode_tool_calls_in_assistant_message(message) if not message.prefix and not continue_message: @@ -1282,28 +1312,32 @@ def _encode_tool_calls_in_assistant_message(self, message: AssistantMessageType) ] return curr_tokens - def encode_tool_message(self, message: ToolMessage, is_before_last_user_message: bool) -> list[int]: + def encode_tool_message( + self, message: ToolMessage, is_before_last_user_message: bool + ) -> tuple[list[int], list[np.ndarray], list[Audio]]: r"""Encode a tool message. Args: message: The message to encode. is_before_last_user_message: Not used. Returns: - The encoded tokens. + The encoded tokens, images, and audios. """ assert message.tool_call_id is not None, "Tool call id must be provided for tokenizer >= v13" - content = message.content - if not isinstance(content, str): - content = "".join(chunk.text for chunk in content) + if isinstance(message.content, str): + content_tokens = self.tokenizer.encode(message.content, bos=False, eos=False) + images: list[np.ndarray] = [] + audios: list[Audio] = [] + else: + content_tokens, images, audios = self._encode_content_chunks(message.content) - tokens = self.tokenizer.encode(content, bos=False, eos=False) curr_tokens = [ self.BEGIN_TOOL_RESULTS, - *tokens, + *content_tokens, self.END_TOOL_RESULTS, ] - return curr_tokens + return curr_tokens, images, audios def encode_think(self, chunk: ThinkChunk) -> list[int]: r"""Encode a thinking chunk. @@ -1368,7 +1402,7 @@ def _encode_settings( ] return settings_tokens - def encode_system_message(self, message: SystemMessage) -> list[int]: + def encode_system_message(self, message: SystemMessage) -> tuple[list[int], list[Audio]]: r"""Encode a system message, rejecting ThinkChunk content.""" if isinstance(message.content, list): if any(isinstance(chunk, ThinkChunk) for chunk in message.content): diff --git a/tests/data/chat_templates/v15.jinja b/tests/data/chat_templates/v15.jinja index 26b4f9d2..e4df96ef 100644 --- a/tests/data/chat_templates/v15.jinja +++ b/tests/data/chat_templates/v15.jinja @@ -180,9 +180,11 @@ {{- eos_token }} - {#- Tool messages only supports text content. #} + {#- Tool messages (multimodal). #} {%- elif message['role'] == 'tool' %} - {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }} + {{- '[TOOL_RESULTS]' -}} + {{- render_content(message['content'], 'tool message contents') -}} + {{- '[/TOOL_RESULTS]' }} {#- System messages. #} {%- elif message['role'] == 'system' %} diff --git a/tests/data/chat_templates/v15_audio.jinja b/tests/data/chat_templates/v15_audio.jinja index 68ce5346..f3c90f1c 100644 --- a/tests/data/chat_templates/v15_audio.jinja +++ b/tests/data/chat_templates/v15_audio.jinja @@ -177,14 +177,16 @@ {{- eos_token }} - {#- Tool messages only supports text content. #} + {#- Tool messages (multimodal). #} {%- elif message['role'] == 'tool' %} - {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }} + {{- '[TOOL_RESULTS]' -}} + {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and audio', support_audio=true) -}} + {{- '[/TOOL_RESULTS]' }} {#- System messages. #} {%- elif message['role'] == 'system' %} {{- '[SYSTEM_PROMPT]' -}} - {{- render_content(message['content'], 'system message contents', supported_types_desc='text', support_audio=false) -}} + {{- render_content(message['content'], 'system message contents', supported_types_desc='text and audio', support_audio=true) -}} {{- '[/SYSTEM_PROMPT]' -}} {#- Raise exception for unsupported roles. #} diff --git a/tests/data/chat_templates/v15_image.jinja b/tests/data/chat_templates/v15_image.jinja index bb338233..bfbfdd17 100644 --- a/tests/data/chat_templates/v15_image.jinja +++ b/tests/data/chat_templates/v15_image.jinja @@ -188,9 +188,11 @@ {{- eos_token }} - {#- Tool messages only supports text content. #} + {#- Tool messages (multimodal). #} {%- elif message['role'] == 'tool' %} - {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }} + {{- '[TOOL_RESULTS]' -}} + {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and image', support_images=true) -}} + {{- '[/TOOL_RESULTS]' }} {#- System messages. #} {%- elif message['role'] == 'system' %} diff --git a/tests/data/chat_templates/v15_image_think.jinja b/tests/data/chat_templates/v15_image_think.jinja index 63c61f6b..e1f26e8f 100644 --- a/tests/data/chat_templates/v15_image_think.jinja +++ b/tests/data/chat_templates/v15_image_think.jinja @@ -215,9 +215,11 @@ {{- eos_token }} - {#- Tool messages only supports text content. #} + {#- Tool messages (multimodal). #} {%- elif message['role'] == 'tool' %} - {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }} + {{- '[TOOL_RESULTS]' -}} + {{- render_content(message['content'], 'tool message contents', supported_types_desc='text and image', support_images=true) -}} + {{- '[/TOOL_RESULTS]' }} {#- System messages. #} {%- elif message['role'] == 'system' %} diff --git a/tests/data/chat_templates/v15_think.jinja b/tests/data/chat_templates/v15_think.jinja index 7df9072a..40b8a7fc 100644 --- a/tests/data/chat_templates/v15_think.jinja +++ b/tests/data/chat_templates/v15_think.jinja @@ -207,9 +207,11 @@ {{- eos_token }} - {#- Tool messages only supports text content. #} + {#- Tool messages (multimodal). #} {%- elif message['role'] == 'tool' %} - {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }} + {{- '[TOOL_RESULTS]' -}} + {{- render_content(message['content'], 'tool message contents') -}} + {{- '[/TOOL_RESULTS]' }} {#- System messages. #} {%- elif message['role'] == 'system' %} diff --git a/tests/fixtures/chunks.py b/tests/fixtures/chunks.py new file mode 100644 index 00000000..289765f5 --- /dev/null +++ b/tests/fixtures/chunks.py @@ -0,0 +1,42 @@ +from typing import Any + +from PIL import Image + +from mistral_common.protocol.instruct.chunk import ( + ContentChunk, + ImageChunk, + ImageURLChunk, + TextChunk, + ThinkChunk, +) +from tests.fixtures.audio import get_dummy_audio_chunk, get_dummy_audio_url_chunk + + +def get_content_chunk(name: str) -> ContentChunk: + r"""Return a single instance of the requested content chunk type. + + Args: + name: One of "text", "image", "image_url", "audio", "audio_url" or "think". + + Returns: + A content chunk of the requested type. + """ + chunks: dict[str, ContentChunk] = { + "text": TextChunk(text="hello"), + "image": ImageChunk(image=Image.new("RGB", (4, 4), "red")), + "image_url": ImageURLChunk(image_url="data:image/png;base64,iVBORw0"), + "audio": get_dummy_audio_chunk(), + "audio_url": get_dummy_audio_url_chunk(), + "think": ThinkChunk(thinking="reasoning"), + } + return chunks[name] + + +def get_content_chunks(names: tuple[str, ...]) -> list[Any]: + r"""Return a list of content chunks for the requested type names. + + The element type is intentionally `Any` so the chunks can be fed to any message role + constructor (including deliberately invalid combinations) in tests, leaving the runtime + Pydantic validation to enforce the actual rules. + """ + return [get_content_chunk(name) for name in names] diff --git a/tests/integrations/chat_templates/fixtures_data.py b/tests/integrations/chat_templates/fixtures_data.py index 6f0e6e55..b396f457 100644 --- a/tests/integrations/chat_templates/fixtures_data.py +++ b/tests/integrations/chat_templates/fixtures_data.py @@ -952,6 +952,75 @@ ] ) +# -- Multimodal content in non-user messages (v15+) -- + +REQUEST_TOOL_IMAGE_TRAIN = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + UserMessage(content="What is in this image?"), + AssistantMessage( + content=None, + tool_calls=[ + ToolCall( + id="tl1mg2345", + function=FunctionCall( + name="tool1", + arguments={"location": "San Francisco, CA"}, # type: ignore[arg-type] + ), + ), + ], + ), + ToolMessage( + content=[ + TextChunk(text="Here is the result."), + ImageURLChunk(image_url=_IMAGE_URL), + ], + tool_call_id="tl1mg2345", + ), + AssistantMessage(content="The tool returned an image of a red square."), + ], + tools=_TOOLS, +) + +REQUEST_TOOL_AUDIO_TRAIN = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + UserMessage(content="What does this sound like?"), + AssistantMessage( + content=None, + tool_calls=[ + ToolCall( + id="tl1mg2345", + function=FunctionCall( + name="tool1", + arguments={"location": "San Francisco, CA"}, # type: ignore[arg-type] + ), + ), + ], + ), + ToolMessage( + content=[ + TextChunk(text="Here is the audio result."), + AudioChunk(input_audio=_AUDIO), + ], + tool_call_id="tl1mg2345", + ), + AssistantMessage(content="The tool returned audio of a sine wave."), + ], + tools=_TOOLS, +) + +REQUEST_SYSTEM_AUDIO_TRAIN = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + SystemMessage( + content=[ + TextChunk(text="You are an audio assistant."), + AudioChunk(input_audio=_AUDIO), + ], + ), + UserMessage(content="What was that sound?"), + AssistantMessage(content="That was a sine wave tone."), + ], +) + def _get_conversations( tokenizer_version: TokenizerVersion, @@ -1102,6 +1171,18 @@ def _get_conversations( ] ) + # v15+ only: multimodal content in non-user messages (finetuning only) + if tokenizer_version >= TokenizerVersion.v15 and validation_mode == ValidationMode.finetuning: + if image: + conversations.extend( + [ + REQUEST_TOOL_IMAGE_TRAIN, + ] + ) + if audio: + conversations.append(REQUEST_SYSTEM_AUDIO_TRAIN) + conversations.append(REQUEST_TOOL_AUDIO_TRAIN) + conversations = [c.model_copy(deep=True) for c in conversations] if think and tokenizer_version >= TokenizerVersion.v15: diff --git a/tests/integrations/chat_templates/hf_utils.py b/tests/integrations/chat_templates/hf_utils.py index 9e8133d2..6d071b3d 100644 --- a/tests/integrations/chat_templates/hf_utils.py +++ b/tests/integrations/chat_templates/hf_utils.py @@ -1,7 +1,7 @@ r"""HuggingFace transformers/tokenizers utilities for chat template tests. -This module contains helpers that depend on ``transformers`` and -``tokenizers``. General-purpose helpers live in ``helpers.py``. +This module contains helpers that depend on `transformers` and +`tokenizers`. General-purpose helpers live in `helpers.py`. """ import json diff --git a/tests/integrations/chat_templates/test_parity.py b/tests/integrations/chat_templates/test_parity.py index 5394d5cb..56d79c5e 100644 --- a/tests/integrations/chat_templates/test_parity.py +++ b/tests/integrations/chat_templates/test_parity.py @@ -3,8 +3,11 @@ import pytest from mistral_common.integrations.chat_templates.template_generator import build_chat_template +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.protocol.instruct.validator import ValidationMode from mistral_common.tokens.tokenizers.base import TokenizerVersion from tests.integrations.chat_templates.conftest import ALL_CONFIGS, _config_id +from tests.integrations.chat_templates.fixtures_data import _get_conversations from tests.integrations.chat_templates.helpers import ( TestConfig, _load_golden_template, @@ -13,6 +16,42 @@ ) +def _request_to_render_args(request: ChatCompletionRequest) -> dict[str, Any]: + r"""Convert a ChatCompletionRequest to render_template kwargs. + + Enriches tool messages with the `name` field resolved from the preceding + assistant's `tool_calls`, which the v2 Jinja template requires but + `ToolMessage.to_openai()` omits. + """ + openai = request.to_openai() + messages = openai["messages"] + + # Build a mapping from tool_call_id -> function name so tool messages + # can carry the `name` field expected by v2 templates. + tool_call_names: dict[str, str] = {} + for msg in messages: + for tc in msg.get("tool_calls", []): + tc_id = tc.get("id") + fn_name = tc.get("function", {}).get("name") + if tc_id and fn_name: + tool_call_names[tc_id] = fn_name + + for msg in messages: + if msg.get("role") == "tool" and "name" not in msg: + tc_id = msg.get("tool_call_id") + if tc_id and tc_id in tool_call_names: + msg["name"] = tool_call_names[tc_id] + + kwargs: dict[str, Any] = { + "messages": messages, + } + if "tools" in openai and openai["tools"]: + kwargs["tools"] = openai["tools"] + if (reasoning := openai.get("reasoning_effort")) is not None: + kwargs["reasoning_effort"] = reasoning + return kwargs + + @pytest.mark.parametrize( "config", [c for c in ALL_CONFIGS if not c.plain_think], @@ -64,399 +103,14 @@ def test_dynamic_template_comprehensive(config: TestConfig) -> None: static_template = _load_golden_template(template_config) dynamic_template = build_chat_template(template_config) - # Test cases - test_cases = [ - # Simple one-turn - { - "name": "one_turn", - "messages": [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - ], - }, - # With system message - { - "name": "with_system", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - ], - }, - # Multi-turn - { - "name": "multi_turn", - "messages": [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - {"role": "user", "content": "How are you?"}, - {"role": "assistant", "content": "I'm doing well!"}, - ], - }, - # Content as list of chunks - { - "name": "content_chunks", - "messages": [ - {"role": "user", "content": [{"type": "text", "text": "Hello"}]}, - {"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]}, - ], - }, - ] - - # Tool call scenarios (v2+ only) - if config.version > TokenizerVersion.v1: - test_cases.extend( - [ - { - "name": "with_tools_definition", - "messages": [ - {"role": "user", "content": "What's the weather?"}, - {"role": "assistant", "content": "It's sunny."}, - ], - "tools": [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get weather", - "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, - }, - } - ], - }, - { - "name": "with_tool_call_and_result", - "messages": [ - {"role": "user", "content": "What's the weather in Paris?"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "abc123def", - "type": "function", - "function": {"name": "get_weather", "arguments": '{"city": "Paris"}'}, - } - ], - }, - {"role": "tool", "content": '{"temp": 20}', "tool_call_id": "abc123def", "name": "get_weather"}, - {"role": "assistant", "content": "It's 20 degrees in Paris."}, - ], - "tools": [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get weather", - "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, - }, - } - ], - }, - ] - ) - - # Add image test case if image support - if config.image: - test_cases.append( - { - "name": "with_image", - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What is this?"}, - {"type": "image_url", "image_url": "http://example.com/image.png"}, - ], - }, - {"role": "assistant", "content": "It's an image."}, - ], - } - ) - - # Add audio test case if audio support - if config.audio: - test_cases.append( - { - "name": "with_audio", - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What is this?"}, - {"type": "audio_url", "audio_url": "http://example.com/audio.mp3"}, - ], - }, - {"role": "assistant", "content": "It's an audio file."}, - ], - } - ) - - # Add thinking test case if thinking support - if config.think: - test_cases.append( - { - "name": "with_thinking", - "messages": [ - {"role": "user", "content": "Solve this problem"}, - { - "role": "assistant", - "content": [ - {"type": "thinking", "thinking": "Let me think..."}, - {"type": "text", "text": "The answer is 42."}, - ], - }, - ], - } - ) - - # Add message aggregation test cases - test_cases.extend( - [ - { - "name": "consecutive_users", - "messages": [ - {"role": "user", "content": "Hello"}, - {"role": "user", "content": "World"}, - {"role": "assistant", "content": "Hi there"}, - ], - }, - { - "name": "consecutive_users_with_system", - "messages": [ - {"role": "system", "content": "You are helpful."}, - {"role": "user", "content": "Hello"}, - {"role": "user", "content": "World"}, - {"role": "assistant", "content": "Hi there"}, - ], - }, - { - "name": "consecutive_assistants", - "messages": [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi"}, - {"role": "assistant", "content": "How can I help?"}, - {"role": "user", "content": "Thanks"}, - {"role": "assistant", "content": "Welcome"}, - ], - }, - { - "name": "multiple_systems", - "messages": [ - {"role": "system", "content": "System 1."}, - {"role": "system", "content": "System 2."}, - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi"}, - ], - }, - { - "name": "mid_conv_system", - "messages": [ - {"role": "user", "content": "Hello"}, - {"role": "system", "content": "New instruction."}, - {"role": "assistant", "content": "Got it"}, - ], - }, - { - "name": "mid_conv_system_with_consecutive_users", - "messages": [ - {"role": "system", "content": "Be helpful."}, - {"role": "user", "content": "Hello"}, - {"role": "user", "content": "World"}, - {"role": "system", "content": "Now be concise."}, - {"role": "assistant", "content": "Got it"}, - ], - }, - ] - ) - - # Multi-chunk aggregation test cases - test_cases.extend( - [ - { - "name": "consecutive_users_text_chunks", - "messages": [ - {"role": "user", "content": "First as string"}, - {"role": "user", "content": [{"type": "text", "text": "Second as chunk"}]}, - { - "role": "user", - "content": [ - {"type": "text", "text": "Third part A"}, - {"type": "text", "text": "Third part B"}, - ], - }, - {"role": "assistant", "content": "Response"}, - ], - }, - { - "name": "system_text_chunks", - "messages": [ - { - "role": "system", - "content": [ - {"type": "text", "text": "You are helpful."}, - {"type": "text", "text": "Be concise."}, - ], - }, - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi"}, - ], - }, - ] - ) - - if config.image: - test_cases.extend( - [ - { - "name": "consecutive_users_with_image", - "messages": [ - {"role": "user", "content": "What is this?"}, - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": "http://example.com/image.png"}, - {"type": "text", "text": "Describe it"}, - ], - }, - {"role": "assistant", "content": "It's an image."}, - ], - }, - { - "name": "consecutive_users_multi_image", - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Describe this"}, - {"type": "image_url", "image_url": "http://example.com/a.png"}, - {"type": "text", "text": "What color?"}, - ], - }, - { - "role": "user", - "content": [ - {"type": "text", "text": "Also this"}, - {"type": "image_url", "image_url": "http://example.com/b.png"}, - {"type": "text", "text": "What shape?"}, - ], - }, - {"role": "assistant", "content": "Both are red squares."}, - ], - }, - ] - ) - - if config.audio: - test_cases.append( - { - "name": "consecutive_users_multi_audio", - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Listen"}, - {"type": "audio_url", "audio_url": "http://example.com/a.wav"}, - {"type": "text", "text": "What language?"}, - ], - }, - { - "role": "user", - "content": [ - {"type": "text", "text": "And this"}, - {"type": "audio_url", "audio_url": "http://example.com/b.wav"}, - {"type": "text", "text": "Transcribe it"}, - ], - }, - {"role": "assistant", "content": "Both are in English."}, - ], - } - ) - - if config.think: - test_cases.extend( - [ - { - "name": "consecutive_assistants_think", - "messages": [ - {"role": "user", "content": "Solve this"}, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Hmm."}, - {"type": "thinking", "thinking": "Let me think..."}, - {"type": "text", "text": "I need more context."}, - ], - }, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "OK."}, - {"type": "thinking", "thinking": "Now I understand."}, - {"type": "text", "text": "The answer is 42."}, - ], - }, - {"role": "user", "content": "Thanks"}, - {"role": "assistant", "content": "You're welcome"}, - ], - }, - { - "name": "consecutive_systems_think", - "messages": [ - { - "role": "system", - "content": [ - {"type": "text", "text": "Rule A"}, - {"type": "text", "text": "Rule B"}, - {"type": "thinking", "thinking": "Think 1"}, - ], - }, - { - "role": "system", - "content": [ - {"type": "thinking", "thinking": "Think 2"}, - {"type": "text", "text": "Rule C"}, - {"type": "text", "text": "Rule D"}, - ], - }, - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi"}, - ], - }, - ] - ) - - skip_names_image = {"with_image", "consecutive_users_with_image", "consecutive_users_multi_image"} - skip_names_audio = {"with_audio", "consecutive_users_multi_audio"} - skip_names_think = {"with_thinking", "consecutive_assistants_think", "consecutive_systems_think"} - # ThinkChunks in system messages are only supported in v13 (not v15+) - skip_names_think_system = {"consecutive_systems_think"} - skip_names_tools = {"with_tools_definition", "with_tool_call_and_result"} - - # Not using parametrize here because test_cases are built dynamically based on the - # config (version/image/audio/think) from the outer parametrize. Each sub-case is - # identifiable via test_name in the assertion failure message. - for test_case in test_cases: - test_name = test_case["name"] - messages = test_case["messages"] - tools = test_case.get("tools") - - if test_name in skip_names_image and not config.image: - continue - if test_name in skip_names_audio and not config.audio: - continue - if test_name in skip_names_think and not config.think: - continue - if test_name in skip_names_think_system and config.version >= TokenizerVersion.v15: - continue - if test_name in skip_names_tools and config.version <= TokenizerVersion.v1: - continue - - static_output = render_template(static_template, messages, tools=tools) # type: ignore - dynamic_output = render_template(dynamic_template, messages, tools=tools) # type: ignore - - assert static_output == dynamic_output, ( - f"Output mismatch for {config}, case={test_name}\n\n" - f"Static output: {static_output}\n" - f"Dynamic output: {dynamic_output}" - ) + for mode in (ValidationMode.finetuning, ValidationMode.test): + conversations = _get_conversations(config.version, mode, config.image, config.audio, config.think) + for idx, request in enumerate(conversations): + render_args = _request_to_render_args(request) + static_output = render_template(static_template, **render_args) + dynamic_output = render_template(dynamic_template, **render_args) + assert static_output == dynamic_output, ( + f"Output mismatch for {config}, mode={mode}, conversation={idx}\n\n" + f"Static output: {static_output}\n" + f"Dynamic output: {dynamic_output}" + ) diff --git a/tests/integrations/chat_templates/transformers/test_core_parity.py b/tests/integrations/chat_templates/transformers/test_core_parity.py index bf5eb23d..c579e53c 100644 --- a/tests/integrations/chat_templates/transformers/test_core_parity.py +++ b/tests/integrations/chat_templates/transformers/test_core_parity.py @@ -205,10 +205,16 @@ def test_invalid_chunks( # Not using parametrize here because invalid_convs depends on the version/image/audio/think # parameters from the outer parametrize. Each sub-case is identifiable via the TemplateError # match string which includes the role and allowed chunks. + is_v15_plus = config.version >= TokenizerVersion.v15 for conv in invalid_convs: msg_template = "Only {chunks} chunks are supported in {role} message content." if conv in sp_invalids: - chunks = "text and thinking" if config.think and config.version < TokenizerVersion.v15 else "text" + if config.think and not is_v15_plus: + chunks = "text and thinking" + elif config.audio and is_v15_plus: + chunks = "text and audio" + else: + chunks = "text" role = "system" elif conv in user_invalids: chunks = "text" @@ -218,7 +224,13 @@ def test_invalid_chunks( chunks += ", input_audio and audio_url" role = "user" elif conv in assistant_invalids: - chunks = "text and thinking" if config.think else "text" + desc_parts = ["text"] + if config.think: + desc_parts.append("thinking") + if len(desc_parts) == 1: + chunks = "text" + else: + chunks = ", ".join(desc_parts[:-1]) + " and " + desc_parts[-1] role = "assistant" err_msg = msg_template.format(chunks=chunks, role=role) diff --git a/tests/integrations/chat_templates/unit/test_v15.py b/tests/integrations/chat_templates/unit/test_v15.py index 1364e5f1..39c4705a 100644 --- a/tests/integrations/chat_templates/unit/test_v15.py +++ b/tests/integrations/chat_templates/unit/test_v15.py @@ -246,3 +246,228 @@ def test_v15_think_template_rejects_think_in_system(self, image: bool) -> None: with pytest.raises(ValueError, match="Only text chunks are supported in system message contents"): render_template(template, messages) + + +class TestV15MultimodalContent: + def test_v15_tool_message_with_image_content(self) -> None: + r"""V15 image template renders tool message with image content using render_content.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v15, + image_support=True, + audio_support=False, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + {"role": "user", "content": "Use tool"}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"id": "test12345", "function": {"name": "fn", "arguments": "{}"}}], + }, + { + "role": "tool", + "content": [ + {"type": "text", "text": "result"}, + {"type": "image_url", "image_url": "http://example.com/img.png"}, + ], + "tool_call_id": "test12345", + }, + {"role": "assistant", "content": "Done"}, + ] + + tools = [{"type": "function", "function": {"name": "fn", "description": "test", "parameters": {}}}] + output = render_template(template, messages, tools=tools, reasoning_effort="none") + assert "[TOOL_RESULTS]" in output + assert "[IMG]" in output + assert "result" in output + + def test_v15_tool_message_with_audio_content(self) -> None: + r"""V15 audio template renders tool message with audio content.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v15, + image_support=False, + audio_support=True, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + {"role": "user", "content": "Use tool"}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"id": "test12345", "function": {"name": "fn", "arguments": "{}"}}], + }, + { + "role": "tool", + "content": [ + {"type": "text", "text": "result"}, + {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, + ], + "tool_call_id": "test12345", + }, + {"role": "assistant", "content": "Done"}, + ] + + tools = [{"type": "function", "function": {"name": "fn", "description": "test", "parameters": {}}}] + output = render_template(template, messages, tools=tools, reasoning_effort="none") + assert "[TOOL_RESULTS]" in output + assert "[AUDIO]" in output + assert "result" in output + + def test_v15_system_message_with_audio_content(self) -> None: + r"""V15 audio template renders system message with audio content.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v15, + image_support=False, + audio_support=True, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "Listen to context"}, + {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, + ], + }, + {"role": "user", "content": "Summarize"}, + {"role": "assistant", "content": "Done"}, + ] + + output = render_template(template, messages, reasoning_effort="none") + assert "[SYSTEM_PROMPT]" in output + assert "Listen to context" in output + assert "[AUDIO]" in output + + def test_pre_v15_image_template_rejects_image_in_assistant(self) -> None: + r"""Pre-V15 image template rejects image chunks in assistant message.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v13, + image_support=True, + audio_support=False, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + {"role": "user", "content": "Show me"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Here"}, + {"type": "image_url", "image_url": "http://example.com/img.png"}, + ], + }, + ] + + with pytest.raises(ValueError, match="Only text chunks are supported in assistant message contents"): + render_template(template, messages) + + def test_pre_v15_audio_template_rejects_audio_in_assistant(self) -> None: + r"""Pre-V15 audio template rejects audio chunks in assistant message.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v13, + image_support=False, + audio_support=True, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + {"role": "user", "content": "Listen"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Here"}, + {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, + ], + }, + ] + + with pytest.raises(ValueError, match="Only text chunks are supported in assistant message contents"): + render_template(template, messages) + + def test_pre_v15_audio_template_rejects_audio_in_system(self) -> None: + r"""Pre-V15 audio template rejects audio chunks in system message.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v13, + image_support=False, + audio_support=True, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "Context"}, + {"type": "input_audio", "input_audio": {"data": "abc", "format": "wav"}}, + ], + }, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ] + + with pytest.raises(ValueError, match="Only text chunks are supported in system message contents"): + render_template(template, messages) + + def test_pre_v15_image_template_rejects_image_in_tool(self) -> None: + r"""Pre-V15 image template rejects image chunks in tool message content.""" + template = generate_chat_template( + spm=False, + tokenizer_version=TokenizerVersion.v13, + image_support=True, + audio_support=False, + thinking_support=False, + default_system_prompt=None, + plain_thinking_support=False, + use_special_token_variables=True, + ) + + messages: list[dict[str, Any]] = [ + {"role": "user", "content": "Use tool"}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"id": "test12345", "function": {"name": "fn", "arguments": "{}"}}], + }, + { + "role": "tool", + "content": [ + {"type": "text", "text": "result"}, + {"type": "image_url", "image_url": "http://example.com/img.png"}, + ], + "tool_call_id": "test12345", + }, + {"role": "assistant", "content": "Done"}, + ] + + tools = [{"type": "function", "function": {"name": "fn", "description": "test", "parameters": {}}}] + # Pre-V15 uses message['content']|string which coerces list to string representation + # rather than rendering through render_content — no [IMG] token produced + output = render_template(template, messages, tools=tools) + assert "[IMG]" not in output diff --git a/tests/test_converters.py b/tests/test_converters.py index d093d945..19b69a22 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -37,6 +37,7 @@ AudioChunk, AudioURL, AudioURLChunk, + ContentChunk, ImageChunk, ImageURL, ImageURLChunk, @@ -586,7 +587,7 @@ def test_non_leading_think_chunks_construction_ok() -> None: [TextChunk(text="A"), TextChunk(text="B"), ThinkChunk(thinking="End", closed=True)], ], ) -def test_non_leading_think_chunks_to_openai_raises(content: list[TextChunk | ThinkChunk]) -> None: +def test_non_leading_think_chunks_to_openai_raises(content: list[ContentChunk]) -> None: """to_openai raises when ThinkChunks are not leading.""" msg = AssistantMessage(content=content) with pytest.raises(InvalidAssistantMessageException, match="ThinkChunks must be leading"): diff --git a/tests/test_messages.py b/tests/test_messages.py new file mode 100644 index 00000000..44aba528 --- /dev/null +++ b/tests/test_messages.py @@ -0,0 +1,44 @@ +import pytest +from pydantic import ValidationError + +from mistral_common.protocol.instruct.messages import ( + AssistantMessage, + SystemMessage, + ToolMessage, + UserMessage, +) +from tests.fixtures.chunks import get_content_chunks as _chunks + + +class TestMessageContentChunkUnions: + r"""Pydantic-level (version-independent) content-chunk unions for each message role.""" + + def test_user_allows_text_image_audio(self) -> None: + UserMessage(content=_chunks(("text", "image", "image_url", "audio", "audio_url"))) + + def test_user_rejects_think(self) -> None: + for name in ("think",): + with pytest.raises(ValidationError): + UserMessage(content=_chunks((name,))) + + def test_assistant_allows_text_and_think(self) -> None: + AssistantMessage(content=_chunks(("text", "think"))) + + def test_assistant_rejects_image_and_audio(self) -> None: + for name in ("image", "image_url", "audio", "audio_url"): + with pytest.raises(ValidationError): + AssistantMessage(content=_chunks((name,))) + + def test_system_allows_text_audio_think(self) -> None: + SystemMessage(content=_chunks(("text", "audio", "think"))) + + def test_system_rejects_image_and_audio_url(self) -> None: + for name in ("image", "image_url", "audio_url"): + with pytest.raises(ValidationError): + SystemMessage(content=_chunks((name,))) + + def test_tool_allows_all_chunk_types(self) -> None: + ToolMessage( + content=_chunks(("text", "image", "image_url", "audio", "audio_url", "think")), + tool_call_id="c1", + ) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 1ff09546..70145f16 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -3,8 +3,8 @@ import pytest from mistral_common.protocol.instruct.chunk import ( + AudioChunk, ChunkTypes, - ContentChunk, ImageURLChunk, TextChunk, ThinkChunk, @@ -62,7 +62,10 @@ def test_user_system_user(self, normalizer: InstructRequestNormalizer) -> None: ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - assert parsed_request.system_prompt == "S" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="U"), UserMessage(content="U")], + system_prompt="S", + ) def test_multiple_system(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -74,7 +77,10 @@ def test_multiple_system(self, normalizer: InstructRequestNormalizer) -> None: ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - assert parsed_request.system_prompt == "S\n\nS\n\nS" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="")], + system_prompt="S\n\nS\n\nS", + ) def test_single_system(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -84,8 +90,10 @@ def test_single_system(self, normalizer: InstructRequestNormalizer) -> None: ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - - assert parsed_request.system_prompt == "S" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="")], + system_prompt="S", + ) def test_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -97,11 +105,10 @@ def test_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> N ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "" - assert parsed_request.system_prompt == "S" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content=""), AssistantMessage(content="A"), UserMessage(content="U")], + system_prompt="S", + ) def test_assistant_content_with_tool_calls(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -125,110 +132,89 @@ def test_assistant_system_user_adds_user(self, normalizer: InstructRequestNormal ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content=""), AssistantMessage(content="A"), UserMessage(content="U")], + system_prompt="S", + ) - assert parsed_request.system_prompt == "S" - - assert len(parsed_request.messages) == 3 # 1 user message added, system message removed - - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "" - assert parsed_request.system_prompt == "S" + def test_message_aggregation_system_then_user(self, normalizer: InstructRequestNormalizer) -> None: + parsed = normalizer.from_chat_completion_request( + mock_chat_completion( + [ + SystemMessage(content="s"), + SystemMessage(content="s"), + SystemMessage(content="s"), + UserMessage(content="u"), + ] + ) + ) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="u")], system_prompt="s\n\ns\n\ns" + ) - def check_merge( - self, - roles: list[str], - expected_roles: list[str], - expected_content: list[list[ContentChunk] | str], - normalizer: InstructRequestNormalizer, - ) -> None: - letter_to_cls: dict[str, ChatMessage] = { - "s": SystemMessage(content="s"), - "u": UserMessage(content="u"), - "a": AssistantMessage(content="a"), - "u2": UserMessage(content="u2"), - } + def test_message_aggregation_system_then_users(self, normalizer: InstructRequestNormalizer) -> None: + parsed = normalizer.from_chat_completion_request( + mock_chat_completion( + [ + SystemMessage(content="s"), + SystemMessage(content="s"), + SystemMessage(content="s"), + UserMessage(content="u"), + UserMessage(content="u"), + ] + ) + ) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="u\n\nu")], system_prompt="s\n\ns\n\ns" + ) - chat_completion_request = mock_chat_completion( - messages=[letter_to_cls[r] for r in roles], + def test_message_aggregation_mixed_with_middle_system(self, normalizer: InstructRequestNormalizer) -> None: + parsed = normalizer.from_chat_completion_request( + mock_chat_completion( + [ + SystemMessage(content="s"), + SystemMessage(content="s"), + SystemMessage(content="s"), + UserMessage(content="u"), + UserMessage(content="u"), + SystemMessage(content="s"), + AssistantMessage(content="a"), + UserMessage(content="u"), + ] + ) + ) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="u\n\nu"), AssistantMessage(content="a"), UserMessage(content="u")], + system_prompt="s\n\ns\n\ns\n\ns", ) - parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - assert len(parsed_request.messages) == len(expected_roles) - assert [message.role for message in parsed_request.messages] == [ - letter_to_cls[role].role for role in expected_roles - ] - assert len(expected_content) == len(parsed_request.messages) - for x, expected in zip(parsed_request.messages, expected_content): - assert isinstance(x, (UserMessage, AssistantMessage)) - assert x.content == expected - def check_merge_chunks( - self, - roles: list[str], - expected_roles: list[str], - expected_content: list[list[ContentChunk] | str], - normalizer: InstructRequestNormalizer, - ) -> None: - letter_to_cls: dict[str, ChatMessage] = { - "s": SystemMessage(content="s"), - "u": UserMessage(content="u"), - "a": AssistantMessage(content="a"), - "a2": AssistantMessage( - content=[ - ThinkChunk(thinking="t1"), - ThinkChunk(thinking="t2"), - TextChunk(text="a1"), - TextChunk(text="a2"), - TextChunk(text="a3"), + def test_message_aggregation_consecutive_assistants(self, normalizer: InstructRequestNormalizer) -> None: + parsed = normalizer.from_chat_completion_request( + mock_chat_completion( + [ + SystemMessage(content="s"), + SystemMessage(content="s"), + SystemMessage(content="s"), + UserMessage(content="u"), + UserMessage(content="u"), + AssistantMessage(content="a"), + AssistantMessage(content="a"), + UserMessage(content="u"), ] - ), - } + ) + ) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="u\n\nu"), AssistantMessage(content="a\n\na"), UserMessage(content="u")], + system_prompt="s\n\ns\n\ns", + ) - chat_completion_request = mock_chat_completion( - messages=[letter_to_cls[r] for r in roles], + def test_message_aggregation_system_assistant_user(self, normalizer: InstructRequestNormalizer) -> None: + parsed = normalizer.from_chat_completion_request( + mock_chat_completion([SystemMessage(content="s"), AssistantMessage(content="a"), UserMessage(content="u")]) ) - parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - assert len(parsed_request.messages) == len(expected_roles) - assert [message.role for message in parsed_request.messages] == [ - letter_to_cls[role].role for role in expected_roles - ] - assert len(expected_content) == len(parsed_request.messages) - for x, expected in zip(parsed_request.messages, expected_content): - assert isinstance(x, (UserMessage, AssistantMessage)) - assert x.content == expected - - def test_message_aggregation(self, normalizer: InstructRequestNormalizer) -> None: - self.check_merge(["s", "s", "s", "u"], ["u"], ["u"], normalizer) - self.check_merge(["s", "s", "s", "u", "u"], ["u"], ["u\n\nu"], normalizer) - self.check_merge(["s", "s", "s", "u", "u", "s", "a", "u"], ["u", "a", "u"], ["u\n\nu", "a", "u"], normalizer) - - self.check_merge( - ["s", "s", "s", "u", "u", "a", "a", "u"], - ["u", "a", "u"], - ["u\n\nu", "a\n\na", "u"], - normalizer, - ) - - self.check_merge( - ["s", "a", "u"], - ["u", "a", "u"], - ["", "a", "u"], - normalizer, - ) - - self.check_merge_chunks( - ["u", "a2", "u"], - ["u", "a", "u"], - [ - "u", - [ - ThinkChunk(thinking="t1"), - ThinkChunk(thinking="t2"), - TextChunk(text="a1\n\na2\n\na3"), - ], - "u", - ], - normalizer, + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content=""), AssistantMessage(content="a"), UserMessage(content="u")], + system_prompt="s", ) def test_tool_chunk_aggregation(self, normalizer: InstructRequestNormalizer) -> None: @@ -266,10 +252,9 @@ def test_normalize_chunks(self, normalizer: InstructRequestNormalizer) -> None: ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "foo\n\nchunk\n\nfoo\n\nchunk" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="foo\n\nchunk\n\nfoo\n\nchunk")] + ) def test_many_chunks_in_user_message(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -285,9 +270,9 @@ def test_many_chunks_in_user_message(self, normalizer: InstructRequestNormalizer ], ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "foo\n\nchunk1\n\nchunk2\n\nchunk3" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="foo\n\nchunk1\n\nchunk2\n\nchunk3")] + ) def test_ignore_middle_empty_text_chunks(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = mock_chat_completion( @@ -309,12 +294,9 @@ def test_ignore_middle_empty_text_chunks(self, normalizer: InstructRequestNormal ] ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "U\n\nV" - second_message = parsed_request.messages[1] - assert isinstance(second_message, AssistantMessage) - assert second_message.content == "A\n\nB" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="U\n\nV"), AssistantMessage(content="A\n\nB")], + ) def test_safety_prompt_aggregated(self, normalizer: InstructRequestNormalizer) -> None: chat_completion_request = ChatCompletionRequest[ChatMessage]( @@ -328,10 +310,10 @@ def test_safety_prompt_aggregated(self, normalizer: InstructRequestNormalizer) - ) parsed_request = normalizer.from_chat_completion_request(chat_completion_request) - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "user" - assert parsed_request.system_prompt == "system" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="user")], + system_prompt="system", + ) def test_normalize_tools(self, normalizer: InstructRequestNormalizer) -> None: """ @@ -410,18 +392,41 @@ def test_assert_parsed_settings( normalizer: InstructRequestNormalizer, ) -> None: chat_completion_request = ChatCompletionRequest(messages=[UserMessage(content="B")]) - parsed_request: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request( - chat_completion_request - ) - assert parsed_request.settings == ModelSettings.none() + parsed_request = normalizer.from_chat_completion_request(chat_completion_request) + assert parsed_request == InstructRequest[ChatMessage, Tool](messages=[UserMessage(content="B")]) def test_continue_final_message_forwarded(self, normalizer: InstructRequestNormalizer) -> None: request = ChatCompletionRequest[ChatMessage]( messages=[UserMessage(content="a"), AssistantMessage(content="b")], continue_final_message=True, ) - result: InstructRequest[ChatMessage, Tool] = normalizer.from_chat_completion_request(request) - assert result.continue_final_message is True + result = normalizer.from_chat_completion_request(request) + assert result == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="a"), AssistantMessage(content="b")], + continue_final_message=True, + ) + + def test_json_normalizes_tool_content(self, normalizer: InstructRequestNormalizer) -> None: + r"""Base normalizer (v1-v3) JSON-normalizes tool message content.""" + messy_json = '{"key" : "value" , "num": 1}' + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), + ToolMessage(content=messy_json, tool_call_id="c1"), + ], + ) + parsed = normalizer.from_chat_completion_request(request) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage( + content="", + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")], + ), + ToolMessage(content='{"key": "value", "num": 1}', tool_call_id="c1"), + ], + ) class TestChatCompletionRequestNormalizationV7: @@ -440,19 +445,14 @@ def test_system_assistant_user_v7(self, normalizer_v7: InstructRequestNormalizer ] ) - parsed_request: InstructRequest = normalizer_v7.from_chat_completion_request(chat_completion_request) - - first_message = parsed_request.messages[0] - assert isinstance(first_message, SystemMessage) - assert first_message.content == "S" - - second_message = parsed_request.messages[1] - assert isinstance(second_message, AssistantMessage) - assert second_message.content == "A" - - assert parsed_request.system_prompt is None + parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( + chat_completion_request + ) + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[SystemMessage(content="S"), AssistantMessage(content="A"), UserMessage(content="U")], + ) - def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNormalizer) -> None: + def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNormalizerV7) -> None: chat_completion_request = mock_chat_completion( messages=[ AssistantMessage(content="A"), @@ -460,21 +460,14 @@ def test_assistant_assistant_system_v7(self, normalizer_v7: InstructRequestNorma ] ) - parsed_request = normalizer_v7.from_chat_completion_request(chat_completion_request) - - assert parsed_request.system_prompt is None - - assert len(parsed_request.messages) == 2 - - first_message = parsed_request.messages[0] - assert isinstance(first_message, AssistantMessage) - assert first_message.content == "A" - - second_message = parsed_request.messages[1] - assert isinstance(second_message, SystemMessage) - assert second_message.content == "S" + parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( + chat_completion_request + ) + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[AssistantMessage(content="A"), SystemMessage(content="S")], + ) - def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestNormalizer) -> None: + def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestNormalizerV7) -> None: chat_completion_request = mock_chat_completion( messages=[ AssistantMessage( @@ -483,28 +476,28 @@ def test_assistant_content_with_tool_calls(self, normalizer_v7: InstructRequestN ) ] ) - normalized_chat_req = normalizer_v7.from_chat_completion_request(chat_completion_request) - - assert normalized_chat_req.messages[0].content == "A", normalized_chat_req.messages[0].content - assert len(normalized_chat_req.messages[0].tool_calls) == 1, normalized_chat_req.messages[0].tool_calls - assert normalized_chat_req.messages[0].tool_calls[0].function.name == "tool1", ( - normalized_chat_req.messages[0].tool_calls[0].function.name + normalized_chat_req: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( + chat_completion_request + ) + assert normalized_chat_req == InstructRequest[ChatMessage, Tool]( + messages=[ + AssistantMessage( + content="A", + tool_calls=[ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "11"}'))], + ), + ], ) - def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructRequestNormalizer) -> None: + def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructRequestNormalizerV7) -> None: chat_completion_request = mock_chat_completion( messages=[ UserMessage(content="A1"), - AssistantMessage( - content="B1", - ), + AssistantMessage(content="B1"), AssistantMessage( content="B2", tool_calls=[ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "1"}'))], ), - AssistantMessage( - content="B3", - ), + AssistantMessage(content="B3"), AssistantMessage( content="B4", tool_calls=[ @@ -512,24 +505,27 @@ def test_assistant_content_with_more_tool_calls(self, normalizer_v7: InstructReq ToolCall(function=FunctionCall(name="tool22", arguments='{"input": "22"}')), ], ), + AssistantMessage(content="B5"), + UserMessage(content="C1"), + ] + ) + normalized_chat_req: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( + chat_completion_request + ) + assert normalized_chat_req == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="A1"), AssistantMessage( - content="B5", + content="B1\n\nB2\n\nB3\n\nB4\n\nB5", + tool_calls=[ + ToolCall(function=FunctionCall(name="tool1", arguments='{"input": "1"}')), + ToolCall(function=FunctionCall(name="tool21", arguments='{"input": "21"}')), + ToolCall(function=FunctionCall(name="tool22", arguments='{"input": "22"}')), + ], ), UserMessage(content="C1"), - ] + ], ) - normalized_chat_req = normalizer_v7.from_chat_completion_request(chat_completion_request) - - assert normalized_chat_req.messages[0].content == "A1" - assert normalized_chat_req.messages[1].content.split("\n\n") == [f"B{i}" for i in range(1, 6)] - - tool_calls = normalized_chat_req.messages[1].tool_calls - - assert len(tool_calls) == 3 - - tool_key = ["1", "21", "22"] - assert all([t.function.name == f"tool{tool_key[i]}" for i, t in enumerate(tool_calls)]) - assert all([json.loads(t.function.arguments)["input"] == tool_key[i] for i, t in enumerate(tool_calls)]) def test_assert_parsed_settings( self, @@ -539,7 +535,7 @@ def test_assert_parsed_settings( parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( chat_completion_request ) - assert parsed_request.settings == ModelSettings.none() + assert parsed_request == InstructRequest[ChatMessage, Tool](messages=[UserMessage(content="B")]) def test_continue_final_message_forwarded(self, normalizer_v7: InstructRequestNormalizerV7) -> None: request = ChatCompletionRequest[ChatMessage]( @@ -547,7 +543,10 @@ def test_continue_final_message_forwarded(self, normalizer_v7: InstructRequestNo continue_final_message=True, ) result: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) - assert result.continue_final_message is True + assert result == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="a"), AssistantMessage(content="b")], + continue_final_message=True, + ) @pytest.mark.parametrize("num_empty", [0, 1, 2]) def test_only_empty_text_chunks(self, normalizer_v7: InstructRequestNormalizerV7, num_empty: int) -> None: @@ -561,13 +560,9 @@ def test_only_empty_text_chunks(self, normalizer_v7: InstructRequestNormalizerV7 parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( chat_completion_request ) - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == "" - second_message = parsed_request.messages[1] - assert isinstance(second_message, AssistantMessage) - # Empty string content is passed through directly - assert second_message.content == "" + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content=""), AssistantMessage(content="")], + ) def test_complex_user_aggregation(self, normalizer_v7: InstructRequestNormalizerV7) -> None: """Complex multi-user-message aggregation with mixed str, chunks, empty, and non-text chunks.""" @@ -594,13 +589,17 @@ def test_complex_user_aggregation(self, normalizer_v7: InstructRequestNormalizer parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( chat_completion_request ) - first_message = parsed_request.messages[0] - assert isinstance(first_message, UserMessage) - assert first_message.content == [ - TextChunk(text="A\n\nB\n\nC\n\nD"), - ImageURLChunk(image_url="E"), - TextChunk(text="G\n\nH"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage( + content=[ + TextChunk(text="A\n\nB\n\nC\n\nD"), + ImageURLChunk(image_url="E"), + TextChunk(text="G\n\nH"), + ] + ), + ], + ) def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNormalizerV7) -> None: """Complex multi-assistant-message aggregation with mixed str, chunks, and empty content.""" @@ -613,7 +612,6 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma AssistantMessage(content=[TextChunk(text="")]), AssistantMessage( content=[ - ThinkChunk(thinking="T"), TextChunk(text="C"), TextChunk(text="D"), ] @@ -623,13 +621,62 @@ def test_complex_assistant_aggregation(self, normalizer_v7: InstructRequestNorma parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request( chat_completion_request ) - first_message = parsed_request.messages[0] - assert isinstance(first_message, AssistantMessage) - assert first_message.content == [ - TextChunk(text="A\n\nB"), - ThinkChunk(thinking="T"), - TextChunk(text="C\n\nD"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[AssistantMessage(content="A\n\nB\n\nC\n\nD")], + ) + + def test_accepts_string_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None: + r"""V7 normalizer accepts string content in assistant messages.""" + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(content="plain text"), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="query"), AssistantMessage(content="plain text")], + ) + + def test_skips_json_normalization_on_tool_content(self, normalizer_v7: InstructRequestNormalizerV7) -> None: + r"""V7+ normalizers do not JSON-normalize tool message content.""" + messy_json = '{"key" : "value" , "num": 1}' + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), + ToolMessage(content=messy_json, tool_call_id="c1"), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage( + content="", + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")], + ), + ToolMessage(content=messy_json, tool_call_id="c1"), + ], + ) + + def test_preserves_audio_in_system_message(self, normalizer_v7: InstructRequestNormalizerV7) -> None: + r"""V7 normalizer preserves AudioChunk in system messages.""" + request = mock_chat_completion( + messages=[ + SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]), + UserMessage(content="test"), + ] + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v7.from_chat_completion_request(request) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + SystemMessage( + content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")], + ), + UserMessage(content="test"), + ], + ) class TestFineTuningNormalizer: @@ -720,18 +767,20 @@ def test_no_reorder_tool_messages(self, normalizer_v13: InstructRequestNormalize parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - UserMessage(content="A"), - AssistantMessage( - content="B", - tool_calls=[ - ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), - ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), - ], - ), - ToolMessage(content="C", tool_call_id="1"), - ToolMessage(content="D", tool_call_id="2"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="A"), + AssistantMessage( + content="B", + tool_calls=[ + ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), + ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), + ], + ), + ToolMessage(content="C", tool_call_id="1"), + ToolMessage(content="D", tool_call_id="2"), + ], + ) def test_reorder_last_tool_messages(self, normalizer_v13: InstructRequestNormalizerV13) -> None: chat_completion_request: ChatCompletionRequest = self._mock_chat_completion( @@ -751,18 +800,20 @@ def test_reorder_last_tool_messages(self, normalizer_v13: InstructRequestNormali parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - UserMessage(content="A"), - AssistantMessage( - content="B", - tool_calls=[ - ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), - ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), - ], - ), - ToolMessage(content="C", tool_call_id="1"), - ToolMessage(content="D", tool_call_id="2"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="A"), + AssistantMessage( + content="B", + tool_calls=[ + ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), + ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), + ], + ), + ToolMessage(content="C", tool_call_id="1"), + ToolMessage(content="D", tool_call_id="2"), + ], + ) def test_reorder_internal_tool_messages(self, normalizer_v13: InstructRequestNormalizerV13) -> None: chat_completion_request: ChatCompletionRequest = self._mock_chat_completion( @@ -783,19 +834,21 @@ def test_reorder_internal_tool_messages(self, normalizer_v13: InstructRequestNor parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - UserMessage(content="A"), - AssistantMessage( - content="B", - tool_calls=[ - ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), - ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), - ], - ), - ToolMessage(content="C", tool_call_id="1"), - ToolMessage(content="D", tool_call_id="2"), - AssistantMessage(content="E"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="A"), + AssistantMessage( + content="B", + tool_calls=[ + ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), + ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), + ], + ), + ToolMessage(content="C", tool_call_id="1"), + ToolMessage(content="D", tool_call_id="2"), + AssistantMessage(content="E"), + ], + ) def test_reorder_extra_tool_messages(self, normalizer_v13: InstructRequestNormalizerV13) -> None: chat_completion_request: ChatCompletionRequest = self._mock_chat_completion( @@ -814,17 +867,19 @@ def test_reorder_extra_tool_messages(self, normalizer_v13: InstructRequestNormal parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - UserMessage(content="A"), - AssistantMessage( - content="B", - tool_calls=[ - ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), - ], - ), - ToolMessage(content="C", tool_call_id="1"), - ToolMessage(content="D", tool_call_id="2"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="A"), + AssistantMessage( + content="B", + tool_calls=[ + ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), + ], + ), + ToolMessage(content="C", tool_call_id="1"), + ToolMessage(content="D", tool_call_id="2"), + ], + ) def test_reorder_only_from_latest_assistant_message(self, normalizer_v13: InstructRequestNormalizerV13) -> None: chat_completion_request: ChatCompletionRequest = self._mock_chat_completion( @@ -853,27 +908,29 @@ def test_reorder_only_from_latest_assistant_message(self, normalizer_v13: Instru parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - UserMessage(content="A"), - AssistantMessage( - content="B", - tool_calls=[ - ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), - ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), - ], - ), - ToolMessage(content="C", tool_call_id="1"), - ToolMessage(content="D", tool_call_id="2"), - AssistantMessage( - content="E", - tool_calls=[ - ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), - ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), - ], - ), - ToolMessage(content="D", tool_call_id="2"), - ToolMessage(content="C", tool_call_id="1"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="A"), + AssistantMessage( + content="B", + tool_calls=[ + ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), + ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), + ], + ), + ToolMessage(content="C", tool_call_id="1"), + ToolMessage(content="D", tool_call_id="2"), + AssistantMessage( + content="E", + tool_calls=[ + ToolCall(id="2", function=FunctionCall(name="foo", arguments="{}")), + ToolCall(id="1", function=FunctionCall(name="foo", arguments="{}")), + ], + ), + ToolMessage(content="D", tool_call_id="2"), + ToolMessage(content="C", tool_call_id="1"), + ], + ) @pytest.mark.parametrize( ["system_message", "expected_system_message"], @@ -919,7 +976,9 @@ def test_aggregate_system_prompt_content( parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [expected_system_message, UserMessage(content="B")] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[expected_system_message, UserMessage(content="B")] + ) def test_system_messages_no_aggregation(self, normalizer_v13: InstructRequestNormalizerV13) -> None: """Consecutive system messages are NOT aggregated into one in V7+.""" @@ -934,12 +993,14 @@ def test_system_messages_no_aggregation(self, normalizer_v13: InstructRequestNor parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - SystemMessage(content="A"), - SystemMessage(content="B"), - UserMessage(content="C"), - SystemMessage(content="D"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + SystemMessage(content="A"), + SystemMessage(content="B"), + UserMessage(content="C"), + SystemMessage(content="D"), + ], + ) def test_system_messages_normalization(self, normalizer_v13: InstructRequestNormalizerV13) -> None: """System message chunks within the same message are aggregated with no separator.""" @@ -952,10 +1013,12 @@ def test_system_messages_normalization(self, normalizer_v13: InstructRequestNorm parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.messages == [ - SystemMessage(content="A\n\nB"), - SystemMessage(content="C"), - ] + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[ + SystemMessage(content="A\n\nB"), + SystemMessage(content="C"), + ], + ) def test_assert_parsed_settings( self, @@ -965,7 +1028,7 @@ def test_assert_parsed_settings( parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( chat_completion_request ) - assert parsed_request.settings == ModelSettings.none() + assert parsed_request == InstructRequest[ChatMessage, Tool](messages=[UserMessage(content="B")]) def test_continue_final_message_forwarded(self, normalizer_v13: InstructRequestNormalizerV13) -> None: request = ChatCompletionRequest[ChatMessage]( @@ -973,7 +1036,128 @@ def test_continue_final_message_forwarded(self, normalizer_v13: InstructRequestN continue_final_message=True, ) result: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) - assert result.continue_final_message is True + assert result == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="a"), AssistantMessage(content="b")], + continue_final_message=True, + ) + + def test_accepts_text_and_think_chunks(self, normalizer_v13: InstructRequestNormalizerV13) -> None: + r"""V13 normalizer accepts TextChunk and ThinkChunk in assistant messages.""" + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), + ], + ) + + def test_accepts_string_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: + r"""V13 normalizer accepts string content in assistant messages.""" + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(content="plain text"), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="query"), AssistantMessage(content="plain text")], + ) + + def test_assistant_think_chunk_inter_message_aggregation( + self, normalizer_v13: InstructRequestNormalizerV13 + ) -> None: + r"""V13 normalizer preserves ThinkChunks across multiple assistant messages.""" + chat_completion_request = mock_chat_completion( + messages=[ + AssistantMessage(content="A"), + AssistantMessage(content=[TextChunk(text="B")]), + AssistantMessage( + content=[ + ThinkChunk(thinking="T"), + TextChunk(text="C"), + TextChunk(text="D"), + ] + ), + ] + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( + chat_completion_request + ) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + AssistantMessage( + content=[ + TextChunk(text="A\n\nB"), + ThinkChunk(thinking="T"), + TextChunk(text="C\n\nD"), + ] + ), + ], + ) + + def test_assistant_think_chunk_intra_message_aggregation( + self, normalizer_v13: InstructRequestNormalizerV13 + ) -> None: + r"""V13 normalizer coalesces TextChunks and preserves multiple ThinkChunks within a single message.""" + chat_completion_request = mock_chat_completion( + messages=[ + UserMessage(content="u"), + AssistantMessage( + content=[ + ThinkChunk(thinking="t1"), + ThinkChunk(thinking="t2"), + TextChunk(text="a1"), + TextChunk(text="a2"), + TextChunk(text="a3"), + ] + ), + UserMessage(content="u"), + ] + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request( + chat_completion_request + ) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="u"), + AssistantMessage( + content=[ + ThinkChunk(thinking="t1"), + ThinkChunk(thinking="t2"), + TextChunk(text="a1\n\na2\n\na3"), + ] + ), + UserMessage(content="u"), + ], + ) + + def test_aggregates_text_tool_content(self, normalizer_v13: InstructRequestNormalizerV13) -> None: + r"""V13 normalizer aggregates TextChunks in tool messages to a string.""" + request = mock_chat_completion( + messages=[ + UserMessage(content="query"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), + ToolMessage(content=[TextChunk(text="hello"), TextChunk(text="world")], tool_call_id="c1"), + ], + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v13.from_chat_completion_request(request) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage( + content="", + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")], + ), + ToolMessage(content="hello\n\nworld", tool_call_id="c1"), + ], + ) class TestChatCompletionRequestNormalizationV15: @@ -1002,7 +1186,10 @@ def test_assert_parsed_settings( parsed_request: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request( chat_completion_request ) - assert parsed_request.settings == ModelSettings(reasoning_effort=reasoning_effort) + assert parsed_request == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="B")], + settings=ModelSettings(reasoning_effort=reasoning_effort), + ) def test_continue_final_message_forwarded(self, normalizer_v15: InstructRequestNormalizerV15) -> None: request = ChatCompletionRequest[ChatMessage]( @@ -1011,7 +1198,11 @@ def test_continue_final_message_forwarded(self, normalizer_v15: InstructRequestN reasoning_effort=ReasoningEffort.high, ) result: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assert result.continue_final_message is True + assert result == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="a"), AssistantMessage(content="b")], + continue_final_message=True, + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_v15_intra_message_chunks_joined_without_separator( self, normalizer_v15: InstructRequestNormalizerV15 @@ -1025,12 +1216,10 @@ def test_v15_intra_message_chunks_joined_without_separator( reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - user_msg = parsed.messages[0] - assert isinstance(user_msg, UserMessage) - assert user_msg.content == "AB" - assistant_msg = parsed.messages[1] - assert isinstance(assistant_msg, AssistantMessage) - assert assistant_msg.content == "CD" + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="AB"), AssistantMessage(content="CD")], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_v15_inter_message_join_still_uses_separator(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 still joins text across different messages with '\n\n'.""" @@ -1043,9 +1232,10 @@ def test_v15_inter_message_join_still_uses_separator(self, normalizer_v15: Instr reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - user_msg = parsed.messages[0] - assert isinstance(user_msg, UserMessage) - assert user_msg.content == "First\n\nSecond" + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="First\n\nSecond"), AssistantMessage(content="Reply")], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_v15_mixed_intra_and_inter_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 combines intra-message ('') and inter-message ('\n\n') joining.""" @@ -1058,9 +1248,10 @@ def test_v15_mixed_intra_and_inter_message(self, normalizer_v15: InstructRequest reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - user_msg = parsed.messages[0] - assert isinstance(user_msg, UserMessage) - assert user_msg.content == "AB\n\nCD" + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="AB\n\nCD"), AssistantMessage(content="Reply")], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_v15_mixed_intra_and_inter_assistant_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None: r"""V15 combines intra-message ('') and inter-message ('\n\n') joining for assistant messages.""" @@ -1073,9 +1264,10 @@ def test_v15_mixed_intra_and_inter_assistant_messages(self, normalizer_v15: Inst reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - assistant_msg = parsed.messages[1] - assert isinstance(assistant_msg, AssistantMessage) - assert assistant_msg.content == "AB\n\nCD" + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="Hello"), AssistantMessage(content="AB\n\nCD")], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) def test_v15_tool_message_text_chunks_joined_without_separator( self, normalizer_v15: InstructRequestNormalizerV15 @@ -1090,9 +1282,129 @@ def test_v15_tool_message_text_chunks_joined_without_separator( reasoning_effort=ReasoningEffort.high, ) parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) - tool_msg = parsed.messages[2] - assert isinstance(tool_msg, ToolMessage) - assert tool_msg.content == "XY" + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage( + content="", + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")], + ), + ToolMessage(content="XY", tool_call_id="c1"), + ], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) + + def test_accepts_text_and_think_chunks(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + r"""V15 normalizer accepts TextChunk and ThinkChunk in assistant messages.""" + request = ChatCompletionRequest[ChatMessage]( + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), + ], + reasoning_effort=ReasoningEffort.high, + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage(content=[ThinkChunk(thinking="reasoning"), TextChunk(text="answer")]), + ], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) + + def test_accepts_string_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + r"""V15 normalizer accepts string content in assistant messages.""" + request = ChatCompletionRequest[ChatMessage]( + messages=[ + UserMessage(content="query"), + AssistantMessage(content="plain text"), + ], + reasoning_effort=ReasoningEffort.high, + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[UserMessage(content="query"), AssistantMessage(content="plain text")], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) + + def test_preserves_non_text_tool_content(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + r"""V15 normalizer preserves non-text chunks in tool messages.""" + image_chunk = ImageURLChunk(image_url="https://example.com/image.png") + request = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + UserMessage(content="query"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")]), + ToolMessage(content=[image_chunk], tool_call_id="c1"), + ], + reasoning_effort=ReasoningEffort.high, + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage( + content="", + tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="c1")], + ), + ToolMessage(content=[image_chunk], tool_call_id="c1"), + ], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) + + def test_sorts_multimodal_tool_messages(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + r"""V15 normalizer sorts multimodal tool messages by tool call order.""" + image_chunk_1 = ImageURLChunk(image_url="https://example.com/img1.png") + image_chunk_2 = ImageURLChunk(image_url="https://example.com/img2.png") + request = ChatCompletionRequest( # type: ignore[type-var] + messages=[ + UserMessage(content="query"), + AssistantMessage( + tool_calls=[ + ToolCall(function=FunctionCall(name="fn1", arguments="{}"), id="c1"), + ToolCall(function=FunctionCall(name="fn2", arguments="{}"), id="c2"), + ] + ), + ToolMessage(content=[image_chunk_2], tool_call_id="c2"), + ToolMessage(content=[image_chunk_1], tool_call_id="c1"), + ], + reasoning_effort=ReasoningEffort.high, + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + UserMessage(content="query"), + AssistantMessage( + content="", + tool_calls=[ + ToolCall(function=FunctionCall(name="fn1", arguments="{}"), id="c1"), + ToolCall(function=FunctionCall(name="fn2", arguments="{}"), id="c2"), + ], + ), + ToolMessage(content=[image_chunk_1], tool_call_id="c1"), + ToolMessage(content=[image_chunk_2], tool_call_id="c2"), + ], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) + + def test_preserves_audio_in_system_message(self, normalizer_v15: InstructRequestNormalizerV15) -> None: + r"""V15 normalizer preserves AudioChunk in system messages.""" + request = ChatCompletionRequest[ChatMessage]( + messages=[ + SystemMessage(content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")]), + UserMessage(content="test"), + ], + reasoning_effort=ReasoningEffort.high, + ) + parsed: InstructRequest[ChatMessage, Tool] = normalizer_v15.from_chat_completion_request(request) + assert parsed == InstructRequest[ChatMessage, Tool]( + messages=[ + SystemMessage( + content=[TextChunk(text="hello"), AudioChunk(input_audio=b"fake_audio_data")], + ), + UserMessage(content="test"), + ], + settings=ModelSettings(reasoning_effort=ReasoningEffort.high), + ) @pytest.mark.parametrize( diff --git a/tests/test_tokenizer_v13.py b/tests/test_tokenizer_v13.py index 15bb6f3a..64380ac3 100644 --- a/tests/test_tokenizer_v13.py +++ b/tests/test_tokenizer_v13.py @@ -237,13 +237,17 @@ def test_end_to_end_v13_wrong_order( def test_encode_tool_message(v13_tekkenizer: InstructTokenizerV13) -> None: tool_message = ToolMessage(content="R1", tool_call_id="123456789") assert isinstance(v13_tekkenizer, InstructTokenizerV13) - encoded = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False) + encoded, images, audios = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False) assert encoded == [7, 182, 149, 8] + assert images == [] + assert audios == [] tool_message = ToolMessage(content=[TextChunk(text="R1"), TextChunk(text="R2")], tool_call_id="123456789") assert isinstance(v13_tekkenizer, InstructTokenizerV13) - encoded = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False) + encoded, images, audios = v13_tekkenizer.encode_tool_message(tool_message, is_before_last_user_message=False) assert encoded == [7, 182, 149, 182, 150, 8] + assert images == [] + assert audios == [] def test_encode_think_chunk(v13_tekkenizer_think: InstructTokenizerV13) -> None: @@ -371,8 +375,9 @@ def test_tokenize_assistant_message_error(v13_tekkenizer: InstructTokenizerV13) def test_encode_system_message( v13_tekkenizer_think: InstructTokenizerV13, message: SystemMessage, expected: str ) -> None: - encoded = v13_tekkenizer_think.encode_system_message(message) + encoded, audios = v13_tekkenizer_think.encode_system_message(message) assert v13_tekkenizer_think.decode(encoded, special_token_policy=SpecialTokenPolicy.KEEP) == expected + assert audios == [] @pytest.mark.parametrize("audio_fixture", ["audio_chunk", "audio_url_chunk"]) diff --git a/tests/test_tokenizer_v15.py b/tests/test_tokenizer_v15.py index fdd84c33..c8abe52c 100644 --- a/tests/test_tokenizer_v15.py +++ b/tests/test_tokenizer_v15.py @@ -1,7 +1,18 @@ +import base64 +from collections.abc import Callable +from io import BytesIO + import pytest +from PIL import Image from mistral_common.exceptions import InvalidRequestException, TokenizerException -from mistral_common.protocol.instruct.chunk import ThinkChunk +from mistral_common.protocol.instruct.chunk import ( + AudioChunk, + AudioURLChunk, + ImageURLChunk, + TextChunk, + ThinkChunk, +) from mistral_common.protocol.instruct.messages import ( AssistantMessage, ChatMessage, @@ -18,11 +29,14 @@ ) from mistral_common.protocol.instruct.tool_calls import Function, FunctionCall, Tool, ToolCall from mistral_common.protocol.instruct.validator import ValidationMode, get_validator -from mistral_common.tokens.tokenizers.base import TokenizerVersion +from mistral_common.tokens.tokenizers.audio import AudioConfig, AudioEncoder, AudioSpectrogramConfig, SpecialAudioIDs +from mistral_common.tokens.tokenizers.base import SpecialTokens, TokenizerVersion +from mistral_common.tokens.tokenizers.image import ImageConfig, ImageEncoder, SpecialImageIDs from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV15 from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.model_settings_builder import EnumBuilder, ModelSettingsBuilder from mistral_common.tokens.tokenizers.tekken import Tekkenizer +from tests.fixtures.audio import get_dummy_audio_chunk, get_dummy_audio_url_chunk from tests.test_tekken import get_special_tokens, quick_vocab EXPECTED_TEXT_V15: str = ( @@ -49,6 +63,44 @@ r"[INST]U2[/INST]" ) +EXPECTED_TEXT_TOOL_AUDIO: str = ( + r"" + r'[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "fn",' + r' "description": "test", "parameters": {}}}]' + r'[/AVAILABLE_TOOLS][MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]' + r"[INST]Use the tool[/INST]" + r"[TOOL_CALLS]fn[ARGS]{}" + r"[TOOL_RESULTS]result[BEGIN_AUDIO][AUDIO][AUDIO][/TOOL_RESULTS]" +) + +EXPECTED_TEXT_TOOL_IMAGE: str = ( + r"" + r'[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "fn",' + r' "description": "test", "parameters": {}}}]' + r'[/AVAILABLE_TOOLS][MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]' + r"[INST]Use the tool[/INST]" + r"[TOOL_CALLS]fn[ARGS]{}" + r"[TOOL_RESULTS]result[IMG][IMG_END][/TOOL_RESULTS]" +) + +EXPECTED_TEXT_SYSTEM_AUDIO: str = ( + r"[SYSTEM_PROMPT]System with content[BEGIN_AUDIO][AUDIO][AUDIO][/SYSTEM_PROMPT]" + r'[MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]' + r"[INST]Hello[/INST]" +) + +EXPECTED_TEXT_USER_AUDIO: str = ( + r"" + r'[MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]' + r"[INST]Here is content[BEGIN_AUDIO][AUDIO][AUDIO][/INST]" +) + +EXPECTED_TEXT_USER_IMAGE: str = ( + r"" + r'[MODEL_SETTINGS]{"reasoning_effort": "none"}[/MODEL_SETTINGS]' + r"[INST][IMG][IMG_END]Here is content[/INST]" +) + def _build_v15_tekkenizer(model_settings_builder: ModelSettingsBuilder | None) -> Tekkenizer: r"""Build a v15 Tekkenizer with the given model settings builder. @@ -89,6 +141,103 @@ def get_v15_mistral_tokenizer( ) +def _build_model_settings_builder( + allowed_reasoning_effort: tuple[str, ...] | None, +) -> ModelSettingsBuilder: + """Build a ModelSettingsBuilder from allowed reasoning effort values. + + When `allowed_reasoning_effort` is `None`, returns `ModelSettingsBuilder.none()` + (all fields ignored). This matches the behavior of `Tekkenizer.from_file` when no + `model_settings_builder` key is present in the JSON. + """ + if allowed_reasoning_effort is None: + return ModelSettingsBuilder.none() + return ModelSettingsBuilder( + reasoning_effort=EnumBuilder[ReasoningEffort]( + values=[ReasoningEffort(v) for v in allowed_reasoning_effort], + accepts_none=True, + default=ReasoningEffort(allowed_reasoning_effort[0]) if allowed_reasoning_effort else None, + ) + ) + + +def _get_dummy_image_url_chunk() -> ImageURLChunk: + r"""Build a small base64-encoded image URL chunk for testing.""" + img = Image.new("RGB", (4, 4), "red") + buf = BytesIO() + img.save(buf, "PNG") + data_url = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" + return ImageURLChunk(image_url=data_url) + + +def get_v15_mistral_tokenizer_with_audio() -> MistralTokenizer: + r"""Build a V15 MistralTokenizer with audio encoder.""" + builder = _build_model_settings_builder(tuple(ReasoningEffort)) + tekkenizer = Tekkenizer( + quick_vocab([b"a", b"b", b"c", b"f", b"de"]), + special_tokens=get_special_tokens(TokenizerVersion.v15, add_audio=True), + pattern=r".+", + vocab_size=256 + 100, + num_special_tokens=100, + version=TokenizerVersion.v15, + model_settings_builder=builder, + ) + audio_config = AudioConfig( + sampling_rate=24_000, + frame_rate=12.5, + encoding_config=AudioSpectrogramConfig( + num_mel_bins=128, + hop_length=160, + window_size=400, + ), + ) + special_audio_ids = SpecialAudioIDs( + audio=tekkenizer.get_special_token(SpecialTokens.audio.value), + begin_audio=tekkenizer.get_special_token(SpecialTokens.begin_audio.value), + streaming_pad=None, + text_to_audio=None, + audio_to_text=None, + ) + audio_encoder = AudioEncoder(audio_config, special_audio_ids) + instruct_tokenizer = InstructTokenizerV15(tekkenizer, audio_encoder=audio_encoder) + request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder) + validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test) + return MistralTokenizer( + instruct_tokenizer=instruct_tokenizer, + validator=validator, + request_normalizer=request_normalizer, + ) + + +def get_v15_mistral_tokenizer_with_image() -> MistralTokenizer: + r"""Build a V15 MistralTokenizer with image encoder.""" + builder = _build_model_settings_builder(tuple(ReasoningEffort)) + tekkenizer = Tekkenizer( + quick_vocab([b"a", b"b", b"c", b"f", b"de"]), + special_tokens=get_special_tokens(TokenizerVersion.v15, add_think=True), + pattern=r".+", + vocab_size=256 + 100, + num_special_tokens=100, + version=TokenizerVersion.v15, + model_settings_builder=builder, + ) + image_config = ImageConfig(image_patch_size=16, max_image_size=1024) + special_image_ids = SpecialImageIDs( + img=tekkenizer.get_special_token(SpecialTokens.img.value), + img_break=tekkenizer.get_special_token(SpecialTokens.img_break.value), + img_end=tekkenizer.get_special_token(SpecialTokens.img_end.value), + ) + image_encoder = ImageEncoder(image_config, special_image_ids) + instruct_tokenizer = InstructTokenizerV15(tekkenizer, image_encoder=image_encoder) + request_normalizer = get_normalizer(TokenizerVersion.v15, tekkenizer.model_settings_builder) + validator = get_validator(TokenizerVersion.v15, mode=ValidationMode.test) + return MistralTokenizer( + instruct_tokenizer=instruct_tokenizer, + validator=validator, + request_normalizer=request_normalizer, + ) + + @pytest.fixture(scope="session") def v15_tekkenizer() -> InstructTokenizerV15: return get_v15_tekkenizer(_build_model_settings_builder(tuple(ReasoningEffort))) @@ -139,6 +288,61 @@ def messages() -> list[ChatMessage]: ] +@pytest.fixture(scope="session") +def audio_chunk() -> AudioChunk: + return get_dummy_audio_chunk() + + +# Multimodal content chunks and their corresponding tokenizer factories for parametrized tests. +_TOOL_MULTIMODAL_PARAMS = [ + pytest.param( + get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, EXPECTED_TEXT_TOOL_AUDIO, id="audio" + ), + pytest.param( + get_dummy_audio_url_chunk(), + 1, + 0, + get_v15_mistral_tokenizer_with_audio, + EXPECTED_TEXT_TOOL_AUDIO, + id="audio_url", + ), + pytest.param( + _get_dummy_image_url_chunk(), + 0, + 1, + get_v15_mistral_tokenizer_with_image, + EXPECTED_TEXT_TOOL_IMAGE, + id="image_url", + ), +] +_SYSTEM_MULTIMODAL_PARAMS = [ + pytest.param( + get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, EXPECTED_TEXT_SYSTEM_AUDIO, id="audio" + ), +] +_USER_MULTIMODAL_PARAMS = [ + pytest.param( + get_dummy_audio_chunk(), 1, 0, get_v15_mistral_tokenizer_with_audio, EXPECTED_TEXT_USER_AUDIO, id="audio" + ), + pytest.param( + get_dummy_audio_url_chunk(), + 1, + 0, + get_v15_mistral_tokenizer_with_audio, + EXPECTED_TEXT_USER_AUDIO, + id="audio_url", + ), + pytest.param( + _get_dummy_image_url_chunk(), + 0, + 1, + get_v15_mistral_tokenizer_with_image, + EXPECTED_TEXT_USER_IMAGE, + id="image_url", + ), +] + + def test_tools_and_reasoning_effort( v15_tekkenizer: InstructTokenizerV15, available_tools: list[Tool], messages: list[ChatMessage] ) -> None: @@ -177,30 +381,6 @@ def test_system_think_chunk_raises_v15(v15_tekkenizer: InstructTokenizerV15) -> v15_tekkenizer.encode_instruct(request) -def _build_model_settings_builder( - allowed_reasoning_effort: tuple[str, ...] | None, -) -> ModelSettingsBuilder: - """Build a ModelSettingsBuilder from allowed reasoning effort values. - - When `allowed_reasoning_effort` is `None`, returns `ModelSettingsBuilder.none()` - (all fields ignored). This matches the behavior of `Tekkenizer.from_file` when no - `model_settings_builder` key is present in the JSON. - """ - if allowed_reasoning_effort is None: - return ModelSettingsBuilder.none() - if not allowed_reasoning_effort: - return ModelSettingsBuilder( - reasoning_effort=EnumBuilder[ReasoningEffort](values=[], accepts_none=True, default=None) - ) - return ModelSettingsBuilder( - reasoning_effort=EnumBuilder[ReasoningEffort]( - values=[ReasoningEffort(v) for v in allowed_reasoning_effort], - accepts_none=True, - default=ReasoningEffort(allowed_reasoning_effort[0]) if allowed_reasoning_effort else None, - ) - ) - - @pytest.mark.parametrize( ("reasoning_effort", "allowed_reasoning_effort", "raises", "match"), [ @@ -294,3 +474,79 @@ def test_encode_chat_completion_continue_final_message() -> None: eos_id = tokenizer_v15.instruct_tokenizer.tokenizer.eos_id assert encoded.tokens[-1] != eos_id + + +@pytest.mark.parametrize( + ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory", "expected_text"), + _TOOL_MULTIMODAL_PARAMS, +) +def test_encode_chat_completion_with_multimodal_tool( + content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk, + expected_audios: int, + expected_images: int, + tokenizer_factory: Callable[[], MistralTokenizer], + expected_text: str, +) -> None: + mistral_tokenizer = tokenizer_factory() + chat_request: ChatCompletionRequest = ChatCompletionRequest( + messages=[ + UserMessage(content="Use the tool"), + AssistantMessage(tool_calls=[ToolCall(id="test12345", function=FunctionCall(name="fn", arguments="{}"))]), + ToolMessage( + content=[TextChunk(text="result"), content_chunk], + tool_call_id="test12345", + ), + ], + tools=[Tool(function=Function(name="fn", description="test", parameters={}))], + ) + encoded = mistral_tokenizer.encode_chat_completion(chat_request) + assert encoded.text == expected_text, encoded.text + assert len(encoded.audios) == expected_audios + assert len(encoded.images) == expected_images + + +@pytest.mark.parametrize( + ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory", "expected_text"), + _SYSTEM_MULTIMODAL_PARAMS, +) +def test_encode_chat_completion_with_multimodal_system( + content_chunk: AudioChunk, + expected_audios: int, + expected_images: int, + tokenizer_factory: Callable[[], MistralTokenizer], + expected_text: str, +) -> None: + mistral_tokenizer = tokenizer_factory() + chat_request: ChatCompletionRequest = ChatCompletionRequest( + messages=[ + SystemMessage(content=[TextChunk(text="System with content"), content_chunk]), + UserMessage(content="Hello"), + ], + ) + encoded = mistral_tokenizer.encode_chat_completion(chat_request) + assert encoded.text == expected_text, encoded.text + assert len(encoded.audios) == expected_audios + assert len(encoded.images) == expected_images + + +@pytest.mark.parametrize( + ("content_chunk", "expected_audios", "expected_images", "tokenizer_factory", "expected_text"), + _USER_MULTIMODAL_PARAMS, +) +def test_encode_chat_completion_with_multimodal_user( + content_chunk: AudioChunk | AudioURLChunk | ImageURLChunk, + expected_audios: int, + expected_images: int, + tokenizer_factory: Callable[[], MistralTokenizer], + expected_text: str, +) -> None: + mistral_tokenizer = tokenizer_factory() + chat_request = ChatCompletionRequest( + messages=[ + UserMessage(content=[TextChunk(text="Here is content"), content_chunk]), + ], + ) + encoded = mistral_tokenizer.encode_chat_completion(chat_request) + assert encoded.text == expected_text, encoded.text + assert len(encoded.audios) == expected_audios + assert len(encoded.images) == expected_images diff --git a/tests/test_tokenizer_v7_audio.py b/tests/test_tokenizer_v7_audio.py index 10995957..4ef3d47c 100644 --- a/tests/test_tokenizer_v7_audio.py +++ b/tests/test_tokenizer_v7_audio.py @@ -11,8 +11,8 @@ AudioChunk, AudioURL, AudioURLChunk, + ContentChunk, TextChunk, - UserContentChunk, ) from mistral_common.protocol.instruct.messages import ( UATS, @@ -215,7 +215,7 @@ def test_tokenize_user_message(tekkenizer: InstructTokenizerV7, audio_first: boo text_chunk = TextChunk(text="a") num_expected_frames = int(np.ceil(duration * frame_rate)) - chunks: list[UserContentChunk] = [audio_chunk, text_chunk] if audio_first else [text_chunk, audio_chunk] + chunks: list[ContentChunk] = [audio_chunk, text_chunk] if audio_first else [text_chunk, audio_chunk] tokenized = tekkenizer.encode_instruct( InstructRequest( @@ -261,7 +261,7 @@ def test_tokenize_multi_turn(tekkenizer: InstructTokenizerV7) -> None: text_chunk = TextChunk(text="a") num_expected_frames = int(np.ceil(duration * frame_rate)) - chunks: list[UserContentChunk] = [audio_chunk, text_chunk] + chunks: list[ContentChunk] = [audio_chunk, text_chunk] tokenized = tekkenizer.encode_instruct( InstructRequest( diff --git a/tests/validation/test_chat_validation.py b/tests/validation/test_chat_validation.py index f4c8653c..c036e2a8 100644 --- a/tests/validation/test_chat_validation.py +++ b/tests/validation/test_chat_validation.py @@ -1,11 +1,19 @@ +from typing import Any + import pytest from mistral_common.exceptions import ( InvalidAssistantMessageException, InvalidMessageStructureException, InvalidRequestException, + InvalidSystemPromptException, + InvalidToolMessageException, + InvalidUserMessageException, +) +from mistral_common.protocol.instruct.chunk import ( + AudioChunk, + AudioURLChunk, ) -from mistral_common.protocol.instruct.chunk import AudioChunk, AudioURLChunk from mistral_common.protocol.instruct.messages import ( AssistantMessage, SystemMessage, @@ -26,6 +34,29 @@ ValidationMode, ) from tests.fixtures.audio import get_dummy_audio_chunk, get_dummy_audio_url_chunk +from tests.fixtures.chunks import get_content_chunks + +_Messages = list[UserMessage | AssistantMessage | SystemMessage | ToolMessage] + + +def _user_convo(content: "str | list[Any]") -> _Messages: + return [UserMessage(content=content)] + + +def _assistant_convo(content: "str | list[Any]") -> _Messages: + return [UserMessage(content="hi"), AssistantMessage(content=content), UserMessage(content="next")] + + +def _system_convo(content: "str | list[Any]") -> _Messages: + return [SystemMessage(content=content), UserMessage(content="hi")] + + +def _tool_convo(content: "str | list[Any]") -> _Messages: + return [ + UserMessage(content="hi"), + AssistantMessage(tool_calls=[ToolCall(function=FunctionCall(name="fn", arguments="{}"), id="test12345")]), + ToolMessage(content=content, tool_call_id="test12345"), + ] @pytest.fixture(scope="module") @@ -49,6 +80,16 @@ def validator(request: pytest.FixtureRequest) -> MistralRequestValidator: return request.param # type: ignore +@pytest.fixture +def validator_base() -> MistralRequestValidator: + return MistralRequestValidator(ValidationMode.serving) + + +@pytest.fixture +def validator_v3() -> MistralRequestValidatorV3: + return MistralRequestValidatorV3(ValidationMode.serving) + + @pytest.fixture def validator_v5() -> MistralRequestValidatorV5: return MistralRequestValidatorV5(ValidationMode.serving) @@ -353,6 +394,57 @@ def test_build_settings_raises_error(self, validator: MistralRequestValidator) - with pytest.raises(InvalidRequestException, match="reasoning_effort='none' is not supported for this model"): validator._validate_model_settings(request) + def test_allows_text_content_chunks(self, validator_base: MistralRequestValidator) -> None: + validator_base.validate_messages(_user_convo(get_content_chunks(("text",))), continue_final_message=False) + validator_base.validate_messages(_assistant_convo(get_content_chunks(("text",))), continue_final_message=False) + validator_base.validate_messages(_system_convo(get_content_chunks(("text",))), continue_final_message=False) + validator_base.validate_messages(_tool_convo(get_content_chunks(("text",))), continue_final_message=False) + + def test_rejects_non_text_in_user(self, validator_base: MistralRequestValidator) -> None: + for name in ("image", "image_url", "audio", "audio_url"): + with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): + validator_base.validate_messages(_user_convo(get_content_chunks((name,))), continue_final_message=False) + + def test_rejects_non_text_in_assistant(self, validator_base: MistralRequestValidator) -> None: + for name in ("think",): + with pytest.raises( + InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message" + ): + validator_base.validate_messages( + _assistant_convo(get_content_chunks((name,))), continue_final_message=False + ) + + def test_rejects_non_text_in_system(self, validator_base: MistralRequestValidator) -> None: + for name in ("audio", "think"): + with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"): + validator_base.validate_messages( + _system_convo(get_content_chunks((name,))), continue_final_message=False + ) + + def test_rejects_non_text_in_tool(self, validator_base: MistralRequestValidator) -> None: + for name in ("image", "image_url", "audio", "audio_url", "think"): + with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"): + validator_base.validate_messages(_tool_convo(get_content_chunks((name,))), continue_final_message=False) + + def test_reports_sorted_unique_invalid_chunk_types(self, validator_base: MistralRequestValidator) -> None: + content = get_content_chunks(("audio", "image_url")) + with pytest.raises( + InvalidUserMessageException, + match=r"Unexpected content chunk types in user message: \['AudioChunk', 'ImageURLChunk'\]", + ): + validator_base.validate_messages(_user_convo(content), continue_final_message=False) + + +class TestChatValidationV3: + def test_allows_text_and_image_in_user(self, validator_v3: MistralRequestValidatorV3) -> None: + content = get_content_chunks(("text", "image", "image_url")) + validator_v3.validate_messages(_user_convo(content), continue_final_message=False) + + def test_rejects_audio_in_user(self, validator_v3: MistralRequestValidatorV3) -> None: + for name in ("audio", "audio_url"): + with pytest.raises(InvalidUserMessageException, match="Unexpected content chunk types in user message"): + validator_v3.validate_messages(_user_convo(get_content_chunks((name,))), continue_final_message=False) + class TestChatValidationV5: @pytest.mark.parametrize("audio_fixture", ["audio_chunk", "audio_url_chunk"]) @@ -403,6 +495,23 @@ def test_build_settings_raises_error(self, validator: MistralRequestValidator) - with pytest.raises(InvalidRequestException, match="reasoning_effort='none' is not supported for this model"): validator._validate_model_settings(request) + def test_allows_text_image_audio_in_user(self, validator_v5: MistralRequestValidatorV5) -> None: + content = get_content_chunks(("text", "image", "image_url", "audio", "audio_url")) + validator_v5.validate_messages(_user_convo(content), continue_final_message=False) + + def test_allows_text_audio_think_in_system(self, validator_v5: MistralRequestValidatorV5) -> None: + content = get_content_chunks(("text", "audio", "think")) + validator_v5.validate_messages(_system_convo(content), continue_final_message=False) + + def test_rejects_think_in_assistant(self, validator_v5: MistralRequestValidatorV5) -> None: + for name in ("think",): + with pytest.raises( + InvalidAssistantMessageException, match="Unexpected content chunk types in assistant message" + ): + validator_v5.validate_messages( + _assistant_convo(get_content_chunks((name,))), continue_final_message=False + ) + class TestChatValidationV13: def test_right_number_results_invalid_id(self, validator_v13: MistralRequestValidatorV13) -> None: @@ -613,6 +722,15 @@ def test_audio_with_system_prompt_raises_ok( continue_final_message=False, ) + def test_allows_text_and_think_in_assistant(self, validator_v13: MistralRequestValidatorV13) -> None: + content = get_content_chunks(("text", "think")) + validator_v13.validate_messages(_assistant_convo(content), continue_final_message=False) + + def test_rejects_non_text_in_tool(self, validator_v13: MistralRequestValidatorV13) -> None: + for name in ("image", "image_url", "audio", "audio_url", "think"): + with pytest.raises(InvalidToolMessageException, match="Unexpected content chunk types in tool message"): + validator_v13.validate_messages(_tool_convo(get_content_chunks((name,))), continue_final_message=False) + class TestChatValidationV15: @pytest.mark.parametrize("reasoning_effort", [*list(ReasoningEffort), None]) @@ -621,3 +739,18 @@ def test_build_settings_v15_reasoning_effort( ) -> None: request = ChatCompletionRequest(messages=[UserMessage(content="Hello")], reasoning_effort=reasoning_effort) validator_v15._validate_model_settings(request) + + def test_allows_text_and_audio_in_system(self, validator_v15: MistralRequestValidatorV15) -> None: + content = get_content_chunks(("text", "audio")) + validator_v15.validate_messages(_system_convo(content), continue_final_message=False) + + def test_rejects_think_in_system(self, validator_v15: MistralRequestValidatorV15) -> None: + for name in ("think",): + with pytest.raises(InvalidSystemPromptException, match="Unexpected content chunk types in system message"): + validator_v15.validate_messages( + _system_convo(get_content_chunks((name,))), continue_final_message=False + ) + + def test_allows_all_chunk_types_in_tool(self, validator_v15: MistralRequestValidatorV15) -> None: + content = get_content_chunks(("text", "image", "image_url", "audio", "audio_url", "think")) + validator_v15.validate_messages(_tool_convo(content), continue_final_message=False)