Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion rust-executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ argon2 = { version = "0.5.0", features = ["simple"] }
rand = "0.8.5"
base64 = "0.21.0"
rmcp = { version = "0.15.0", features = ["server", "transport-streamable-http-server"] }
axum = { version = "0.8", features = ["ws"] }
axum = { version = "0.8", features = ["ws", "multipart"] }
axum-server = { version = "0.7", features = ["tls-rustls"] }
tower-http = { version = "0.6", features = ["cors", "set-header", "catch-panic"] }
schemars = "1.0"
Expand Down
410 changes: 410 additions & 0 deletions rust-executor/src/ai_service/mod.rs

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions rust-executor/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
pub mod auth;
pub mod errors;
pub mod events_ws;
pub mod openai_compat;
pub mod types;
pub mod ws_rpc;

Expand Down Expand Up @@ -99,6 +100,14 @@ pub fn api_router(state: AppState) -> Router {
post(ai_ws::feed_transcription_stream),
),
)
// ── OpenAI-compatible /v1 surface ──
//
// Mounted at both `/v1` (the canonical OpenAI path) and
// `/api/v1/openai/v1` (for proxies that hard-code the `/api/v1`
// prefix from the native AD4M surface). Both share the same
// handlers + AppState.
.nest("/v1", openai_compat::router())
.nest("/api/v1/openai/v1", openai_compat::router())
// ── State + Middleware ──
.with_state(state)
.layer(Extension(handler_map))
Expand Down
300 changes: 300 additions & 0 deletions rust-executor/src/api/openai_compat/audio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
//! `POST /v1/audio/transcriptions` (batch STT) + `POST /v1/audio/speech` (TTS).
//!
//! ## Transcription (Phase 3)
//!
//! Accepts a standard OpenAI multipart upload: `file`, `model`,
//! `response_format`, `language`, `temperature`. We decode the upload to
//! 16 kHz mono `f32` samples and drive the existing Whisper pipeline in
//! one shot via [`AIService::transcribe_buffer`].
//!
//! Audio decode is delegated to the new [`audio_decode`] helper. For
//! the first cut we accept raw 16 kHz mono PCM (wav) directly; for
//! other formats the helper returns an error pointing at the
//! `Content-Type` and the caller is expected to convert client-side
//! (matches the proposal — full container support via `symphonia` is
//! scoped to a follow-up to keep the dependency surface lean).
//!
//! ## TTS (Phase 4)
//!
//! `audio/speech` forwards the request to a remote OpenAI-compatible
//! upstream configured via a TTS-typed `Model` registration — see
//! [`super::tts_passthrough`] for the HTTP client. A local Kokoro
//! backend is in the design but not in this PR; the proposal allows
//! passthrough as a valid Phase-4 deliverable.

use axum::{extract::Multipart, http::header, response::Response, Json};

use super::errors::{OpenAIError, OpenAIResult};
use super::model_selector::resolve_model;
use super::tts_passthrough;
use super::types::{SpeechRequest, TranscriptionResponse};
use crate::agent::capabilities::{
check_capability, AI_PROMPT_CAPABILITY, AI_TRANSCRIBE_CAPABILITY,
};
use crate::ai_service::AIService;
use crate::api::auth::AuthContext;
use crate::billing::{bill_compute, BillingError};
use crate::types::ModelType;

