diff --git a/src/mistral_common/protocol/instruct/chunk.py b/src/mistral_common/protocol/instruct/chunk.py index 2ed61528..c616425b 100644 --- a/src/mistral_common/protocol/instruct/chunk.py +++ b/src/mistral_common/protocol/instruct/chunk.py @@ -20,6 +20,13 @@ from mistral_common.tokens.tokenizers.audio import Audio +def _strip_audio_data_url_prefix(data: str) -> str: + r"""Remove the optional base64 audio data URL prefix.""" + if re.match(r"^data:audio/\w+;base64,", data): + return data.split(",", 1)[1] + return data + + def _detect_audio_format(data: str | bytes) -> str: r"""Detect audio format from base64-encoded string or raw bytes. @@ -37,7 +44,7 @@ def _detect_audio_format(data: str | bytes) -> str: assert_soundfile_installed() if isinstance(data, str): - audio_bytes = base64.b64decode(data) + audio_bytes = base64.b64decode(_strip_audio_data_url_prefix(data)) else: audio_bytes = data @@ -379,7 +386,10 @@ def to_openai(self) -> dict[str, Any]: Returns: A dictionary representing the audio chunk in the OpenAI format. """ - content = self.input_audio.decode("utf-8") if isinstance(self.input_audio, bytes) else self.input_audio + if isinstance(self.input_audio, bytes): + content = base64.b64encode(self.input_audio).decode("utf-8") + else: + content = _strip_audio_data_url_prefix(self.input_audio) fmt = _detect_audio_format(self.input_audio) return { "type": self.type, diff --git a/tests/test_converters.py b/tests/test_converters.py index d093d945..e8f175da 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -1,3 +1,4 @@ +import base64 import copy import io import warnings @@ -1197,6 +1198,33 @@ def test_audio_chunk_to_openai_format_detection(fmt: str) -> None: assert AudioChunk.from_openai(result).input_audio == b64 +@pytest.mark.parametrize("fmt", ["wav", "flac"]) +def test_audio_chunk_to_openai_raw_bytes_format_detection(fmt: str) -> None: + audio = _make_fake_audio(0.5) + buffer = io.BytesIO() + sf.write(buffer, audio.audio_array, audio.sampling_rate, format=fmt) + raw_bytes = buffer.getvalue() + + result = AudioChunk(input_audio=raw_bytes).to_openai() + + assert result["input_audio"]["format"] == fmt + assert result["input_audio"]["data"] == base64.b64encode(raw_bytes).decode("utf-8") + assert AudioChunk.from_openai(result).input_audio == result["input_audio"]["data"] + + +@pytest.mark.parametrize("fmt", ["wav", "flac"]) +def test_audio_chunk_to_openai_strips_base64_data_url_prefix(fmt: str) -> None: + audio = _make_fake_audio(0.5) + b64 = audio.to_base64(fmt) + chunk = AudioChunk(input_audio=f"data:audio/{fmt};base64,{b64}") + + result = chunk.to_openai() + + assert result["input_audio"]["format"] == fmt + assert result["input_audio"]["data"] == b64 + assert AudioChunk.from_openai(result).input_audio == b64 + + @pytest.mark.parametrize("fmt", ["wav", "flac"]) def test_transcription_to_openai_format_detection(fmt: str) -> None: audio = _make_fake_audio(0.5)