diff --git a/Cargo.lock b/Cargo.lock index 1ca24b8a4..11d939864 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1197,6 +1197,7 @@ dependencies = [ "matchit 0.8.4", "memchr", "mime", + "multer 3.1.0", "percent-encoding", "pin-project-lite", "rustversion", diff --git a/rust-executor/Cargo.toml b/rust-executor/Cargo.toml index c654e3fe1..26b35d20c 100644 --- a/rust-executor/Cargo.toml +++ b/rust-executor/Cargo.toml @@ -68,7 +68,7 @@ argon2 = { version = "0.5.0", features = ["simple"] } rand = "0.8.5" base64 = "0.21.0" rmcp = { version = "0.15.0", features = ["server", "transport-streamable-http-server"] } -axum = { version = "0.8", features = ["ws"] } +axum = { version = "0.8", features = ["ws", "multipart"] } axum-server = { version = "0.7", features = ["tls-rustls"] } tower-http = { version = "0.6", features = ["cors", "set-header", "catch-panic"] } schemars = "1.0" diff --git a/rust-executor/src/ai_service/mod.rs b/rust-executor/src/ai_service/mod.rs index e67f9111a..33f602bc8 100644 --- a/rust-executor/src/ai_service/mod.rs +++ b/rust-executor/src/ai_service/mod.rs @@ -34,6 +34,7 @@ pub type Result = std::result::Result; /// Result of an LLM prompt call, with token counts for billing. /// Token counts are estimated (chars/4) with the current Kalosum backend. /// When Ollama is integrated, these will be exact from the API response. +#[derive(Debug, Clone)] pub struct PromptResult { pub text: String, pub prompt_tokens: usize, @@ -137,11 +138,34 @@ struct LLMTaskShutdownRequest { pub result_sender: oneshot::Sender<()>, } +/// Streaming-prompt variant for the OpenAI-compatible chat endpoint. +/// +/// Identical to [`LLMTaskPromptRequest`] but instead of a oneshot reply +/// the caller receives an async stream of token chunks. Used by +/// `prompt_messages_stream` to back `POST /v1/chat/completions` with +/// `stream: true`. +/// +/// The receiver side is an `mpsc::Receiver` so each `Token` arm +/// can push without awaiting consumer backpressure; the HTTP handler +/// converts each token into one SSE `chat.completion.chunk`. +/// +/// `done_sender` fires once the model has emitted its final token (or +/// errored) and carries `PromptResult` for the closing chunk's `usage`. +#[allow(dead_code)] +#[derive(Debug)] +struct LLMTaskPromptStreamRequest { + pub task_id: String, + pub prompt: String, + pub token_sender: mpsc::UnboundedSender, + pub done_sender: oneshot::Sender>, +} + #[allow(dead_code)] #[derive(Debug)] enum LLMTaskRequest { Spawn(LLMTaskSpawnRequest), Prompt(LLMTaskPromptRequest), + PromptStream(LLMTaskPromptStreamRequest), Remove(LLMTaskRemoveRequest), Shutdown(LLMTaskShutdownRequest), } @@ -844,6 +868,144 @@ impl AIService { } }, + LLMTaskRequest::PromptStream(stream_request) => match model { + LlmModel::Remote((ref mut remote_client, ref model_string)) => { + // Remote upstreams use the non-streaming + // chat call today; we deliver the full + // response as a single token chunk so + // SSE consumers still see the "stream" + // protocol (one delta + final usage). + // A native streaming upstream client is + // a follow-up. + if let Some(task) = + task_descriptions.get(&stream_request.task_id) + { + let mut messages = vec![Message { + role: Role::System, + content: task.system_prompt.clone(), + }]; + for example in task.prompt_examples.iter() { + messages.push(Message { + role: Role::User, + content: example.input.clone(), + }); + messages.push(Message { + role: Role::Assistant, + content: example.output.clone(), + }); + } + let prompt_clone = stream_request.prompt.clone(); + messages.push(Message { + role: Role::User, + content: prompt_clone.clone(), + }); + let chat_input = ChatInput { + model: chat_gpt_lib_rs::Model::Custom( + model_string.clone(), + ), + messages, + ..Default::default() + }; + match rt.block_on(remote_client.chat(chat_input)) { + Err(e) => { + let _ = + stream_request.done_sender.send(Err(anyhow!( + "Error connecting to remote LLM API: {:?}", + e + ))); + } + Ok(response) => { + let text = response + .choices + .first() + .map(|c| c.message.content.clone()) + .unwrap_or_default(); + let _ = + stream_request.token_sender.send(text.clone()); + let prompt_tokens = + estimate_token_count(&prompt_clone); + let completion_tokens = estimate_token_count(&text); + let _ = stream_request.done_sender.send(Ok( + PromptResult { + text, + prompt_tokens, + completion_tokens, + model_id: model_config.id.clone(), + }, + )); + } + } + } else { + let _ = stream_request.done_sender.send(Err(anyhow!( + "Task with ID {} not spawned", + stream_request.task_id + ))); + } + } + LlmModel::Local(_) => { + if let Some(task) = tasks.get(&stream_request.task_id) { + rt.block_on(publish_model_status( + model_config.id.clone(), + 100.0, + "Running inference...", + true, + true, + )); + + let prompt_clone = stream_request.prompt.clone(); + let prompt_tokens = estimate_token_count(&prompt_clone); + let token_sender = stream_request.token_sender.clone(); + // Forward each token through the + // mpsc back to the SSE handler. We + // accumulate the full text so the + // closing event still carries a + // single concatenated string for + // billing + usage parity with the + // non-stream path. + let result = rt.block_on(async { + use futures::StreamExt; + // `task.run(...)` returns a + // `ChatResponseBuilder` which + // implements `Stream`; + // polling it yields one token + // chunk at a time. + let mut stream = + Box::pin(task.run(prompt_clone.clone())); + let mut accumulated = String::new(); + while let Some(token) = stream.next().await { + accumulated.push_str(&token); + if token_sender.send(token).is_err() { + // consumer dropped — stop generating + break; + } + } + accumulated + }); + + rt.block_on(publish_model_status( + model_config.id.clone(), + 100.0, + "Ready", + true, + true, + )); + + let completion_tokens = estimate_token_count(&result); + let _ = stream_request.done_sender.send(Ok(PromptResult { + text: result, + prompt_tokens, + completion_tokens, + model_id: model_config.id.clone(), + })); + } else { + let _ = stream_request.done_sender.send(Err(anyhow!( + "Task with ID {} not spawned", + stream_request.task_id + ))); + } + } + }, + LLMTaskRequest::Remove(remove_request) => { let _ = tasks.remove(&remove_request.task_id); let _ = task_descriptions.remove(&remove_request.task_id); @@ -973,6 +1135,205 @@ impl AIService { Ok(()) } + /// Build an ephemeral, in-memory task from a list of `(system, user, + /// assistant)` messages and return its `task_id`. Used by the + /// OpenAI-compatible `chat.completions` handler so each request is + /// stateless: spawn → prompt → remove. Nothing touches the DB. + /// + /// `messages` is in OpenAI order — system message(s) at the front, + /// then alternating user/assistant turns ending with a final user + /// message. The final user message is returned separately so the + /// caller passes it as the prompt; everything before it becomes the + /// task's `system_prompt` + few-shot example pairs. + fn build_ephemeral_task(model_id: &str, messages: Vec<(String, String)>) -> (AITask, String) { + // Split: leading System messages → system_prompt; alternating + // User/Assistant pairs → examples; last User message → prompt. + let mut system_prompt = String::new(); + let mut examples: Vec = Vec::new(); + let mut current_user: Option = None; + let mut final_prompt = String::new(); + + for (role, content) in &messages { + match role.as_str() { + "system" => { + if !system_prompt.is_empty() { + system_prompt.push_str("\n"); + } + system_prompt.push_str(content); + } + "user" => { + if let Some(prev_user) = current_user.take() { + // user→user with no assistant in between; treat + // earlier user as a finalised final prompt slot + // (will be overwritten if more turns follow). + final_prompt = prev_user; + } + current_user = Some(content.clone()); + } + "assistant" => { + if let Some(user_msg) = current_user.take() { + examples.push(crate::types::AIPromptExamples { + input: user_msg, + output: content.clone(), + }); + } + } + _ => { /* tool/function/developer — silently dropped */ } + } + } + if let Some(last_user) = current_user { + final_prompt = last_user; + } + + let task_id = format!("openai-compat-{}", uuid::Uuid::new_v4()); + let now = chrono::Utc::now().to_rfc3339(); + let task = AITask { + name: format!("openai-compat-{}", &task_id[..8]), + task_id: task_id.clone(), + model_id: model_id.to_string(), + system_prompt, + prompt_examples: examples, + meta_data: None, + created_at: now.clone(), + updated_at: now, + }; + (task, final_prompt) + } + + /// Run a chat-style prompt on `model_id` with a stateless message + /// list and return the full response text + token counts. Backs + /// `POST /v1/chat/completions`. + pub async fn prompt_messages( + &self, + model_id: String, + messages: Vec<(String, String)>, + ) -> Result { + let resolved = Self::replace_model_variables(&model_id)?; + let (task, final_prompt) = Self::build_ephemeral_task(&resolved, messages); + + // Spawn + prompt + remove. We use the existing per-model thread + // command channels so model state stays local to its thread and + // we don't duplicate the build-llama machinery. + let task_id = task.task_id.clone(); + let (spawn_tx, spawn_rx) = oneshot::channel(); + { + let llm_channel = self.llm_channel.lock().await; + let sender = llm_channel + .get(&resolved) + .ok_or_else(|| anyhow!("Model '{}' not found in LLM channel", resolved))?; + sender.send(LLMTaskRequest::Spawn(LLMTaskSpawnRequest { + task, + result_sender: spawn_tx, + }))?; + } + spawn_rx.await??; + + let (prompt_tx, prompt_rx) = oneshot::channel(); + { + let llm_channel = self.llm_channel.lock().await; + let sender = llm_channel + .get(&resolved) + .ok_or_else(|| anyhow!("Model '{}' not found in LLM channel", resolved))?; + sender.send(LLMTaskRequest::Prompt(LLMTaskPromptRequest { + task_id: task_id.clone(), + prompt: final_prompt.clone(), + result_sender: prompt_tx, + }))?; + } + let prompt_tokens = estimate_token_count(&final_prompt); + let text = prompt_rx.await??; + let completion_tokens = estimate_token_count(&text); + + // Clean up the ephemeral task entry on the LLM thread. + let (remove_tx, _) = oneshot::channel(); + { + let llm_channel = self.llm_channel.lock().await; + if let Some(sender) = llm_channel.get(&resolved) { + let _ = sender.send(LLMTaskRequest::Remove(LLMTaskRemoveRequest { + task_id: task_id.clone(), + result_sender: remove_tx, + })); + } + } + + Ok(PromptResult { + text, + prompt_tokens, + completion_tokens, + model_id: resolved, + }) + } + + /// Streaming variant of [`Self::prompt_messages`]. Returns a token + /// stream + a oneshot for the final [`PromptResult`] (carries the + /// full text + token counts for the closing SSE event). + pub async fn prompt_messages_stream( + &self, + model_id: String, + messages: Vec<(String, String)>, + ) -> Result<( + mpsc::UnboundedReceiver, + oneshot::Receiver>, + )> { + let resolved = Self::replace_model_variables(&model_id)?; + let (task, final_prompt) = Self::build_ephemeral_task(&resolved, messages); + + let task_id = task.task_id.clone(); + let (spawn_tx, spawn_rx) = oneshot::channel(); + { + let llm_channel = self.llm_channel.lock().await; + let sender = llm_channel + .get(&resolved) + .ok_or_else(|| anyhow!("Model '{}' not found in LLM channel", resolved))?; + sender.send(LLMTaskRequest::Spawn(LLMTaskSpawnRequest { + task, + result_sender: spawn_tx, + }))?; + } + spawn_rx.await??; + + let (token_tx, token_rx) = mpsc::unbounded_channel::(); + let (done_tx, done_rx) = oneshot::channel::>(); + { + let llm_channel = self.llm_channel.lock().await; + let sender = llm_channel + .get(&resolved) + .ok_or_else(|| anyhow!("Model '{}' not found in LLM channel", resolved))?; + sender.send(LLMTaskRequest::PromptStream(LLMTaskPromptStreamRequest { + task_id: task_id.clone(), + prompt: final_prompt, + token_sender: token_tx, + done_sender: done_tx, + }))?; + } + + // Schedule a Remove for after the prompt completes — best-effort + // background cleanup so the caller doesn't have to await it. + let llm_channel_clone = self.llm_channel.clone(); + let resolved_clone = resolved.clone(); + tokio::spawn(async move { + // Wait a little for the prompt to finish before removing the + // task entry; if the model is local, the thread holds the + // task in its `tasks` map for the duration of the call, so + // removing too early would race the inference. + // + // We don't have a reliable signal here without restructuring + // the channel; the LLM thread serialises requests anyway, so + // a Remove queued immediately will run after the + // PromptStream completes. No sleep needed. + let (remove_tx, _) = oneshot::channel(); + let llm_channel = llm_channel_clone.lock().await; + if let Some(sender) = llm_channel.get(&resolved_clone) { + let _ = sender.send(LLMTaskRequest::Remove(LLMTaskRemoveRequest { + task_id, + result_sender: remove_tx, + })); + } + }); + + Ok((token_rx, done_rx)) + } + pub async fn prompt(&self, task_id: String, prompt: String) -> Result { let (result_sender, rx) = oneshot::channel(); @@ -1472,6 +1833,55 @@ impl AIService { } } + /// One-shot transcription: open a stream, feed the whole audio + /// buffer, drain the broadcast channel until idle, close. + /// + /// Backs `POST /v1/audio/transcriptions` (batch multipart upload). + /// The caller is responsible for decoding the upload to 16 kHz mono + /// `f32` samples — this method only handles the whisper-side + /// session lifecycle and assembly of the final transcript. + pub async fn transcribe_buffer( + &self, + model_id: String, + samples: Vec, + auth_token: String, + ) -> Result { + let stream_id = self + .open_transcription_stream(model_id, None, auth_token.clone()) + .await?; + + // Subscribe to the broadcast BEFORE feeding so we don't drop any + // partial events the whisper thread emits during the first + // window. + let mut rx = self + .feed_transcription_stream_with_broadcast(&stream_id, samples, &auth_token) + .await?; + + // Drain until the receiver lags or no message arrives within a + // short idle window. Whisper emits one event per finalised + // segment; for a one-shot buffer the model usually emits 1-3 + // segments and then goes quiet. + let mut transcript = String::new(); + let idle = Duration::from_millis(2500); + loop { + match tokio::time::timeout(idle, rx.recv()).await { + Ok(Ok(text)) => { + if !transcript.is_empty() && !text.is_empty() { + transcript.push(' '); + } + transcript.push_str(&text); + } + Ok(Err(_)) => break, // sender dropped + Err(_) => break, // idle timeout + } + } + + let _ = self + .close_transcription_stream(&stream_id, &auth_token) + .await; + Ok(transcript) + } + pub async fn close_transcription_stream( &self, stream_id: &String, diff --git a/rust-executor/src/api/mod.rs b/rust-executor/src/api/mod.rs index 3a48574b0..e5bed8aa5 100644 --- a/rust-executor/src/api/mod.rs +++ b/rust-executor/src/api/mod.rs @@ -6,6 +6,7 @@ pub mod auth; pub mod errors; pub mod events_ws; +pub mod openai_compat; pub mod types; pub mod ws_rpc; @@ -99,6 +100,14 @@ pub fn api_router(state: AppState) -> Router { post(ai_ws::feed_transcription_stream), ), ) + // ── OpenAI-compatible /v1 surface ── + // + // Mounted at both `/v1` (the canonical OpenAI path) and + // `/api/v1/openai/v1` (for proxies that hard-code the `/api/v1` + // prefix from the native AD4M surface). Both share the same + // handlers + AppState. + .nest("/v1", openai_compat::router()) + .nest("/api/v1/openai/v1", openai_compat::router()) // ── State + Middleware ── .with_state(state) .layer(Extension(handler_map)) diff --git a/rust-executor/src/api/openai_compat/audio.rs b/rust-executor/src/api/openai_compat/audio.rs new file mode 100644 index 000000000..c818416e8 --- /dev/null +++ b/rust-executor/src/api/openai_compat/audio.rs @@ -0,0 +1,300 @@ +//! `POST /v1/audio/transcriptions` (batch STT) + `POST /v1/audio/speech` (TTS). +//! +//! ## Transcription (Phase 3) +//! +//! Accepts a standard OpenAI multipart upload: `file`, `model`, +//! `response_format`, `language`, `temperature`. We decode the upload to +//! 16 kHz mono `f32` samples and drive the existing Whisper pipeline in +//! one shot via [`AIService::transcribe_buffer`]. +//! +//! Audio decode is delegated to the new [`audio_decode`] helper. For +//! the first cut we accept raw 16 kHz mono PCM (wav) directly; for +//! other formats the helper returns an error pointing at the +//! `Content-Type` and the caller is expected to convert client-side +//! (matches the proposal — full container support via `symphonia` is +//! scoped to a follow-up to keep the dependency surface lean). +//! +//! ## TTS (Phase 4) +//! +//! `audio/speech` forwards the request to a remote OpenAI-compatible +//! upstream configured via a TTS-typed `Model` registration — see +//! [`super::tts_passthrough`] for the HTTP client. A local Kokoro +//! backend is in the design but not in this PR; the proposal allows +//! passthrough as a valid Phase-4 deliverable. + +use axum::{extract::Multipart, http::header, response::Response, Json}; + +use super::errors::{OpenAIError, OpenAIResult}; +use super::model_selector::resolve_model; +use super::tts_passthrough; +use super::types::{SpeechRequest, TranscriptionResponse}; +use crate::agent::capabilities::{ + check_capability, AI_PROMPT_CAPABILITY, AI_TRANSCRIBE_CAPABILITY, +}; +use crate::ai_service::AIService; +use crate::api::auth::AuthContext; +use crate::billing::{bill_compute, BillingError}; +use crate::types::ModelType; + +pub async fn transcriptions( + auth: AuthContext, + mut multipart: Multipart, +) -> OpenAIResult> { + check_capability(&auth.capabilities, &AI_TRANSCRIBE_CAPABILITY) + .map_err(OpenAIError::forbidden)?; + + let mut model_field: Option = None; + let mut audio_bytes: Option> = None; + let mut content_type: Option = None; + let mut response_format = "json".to_string(); + + while let Some(field) = multipart + .next_field() + .await + .map_err(|e| OpenAIError::invalid_request(format!("Multipart parse error: {e}")))? + { + match field.name().unwrap_or("") { + "model" => { + model_field = field + .text() + .await + .ok() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()); + } + "file" => { + content_type = field + .content_type() + .map(|s| s.to_string()) + .or_else(|| Some("application/octet-stream".to_string())); + let bytes = field + .bytes() + .await + .map_err(|e| OpenAIError::invalid_request(format!("File read error: {e}")))?; + audio_bytes = Some(bytes.to_vec()); + } + "response_format" => { + if let Ok(v) = field.text().await { + response_format = v.trim().to_string(); + } + } + _ => { + // Other OpenAI fields (`language`, `temperature`, + // `prompt`) are not yet plumbed through to Whisper — we + // accept and drop them so callers don't 400. + let _ = field.bytes().await; + } + } + } + + let model = model_field + .ok_or_else(|| OpenAIError::invalid_request("Missing required form field: model"))?; + let bytes = audio_bytes + .ok_or_else(|| OpenAIError::invalid_request("Missing required form field: file"))?; + let model_id = resolve_model(&model, ModelType::Transcription).await?; + + let samples = audio_decode(&bytes, content_type.as_deref())?; + + let service = AIService::global_instance() + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + let text = service + .transcribe_buffer(model_id, samples, auth.auth_token.clone()) + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + + if let Some(email) = crate::agent::capabilities::user_email_from_token(auth.auth_token.clone()) + { + // Mirror native transcription billing: 1 credit per request + // (the native path bills per-word but only after `delta` events + // — batch jobs round to a flat charge until per-word accounting + // is plumbed through). + if let Err(BillingError::InsufficientCredits) = bill_compute( + &email, + 1.0, + "ai_transcription", + Some("v1/audio/transcriptions"), + ) { + return Err(OpenAIError::insufficient_quota( + "Insufficient compute credits", + )); + } + } + + // `response_format` honoured: `json` (default) returns the envelope + // below. `text` / `verbose_json` / `srt` / `vtt` are accepted for + // request-compatibility but currently always produce the `json` + // envelope; widening the return to a raw `Response` so we can emit + // text/srt/vtt directly is a follow-up that touches the router's + // type signature. + let _ = response_format; + Ok(Json(TranscriptionResponse { text })) +} + +pub async fn speech( + auth: AuthContext, + Json(req): Json, +) -> Result { + check_capability(&auth.capabilities, &AI_PROMPT_CAPABILITY).map_err(OpenAIError::forbidden)?; + + // TTS isn't a registered ModelType yet (the proposal calls for a new + // `ModelType::TextToSpeech` in a follow-up); for now we accept the + // model name verbatim and forward to the configured passthrough + // upstream. See `tts_passthrough` for the configuration shape. + let response_format = req + .response_format + .clone() + .unwrap_or_else(|| "mp3".to_string()); + let audio = tts_passthrough::synthesize(&auth, &req) + .await + .map_err(|e| match e { + tts_passthrough::TtsPassthroughError::NotConfigured => OpenAIError::not_implemented( + "No TTS backend is configured on this executor. \ + Register a TTS upstream via the AD4M config or wait for the local Kokoro backend.", + ), + tts_passthrough::TtsPassthroughError::Upstream(msg) => { + OpenAIError::internal(format!("Upstream TTS error: {msg}")) + } + })?; + + if let Some(email) = crate::agent::capabilities::user_email_from_token(auth.auth_token.clone()) + { + if let Err(BillingError::InsufficientCredits) = + bill_compute(&email, 1.0, "ai_tts", Some("v1/audio/speech")) + { + return Err(OpenAIError::insufficient_quota( + "Insufficient compute credits", + )); + } + } + + let content_type = match response_format.as_str() { + "mp3" => "audio/mpeg", + "wav" => "audio/wav", + "opus" => "audio/opus", + "flac" => "audio/flac", + "pcm" => "application/octet-stream", + _ => "application/octet-stream", + }; + Response::builder() + .status(200) + .header(header::CONTENT_TYPE, content_type) + .body(audio.into()) + .map_err(|e| OpenAIError::internal(format!("Response build error: {e}"))) +} + +/// Decode an uploaded audio file to 16 kHz mono `f32` samples. +/// +/// The first-cut implementation supports raw WAV (16-bit PCM, mono, +/// 16 kHz) — the simplest case that requires no external dependencies +/// and matches what the existing transcription WS feed already expects. +/// For other containers, return a descriptive 400 so the client can +/// transcode locally. +/// +/// Full `symphonia` / `rubato` decode is a follow-up: it's the right +/// home for that work, but adding two heavy crates in the same PR as +/// the rest of the OpenAI surface would dominate the diff for a single +/// endpoint. +fn audio_decode(bytes: &[u8], content_type: Option<&str>) -> OpenAIResult> { + // Quick WAV header sniff — RIFF....WAVEfmt + if bytes.len() >= 44 && &bytes[0..4] == b"RIFF" && &bytes[8..12] == b"WAVE" { + return decode_pcm_wav(bytes); + } + Err(OpenAIError::invalid_request(format!( + "Unsupported audio format (content-type: {}). \ + The current /v1/audio/transcriptions endpoint accepts 16 kHz mono PCM WAV. \ + Full container decode (mp3/m4a/flac/ogg) is on the roadmap; \ + transcode client-side as a workaround.", + content_type.unwrap_or("(unknown)"), + ))) +} + +/// Minimal WAV decoder: 16-bit PCM, mono, 16 kHz. +/// +/// We don't attempt to handle every WAV variant — the goal is to clear +/// the most common upload path (whisper's reference format) without +/// pulling in `symphonia`. Mismatched format → 400 with a clear hint. +fn decode_pcm_wav(bytes: &[u8]) -> OpenAIResult> { + // Find the fmt and data chunks. WAV is a RIFF container: after + // the 12-byte RIFF header, chunks alternate {id (4 bytes), size + // (u32 LE), payload}. + let mut pos = 12; + let mut audio_format: u16 = 0; + let mut channels: u16 = 0; + let mut sample_rate: u32 = 0; + let mut bits_per_sample: u16 = 0; + let mut data_offset: Option = None; + let mut data_size: usize = 0; + + while pos + 8 <= bytes.len() { + let id = &bytes[pos..pos + 4]; + let size = u32::from_le_bytes([ + bytes[pos + 4], + bytes[pos + 5], + bytes[pos + 6], + bytes[pos + 7], + ]) as usize; + pos += 8; + match id { + b"fmt " => { + if pos + 16 > bytes.len() { + return Err(OpenAIError::invalid_request("Truncated WAV fmt chunk")); + } + audio_format = u16::from_le_bytes([bytes[pos], bytes[pos + 1]]); + channels = u16::from_le_bytes([bytes[pos + 2], bytes[pos + 3]]); + sample_rate = u32::from_le_bytes([ + bytes[pos + 4], + bytes[pos + 5], + bytes[pos + 6], + bytes[pos + 7], + ]); + bits_per_sample = u16::from_le_bytes([bytes[pos + 14], bytes[pos + 15]]); + pos += size; + } + b"data" => { + data_offset = Some(pos); + data_size = size; + pos += size; + } + _ => { + // Skip unknown chunks (LIST, INFO, etc). + pos += size; + } + } + if pos % 2 != 0 { + pos += 1; // RIFF chunks are word-aligned. + } + } + + let data_offset = + data_offset.ok_or_else(|| OpenAIError::invalid_request("WAV file has no data chunk"))?; + if audio_format != 1 { + return Err(OpenAIError::invalid_request(format!( + "Only uncompressed PCM WAV is supported (audio_format = {audio_format}). \ + Re-encode as PCM and retry." + ))); + } + if channels != 1 { + return Err(OpenAIError::invalid_request(format!( + "Only mono WAV is supported (channels = {channels}). Downmix to mono." + ))); + } + if sample_rate != 16_000 { + return Err(OpenAIError::invalid_request(format!( + "Only 16 kHz WAV is supported (sample_rate = {sample_rate}). Resample to 16 kHz." + ))); + } + if bits_per_sample != 16 { + return Err(OpenAIError::invalid_request(format!( + "Only 16-bit PCM WAV is supported (bits_per_sample = {bits_per_sample})." + ))); + } + + let payload = &bytes[data_offset..data_offset + data_size.min(bytes.len() - data_offset)]; + let mut samples = Vec::with_capacity(payload.len() / 2); + for chunk in payload.chunks_exact(2) { + let s = i16::from_le_bytes([chunk[0], chunk[1]]); + samples.push((s as f32) / (i16::MAX as f32)); + } + Ok(samples) +} diff --git a/rust-executor/src/api/openai_compat/chat.rs b/rust-executor/src/api/openai_compat/chat.rs new file mode 100644 index 000000000..34388653e --- /dev/null +++ b/rust-executor/src/api/openai_compat/chat.rs @@ -0,0 +1,294 @@ +//! `POST /v1/chat/completions` and `POST /v1/completions`. +//! +//! Both endpoints translate an OpenAI request into an ephemeral +//! `AIService::prompt_messages{,_stream}` call. No DB-backed task is +//! created; the model thread spawns the task in-memory for the duration +//! of the call. + +use std::convert::Infallible; +use std::time::SystemTime; + +use axum::{ + response::{sse::Event, IntoResponse, Sse}, + Json, +}; +use futures::Stream; +use uuid::Uuid; + +use super::errors::{OpenAIError, OpenAIResult}; +use super::model_selector::resolve_model; +use super::types::{ + ChatChoice, ChatChunkChoice, ChatChunkDelta, ChatCompletionChunk, ChatCompletionRequest, + ChatCompletionResponse, ChatResponseMessage, CompletionChoice, CompletionRequest, + CompletionResponse, Role, Usage, +}; +use crate::agent::capabilities::{check_capability, AI_PROMPT_CAPABILITY}; +use crate::ai_service::AIService; +use crate::api::auth::AuthContext; +use crate::billing::{bill_compute, BillingError}; +use crate::types::ModelType; + +/// `POST /v1/chat/completions` — handles both streaming (`stream: true`) +/// and non-streaming responses. +pub async fn chat_completions( + auth: AuthContext, + Json(req): Json, +) -> Result { + check_capability(&auth.capabilities, &AI_PROMPT_CAPABILITY).map_err(OpenAIError::forbidden)?; + + // Resolve the OpenAI `model` string to an AD4M model_id. + let model_id = resolve_model(&req.model, ModelType::Llm).await?; + + // Flatten messages to (role, content) pairs. + let messages: Vec<(String, String)> = req + .messages + .iter() + .map(|m| { + ( + role_to_str(&m.role).to_string(), + m.content.flatten_to_text(), + ) + }) + .collect(); + + if req.stream { + chat_stream(auth, req.model.clone(), model_id, messages).await + } else { + chat_oneshot(auth, req.model.clone(), model_id, messages).await + } +} + +/// `POST /v1/completions` (legacy text-completion). Treats `prompt` as a +/// single user message with no system prompt. +pub async fn completions( + auth: AuthContext, + Json(req): Json, +) -> OpenAIResult> { + check_capability(&auth.capabilities, &AI_PROMPT_CAPABILITY).map_err(OpenAIError::forbidden)?; + + let model_id = resolve_model(&req.model, ModelType::Llm).await?; + let messages = vec![("user".to_string(), req.prompt.first())]; + + let service = AIService::global_instance() + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + let result = service + .prompt_messages(model_id, messages) + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + + if let Some(email) = user_email(&auth) { + if let Err(BillingError::InsufficientCredits) = + bill_compute(&email, 1.0, "ai_prompt", Some("v1/completions")) + { + return Err(OpenAIError::insufficient_quota( + "Insufficient compute credits", + )); + } + } + + Ok(Json(CompletionResponse { + id: format!("cmpl-{}", Uuid::new_v4()), + object: "text_completion", + created: epoch_seconds(), + model: req.model, + choices: vec![CompletionChoice { + index: 0, + text: result.text, + finish_reason: "stop", + }], + usage: Usage { + prompt_tokens: result.prompt_tokens as u64, + completion_tokens: result.completion_tokens as u64, + total_tokens: (result.prompt_tokens + result.completion_tokens) as u64, + }, + })) +} + +async fn chat_oneshot( + auth: AuthContext, + requested_model: String, + model_id: String, + messages: Vec<(String, String)>, +) -> Result { + let service = AIService::global_instance() + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + let result = service + .prompt_messages(model_id, messages) + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + + if let Some(email) = user_email(&auth) { + if let Err(BillingError::InsufficientCredits) = + bill_compute(&email, 1.0, "ai_prompt", Some("v1/chat/completions")) + { + return Err(OpenAIError::insufficient_quota( + "Insufficient compute credits", + )); + } + } + + let body = ChatCompletionResponse { + id: format!("chatcmpl-{}", Uuid::new_v4()), + object: "chat.completion", + created: epoch_seconds(), + model: requested_model, + choices: vec![ChatChoice { + index: 0, + message: ChatResponseMessage { + role: "assistant", + content: result.text, + }, + finish_reason: "stop", + }], + usage: Usage { + prompt_tokens: result.prompt_tokens as u64, + completion_tokens: result.completion_tokens as u64, + total_tokens: (result.prompt_tokens + result.completion_tokens) as u64, + }, + }; + Ok(Json(body).into_response()) +} + +async fn chat_stream( + auth: AuthContext, + requested_model: String, + model_id: String, + messages: Vec<(String, String)>, +) -> Result { + let service = AIService::global_instance() + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + let (token_rx, done_rx) = service + .prompt_messages_stream(model_id, messages) + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + + let id = format!("chatcmpl-{}", Uuid::new_v4()); + let created = epoch_seconds(); + let stream_model = requested_model.clone(); + + // We construct the SSE stream by spawning a forwarder task that + // drains tokens from the LLM thread + emits one `Event` per chunk + // into a bounded channel. axum's `Sse` consumes the resulting + // receiver. This avoids pulling in `async-stream` for the sole + // sake of one `yield`-style generator. + let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::>(); + let auth_clone = auth.clone(); + + // Initial role event. + let role_chunk = ChatCompletionChunk { + id: id.clone(), + object: "chat.completion.chunk", + created, + model: stream_model.clone(), + choices: vec![ChatChunkChoice { + index: 0, + delta: ChatChunkDelta { + role: Some("assistant"), + content: None, + }, + finish_reason: None, + }], + }; + let _ = event_tx.send(Ok( + Event::default().data(serde_json::to_string(&role_chunk).unwrap()) + )); + + tokio::spawn({ + let event_tx = event_tx.clone(); + let stream_model = stream_model.clone(); + let id = id.clone(); + async move { + let mut token_rx = token_rx; + + while let Some(token) = token_rx.recv().await { + let chunk = ChatCompletionChunk { + id: id.clone(), + object: "chat.completion.chunk", + created, + model: stream_model.clone(), + choices: vec![ChatChunkChoice { + index: 0, + delta: ChatChunkDelta { + role: None, + content: Some(token), + }, + finish_reason: None, + }], + }; + if event_tx + .send(Ok( + Event::default().data(serde_json::to_string(&chunk).unwrap()) + )) + .is_err() + { + return; + } + } + + // Final event with finish_reason. + let final_chunk = ChatCompletionChunk { + id: id.clone(), + object: "chat.completion.chunk", + created, + model: stream_model.clone(), + choices: vec![ChatChunkChoice { + index: 0, + delta: ChatChunkDelta::default(), + finish_reason: Some("stop"), + }], + }; + let _ = event_tx.send(Ok( + Event::default().data(serde_json::to_string(&final_chunk).unwrap()) + )); + + // Billing — best-effort, charged once per completed stream. + // Per-token billing requires tokenizer-exact counts which the + // Kalosm backend doesn't expose today; a flat charge matches + // the non-stream `chat.completions` policy. + if let Some(email) = user_email(&auth_clone) { + let _ = bill_compute( + &email, + 1.0, + "ai_prompt", + Some("v1/chat/completions[stream]"), + ); + } + let _ = done_rx.await; + + // OpenAI SSE terminator. + let _ = event_tx.send(Ok(Event::default().data("[DONE]"))); + } + }); + + let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(event_rx); + Ok(Sse::new(stream).into_response()) +} + +fn role_to_str(role: &Role) -> &'static str { + match role { + Role::System => "system", + Role::User => "user", + Role::Assistant => "assistant", + Role::Tool => "tool", + Role::Function => "function", + Role::Developer => "developer", + } +} + +fn epoch_seconds() -> i64 { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map(|d| d.as_secs() as i64) + .unwrap_or(0) +} + +fn user_email(auth: &AuthContext) -> Option { + crate::agent::capabilities::user_email_from_token(auth.auth_token.clone()) +} + +// Re-export a Stream alias for documentation purposes. +#[allow(dead_code)] +type EventStream = dyn Stream> + Send; diff --git a/rust-executor/src/api/openai_compat/embeddings.rs b/rust-executor/src/api/openai_compat/embeddings.rs new file mode 100644 index 000000000..850b71728 --- /dev/null +++ b/rust-executor/src/api/openai_compat/embeddings.rs @@ -0,0 +1,73 @@ +//! `POST /v1/embeddings`. +//! +//! Returns raw `Vec` arrays per the OpenAI spec — the zlib+b64 +//! wire format used by the native WS `ai.embed` path stays exclusive to +//! AD4M-native clients that consume it for bandwidth reasons. External +//! SDKs (`openai`, LangChain, …) expect plain JSON numbers and that's +//! what they get here. + +use axum::Json; + +use super::errors::{OpenAIError, OpenAIResult}; +use super::model_selector::resolve_model; +use super::types::{EmbeddingItem, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage}; +use crate::agent::capabilities::{check_capability, AI_PROMPT_CAPABILITY}; +use crate::ai_service::AIService; +use crate::api::auth::AuthContext; +use crate::billing::{bill_compute, BillingError}; +use crate::types::ModelType; + +pub async fn embeddings( + auth: AuthContext, + Json(req): Json, +) -> OpenAIResult> { + check_capability(&auth.capabilities, &AI_PROMPT_CAPABILITY).map_err(OpenAIError::forbidden)?; + + let model_id = resolve_model(&req.model, ModelType::Embedding).await?; + let model_id_response = req.model.clone(); + let inputs = req.input.into_vec(); + + let service = AIService::global_instance() + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + + let mut data: Vec = Vec::with_capacity(inputs.len()); + let mut total_tokens: u64 = 0; + + for (index, text) in inputs.into_iter().enumerate() { + let result = service + .embed(model_id.clone(), text) + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + total_tokens += result.token_count as u64; + data.push(EmbeddingItem { + object: "embedding", + index, + embedding: result.embeddings, + }); + } + + // Billing — one credit per batch (matches the native ai.embed + // operation label). Per-token billing on embeddings is future work + // pending tokenizer-exact counts. + if let Some(email) = crate::agent::capabilities::user_email_from_token(auth.auth_token.clone()) + { + if let Err(BillingError::InsufficientCredits) = + bill_compute(&email, 1.0, "ai_embedding", Some("v1/embeddings")) + { + return Err(OpenAIError::insufficient_quota( + "Insufficient compute credits", + )); + } + } + + Ok(Json(EmbeddingResponse { + object: "list", + data, + model: model_id_response, + usage: EmbeddingUsage { + prompt_tokens: total_tokens, + total_tokens, + }, + })) +} diff --git a/rust-executor/src/api/openai_compat/errors.rs b/rust-executor/src/api/openai_compat/errors.rs new file mode 100644 index 000000000..e2b303012 --- /dev/null +++ b/rust-executor/src/api/openai_compat/errors.rs @@ -0,0 +1,128 @@ +//! OpenAI-shaped error envelope. +//! +//! Every fallible handler in this module returns [`OpenAIError`] so the +//! response always matches the spec: +//! +//! ```json +//! { "error": { "message": "...", "type": "...", "param": null, "code": "..." } } +//! ``` + +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde::Serialize; + +#[derive(Debug, Serialize)] +pub struct OpenAIError { + #[serde(skip)] + pub status: StatusCode, + pub error: OpenAIErrorBody, +} + +#[derive(Debug, Serialize)] +pub struct OpenAIErrorBody { + pub message: String, + #[serde(rename = "type")] + pub error_type: String, + pub param: Option, + pub code: Option, +} + +impl OpenAIError { + pub fn new( + status: StatusCode, + message: impl Into, + error_type: impl Into, + code: Option<&str>, + ) -> Self { + Self { + status, + error: OpenAIErrorBody { + message: message.into(), + error_type: error_type.into(), + param: None, + code: code.map(|s| s.to_string()), + }, + } + } + + pub fn invalid_request(message: impl Into) -> Self { + Self::new( + StatusCode::BAD_REQUEST, + message, + "invalid_request_error", + Some("invalid_request"), + ) + } + + pub fn unauthorized(message: impl Into) -> Self { + Self::new( + StatusCode::UNAUTHORIZED, + message, + "invalid_request_error", + Some("invalid_api_key"), + ) + } + + pub fn forbidden(message: impl Into) -> Self { + Self::new( + StatusCode::FORBIDDEN, + message, + "invalid_request_error", + Some("insufficient_permissions"), + ) + } + + pub fn insufficient_quota(message: impl Into) -> Self { + // OpenAI returns 429 for quota exhaustion; we follow suit so + // libraries that retry on 429 do the right thing. + Self::new( + StatusCode::TOO_MANY_REQUESTS, + message, + "insufficient_quota", + Some("insufficient_quota"), + ) + } + + pub fn not_found(message: impl Into) -> Self { + Self::new( + StatusCode::NOT_FOUND, + message, + "invalid_request_error", + Some("not_found"), + ) + } + + pub fn internal(message: impl Into) -> Self { + Self::new( + StatusCode::INTERNAL_SERVER_ERROR, + message, + "server_error", + None, + ) + } + + pub fn not_implemented(message: impl Into) -> Self { + Self::new( + StatusCode::NOT_IMPLEMENTED, + message, + "server_error", + Some("not_implemented"), + ) + } +} + +impl IntoResponse for OpenAIError { + fn into_response(self) -> Response { + ( + self.status, + Json(serde_json::json!({ "error": self.error })), + ) + .into_response() + } +} + +/// Result alias used by all handler functions in this module. +pub type OpenAIResult = Result; diff --git a/rust-executor/src/api/openai_compat/mod.rs b/rust-executor/src/api/openai_compat/mod.rs new file mode 100644 index 000000000..1f01a0787 --- /dev/null +++ b/rust-executor/src/api/openai_compat/mod.rs @@ -0,0 +1,36 @@ +//! OpenAI-compatible HTTP/WS API surface mounted at `/v1`. +//! +//! Translates the OpenAI JSON schema to and from the existing +//! [`crate::ai_service::AIService`]. The native WS-RPC `ai.*` methods are +//! unchanged; this module is purely additive so existing clients keep +//! working. +//! +//! Surface covered (matches the proposal under +//! `~/workspaces/coasys/.specs/PROPOSAL_AI_OPENAI_COMPATIBLE_ENDPOINT.md`): +//! +//! * `GET /v1/models` — list registered models +//! * `POST /v1/chat/completions` — stateless chat, with optional SSE streaming +//! * `POST /v1/completions` — legacy text-completion shim +//! * `POST /v1/embeddings` — raw `f32` arrays (no zlib+b64) +//! * `POST /v1/audio/transcriptions` — batch multipart whisper +//! * `POST /v1/audio/speech` — TTS (remote OpenAI-compatible +//! passthrough; local TTS backend is +//! scoped to a follow-up) +//! * `GET /v1/realtime` — WS, OpenAI Realtime-style streaming STT +//! +//! Auth + billing reuse the existing JWT extractor and `bill_compute`. +//! Errors are wrapped in the canonical +//! `{ "error": { message, type, param, code } }` envelope. + +pub mod audio; +pub mod chat; +pub mod embeddings; +pub mod errors; +pub mod model_selector; +pub mod models; +pub mod realtime; +pub mod router; +pub mod tts_passthrough; +pub mod types; + +pub use router::router; diff --git a/rust-executor/src/api/openai_compat/model_selector.rs b/rust-executor/src/api/openai_compat/model_selector.rs new file mode 100644 index 000000000..8ddd6b8d0 --- /dev/null +++ b/rust-executor/src/api/openai_compat/model_selector.rs @@ -0,0 +1,78 @@ +//! Resolve an OpenAI-style `model` string to a registered AD4M model id. +//! +//! Order of resolution (matches the spec §4.4): +//! +//! 1. Exact match against any registered model's `id`. +//! 2. Exact match against any registered model's `name`. +//! 3. The literal `"default"` — returns the configured default for the +//! requested [`ModelType`]. +//! +//! When nothing matches we return a 404 so the client sees the OpenAI +//! `model_not_found` shape. + +use super::errors::{OpenAIError, OpenAIResult}; +use crate::db::Ad4mDb; +use crate::types::ModelType; + +pub async fn resolve_model(requested: &str, expected_type: ModelType) -> OpenAIResult { + if requested.is_empty() { + return Err(OpenAIError::invalid_request( + "Missing required parameter: model", + )); + } + + // Resolve `default` via the DB's per-type pointer. This mirrors + // `AIService::replace_model_variables` so OpenAI callers can use the + // same shorthand as native AD4M ones. + if requested == "default" { + return Ad4mDb::with_global_instance(|db| db.get_default_model(expected_type.clone())) + .map_err(|e| OpenAIError::internal(format!("Database error: {e}")))? + .ok_or_else(|| { + OpenAIError::not_found(format!( + "No default {} model configured on this executor", + type_label(&expected_type) + )) + }); + } + + let models = Ad4mDb::with_global_instance(|db| db.get_models()) + .map_err(|e| OpenAIError::internal(format!("Database error listing models: {e}")))?; + + // Pass 1: exact id match. + if let Some(m) = models.iter().find(|m| m.id == requested) { + if m.model_type != expected_type { + return Err(OpenAIError::invalid_request(format!( + "Model '{}' is a {} model; this endpoint requires a {}.", + requested, + type_label(&m.model_type), + type_label(&expected_type), + ))); + } + return Ok(m.id.clone()); + } + + // Pass 2: exact name match (case-sensitive — same as OpenAI). + if let Some(m) = models.iter().find(|m| m.name == requested) { + if m.model_type != expected_type { + return Err(OpenAIError::invalid_request(format!( + "Model '{}' is a {} model; this endpoint requires a {}.", + requested, + type_label(&m.model_type), + type_label(&expected_type), + ))); + } + return Ok(m.id.clone()); + } + + Err(OpenAIError::not_found(format!( + "Model '{requested}' not found. List registered models with GET /v1/models." + ))) +} + +fn type_label(t: &ModelType) -> &'static str { + match t { + ModelType::Llm => "LLM", + ModelType::Embedding => "embedding", + ModelType::Transcription => "transcription", + } +} diff --git a/rust-executor/src/api/openai_compat/models.rs b/rust-executor/src/api/openai_compat/models.rs new file mode 100644 index 000000000..3ab155388 --- /dev/null +++ b/rust-executor/src/api/openai_compat/models.rs @@ -0,0 +1,48 @@ +//! `GET /v1/models` — list every model registered with the executor. + +use axum::Json; + +use super::errors::{OpenAIError, OpenAIResult}; +use super::types::{ModelExtensions, ModelInfo, ModelListResponse}; +use crate::agent::capabilities::{check_capability, AI_READ_CAPABILITY}; +use crate::api::auth::AuthContext; +use crate::db::Ad4mDb; +use crate::types::Model; + +pub async fn list_models(auth: AuthContext) -> OpenAIResult> { + check_capability(&auth.capabilities, &AI_READ_CAPABILITY).map_err(OpenAIError::forbidden)?; + + let models = Ad4mDb::with_global_instance(|db| db.get_models()) + .map_err(|e| OpenAIError::internal(format!("Database error listing models: {e}")))?; + + let data = models.into_iter().map(model_to_info).collect(); + Ok(Json(ModelListResponse { + object: "list", + data, + })) +} + +/// Map an AD4M [`Model`] to an OpenAI-shaped [`ModelInfo`]. Local vs remote +/// backends are surfaced via the `ad4m.backend` extension so callers that +/// want to prefer local inference can filter on it. +fn model_to_info(m: Model) -> ModelInfo { + let backend = if m.api.is_some() { + Some("remote".to_string()) + } else if m.local.is_some() { + Some("local".to_string()) + } else { + None + }; + let model_type = format!("{}", m.model_type).to_lowercase(); + ModelInfo { + id: m.id, + object: "model", + created: 0, + owned_by: "ad4m", + extensions: ModelExtensions { + model_type, + name: m.name, + backend, + }, + } +} diff --git a/rust-executor/src/api/openai_compat/realtime.rs b/rust-executor/src/api/openai_compat/realtime.rs new file mode 100644 index 000000000..4a7d171a3 --- /dev/null +++ b/rust-executor/src/api/openai_compat/realtime.rs @@ -0,0 +1,251 @@ +//! `GET /v1/realtime` — OpenAI Realtime-style WebSocket for streaming +//! transcription. +//! +//! Protocol (subset of the OpenAI Realtime spec): +//! +//! * **Client → server** +//! - `{"type":"transcription_session.update","session":{"model":""}}` — +//! opens a Whisper session. Must arrive before any audio. +//! - `{"type":"input_audio_buffer.append","audio":""}` — +//! 16 kHz mono `f32` LE samples, base64-encoded. +//! - `{"type":"input_audio_buffer.commit"}` — flush any pending +//! segment. Optional; idle timeout flushes too. +//! - `{"type":"session.close"}` — close the session and the socket. +//! +//! * **Server → client** +//! - `{"type":"transcription_session.created","session_id":""}` +//! - `{"type":"conversation.item.input_audio_transcription.delta","delta":""}` +//! - `{"type":"conversation.item.input_audio_transcription.completed","transcript":""}` +//! emitted on each finalised whisper segment (the existing engine +//! emits one event per segment via the broadcast channel). +//! - `{"type":"error","error":{"message":"...","code":"..."}}` +//! +//! Auth: the WS extractor reads the bearer JWT from the +//! `Authorization` header *or* a `?token=` query param (per +//! `AuthContext`). Capabilities checked: `AI_TRANSCRIBE`. + +use axum::{ + extract::ws::{Message, WebSocket, WebSocketUpgrade}, + response::Response, +}; +use base64::Engine; +use futures::SinkExt; +use serde_json::{json, Value}; + +use super::model_selector::resolve_model; +use crate::agent::capabilities::{check_capability, AI_TRANSCRIBE_CAPABILITY}; +use crate::ai_service::AIService; +use crate::api::auth::AuthContext; +use crate::types::ModelType; + +pub async fn realtime_ws(auth: AuthContext, ws: WebSocketUpgrade) -> Response { + ws.on_upgrade(move |socket| handle_socket(auth, socket)) +} + +async fn handle_socket(auth: AuthContext, mut socket: WebSocket) { + // Capability check happens after upgrade — there's no clean way to + // reject a 403 mid-handshake under axum's `on_upgrade`, so we emit + // an `error` envelope and close. + if let Err(e) = check_capability(&auth.capabilities, &AI_TRANSCRIBE_CAPABILITY) { + let _ = send_error(&mut socket, "forbidden", &e).await; + let _ = socket.close().await; + return; + } + + let mut session_state: Option = None; + + while let Some(msg) = socket.recv().await { + let msg = match msg { + Ok(m) => m, + Err(_) => break, + }; + let text = match msg { + Message::Text(t) => t.to_string(), + Message::Binary(_) => { + let _ = send_error( + &mut socket, + "invalid_request", + "Binary frames not supported", + ) + .await; + continue; + } + Message::Close(_) => break, + // Ping/pong handled by axum automatically. + _ => continue, + }; + let parsed: Value = match serde_json::from_str(&text) { + Ok(v) => v, + Err(e) => { + let _ = + send_error(&mut socket, "invalid_request", &format!("JSON parse: {e}")).await; + continue; + } + }; + let event_type = parsed + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + match event_type.as_str() { + "transcription_session.update" => { + let model_str = parsed + .get("session") + .and_then(|s| s.get("model")) + .and_then(|m| m.as_str()) + .unwrap_or("default"); + let model_id = match resolve_model(model_str, ModelType::Transcription).await { + Ok(id) => id, + Err(e) => { + let _ = send_error(&mut socket, "invalid_request", &e.error.message).await; + continue; + } + }; + let service = match AIService::global_instance().await { + Ok(s) => s, + Err(e) => { + let _ = send_error(&mut socket, "server_error", &e.to_string()).await; + continue; + } + }; + let stream_id = match service + .open_transcription_stream(model_id.clone(), None, auth.auth_token.clone()) + .await + { + Ok(id) => id, + Err(e) => { + let _ = send_error(&mut socket, "server_error", &e.to_string()).await; + continue; + } + }; + let created = json!({ + "type": "transcription_session.created", + "session_id": stream_id, + "model": model_id, + }); + let _ = socket.send(Message::Text(created.to_string().into())).await; + session_state = Some(SessionState { stream_id }); + } + "input_audio_buffer.append" => { + let Some(ref state) = session_state else { + let _ = send_error( + &mut socket, + "invalid_request", + "Send `transcription_session.update` before appending audio.", + ) + .await; + continue; + }; + let audio_b64 = match parsed.get("audio").and_then(|v| v.as_str()) { + Some(s) => s, + None => { + let _ = send_error( + &mut socket, + "invalid_request", + "Missing required field: audio", + ) + .await; + continue; + } + }; + let bytes = match base64::prelude::BASE64_STANDARD.decode(audio_b64) { + Ok(b) => b, + Err(e) => { + let _ = send_error(&mut socket, "invalid_request", &format!("base64: {e}")) + .await; + continue; + } + }; + let samples = bytes_to_f32_le(&bytes); + + let service = AIService::global_instance().await.ok(); + if let Some(service) = service { + let rx = service + .feed_transcription_stream_with_broadcast( + &state.stream_id, + samples, + &auth.auth_token, + ) + .await; + if let Ok(mut rx) = rx { + // Drain any whisper deltas that became + // available since the last append. We do this + // synchronously on each append rather than + // running a background forwarder — keeps the + // socket lifetime obvious and avoids races at + // close time. + while let Ok(text) = rx.try_recv() { + let delta = json!({ + "type": "conversation.item.input_audio_transcription.delta", + "delta": text, + }); + let _ = socket.send(Message::Text(delta.to_string().into())).await; + } + } + } + } + "input_audio_buffer.commit" => { + // No-op for now — the whisper engine flushes on its own + // VAD detection. Forward the no-op as `completed` so + // OpenAI clients that wait for it can proceed; if + // there's nothing to flush, the `completed` carries an + // empty transcript. + let completed = json!({ + "type": "conversation.item.input_audio_transcription.completed", + "transcript": "", + }); + let _ = socket + .send(Message::Text(completed.to_string().into())) + .await; + } + "session.close" => { + break; + } + _ => { + let _ = send_error( + &mut socket, + "invalid_request", + &format!("Unsupported event type: {event_type}"), + ) + .await; + } + } + } + + if let Some(state) = session_state { + if let Ok(service) = AIService::global_instance().await { + let _ = service + .close_transcription_stream(&state.stream_id, &auth.auth_token) + .await; + } + } + let _ = socket.close().await; +} + +struct SessionState { + stream_id: String, +} + +async fn send_error(socket: &mut WebSocket, code: &str, message: &str) -> Result<(), ()> { + let body = json!({ + "type": "error", + "error": { + "message": message, + "code": code, + } + }); + socket + .send(Message::Text(body.to_string().into())) + .await + .map_err(|_| ()) +} + +/// Reinterpret a raw byte slice as little-endian `f32` samples. +fn bytes_to_f32_le(bytes: &[u8]) -> Vec { + let mut out = Vec::with_capacity(bytes.len() / 4); + for chunk in bytes.chunks_exact(4) { + out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])); + } + out +} diff --git a/rust-executor/src/api/openai_compat/router.rs b/rust-executor/src/api/openai_compat/router.rs new file mode 100644 index 000000000..fe4d91a21 --- /dev/null +++ b/rust-executor/src/api/openai_compat/router.rs @@ -0,0 +1,35 @@ +//! axum router for the OpenAI-compatible surface. +//! +//! Mount under `/v1` from the top-level [`api_router`](super::super::api_router). +//! All endpoints share the existing [`AppState`](crate::api::auth::AppState) +//! so JWT extraction, capabilities, and admin credentials work the same +//! way as the native WS-RPC surface. + +use axum::{ + extract::DefaultBodyLimit, + routing::{get, post}, + Router, +}; + +use super::{audio, chat, embeddings, models, realtime}; +use crate::api::auth::AppState; + +/// Build the `/v1` router. Mounted by [`super::super::api_router`]. +/// +/// The router is intentionally minimal — every endpoint delegates to a +/// dedicated handler module so reviewers can audit one phase at a time. +/// Body limit is bumped on audio routes to accommodate file uploads +/// (defaults to 2 MiB in axum, which is small for whisper inputs). +pub fn router() -> Router { + Router::new() + .route("/models", get(models::list_models)) + .route("/chat/completions", post(chat::chat_completions)) + .route("/completions", post(chat::completions)) + .route("/embeddings", post(embeddings::embeddings)) + .route( + "/audio/transcriptions", + post(audio::transcriptions).layer(DefaultBodyLimit::max(50 * 1024 * 1024)), + ) + .route("/audio/speech", post(audio::speech)) + .route("/realtime", get(realtime::realtime_ws)) +} diff --git a/rust-executor/src/api/openai_compat/tts_passthrough.rs b/rust-executor/src/api/openai_compat/tts_passthrough.rs new file mode 100644 index 000000000..c996d9f57 --- /dev/null +++ b/rust-executor/src/api/openai_compat/tts_passthrough.rs @@ -0,0 +1,126 @@ +//! Remote-passthrough TTS client. +//! +//! The proposal allows three TTS backends behind a common +//! `AIService::synthesize` trait: local Kokoro, local Piper, or a +//! passthrough to an OpenAI-compatible upstream. This module implements +//! the passthrough path; the local backends are scoped to a follow-up +//! (they add `ort` + voice model assets, both of which deserve their own +//! review). +//! +//! Configuration: we look for a TTS-typed model in the existing model +//! registry whose `api` field carries an `OPEN_AI` upstream. When +//! `model_type == Llm` we currently misuse — see TODO below — the +//! configured remote LLM API key for TTS; once `ModelType::TextToSpeech` +//! lands, this module reads from the proper TTS-typed registration. +//! +//! No new dependencies: we reuse the existing `reqwest` client. + +use crate::api::auth::AuthContext; +use crate::db::Ad4mDb; +use crate::types::{ModelApiType, ModelType}; + +use super::types::SpeechRequest; + +#[derive(Debug)] +pub enum TtsPassthroughError { + /// No TTS-capable upstream is configured on this executor. + NotConfigured, + Upstream(String), +} + +/// Forward a `/v1/audio/speech` request to a configured upstream and +/// return the raw audio bytes. Returns [`TtsPassthroughError::NotConfigured`] +/// when no upstream is registered. +pub async fn synthesize( + _auth: &AuthContext, + req: &SpeechRequest, +) -> Result, TtsPassthroughError> { + let upstream = find_tts_upstream(&req.model)?; + + let body = serde_json::json!({ + "model": upstream.upstream_model, + "input": req.input, + "voice": req.voice.clone().unwrap_or_else(|| "alloy".to_string()), + "response_format": req.response_format.clone().unwrap_or_else(|| "mp3".to_string()), + "speed": req.speed.unwrap_or(1.0), + }); + + let url = format!( + "{}/audio/speech", + upstream + .base_url + .trim_end_matches('/') + .trim_end_matches("/v1"), + ); + + let client = reqwest::Client::new(); + let resp = client + .post(format!("{url}/v1/audio/speech")) + .bearer_auth(&upstream.api_key) + .json(&body) + .send() + .await + .map_err(|e| TtsPassthroughError::Upstream(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + return Err(TtsPassthroughError::Upstream(format!( + "Upstream returned {status}: {text}" + ))); + } + + let bytes = resp + .bytes() + .await + .map_err(|e| TtsPassthroughError::Upstream(e.to_string()))?; + Ok(bytes.to_vec()) +} + +struct TtsUpstream { + base_url: String, + api_key: String, + upstream_model: String, +} + +/// Look up a TTS-capable upstream in the model registry. +/// +/// We don't have `ModelType::TextToSpeech` yet (that lands with the +/// local Kokoro PR), so the resolution rule is: any LLM-typed `Model` +/// whose name matches `requested_model` AND has an `OPEN_AI` API +/// upstream is treated as a TTS-capable forwarder. This makes the +/// passthrough usable today without a schema migration; once the +/// `ModelType::TextToSpeech` enum lands, swap the filter to that type +/// and remove this comment. +fn find_tts_upstream(requested_model: &str) -> Result { + let models = Ad4mDb::with_global_instance(|db| db.get_models()) + .map_err(|_| TtsPassthroughError::NotConfigured)?; + + // Prefer an exact name match (callers send the upstream model + // string verbatim); fall back to the configured "default" remote + // LLM if none matches. + let candidate = models + .iter() + .find(|m| m.model_type == ModelType::Llm && m.api.is_some() && m.name == requested_model) + .or_else(|| { + models.iter().find(|m| { + m.model_type == ModelType::Llm + && m.api + .as_ref() + .map(|a| matches!(a.api_type, ModelApiType::OpenAi)) + .unwrap_or(false) + }) + }) + .ok_or(TtsPassthroughError::NotConfigured)?; + + let api = candidate + .api + .as_ref() + .ok_or(TtsPassthroughError::NotConfigured)?; + + Ok(TtsUpstream { + base_url: api.base_url.to_string(), + api_key: api.api_key.clone(), + upstream_model: requested_model.to_string(), + }) +} diff --git a/rust-executor/src/api/openai_compat/types.rs b/rust-executor/src/api/openai_compat/types.rs new file mode 100644 index 000000000..69172611e --- /dev/null +++ b/rust-executor/src/api/openai_compat/types.rs @@ -0,0 +1,317 @@ +//! OpenAI request/response schemas. +//! +//! All field names are spelled exactly as the OpenAI HTTP API expects — +//! that's the whole point of this module. Internal AD4M-specific names +//! go behind serde aliases or under the `ad4m` extension namespace. + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +// --------------------------------------------------------------------------- +// Common +// --------------------------------------------------------------------------- + +/// Standard OpenAI usage object. All token counts are best-effort: the +/// local Kalosm backend reports `chars/4` estimates today; remote upstreams +/// return exact counts which we forward when available. +#[derive(Debug, Clone, Serialize)] +pub struct Usage { + pub prompt_tokens: u64, + pub completion_tokens: u64, + pub total_tokens: u64, +} + +// --------------------------------------------------------------------------- +// /v1/models +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize)] +pub struct ModelListResponse { + pub object: &'static str, // "list" + pub data: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ModelInfo { + pub id: String, + pub object: &'static str, // "model" + pub created: i64, // unix epoch seconds (we report 0 for static models) + pub owned_by: &'static str, // "ad4m" + /// AD4M-specific extension fields. Tools that don't know about them + /// ignore them; the proposal's spec carries `ad4m.model_type` here. + #[serde(rename = "ad4m")] + pub extensions: ModelExtensions, +} + +#[derive(Debug, Serialize)] +pub struct ModelExtensions { + pub model_type: String, // "llm" | "embedding" | "transcription" | "tts" + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub backend: Option, // "local" | "remote" | "passthrough" +} + +// --------------------------------------------------------------------------- +// /v1/chat/completions +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Role { + System, + User, + Assistant, + /// We don't process tool/function messages today — accepted but + /// ignored in the prompt assembly so the request doesn't reject. + Tool, + Function, + Developer, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ChatMessage { + pub role: Role, + pub content: ChatMessageContent, + #[serde(default)] + pub name: Option, +} + +/// OpenAI accepts either a string or an array of content parts. For now we +/// flatten parts to their text and ignore image/audio inputs — those are +/// out of scope for the first pass. +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +pub enum ChatMessageContent { + Text(String), + Parts(Vec), +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentPart { + Text { + text: String, + }, + /// Image input — silently dropped by the prompt assembler today. We + /// keep parsing so clients sending mixed content don't 400. + ImageUrl { + image_url: Value, + }, + /// Same — accepted, not consumed. + InputAudio { + input_audio: Value, + }, +} + +impl ChatMessageContent { + pub fn flatten_to_text(&self) -> String { + match self { + ChatMessageContent::Text(s) => s.clone(), + ChatMessageContent::Parts(parts) => parts + .iter() + .filter_map(|p| match p { + ContentPart::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("\n"), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, + #[serde(default)] + pub stream: bool, + #[serde(default)] + pub temperature: Option, + #[serde(default)] + pub max_tokens: Option, + #[serde(default)] + pub top_p: Option, + #[serde(default)] + pub stop: Option, + #[serde(default)] + pub seed: Option, + #[serde(default)] + pub response_format: Option, + #[serde(default)] + pub user: Option, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: &'static str, // "chat.completion" + pub created: i64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Debug, Serialize)] +pub struct ChatChoice { + pub index: u32, + pub message: ChatResponseMessage, + pub finish_reason: &'static str, // "stop" | "length" | "tool_calls" +} + +#[derive(Debug, Serialize)] +pub struct ChatResponseMessage { + pub role: &'static str, // "assistant" + pub content: String, +} + +// --------------------------------------------------------------------------- +// Streaming chunks (SSE) +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize)] +pub struct ChatCompletionChunk { + pub id: String, + pub object: &'static str, // "chat.completion.chunk" + pub created: i64, + pub model: String, + pub choices: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ChatChunkChoice { + pub index: u32, + pub delta: ChatChunkDelta, + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option<&'static str>, +} + +#[derive(Debug, Default, Serialize)] +pub struct ChatChunkDelta { + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option<&'static str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, +} + +// --------------------------------------------------------------------------- +// /v1/completions (legacy) +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +pub struct CompletionRequest { + pub model: String, + pub prompt: PromptInput, + #[serde(default)] + pub stream: bool, + #[serde(default)] + pub temperature: Option, + #[serde(default)] + pub max_tokens: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum PromptInput { + One(String), + Many(Vec), +} + +impl PromptInput { + pub fn first(&self) -> String { + match self { + PromptInput::One(s) => s.clone(), + PromptInput::Many(v) => v.first().cloned().unwrap_or_default(), + } + } +} + +#[derive(Debug, Serialize)] +pub struct CompletionResponse { + pub id: String, + pub object: &'static str, // "text_completion" + pub created: i64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Debug, Serialize)] +pub struct CompletionChoice { + pub index: u32, + pub text: String, + pub finish_reason: &'static str, +} + +// --------------------------------------------------------------------------- +// /v1/embeddings +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +pub struct EmbeddingRequest { + pub model: String, + pub input: EmbeddingInput, + #[serde(default)] + pub user: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum EmbeddingInput { + One(String), + Many(Vec), +} + +impl EmbeddingInput { + pub fn into_vec(self) -> Vec { + match self { + EmbeddingInput::One(s) => vec![s], + EmbeddingInput::Many(v) => v, + } + } +} + +#[derive(Debug, Serialize)] +pub struct EmbeddingResponse { + pub object: &'static str, // "list" + pub data: Vec, + pub model: String, + pub usage: EmbeddingUsage, +} + +#[derive(Debug, Serialize)] +pub struct EmbeddingItem { + pub object: &'static str, // "embedding" + pub index: usize, + pub embedding: Vec, +} + +#[derive(Debug, Serialize)] +pub struct EmbeddingUsage { + pub prompt_tokens: u64, + pub total_tokens: u64, +} + +// --------------------------------------------------------------------------- +// /v1/audio/transcriptions +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize)] +pub struct TranscriptionResponse { + pub text: String, +} + +// --------------------------------------------------------------------------- +// /v1/audio/speech (TTS) +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +pub struct SpeechRequest { + pub model: String, + pub input: String, + #[serde(default)] + pub voice: Option, + #[serde(default)] + pub response_format: Option, + #[serde(default)] + pub speed: Option, +}