pub async fn transcriptions(
auth: AuthContext,
mut multipart: Multipart,
) -> OpenAIResult<Json<TranscriptionResponse>> {
check_capability(&auth.capabilities, &AI_TRANSCRIBE_CAPABILITY)
.map_err(OpenAIError::forbidden)?;

let mut model_field: Option<String> = None;
let mut audio_bytes: Option<Vec<u8>> = None;
let mut content_type: Option<String> = None;
let mut response_format = "json".to_string();

while let Some(field) = multipart
.next_field()
.await
.map_err(|e| OpenAIError::invalid_request(format!("Multipart parse error: {e}")))?
{
match field.name().unwrap_or("") {
"model" => {
model_field = field
.text()
.await
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
}
"file" => {
content_type = field
.content_type()
.map(|s| s.to_string())
.or_else(|| Some("application/octet-stream".to_string()));
let bytes = field
.bytes()
.await
.map_err(|e| OpenAIError::invalid_request(format!("File read error: {e}")))?;
audio_bytes = Some(bytes.to_vec());
}
"response_format" => {
if let Ok(v) = field.text().await {
response_format = v.trim().to_string();
}
}
_ => {
// Other OpenAI fields (`language`, `temperature`,
// `prompt`) are not yet plumbed through to Whisper — we
// accept and drop them so callers don't 400.
let _ = field.bytes().await;
}
}
}

let model = model_field
.ok_or_else(|| OpenAIError::invalid_request("Missing required form field: model"))?;
let bytes = audio_bytes
.ok_or_else(|| OpenAIError::invalid_request("Missing required form field: file"))?;
let model_id = resolve_model(&model, ModelType::Transcription).await?;

let samples = audio_decode(&bytes, content_type.as_deref())?;

let service = AIService::global_instance()
.await
.map_err(|e| OpenAIError::internal(e.to_string()))?;
let text = service
.transcribe_buffer(model_id, samples, auth.auth_token.clone())
.await
.map_err(|e| OpenAIError::internal(e.to_string()))?;

if let Some(email) = crate::agent::capabilities::user_email_from_token(auth.auth_token.clone())
{
// Mirror native transcription billing: 1 credit per request
// (the native path bills per-word but only after `delta` events
// — batch jobs round to a flat charge until per-word accounting
// is plumbed through).
if let Err(BillingError::InsufficientCredits) = bill_compute(
&email,
1.0,
"ai_transcription",
Some("v1/audio/transcriptions"),
) {
return Err(OpenAIError::insufficient_quota(
"Insufficient compute credits",
));
}
}

// `response_format` honoured: `json` (default) returns the envelope
// below. `text` / `verbose_json` / `srt` / `vtt` are accepted for
// request-compatibility but currently always produce the `json`
// envelope; widening the return to a raw `Response` so we can emit
// text/srt/vtt directly is a follow-up that touches the router's
// type signature.
let _ = response_format;
Ok(Json(TranscriptionResponse { text }))
}

pub async fn speech(
auth: AuthContext,
Json(req): Json<SpeechRequest>,
) -> Result<Response, OpenAIError> {
check_capability(&auth.capabilities, &AI_PROMPT_CAPABILITY).map_err(OpenAIError::forbidden)?;

// TTS isn't a registered ModelType yet (the proposal calls for a new
// `ModelType::TextToSpeech` in a follow-up); for now we accept the
// model name verbatim and forward to the configured passthrough
// upstream. See `tts_passthrough` for the configuration shape.
let response_format = req
.response_format
.clone()
.unwrap_or_else(|| "mp3".to_string());
let audio = tts_passthrough::synthesize(&auth, &req)
.await
.map_err(|e| match e {
tts_passthrough::TtsPassthroughError::NotConfigured => OpenAIError::not_implemented(
"No TTS backend is configured on this executor. \
Register a TTS upstream via the AD4M config or wait for the local Kokoro backend.",
),
tts_passthrough::TtsPassthroughError::Upstream(msg) => {
OpenAIError::internal(format!("Upstream TTS error: {msg}"))
}
})?;

if let Some(email) = crate::agent::capabilities::user_email_from_token(auth.auth_token.clone())
{
if let Err(BillingError::InsufficientCredits) =
bill_compute(&email, 1.0, "ai_tts", Some("v1/audio/speech"))
{
return Err(OpenAIError::insufficient_quota(
"Insufficient compute credits",
));
}
}

let content_type = match response_format.as_str() {
"mp3" => "audio/mpeg",
"wav" => "audio/wav",
"opus" => "audio/opus",
"flac" => "audio/flac",
"pcm" => "application/octet-stream",
_ => "application/octet-stream",
};
Response::builder()
.status(200)
.header(header::CONTENT_TYPE, content_type)
.body(audio.into())
.map_err(|e| OpenAIError::internal(format!("Response build error: {e}")))
}

