Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/mistral_common/protocol/instruct/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
from mistral_common.tokens.tokenizers.audio import Audio


def _strip_audio_data_url_prefix(data: str) -> str:
r"""Remove the optional base64 audio data URL prefix."""
if re.match(r"^data:audio/\w+;base64,", data):
return data.split(",", 1)[1]
return data


def _detect_audio_format(data: str | bytes) -> str:
r"""Detect audio format from base64-encoded string or raw bytes.

Expand All @@ -37,7 +44,7 @@ def _detect_audio_format(data: str | bytes) -> str:
assert_soundfile_installed()

if isinstance(data, str):
audio_bytes = base64.b64decode(data)
audio_bytes = base64.b64decode(_strip_audio_data_url_prefix(data))
else:
audio_bytes = data

Expand Down Expand Up @@ -379,7 +386,10 @@ def to_openai(self) -> dict[str, Any]:
Returns:
A dictionary representing the audio chunk in the OpenAI format.
"""
content = self.input_audio.decode("utf-8") if isinstance(self.input_audio, bytes) else self.input_audio
if isinstance(self.input_audio, bytes):
content = base64.b64encode(self.input_audio).decode("utf-8")
else:
content = _strip_audio_data_url_prefix(self.input_audio)
fmt = _detect_audio_format(self.input_audio)
return {
"type": self.type,
Expand Down
28 changes: 28 additions & 0 deletions tests/test_converters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import copy
import io
import warnings
Expand Down Expand Up @@ -1197,6 +1198,33 @@ def test_audio_chunk_to_openai_format_detection(fmt: str) -> None:
assert AudioChunk.from_openai(result).input_audio == b64


@pytest.mark.parametrize("fmt", ["wav", "flac"])
def test_audio_chunk_to_openai_raw_bytes_format_detection(fmt: str) -> None:
audio = _make_fake_audio(0.5)
buffer = io.BytesIO()
sf.write(buffer, audio.audio_array, audio.sampling_rate, format=fmt)
raw_bytes = buffer.getvalue()

result = AudioChunk(input_audio=raw_bytes).to_openai()

assert result["input_audio"]["format"] == fmt
assert result["input_audio"]["data"] == base64.b64encode(raw_bytes).decode("utf-8")
assert AudioChunk.from_openai(result).input_audio == result["input_audio"]["data"]


@pytest.mark.parametrize("fmt", ["wav", "flac"])
def test_audio_chunk_to_openai_strips_base64_data_url_prefix(fmt: str) -> None:
audio = _make_fake_audio(0.5)
b64 = audio.to_base64(fmt)
chunk = AudioChunk(input_audio=f"data:audio/{fmt};base64,{b64}")

result = chunk.to_openai()

assert result["input_audio"]["format"] == fmt
assert result["input_audio"]["data"] == b64
assert AudioChunk.from_openai(result).input_audio == b64


@pytest.mark.parametrize("fmt", ["wav", "flac"])
def test_transcription_to_openai_format_detection(fmt: str) -> None:
audio = _make_fake_audio(0.5)
Expand Down
Loading