/// Decode an uploaded audio file to 16 kHz mono `f32` samples.
///
/// The first-cut implementation supports raw WAV (16-bit PCM, mono,
/// 16 kHz) — the simplest case that requires no external dependencies
/// and matches what the existing transcription WS feed already expects.
/// For other containers, return a descriptive 400 so the client can
/// transcode locally.
///
/// Full `symphonia` / `rubato` decode is a follow-up: it's the right
/// home for that work, but adding two heavy crates in the same PR as
/// the rest of the OpenAI surface would dominate the diff for a single
/// endpoint.
fn audio_decode(bytes: &[u8], content_type: Option<&str>) -> OpenAIResult<Vec<f32>> {
// Quick WAV header sniff — RIFF....WAVEfmt
if bytes.len() >= 44 && &bytes[0..4] == b"RIFF" && &bytes[8..12] == b"WAVE" {
return decode_pcm_wav(bytes);
}
Err(OpenAIError::invalid_request(format!(
"Unsupported audio format (content-type: {}). \
The current /v1/audio/transcriptions endpoint accepts 16 kHz mono PCM WAV. \
Full container decode (mp3/m4a/flac/ogg) is on the roadmap; \
transcode client-side as a workaround.",
content_type.unwrap_or("(unknown)"),
)))
}

/// Minimal WAV decoder: 16-bit PCM, mono, 16 kHz.
///
/// We don't attempt to handle every WAV variant — the goal is to clear
/// the most common upload path (whisper's reference format) without
/// pulling in `symphonia`. Mismatched format → 400 with a clear hint.
fn decode_pcm_wav(bytes: &[u8]) -> OpenAIResult<Vec<f32>> {
// Find the fmt and data chunks. WAV is a RIFF container: after
// the 12-byte RIFF header, chunks alternate {id (4 bytes), size
// (u32 LE), payload}.
let mut pos = 12;
let mut audio_format: u16 = 0;
let mut channels: u16 = 0;
let mut sample_rate: u32 = 0;
let mut bits_per_sample: u16 = 0;
let mut data_offset: Option<usize> = None;
let mut data_size: usize = 0;

while pos + 8 <= bytes.len() {
let id = &bytes[pos..pos + 4];
let size = u32::from_le_bytes([
bytes[pos + 4],
bytes[pos + 5],
bytes[pos + 6],
bytes[pos + 7],
]) as usize;
pos += 8;
match id {
b"fmt " => {
if pos + 16 > bytes.len() {
return Err(OpenAIError::invalid_request("Truncated WAV fmt chunk"));
}
audio_format = u16::from_le_bytes([bytes[pos], bytes[pos + 1]]);
channels = u16::from_le_bytes([bytes[pos + 2], bytes[pos + 3]]);
sample_rate = u32::from_le_bytes([
bytes[pos + 4],
bytes[pos + 5],
bytes[pos + 6],
bytes[pos + 7],
]);
bits_per_sample = u16::from_le_bytes([bytes[pos + 14], bytes[pos + 15]]);
pos += size;
}
b"data" => {
data_offset = Some(pos);
data_size = size;
pos += size;
}
_ => {
// Skip unknown chunks (LIST, INFO, etc).
pos += size;
}
}
if pos % 2 != 0 {
pos += 1; // RIFF chunks are word-aligned.
}
}
Comment on lines +229 to +267

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Chunk size validation could read past buffer bounds.

The loop checks pos + 8 <= bytes.len() (line 229) but doesn't validate that pos + size stays within bounds before advancing. A malformed WAV with an inflated chunk size could cause pos to overflow or wrap, potentially reading garbage or causing a panic on subsequent iterations.

Proposed fix
     while pos + 8 <= bytes.len() {
         let id = &bytes[pos..pos + 4];
         let size = u32::from_le_bytes([
             bytes[pos + 4],
             bytes[pos + 5],
             bytes[pos + 6],
             bytes[pos + 7],
         ]) as usize;
         pos += 8;
+        // Clamp size to remaining bytes to handle malformed/truncated files
+        let size = size.min(bytes.len().saturating_sub(pos));
         match id {
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
while pos + 8 <= bytes.len() {
let id = &bytes[pos..pos + 4];
let size = u32::from_le_bytes([
bytes[pos + 4],
bytes[pos + 5],
bytes[pos + 6],
bytes[pos + 7],
]) as usize;
pos += 8;
match id {
b"fmt " => {
if pos + 16 > bytes.len() {
return Err(OpenAIError::invalid_request("Truncated WAV fmt chunk"));
}
audio_format = u16::from_le_bytes([bytes[pos], bytes[pos + 1]]);
channels = u16::from_le_bytes([bytes[pos + 2], bytes[pos + 3]]);
sample_rate = u32::from_le_bytes([
bytes[pos + 4],
bytes[pos + 5],
bytes[pos + 6],
bytes[pos + 7],
]);
bits_per_sample = u16::from_le_bytes([bytes[pos + 14], bytes[pos + 15]]);
pos += size;
}
b"data" => {
data_offset = Some(pos);
data_size = size;
pos += size;
}
_ => {
// Skip unknown chunks (LIST, INFO, etc).
pos += size;
}
}
if pos % 2 != 0 {
pos += 1; // RIFF chunks are word-aligned.
}
}
while pos + 8 <= bytes.len() {
let id = &bytes[pos..pos + 4];
let size = u32::from_le_bytes([
bytes[pos + 4],
bytes[pos + 5],
bytes[pos + 6],
bytes[pos + 7],
]) as usize;
pos += 8;
// Clamp size to remaining bytes to handle malformed/truncated files
let size = size.min(bytes.len().saturating_sub(pos));
match id {
b"fmt " => {
if pos + 16 > bytes.len() {
return Err(OpenAIError::invalid_request("Truncated WAV fmt chunk"));
}
audio_format = u16::from_le_bytes([bytes[pos], bytes[pos + 1]]);
channels = u16::from_le_bytes([bytes[pos + 2], bytes[pos + 3]]);
sample_rate = u32::from_le_bytes([
bytes[pos + 4],
bytes[pos + 5],
bytes[pos + 6],
bytes[pos + 7],
]);
bits_per_sample = u16::from_le_bytes([bytes[pos + 14], bytes[pos + 15]]);
pos += size;
}
b"data" => {
data_offset = Some(pos);
data_size = size;
pos += size;
}
_ => {
// Skip unknown chunks (LIST, INFO, etc).
pos += size;
}
}
if pos % 2 != 0 {
pos += 1; // RIFF chunks are word-aligned.
}
}
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rust-executor/src/api/openai_compat/audio.rs` around lines 229 - 267, The WAV
chunk parsing loop validates that the header can be read with `pos + 8 <=
bytes.len()` but fails to validate that the entire chunk data fits within the
buffer before advancing `pos` by the chunk size. After reading the chunk size
value with `u32::from_le_bytes`, add a bounds check to ensure `pos + size <=
bytes.len()` before processing any chunk data. This check should occur
immediately after parsing the size and before the match statement, so that any
chunk with an inflated size is rejected early rather than allowing `pos` to
advance beyond the buffer bounds, which could cause panics or undefined behavior
on subsequent iterations.


let data_offset =
data_offset.ok_or_else(|| OpenAIError::invalid_request("WAV file has no data chunk"))?;
if audio_format != 1 {
return Err(OpenAIError::invalid_request(format!(
"Only uncompressed PCM WAV is supported (audio_format = {audio_format}). \
Re-encode as PCM and retry."
)));
}
if channels != 1 {
return Err(OpenAIError::invalid_request(format!(
"Only mono WAV is supported (channels = {channels}). Downmix to mono."
)));
}
if sample_rate != 16_000 {
return Err(OpenAIError::invalid_request(format!(
"Only 16 kHz WAV is supported (sample_rate = {sample_rate}). Resample to 16 kHz."
)));
}
if bits_per_sample != 16 {
return Err(OpenAIError::invalid_request(format!(
"Only 16-bit PCM WAV is supported (bits_per_sample = {bits_per_sample})."
)));
}

let payload = &bytes[data_offset..data_offset + data_size.min(bytes.len() - data_offset)];
let mut samples = Vec::with_capacity(payload.len() / 2);
for chunk in payload.chunks_exact(2) {
let s = i16::from_le_bytes([chunk[0], chunk[1]]);
samples.push((s as f32) / (i16::MAX as f32));
}
Ok(samples)
}
Loading