From 6667357584ae5d7b86bd7bdb104affe40589f5cb Mon Sep 17 00:00:00 2001 From: HexaField <10372036+HexaField@users.noreply.github.com> Date: Mon, 8 Jun 2026 23:36:05 +1000 Subject: [PATCH] feat(api): OpenAI-compatible /v1 endpoint surface (chat, embeddings, audio, realtime) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an industry-standard, OpenAI-compatible HTTP/WS API mounted at \`/v1\` alongside the existing AD4M WS-RPC transport. Any tool that speaks OpenAI — the official \`openai\` SDKs, LangChain, LlamaIndex, the Vercel AI SDK, Cursor/Continue, Open WebUI, \`llm\` — can point \`base_url\` at the executor and just work. All five proposal phases land together. Surface ======= GET /v1/models list registered models POST /v1/chat/completions stateless chat, oneshot + SSE streaming POST /v1/completions legacy text-completion shim POST /v1/embeddings raw f32 arrays (no zlib+b64) POST /v1/audio/transcriptions multipart batch whisper POST /v1/audio/speech TTS, remote OpenAI-compatible passthrough GET /v1/realtime WS, OpenAI Realtime-style streaming STT Also mounted at \`/api/v1/openai/v1\` for proxies that hard-code the existing AD4M prefix. What changed ============ New module \`api/openai_compat/\` (11 files, ~1200 LOC): mod.rs module wiring router.rs axum router mounted at /v1 and /api/v1/openai/v1 types.rs serde structs mirroring OpenAI's wire schema errors.rs { error: { message, type, param, code } } envelope model_selector.rs resolve model strings → AD4M model_id (id → name → "default") models.rs GET /v1/models chat.rs POST /v1/chat/completions (oneshot + SSE) + completions embeddings.rs POST /v1/embeddings — raw f32 arrays audio.rs POST /v1/audio/transcriptions + speech tts_passthrough.rs OpenAI-compatible upstream passthrough for /v1/audio/speech realtime.rs GET /v1/realtime WebSocket (subset of OpenAI Realtime STT) AIService surface additions ============================ AIService::prompt_messages(model_id, messages: Vec<(role, content)>) Stateless ephemeral task: build → spawn → prompt → remove on the existing per-model LLM thread. No DB writes. AIService::prompt_messages_stream(model_id, messages) Same shape but the LLM thread forwards each Kalosm token through an mpsc::UnboundedReceiver. Backs stream: true. Local Kalosm uses ChatResponseBuilder's Stream impl; remote upstreams degrade to one terminal chunk (native streaming upstream client is a follow-up). AIService::transcribe_buffer(model_id, samples, auth_token) One-shot transcription: open → feed → drain broadcast → close. Backs the batch /v1/audio/transcriptions endpoint. LLMTaskRequest::PromptStream new variant in the per-model thread's command channel; carries token_sender + done_sender. PromptResult now derives Debug + Clone (needed for oneshot transit). Auth + billing =============== Reuses the existing JWT extractor (AuthContext from auth.rs) and capability framework verbatim. Authorization: Bearer (or ?token= for the realtime WS, same as events_ws). Capabilities map: chat/completions/embeddings → AI_PROMPT transcriptions + realtime → AI_TRANSCRIBE speech → AI_PROMPT models → AI_READ bill_compute deducts per-user credits with the existing operation labels (ai_prompt, ai_embedding, ai_transcription, ai_tts). InsufficientCredits → HTTP 429 with type: "insufficient_quota" matching OpenAI's quota-exhaustion shape so SDKs retrying on 429 do the right thing. Error envelope =============== Every error response carries: { "error": { "message", "type", "param", "code" } } With correct HTTP status (400 / 401 / 403 / 404 / 429 / 500 / 501). Audio decode (Phase 3 scope note) ================================== Batch transcription accepts 16 kHz mono PCM WAV out of the box (the simplest container that requires no new deps and matches Whisper's reference format). Other containers (mp3, m4a, ogg, flac) return a 400 pointing at the Content-Type so clients can transcode locally; full container support via \`symphonia\` + \`rubato\` is a self- contained follow-up that wasn't worth pulling into the same diff. TTS backend (Phase 4 scope note) ================================= /v1/audio/speech forwards to a configured OpenAI-compatible upstream via the existing reqwest dep (no new deps). The proposal allows passthrough as a valid Phase-4 deliverable alongside local Kokoro/Piper backends; the local engines pull in ort + voice model assets that deserve their own review and didn't make sense to ship in the same diff. No upstream configured → 501 with a clear message. Dependency surface =================== Only change: axum gains the \`multipart\` feature (already had \`ws\`). No new top-level deps. Native surface unchanged ========================= The WS-RPC ai.* methods, core/AIClient, Flux, and the launcher are untouched. Existing clients see no behaviour change. Refs: \`PROPOSAL_AI_OPENAI_COMPATIBLE_ENDPOINT.md\` --- Cargo.lock | 1 + rust-executor/Cargo.toml | 2 +- rust-executor/src/ai_service/mod.rs | 410 ++++++++++++++++++ rust-executor/src/api/mod.rs | 9 + rust-executor/src/api/openai_compat/audio.rs | 300 +++++++++++++ rust-executor/src/api/openai_compat/chat.rs | 294 +++++++++++++ .../src/api/openai_compat/embeddings.rs | 73 ++++ rust-executor/src/api/openai_compat/errors.rs | 128 ++++++ rust-executor/src/api/openai_compat/mod.rs | 36 ++ .../src/api/openai_compat/model_selector.rs | 78 ++++ rust-executor/src/api/openai_compat/models.rs | 48 ++ .../src/api/openai_compat/realtime.rs | 251 +++++++++++ rust-executor/src/api/openai_compat/router.rs | 35 ++ .../src/api/openai_compat/tts_passthrough.rs | 126 ++++++ rust-executor/src/api/openai_compat/types.rs | 317 ++++++++++++++ 15 files changed, 2107 insertions(+), 1 deletion(-) create mode 100644 rust-executor/src/api/openai_compat/audio.rs create mode 100644 rust-executor/src/api/openai_compat/chat.rs create mode 100644 rust-executor/src/api/openai_compat/embeddings.rs create mode 100644 rust-executor/src/api/openai_compat/errors.rs create mode 100644 rust-executor/src/api/openai_compat/mod.rs create mode 100644 rust-executor/src/api/openai_compat/model_selector.rs create mode 100644 rust-executor/src/api/openai_compat/models.rs create mode 100644 rust-executor/src/api/openai_compat/realtime.rs create mode 100644 rust-executor/src/api/openai_compat/router.rs create mode 100644 rust-executor/src/api/openai_compat/tts_passthrough.rs create mode 100644 rust-executor/src/api/openai_compat/types.rs diff --git a/Cargo.lock b/Cargo.lock index 1ca24b8a4..11d939864 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1197,6 +1197,7 @@ dependencies = [ "matchit 0.8.4", "memchr", "mime", + "multer 3.1.0", "percent-encoding", "pin-project-lite", "rustversion", diff --git a/rust-executor/Cargo.toml b/rust-executor/Cargo.toml index c654e3fe1..26b35d20c 100644 --- a/rust-executor/Cargo.toml +++ b/rust-executor/Cargo.toml @@ -68,7 +68,7 @@ argon2 = { version = "0.5.0", features = ["simple"] } rand = "0.8.5" base64 = "0.21.0" rmcp = { version = "0.15.0", features = ["server", "transport-streamable-http-server"] } -axum = { version = "0.8", features = ["ws"] } +axum = { version = "0.8", features = ["ws", "multipart"] } axum-server = { version = "0.7", features = ["tls-rustls"] } tower-http = { version = "0.6", features = ["cors", "set-header", "catch-panic"] } schemars = "1.0" diff --git a/rust-executor/src/ai_service/mod.rs b/rust-executor/src/ai_service/mod.rs index e67f9111a..33f602bc8 100644 --- a/rust-executor/src/ai_service/mod.rs +++ b/rust-executor/src/ai_service/mod.rs @@ -34,6 +34,7 @@ pub type Result = std::result::Result; /// Result of an LLM prompt call, with token counts for billing. /// Token counts are estimated (chars/4) with the current Kalosum backend. /// When Ollama is integrated, these will be exact from the API response. +#[derive(Debug, Clone)] pub struct PromptResult { pub text: String, pub prompt_tokens: usize, @@ -137,11 +138,34 @@ struct LLMTaskShutdownRequest { pub result_sender: oneshot::Sender<()>, } +/// Streaming-prompt variant for the OpenAI-compatible chat endpoint. +/// +/// Identical to [`LLMTaskPromptRequest`] but instead of a oneshot reply +/// the caller receives an async stream of token chunks. Used by +/// `prompt_messages_stream` to back `POST /v1/chat/completions` with +/// `stream: true`. +/// +/// The receiver side is an `mpsc::Receiver` so each `Token` arm +/// can push without awaiting consumer backpressure; the HTTP handler +/// converts each token into one SSE `chat.completion.chunk`. +/// +/// `done_sender` fires once the model has emitted its final token (or +/// errored) and carries `PromptResult` for the closing chunk's `usage`. +#[allow(dead_code)] +#[derive(Debug)] +struct LLMTaskPromptStreamRequest { + pub task_id: String, + pub prompt: String, + pub token_sender: mpsc::UnboundedSender, + pub done_sender: oneshot::Sender>, +} + #[allow(dead_code)] #[derive(Debug)] enum LLMTaskRequest { Spawn(LLMTaskSpawnRequest), Prompt(LLMTaskPromptRequest), + PromptStream(LLMTaskPromptStreamRequest), Remove(LLMTaskRemoveRequest), Shutdown(LLMTaskShutdownRequest), } @@ -844,6 +868,144 @@ impl AIService { } }, + LLMTaskRequest::PromptStream(stream_request) => match model { + LlmModel::Remote((ref mut remote_client, ref model_string)) => { + // Remote upstreams use the non-streaming + // chat call today; we deliver the full + // response as a single token chunk so + // SSE consumers still see the "stream" + // protocol (one delta + final usage). + // A native streaming upstream client is + // a follow-up. + if let Some(task) = + task_descriptions.get(&stream_request.task_id) + { + let mut messages = vec![Message { + role: Role::System, + content: task.system_prompt.clone(), + }]; + for example in task.prompt_examples.iter() { + messages.push(Message { + role: Role::User, + content: example.input.clone(), + }); + messages.push(Message { + role: Role::Assistant, + content: example.output.clone(), + }); + } + let prompt_clone = stream_request.prompt.clone(); + messages.push(Message { + role: Role::User, + content: prompt_clone.clone(), + }); + let chat_input = ChatInput { + model: chat_gpt_lib_rs::Model::Custom( + model_string.clone(), + ), + messages, + ..Default::default() + }; + match rt.block_on(remote_client.chat(chat_input)) { + Err(e) => { + let _ = + stream_request.done_sender.send(Err(anyhow!( + "Error connecting to remote LLM API: {:?}", + e + ))); + } + Ok(response) => { + let text = response + .choices + .first() + .map(|c| c.message.content.clone()) + .unwrap_or_default(); + let _ = + stream_request.token_sender.send(text.clone()); + let prompt_tokens = + estimate_token_count(&prompt_clone); + let completion_tokens = estimate_token_count(&text); + let _ = stream_request.done_sender.send(Ok( + PromptResult { + text, + prompt_tokens, + completion_tokens, + model_id: model_config.id.clone(), + }, + )); + } + } + } else { + let _ = stream_request.done_sender.send(Err(anyhow!( + "Task with ID {} not spawned", + stream_request.task_id + ))); + } + } + LlmModel::Local(_) => { + if let Some(task) = tasks.get(&stream_request.task_id) { + rt.block_on(publish_model_status( + model_config.id.clone(), + 100.0, + "Running inference...", + true, + true, + )); + + let prompt_clone = stream_request.prompt.clone(); + let prompt_tokens = estimate_token_count(&prompt_clone); + let token_sender = stream_request.token_sender.clone(); + // Forward each token through the + // mpsc back to the SSE handler. We + // accumulate the full text so the + // closing event still carries a + // single concatenated string for + // billing + usage parity with the + // non-stream path. + let result = rt.block_on(async { + use futures::StreamExt; + // `task.run(...)` returns a + // `ChatResponseBuilder` which + // implements `Stream`; + // polling it yields one token + // chunk at a time. + let mut stream = + Box::pin(task.run(prompt_clone.clone())); + let mut accumulated = String::new(); + while let Some(token) = stream.next().await { + accumulated.push_str(&token); + if token_sender.send(token).is_err() { + // consumer dropped — stop generating + break; + } + } + accumulated + }); + + rt.block_on(publish_model_status( + model_config.id.clone(), + 100.0, + "Ready", + true, + true, + )); + + let completion_tokens = estimate_token_count(&result); + let _ = stream_request.done_sender.send(Ok(PromptResult { + text: result, + prompt_tokens, + completion_tokens, + model_id: model_config.id.clone(), + })); + } else { + let _ = stream_request.done_sender.send(Err(anyhow!( + "Task with ID {} not spawned", + stream_request.task_id + ))); + } + } + }, + LLMTaskRequest::Remove(remove_request) => { let _ = tasks.remove(&remove_request.task_id); let _ = task_descriptions.remove(&remove_request.task_id); @@ -973,6 +1135,205 @@ impl AIService { Ok(()) } + /// Build an ephemeral, in-memory task from a list of `(system, user, + /// assistant)` messages and return its `task_id`. Used by the + /// OpenAI-compatible `chat.completions` handler so each request is + /// stateless: spawn → prompt → remove. Nothing touches the DB. + /// + /// `messages` is in OpenAI order — system message(s) at the front, + /// then alternating user/assistant turns ending with a final user + /// message. The final user message is returned separately so the + /// caller passes it as the prompt; everything before it becomes the + /// task's `system_prompt` + few-shot example pairs. + fn build_ephemeral_task(model_id: &str, messages: Vec<(String, String)>) -> (AITask, String) { + // Split: leading System messages → system_prompt; alternating + // User/Assistant pairs → examples; last User message → prompt. + let mut system_prompt = String::new(); + let mut examples: Vec = Vec::new(); + let mut current_user: Option = None; + let mut final_prompt = String::new(); + + for (role, content) in &messages { + match role.as_str() { + "system" => { + if !system_prompt.is_empty() { + system_prompt.push_str("\n"); + } + system_prompt.push_str(content); + } + "user" => { + if let Some(prev_user) = current_user.take() { + // user→user with no assistant in between; treat + // earlier user as a finalised final prompt slot + // (will be overwritten if more turns follow). + final_prompt = prev_user; + } + current_user = Some(content.clone()); + } + "assistant" => { + if let Some(user_msg) = current_user.take() { + examples.push(crate::types::AIPromptExamples { + input: user_msg, + output: content.clone(), + }); + } + } + _ => { /* tool/function/developer — silently dropped */ } + } + } + if let Some(last_user) = current_user { + final_prompt = last_user; + } + + let task_id = format!("openai-compat-{}", uuid::Uuid::new_v4()); + let now = chrono::Utc::now().to_rfc3339(); + let task = AITask { + name: format!("openai-compat-{}", &task_id[..8]), + task_id: task_id.clone(), + model_id: model_id.to_string(), + system_prompt, + prompt_examples: examples, + meta_data: None, + created_at: now.clone(), + updated_at: now, + }; + (task, final_prompt) + } + + /// Run a chat-style prompt on `model_id` with a stateless message + /// list and return the full response text + token counts. Backs + /// `POST /v1/chat/completions`. + pub async fn prompt_messages( + &self, + model_id: String, + messages: Vec<(String, String)>, + ) -> Result { + let resolved = Self::replace_model_variables(&model_id)?; + let (task, final_prompt) = Self::build_ephemeral_task(&resolved, messages); + + // Spawn + prompt + remove. We use the existing per-model thread + // command channels so model state stays local to its thread and + // we don't duplicate the build-llama machinery. + let task_id = task.task_id.clone(); + let (spawn_tx, spawn_rx) = oneshot::channel(); + { + let llm_channel = self.llm_channel.lock().await; + let sender = llm_channel + .get(&resolved) + .ok_or_else(|| anyhow!("Model '{}' not found in LLM channel", resolved))?; + sender.send(LLMTaskRequest::Spawn(LLMTaskSpawnRequest { + task, + result_sender: spawn_tx, + }))?; + } + spawn_rx.await??; + + let (prompt_tx, prompt_rx) = oneshot::channel(); + { + let llm_channel = self.llm_channel.lock().await; + let sender = llm_channel + .get(&resolved) + .ok_or_else(|| anyhow!("Model '{}' not found in LLM channel", resolved))?; + sender.send(LLMTaskRequest::Prompt(LLMTaskPromptRequest { + task_id: task_id.clone(), + prompt: final_prompt.clone(), + result_sender: prompt_tx, + }))?; + } + let prompt_tokens = estimate_token_count(&final_prompt); + let text = prompt_rx.await??; + let completion_tokens = estimate_token_count(&text); + + // Clean up the ephemeral task entry on the LLM thread. + let (remove_tx, _) = oneshot::channel(); + { + let llm_channel = self.llm_channel.lock().await; + if let Some(sender) = llm_channel.get(&resolved) { + let _ = sender.send(LLMTaskRequest::Remove(LLMTaskRemoveRequest { + task_id: task_id.clone(), + result_sender: remove_tx, + })); + } + } + + Ok(PromptResult { + text, + prompt_tokens, + completion_tokens, + model_id: resolved, + }) + } + + /// Streaming variant of [`Self::prompt_messages`]. Returns a token + /// stream + a oneshot for the final [`PromptResult`] (carries the + /// full text + token counts for the closing SSE event). + pub async fn prompt_messages_stream( + &self, + model_id: String, + messages: Vec<(String, String)>, + ) -> Result<( + mpsc::UnboundedReceiver, + oneshot::Receiver>, + )> { + let resolved = Self::replace_model_variables(&model_id)?; + let (task, final_prompt) = Self::build_ephemeral_task(&resolved, messages); + + let task_id = task.task_id.clone(); + let (spawn_tx, spawn_rx) = oneshot::channel(); + { + let llm_channel = self.llm_channel.lock().await; + let sender = llm_channel + .get(&resolved) + .ok_or_else(|| anyhow!("Model '{}' not found in LLM channel", resolved))?; + sender.send(LLMTaskRequest::Spawn(LLMTaskSpawnRequest { + task, + result_sender: spawn_tx, + }))?; + } + spawn_rx.await??; + + let (token_tx, token_rx) = mpsc::unbounded_channel::(); + let (done_tx, done_rx) = oneshot::channel::>(); + { + let llm_channel = self.llm_channel.lock().await; + let sender = llm_channel + .get(&resolved) + .ok_or_else(|| anyhow!("Model '{}' not found in LLM channel", resolved))?; + sender.send(LLMTaskRequest::PromptStream(LLMTaskPromptStreamRequest { + task_id: task_id.clone(), + prompt: final_prompt, + token_sender: token_tx, + done_sender: done_tx, + }))?; + } + + // Schedule a Remove for after the prompt completes — best-effort + // background cleanup so the caller doesn't have to await it. + let llm_channel_clone = self.llm_channel.clone(); + let resolved_clone = resolved.clone(); + tokio::spawn(async move { + // Wait a little for the prompt to finish before removing the + // task entry; if the model is local, the thread holds the + // task in its `tasks` map for the duration of the call, so + // removing too early would race the inference. + // + // We don't have a reliable signal here without restructuring + // the channel; the LLM thread serialises requests anyway, so + // a Remove queued immediately will run after the + // PromptStream completes. No sleep needed. + let (remove_tx, _) = oneshot::channel(); + let llm_channel = llm_channel_clone.lock().await; + if let Some(sender) = llm_channel.get(&resolved_clone) { + let _ = sender.send(LLMTaskRequest::Remove(LLMTaskRemoveRequest { + task_id, + result_sender: remove_tx, + })); + } + }); + + Ok((token_rx, done_rx)) + } + pub async fn prompt(&self, task_id: String, prompt: String) -> Result { let (result_sender, rx) = oneshot::channel(); @@ -1472,6 +1833,55 @@ impl AIService { } } + /// One-shot transcription: open a stream, feed the whole audio + /// buffer, drain the broadcast channel until idle, close. + /// + /// Backs `POST /v1/audio/transcriptions` (batch multipart upload). + /// The caller is responsible for decoding the upload to 16 kHz mono + /// `f32` samples — this method only handles the whisper-side + /// session lifecycle and assembly of the final transcript. + pub async fn transcribe_buffer( + &self, + model_id: String, + samples: Vec, + auth_token: String, + ) -> Result { + let stream_id = self + .open_transcription_stream(model_id, None, auth_token.clone()) + .await?; + + // Subscribe to the broadcast BEFORE feeding so we don't drop any + // partial events the whisper thread emits during the first + // window. + let mut rx = self + .feed_transcription_stream_with_broadcast(&stream_id, samples, &auth_token) + .await?; + + // Drain until the receiver lags or no message arrives within a + // short idle window. Whisper emits one event per finalised + // segment; for a one-shot buffer the model usually emits 1-3 + // segments and then goes quiet. + let mut transcript = String::new(); + let idle = Duration::from_millis(2500); + loop { + match tokio::time::timeout(idle, rx.recv()).await { + Ok(Ok(text)) => { + if !transcript.is_empty() && !text.is_empty() { + transcript.push(' '); + } + transcript.push_str(&text); + } + Ok(Err(_)) => break, // sender dropped + Err(_) => break, // idle timeout + } + } + + let _ = self + .close_transcription_stream(&stream_id, &auth_token) + .await; + Ok(transcript) + } + pub async fn close_transcription_stream( &self, stream_id: &String, diff --git a/rust-executor/src/api/mod.rs b/rust-executor/src/api/mod.rs index 3a48574b0..e5bed8aa5 100644 --- a/rust-executor/src/api/mod.rs +++ b/rust-executor/src/api/mod.rs @@ -6,6 +6,7 @@ pub mod auth; pub mod errors; pub mod events_ws; +pub mod openai_compat; pub mod types; pub mod ws_rpc; @@ -99,6 +100,14 @@ pub fn api_router(state: AppState) -> Router { post(ai_ws::feed_transcription_stream), ), ) + // ── OpenAI-compatible /v1 surface ── + // + // Mounted at both `/v1` (the canonical OpenAI path) and + // `/api/v1/openai/v1` (for proxies that hard-code the `/api/v1` + // prefix from the native AD4M surface). Both share the same + // handlers + AppState. + .nest("/v1", openai_compat::router()) + .nest("/api/v1/openai/v1", openai_compat::router()) // ── State + Middleware ── .with_state(state) .layer(Extension(handler_map)) diff --git a/rust-executor/src/api/openai_compat/audio.rs b/rust-executor/src/api/openai_compat/audio.rs new file mode 100644 index 000000000..c818416e8 --- /dev/null +++ b/rust-executor/src/api/openai_compat/audio.rs @@ -0,0 +1,300 @@ +//! `POST /v1/audio/transcriptions` (batch STT) + `POST /v1/audio/speech` (TTS). +//! +//! ## Transcription (Phase 3) +//! +//! Accepts a standard OpenAI multipart upload: `file`, `model`, +//! `response_format`, `language`, `temperature`. We decode the upload to +//! 16 kHz mono `f32` samples and drive the existing Whisper pipeline in +//! one shot via [`AIService::transcribe_buffer`]. +//! +//! Audio decode is delegated to the new [`audio_decode`] helper. For +//! the first cut we accept raw 16 kHz mono PCM (wav) directly; for +//! other formats the helper returns an error pointing at the +//! `Content-Type` and the caller is expected to convert client-side +//! (matches the proposal — full container support via `symphonia` is +//! scoped to a follow-up to keep the dependency surface lean). +//! +//! ## TTS (Phase 4) +//! +//! `audio/speech` forwards the request to a remote OpenAI-compatible +//! upstream configured via a TTS-typed `Model` registration — see +//! [`super::tts_passthrough`] for the HTTP client. A local Kokoro +//! backend is in the design but not in this PR; the proposal allows +//! passthrough as a valid Phase-4 deliverable. + +use axum::{extract::Multipart, http::header, response::Response, Json}; + +use super::errors::{OpenAIError, OpenAIResult}; +use super::model_selector::resolve_model; +use super::tts_passthrough; +use super::types::{SpeechRequest, TranscriptionResponse}; +use crate::agent::capabilities::{ + check_capability, AI_PROMPT_CAPABILITY, AI_TRANSCRIBE_CAPABILITY, +}; +use crate::ai_service::AIService; +use crate::api::auth::AuthContext; +use crate::billing::{bill_compute, BillingError}; +use crate::types::ModelType; + +pub async fn transcriptions( + auth: AuthContext, + mut multipart: Multipart, +) -> OpenAIResult> { + check_capability(&auth.capabilities, &AI_TRANSCRIBE_CAPABILITY) + .map_err(OpenAIError::forbidden)?; + + let mut model_field: Option = None; + let mut audio_bytes: Option> = None; + let mut content_type: Option = None; + let mut response_format = "json".to_string(); + + while let Some(field) = multipart + .next_field() + .await + .map_err(|e| OpenAIError::invalid_request(format!("Multipart parse error: {e}")))? + { + match field.name().unwrap_or("") { + "model" => { + model_field = field + .text() + .await + .ok() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()); + } + "file" => { + content_type = field + .content_type() + .map(|s| s.to_string()) + .or_else(|| Some("application/octet-stream".to_string())); + let bytes = field + .bytes() + .await + .map_err(|e| OpenAIError::invalid_request(format!("File read error: {e}")))?; + audio_bytes = Some(bytes.to_vec()); + } + "response_format" => { + if let Ok(v) = field.text().await { + response_format = v.trim().to_string(); + } + } + _ => { + // Other OpenAI fields (`language`, `temperature`, + // `prompt`) are not yet plumbed through to Whisper — we + // accept and drop them so callers don't 400. + let _ = field.bytes().await; + } + } + } + + let model = model_field + .ok_or_else(|| OpenAIError::invalid_request("Missing required form field: model"))?; + let bytes = audio_bytes + .ok_or_else(|| OpenAIError::invalid_request("Missing required form field: file"))?; + let model_id = resolve_model(&model, ModelType::Transcription).await?; + + let samples = audio_decode(&bytes, content_type.as_deref())?; + + let service = AIService::global_instance() + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + let text = service + .transcribe_buffer(model_id, samples, auth.auth_token.clone()) + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + + if let Some(email) = crate::agent::capabilities::user_email_from_token(auth.auth_token.clone()) + { + // Mirror native transcription billing: 1 credit per request + // (the native path bills per-word but only after `delta` events + // — batch jobs round to a flat charge until per-word accounting + // is plumbed through). + if let Err(BillingError::InsufficientCredits) = bill_compute( + &email, + 1.0, + "ai_transcription", + Some("v1/audio/transcriptions"), + ) { + return Err(OpenAIError::insufficient_quota( + "Insufficient compute credits", + )); + } + } + + // `response_format` honoured: `json` (default) returns the envelope + // below. `text` / `verbose_json` / `srt` / `vtt` are accepted for + // request-compatibility but currently always produce the `json` + // envelope; widening the return to a raw `Response` so we can emit + // text/srt/vtt directly is a follow-up that touches the router's + // type signature. + let _ = response_format; + Ok(Json(TranscriptionResponse { text })) +} + +pub async fn speech( + auth: AuthContext, + Json(req): Json, +) -> Result { + check_capability(&auth.capabilities, &AI_PROMPT_CAPABILITY).map_err(OpenAIError::forbidden)?; + + // TTS isn't a registered ModelType yet (the proposal calls for a new + // `ModelType::TextToSpeech` in a follow-up); for now we accept the + // model name verbatim and forward to the configured passthrough + // upstream. See `tts_passthrough` for the configuration shape. + let response_format = req + .response_format + .clone() + .unwrap_or_else(|| "mp3".to_string()); + let audio = tts_passthrough::synthesize(&auth, &req) + .await + .map_err(|e| match e { + tts_passthrough::TtsPassthroughError::NotConfigured => OpenAIError::not_implemented( + "No TTS backend is configured on this executor. \ + Register a TTS upstream via the AD4M config or wait for the local Kokoro backend.", + ), + tts_passthrough::TtsPassthroughError::Upstream(msg) => { + OpenAIError::internal(format!("Upstream TTS error: {msg}")) + } + })?; + + if let Some(email) = crate::agent::capabilities::user_email_from_token(auth.auth_token.clone()) + { + if let Err(BillingError::InsufficientCredits) = + bill_compute(&email, 1.0, "ai_tts", Some("v1/audio/speech")) + { + return Err(OpenAIError::insufficient_quota( + "Insufficient compute credits", + )); + } + } + + let content_type = match response_format.as_str() { + "mp3" => "audio/mpeg", + "wav" => "audio/wav", + "opus" => "audio/opus", + "flac" => "audio/flac", + "pcm" => "application/octet-stream", + _ => "application/octet-stream", + }; + Response::builder() + .status(200) + .header(header::CONTENT_TYPE, content_type) + .body(audio.into()) + .map_err(|e| OpenAIError::internal(format!("Response build error: {e}"))) +} + +/// Decode an uploaded audio file to 16 kHz mono `f32` samples. +/// +/// The first-cut implementation supports raw WAV (16-bit PCM, mono, +/// 16 kHz) — the simplest case that requires no external dependencies +/// and matches what the existing transcription WS feed already expects. +/// For other containers, return a descriptive 400 so the client can +/// transcode locally. +/// +/// Full `symphonia` / `rubato` decode is a follow-up: it's the right +/// home for that work, but adding two heavy crates in the same PR as +/// the rest of the OpenAI surface would dominate the diff for a single +/// endpoint. +fn audio_decode(bytes: &[u8], content_type: Option<&str>) -> OpenAIResult> { + // Quick WAV header sniff — RIFF....WAVEfmt + if bytes.len() >= 44 && &bytes[0..4] == b"RIFF" && &bytes[8..12] == b"WAVE" { + return decode_pcm_wav(bytes); + } + Err(OpenAIError::invalid_request(format!( + "Unsupported audio format (content-type: {}). \ + The current /v1/audio/transcriptions endpoint accepts 16 kHz mono PCM WAV. \ + Full container decode (mp3/m4a/flac/ogg) is on the roadmap; \ + transcode client-side as a workaround.", + content_type.unwrap_or("(unknown)"), + ))) +} + +/// Minimal WAV decoder: 16-bit PCM, mono, 16 kHz. +/// +/// We don't attempt to handle every WAV variant — the goal is to clear +/// the most common upload path (whisper's reference format) without +/// pulling in `symphonia`. Mismatched format → 400 with a clear hint. +fn decode_pcm_wav(bytes: &[u8]) -> OpenAIResult> { + // Find the fmt and data chunks. WAV is a RIFF container: after + // the 12-byte RIFF header, chunks alternate {id (4 bytes), size + // (u32 LE), payload}. + let mut pos = 12; + let mut audio_format: u16 = 0; + let mut channels: u16 = 0; + let mut sample_rate: u32 = 0; + let mut bits_per_sample: u16 = 0; + let mut data_offset: Option = None; + let mut data_size: usize = 0; + + while pos + 8 <= bytes.len() { + let id = &bytes[pos..pos + 4]; + let size = u32::from_le_bytes([ + bytes[pos + 4], + bytes[pos + 5], + bytes[pos + 6], + bytes[pos + 7], + ]) as usize; + pos += 8; + match id { + b"fmt " => { + if pos + 16 > bytes.len() { + return Err(OpenAIError::invalid_request("Truncated WAV fmt chunk")); + } + audio_format = u16::from_le_bytes([bytes[pos], bytes[pos + 1]]); + channels = u16::from_le_bytes([bytes[pos + 2], bytes[pos + 3]]); + sample_rate = u32::from_le_bytes([ + bytes[pos + 4], + bytes[pos + 5], + bytes[pos + 6], + bytes[pos + 7], + ]); + bits_per_sample = u16::from_le_bytes([bytes[pos + 14], bytes[pos + 15]]); + pos += size; + } + b"data" => { + data_offset = Some(pos); + data_size = size; + pos += size; + } + _ => { + // Skip unknown chunks (LIST, INFO, etc). + pos += size; + } + } + if pos % 2 != 0 { + pos += 1; // RIFF chunks are word-aligned. + } + } + + let data_offset = + data_offset.ok_or_else(|| OpenAIError::invalid_request("WAV file has no data chunk"))?; + if audio_format != 1 { + return Err(OpenAIError::invalid_request(format!( + "Only uncompressed PCM WAV is supported (audio_format = {audio_format}). \ + Re-encode as PCM and retry." + ))); + } + if channels != 1 { + return Err(OpenAIError::invalid_request(format!( + "Only mono WAV is supported (channels = {channels}). Downmix to mono." + ))); + } + if sample_rate != 16_000 { + return Err(OpenAIError::invalid_request(format!( + "Only 16 kHz WAV is supported (sample_rate = {sample_rate}). Resample to 16 kHz." + ))); + } + if bits_per_sample != 16 { + return Err(OpenAIError::invalid_request(format!( + "Only 16-bit PCM WAV is supported (bits_per_sample = {bits_per_sample})." + ))); + } + + let payload = &bytes[data_offset..data_offset + data_size.min(bytes.len() - data_offset)]; + let mut samples = Vec::with_capacity(payload.len() / 2); + for chunk in payload.chunks_exact(2) { + let s = i16::from_le_bytes([chunk[0], chunk[1]]); + samples.push((s as f32) / (i16::MAX as f32)); + } + Ok(samples) +} diff --git a/rust-executor/src/api/openai_compat/chat.rs b/rust-executor/src/api/openai_compat/chat.rs new file mode 100644 index 000000000..34388653e --- /dev/null +++ b/rust-executor/src/api/openai_compat/chat.rs @@ -0,0 +1,294 @@ +//! `POST /v1/chat/completions` and `POST /v1/completions`. +//! +//! Both endpoints translate an OpenAI request into an ephemeral +//! `AIService::prompt_messages{,_stream}` call. No DB-backed task is +//! created; the model thread spawns the task in-memory for the duration +//! of the call. + +use std::convert::Infallible; +use std::time::SystemTime; + +use axum::{ + response::{sse::Event, IntoResponse, Sse}, + Json, +}; +use futures::Stream; +use uuid::Uuid; + +use super::errors::{OpenAIError, OpenAIResult}; +use super::model_selector::resolve_model; +use super::types::{ + ChatChoice, ChatChunkChoice, ChatChunkDelta, ChatCompletionChunk, ChatCompletionRequest, + ChatCompletionResponse, ChatResponseMessage, CompletionChoice, CompletionRequest, + CompletionResponse, Role, Usage, +}; +use crate::agent::capabilities::{check_capability, AI_PROMPT_CAPABILITY}; +use crate::ai_service::AIService; +use crate::api::auth::AuthContext; +use crate::billing::{bill_compute, BillingError}; +use crate::types::ModelType; + +/// `POST /v1/chat/completions` — handles both streaming (`stream: true`) +/// and non-streaming responses. +pub async fn chat_completions( + auth: AuthContext, + Json(req): Json, +) -> Result { + check_capability(&auth.capabilities, &AI_PROMPT_CAPABILITY).map_err(OpenAIError::forbidden)?; + + // Resolve the OpenAI `model` string to an AD4M model_id. + let model_id = resolve_model(&req.model, ModelType::Llm).await?; + + // Flatten messages to (role, content) pairs. + let messages: Vec<(String, String)> = req + .messages + .iter() + .map(|m| { + ( + role_to_str(&m.role).to_string(), + m.content.flatten_to_text(), + ) + }) + .collect(); + + if req.stream { + chat_stream(auth, req.model.clone(), model_id, messages).await + } else { + chat_oneshot(auth, req.model.clone(), model_id, messages).await + } +} + +/// `POST /v1/completions` (legacy text-completion). Treats `prompt` as a +/// single user message with no system prompt. +pub async fn completions( + auth: AuthContext, + Json(req): Json, +) -> OpenAIResult> { + check_capability(&auth.capabilities, &AI_PROMPT_CAPABILITY).map_err(OpenAIError::forbidden)?; + + let model_id = resolve_model(&req.model, ModelType::Llm).await?; + let messages = vec![("user".to_string(), req.prompt.first())]; + + let service = AIService::global_instance() + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + let result = service + .prompt_messages(model_id, messages) + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + + if let Some(email) = user_email(&auth) { + if let Err(BillingError::InsufficientCredits) = + bill_compute(&email, 1.0, "ai_prompt", Some("v1/completions")) + { + return Err(OpenAIError::insufficient_quota( + "Insufficient compute credits", + )); + } + } + + Ok(Json(CompletionResponse { + id: format!("cmpl-{}", Uuid::new_v4()), + object: "text_completion", + created: epoch_seconds(), + model: req.model, + choices: vec![CompletionChoice { + index: 0, + text: result.text, + finish_reason: "stop", + }], + usage: Usage { + prompt_tokens: result.prompt_tokens as u64, + completion_tokens: result.completion_tokens as u64, + total_tokens: (result.prompt_tokens + result.completion_tokens) as u64, + }, + })) +} + +async fn chat_oneshot( + auth: AuthContext, + requested_model: String, + model_id: String, + messages: Vec<(String, String)>, +) -> Result { + let service = AIService::global_instance() + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + let result = service + .prompt_messages(model_id, messages) + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + + if let Some(email) = user_email(&auth) { + if let Err(BillingError::InsufficientCredits) = + bill_compute(&email, 1.0, "ai_prompt", Some("v1/chat/completions")) + { + return Err(OpenAIError::insufficient_quota( + "Insufficient compute credits", + )); + } + } + + let body = ChatCompletionResponse { + id: format!("chatcmpl-{}", Uuid::new_v4()), + object: "chat.completion", + created: epoch_seconds(), + model: requested_model, + choices: vec![ChatChoice { + index: 0, + message: ChatResponseMessage { + role: "assistant", + content: result.text, + }, + finish_reason: "stop", + }], + usage: Usage { + prompt_tokens: result.prompt_tokens as u64, + completion_tokens: result.completion_tokens as u64, + total_tokens: (result.prompt_tokens + result.completion_tokens) as u64, + }, + }; + Ok(Json(body).into_response()) +} + +async fn chat_stream( + auth: AuthContext, + requested_model: String, + model_id: String, + messages: Vec<(String, String)>, +) -> Result { + let service = AIService::global_instance() + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + let (token_rx, done_rx) = service + .prompt_messages_stream(model_id, messages) + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + + let id = format!("chatcmpl-{}", Uuid::new_v4()); + let created = epoch_seconds(); + let stream_model = requested_model.clone(); + + // We construct the SSE stream by spawning a forwarder task that + // drains tokens from the LLM thread + emits one `Event` per chunk + // into a bounded channel. axum's `Sse` consumes the resulting + // receiver. This avoids pulling in `async-stream` for the sole + // sake of one `yield`-style generator. + let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::>(); + let auth_clone = auth.clone(); + + // Initial role event. + let role_chunk = ChatCompletionChunk { + id: id.clone(), + object: "chat.completion.chunk", + created, + model: stream_model.clone(), + choices: vec![ChatChunkChoice { + index: 0, + delta: ChatChunkDelta { + role: Some("assistant"), + content: None, + }, + finish_reason: None, + }], + }; + let _ = event_tx.send(Ok( + Event::default().data(serde_json::to_string(&role_chunk).unwrap()) + )); + + tokio::spawn({ + let event_tx = event_tx.clone(); + let stream_model = stream_model.clone(); + let id = id.clone(); + async move { + let mut token_rx = token_rx; + + while let Some(token) = token_rx.recv().await { + let chunk = ChatCompletionChunk { + id: id.clone(), + object: "chat.completion.chunk", + created, + model: stream_model.clone(), + choices: vec![ChatChunkChoice { + index: 0, + delta: ChatChunkDelta { + role: None, + content: Some(token), + }, + finish_reason: None, + }], + }; + if event_tx + .send(Ok( + Event::default().data(serde_json::to_string(&chunk).unwrap()) + )) + .is_err() + { + return; + } + } + + // Final event with finish_reason. + let final_chunk = ChatCompletionChunk { + id: id.clone(), + object: "chat.completion.chunk", + created, + model: stream_model.clone(), + choices: vec![ChatChunkChoice { + index: 0, + delta: ChatChunkDelta::default(), + finish_reason: Some("stop"), + }], + }; + let _ = event_tx.send(Ok( + Event::default().data(serde_json::to_string(&final_chunk).unwrap()) + )); + + // Billing — best-effort, charged once per completed stream. + // Per-token billing requires tokenizer-exact counts which the + // Kalosm backend doesn't expose today; a flat charge matches + // the non-stream `chat.completions` policy. + if let Some(email) = user_email(&auth_clone) { + let _ = bill_compute( + &email, + 1.0, + "ai_prompt", + Some("v1/chat/completions[stream]"), + ); + } + let _ = done_rx.await; + + // OpenAI SSE terminator. + let _ = event_tx.send(Ok(Event::default().data("[DONE]"))); + } + }); + + let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(event_rx); + Ok(Sse::new(stream).into_response()) +} + +fn role_to_str(role: &Role) -> &'static str { + match role { + Role::System => "system", + Role::User => "user", + Role::Assistant => "assistant", + Role::Tool => "tool", + Role::Function => "function", + Role::Developer => "developer", + } +} + +fn epoch_seconds() -> i64 { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map(|d| d.as_secs() as i64) + .unwrap_or(0) +} + +fn user_email(auth: &AuthContext) -> Option { + crate::agent::capabilities::user_email_from_token(auth.auth_token.clone()) +} + +// Re-export a Stream alias for documentation purposes. +#[allow(dead_code)] +type EventStream = dyn Stream> + Send; diff --git a/rust-executor/src/api/openai_compat/embeddings.rs b/rust-executor/src/api/openai_compat/embeddings.rs new file mode 100644 index 000000000..850b71728 --- /dev/null +++ b/rust-executor/src/api/openai_compat/embeddings.rs @@ -0,0 +1,73 @@ +//! `POST /v1/embeddings`. +//! +//! Returns raw `Vec` arrays per the OpenAI spec — the zlib+b64 +//! wire format used by the native WS `ai.embed` path stays exclusive to +//! AD4M-native clients that consume it for bandwidth reasons. External +//! SDKs (`openai`, LangChain, …) expect plain JSON numbers and that's +//! what they get here. + +use axum::Json; + +use super::errors::{OpenAIError, OpenAIResult}; +use super::model_selector::resolve_model; +use super::types::{EmbeddingItem, EmbeddingRequest, EmbeddingResponse, EmbeddingUsage}; +use crate::agent::capabilities::{check_capability, AI_PROMPT_CAPABILITY}; +use crate::ai_service::AIService; +use crate::api::auth::AuthContext; +use crate::billing::{bill_compute, BillingError}; +use crate::types::ModelType; + +pub async fn embeddings( + auth: AuthContext, + Json(req): Json, +) -> OpenAIResult> { + check_capability(&auth.capabilities, &AI_PROMPT_CAPABILITY).map_err(OpenAIError::forbidden)?; + + let model_id = resolve_model(&req.model, ModelType::Embedding).await?; + let model_id_response = req.model.clone(); + let inputs = req.input.into_vec(); + + let service = AIService::global_instance() + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + + let mut data: Vec = Vec::with_capacity(inputs.len()); + let mut total_tokens: u64 = 0; + + for (index, text) in inputs.into_iter().enumerate() { + let result = service + .embed(model_id.clone(), text) + .await + .map_err(|e| OpenAIError::internal(e.to_string()))?; + total_tokens += result.token_count as u64; + data.push(EmbeddingItem { + object: "embedding", + index, + embedding: result.embeddings, + }); + } + + // Billing — one credit per batch (matches the native ai.embed + // operation label). Per-token billing on embeddings is future work + // pending tokenizer-exact counts. + if let Some(email) = crate::agent::capabilities::user_email_from_token(auth.auth_token.clone()) + { + if let Err(BillingError::InsufficientCredits) = + bill_compute(&email, 1.0, "ai_embedding", Some("v1/embeddings")) + { + return Err(OpenAIError::insufficient_quota( + "Insufficient compute credits", + )); + } + } + + Ok(Json(EmbeddingResponse { + object: "list", + data, + model: model_id_response, + usage: EmbeddingUsage { + prompt_tokens: total_tokens, + total_tokens, + }, + })) +} diff --git a/rust-executor/src/api/openai_compat/errors.rs b/rust-executor/src/api/openai_compat/errors.rs new file mode 100644 index 000000000..e2b303012 --- /dev/null +++ b/rust-executor/src/api/openai_compat/errors.rs @@ -0,0 +1,128 @@ +//! OpenAI-shaped error envelope. +//! +//! Every fallible handler in this module returns [`OpenAIError`] so the +//! response always matches the spec: +//! +//! ```json +//! { "error": { "message": "...", "type": "...", "param": null, "code": "..." } } +//! ``` + +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde::Serialize; + +#[derive(Debug, Serialize)] +pub struct OpenAIError { + #[serde(skip)] + pub status: StatusCode, + pub error: OpenAIErrorBody, +} + +#[derive(Debug, Serialize)] +pub struct OpenAIErrorBody { + pub message: String, + #[serde(rename = "type")] + pub error_type: String, + pub param: Option, + pub code: Option, +} + +impl OpenAIError { + pub fn new( + status: StatusCode, + message: impl Into, + error_type: impl Into, + code: Option<&str>, + ) -> Self { + Self { + status, + error: OpenAIErrorBody { + message: message.into(), + error_type: error_type.into(), + param: None, + code: code.map(|s| s.to_string()), + }, + } + } + + pub fn invalid_request(message: impl Into) -> Self { + Self::new( + StatusCode::BAD_REQUEST, + message, + "invalid_request_error", + Some("invalid_request"), + ) + } + + pub fn unauthorized(message: impl Into) -> Self { + Self::new( + StatusCode::UNAUTHORIZED, + message, + "invalid_request_error", + Some("invalid_api_key"), + ) + } + + pub fn forbidden(message: impl Into) -> Self { + Self::new( + StatusCode::FORBIDDEN, + message, + "invalid_request_error", + Some("insufficient_permissions"), + ) + } + + pub fn insufficient_quota(message: impl Into) -> Self { + // OpenAI returns 429 for quota exhaustion; we follow suit so + // libraries that retry on 429 do the right thing. + Self::new( + StatusCode::TOO_MANY_REQUESTS, + message, + "insufficient_quota", + Some("insufficient_quota"), + ) + } + + pub fn not_found(message: impl Into) -> Self { + Self::new( + StatusCode::NOT_FOUND, + message, + "invalid_request_error", + Some("not_found"), + ) + } + + pub fn internal(message: impl Into) -> Self { + Self::new( + StatusCode::INTERNAL_SERVER_ERROR, + message, + "server_error", + None, + ) + } + + pub fn not_implemented(message: impl Into) -> Self { + Self::new( + StatusCode::NOT_IMPLEMENTED, + message, + "server_error", + Some("not_implemented"), + ) + } +} + +impl IntoResponse for OpenAIError { + fn into_response(self) -> Response { + ( + self.status, + Json(serde_json::json!({ "error": self.error })), + ) + .into_response() + } +} + +/// Result alias used by all handler functions in this module. +pub type OpenAIResult = Result; diff --git a/rust-executor/src/api/openai_compat/mod.rs b/rust-executor/src/api/openai_compat/mod.rs new file mode 100644 index 000000000..1f01a0787 --- /dev/null +++ b/rust-executor/src/api/openai_compat/mod.rs @@ -0,0 +1,36 @@ +//! OpenAI-compatible HTTP/WS API surface mounted at `/v1`. +//! +//! Translates the OpenAI JSON schema to and from the existing +//! [`crate::ai_service::AIService`]. The native WS-RPC `ai.*` methods are +//! unchanged; this module is purely additive so existing clients keep +//! working. +//! +//! Surface covered (matches the proposal under +//! `~/workspaces/coasys/.specs/PROPOSAL_AI_OPENAI_COMPATIBLE_ENDPOINT.md`): +//! +//! * `GET /v1/models` — list registered models +//! * `POST /v1/chat/completions` — stateless chat, with optional SSE streaming +//! * `POST /v1/completions` — legacy text-completion shim +//! * `POST /v1/embeddings` — raw `f32` arrays (no zlib+b64) +//! * `POST /v1/audio/transcriptions` — batch multipart whisper +//! * `POST /v1/audio/speech` — TTS (remote OpenAI-compatible +//! passthrough; local TTS backend is +//! scoped to a follow-up) +//! * `GET /v1/realtime` — WS, OpenAI Realtime-style streaming STT +//! +//! Auth + billing reuse the existing JWT extractor and `bill_compute`. +//! Errors are wrapped in the canonical +//! `{ "error": { message, type, param, code } }` envelope. + +pub mod audio; +pub mod chat; +pub mod embeddings; +pub mod errors; +pub mod model_selector; +pub mod models; +pub mod realtime; +pub mod router; +pub mod tts_passthrough; +pub mod types; + +pub use router::router; diff --git a/rust-executor/src/api/openai_compat/model_selector.rs b/rust-executor/src/api/openai_compat/model_selector.rs new file mode 100644 index 000000000..8ddd6b8d0 --- /dev/null +++ b/rust-executor/src/api/openai_compat/model_selector.rs @@ -0,0 +1,78 @@ +//! Resolve an OpenAI-style `model` string to a registered AD4M model id. +//! +//! Order of resolution (matches the spec §4.4): +//! +//! 1. Exact match against any registered model's `id`. +//! 2. Exact match against any registered model's `name`. +//! 3. The literal `"default"` — returns the configured default for the +//! requested [`ModelType`]. +//! +//! When nothing matches we return a 404 so the client sees the OpenAI +//! `model_not_found` shape. + +use super::errors::{OpenAIError, OpenAIResult}; +use crate::db::Ad4mDb; +use crate::types::ModelType; + +pub async fn resolve_model(requested: &str, expected_type: ModelType) -> OpenAIResult { + if requested.is_empty() { + return Err(OpenAIError::invalid_request( + "Missing required parameter: model", + )); + } + + // Resolve `default` via the DB's per-type pointer. This mirrors + // `AIService::replace_model_variables` so OpenAI callers can use the + // same shorthand as native AD4M ones. + if requested == "default" { + return Ad4mDb::with_global_instance(|db| db.get_default_model(expected_type.clone())) + .map_err(|e| OpenAIError::internal(format!("Database error: {e}")))? + .ok_or_else(|| { + OpenAIError::not_found(format!( + "No default {} model configured on this executor", + type_label(&expected_type) + )) + }); + } + + let models = Ad4mDb::with_global_instance(|db| db.get_models()) + .map_err(|e| OpenAIError::internal(format!("Database error listing models: {e}")))?; + + // Pass 1: exact id match. + if let Some(m) = models.iter().find(|m| m.id == requested) { + if m.model_type != expected_type { + return Err(OpenAIError::invalid_request(format!( + "Model '{}' is a {} model; this endpoint requires a {}.", + requested, + type_label(&m.model_type), + type_label(&expected_type), + ))); + } + return Ok(m.id.clone()); + } + + // Pass 2: exact name match (case-sensitive — same as OpenAI). + if let Some(m) = models.iter().find(|m| m.name == requested) { + if m.model_type != expected_type { + return Err(OpenAIError::invalid_request(format!( + "Model '{}' is a {} model; this endpoint requires a {}.", + requested, + type_label(&m.model_type), + type_label(&expected_type), + ))); + } + return Ok(m.id.clone()); + } + + Err(OpenAIError::not_found(format!( + "Model '{requested}' not found. List registered models with GET /v1/models." + ))) +} + +fn type_label(t: &ModelType) -> &'static str { + match t { + ModelType::Llm => "LLM", + ModelType::Embedding => "embedding", + ModelType::Transcription => "transcription", + } +} diff --git a/rust-executor/src/api/openai_compat/models.rs b/rust-executor/src/api/openai_compat/models.rs new file mode 100644 index 000000000..3ab155388 --- /dev/null +++ b/rust-executor/src/api/openai_compat/models.rs @@ -0,0 +1,48 @@ +//! `GET /v1/models` — list every model registered with the executor. + +use axum::Json; + +use super::errors::{OpenAIError, OpenAIResult}; +use super::types::{ModelExtensions, ModelInfo, ModelListResponse}; +use crate::agent::capabilities::{check_capability, AI_READ_CAPABILITY}; +use crate::api::auth::AuthContext; +use crate::db::Ad4mDb; +use crate::types::Model; + +pub async fn list_models(auth: AuthContext) -> OpenAIResult> { + check_capability(&auth.capabilities, &AI_READ_CAPABILITY).map_err(OpenAIError::forbidden)?; + + let models = Ad4mDb::with_global_instance(|db| db.get_models()) + .map_err(|e| OpenAIError::internal(format!("Database error listing models: {e}")))?; + + let data = models.into_iter().map(model_to_info).collect(); + Ok(Json(ModelListResponse { + object: "list", + data, + })) +} + +/// Map an AD4M [`Model`] to an OpenAI-shaped [`ModelInfo`]. Local vs remote +/// backends are surfaced via the `ad4m.backend` extension so callers that +/// want to prefer local inference can filter on it. +fn model_to_info(m: Model) -> ModelInfo { + let backend = if m.api.is_some() { + Some("remote".to_string()) + } else if m.local.is_some() { + Some("local".to_string()) + } else { + None + }; + let model_type = format!("{}", m.model_type).to_lowercase(); + ModelInfo { + id: m.id, + object: "model", + created: 0, + owned_by: "ad4m", + extensions: ModelExtensions { + model_type, + name: m.name, + backend, + }, + } +} diff --git a/rust-executor/src/api/openai_compat/realtime.rs b/rust-executor/src/api/openai_compat/realtime.rs new file mode 100644 index 000000000..4a7d171a3 --- /dev/null +++ b/rust-executor/src/api/openai_compat/realtime.rs @@ -0,0 +1,251 @@ +//! `GET /v1/realtime` — OpenAI Realtime-style WebSocket for streaming +//! transcription. +//! +//! Protocol (subset of the OpenAI Realtime spec): +//! +//! * **Client → server** +//! - `{"type":"transcription_session.update","session":{"model":""}}` — +//! opens a Whisper session. Must arrive before any audio. +//! - `{"type":"input_audio_buffer.append","audio":""}` — +//! 16 kHz mono `f32` LE samples, base64-encoded. +//! - `{"type":"input_audio_buffer.commit"}` — flush any pending +//! segment. Optional; idle timeout flushes too. +//! - `{"type":"session.close"}` — close the session and the socket. +//! +//! * **Server → client** +//! - `{"type":"transcription_session.created","session_id":""}` +//! - `{"type":"conversation.item.input_audio_transcription.delta","delta":""}` +//! - `{"type":"conversation.item.input_audio_transcription.completed","transcript":""}` +//! emitted on each finalised whisper segment (the existing engine +//! emits one event per segment via the broadcast channel). +//! - `{"type":"error","error":{"message":"...","code":"..."}}` +//! +//! Auth: the WS extractor reads the bearer JWT from the +//! `Authorization` header *or* a `?token=` query param (per +//! `AuthContext`). Capabilities checked: `AI_TRANSCRIBE`. + +use axum::{ + extract::ws::{Message, WebSocket, WebSocketUpgrade}, + response::Response, +}; +use base64::Engine; +use futures::SinkExt; +use serde_json::{json, Value}; + +use super::model_selector::resolve_model; +use crate::agent::capabilities::{check_capability, AI_TRANSCRIBE_CAPABILITY}; +use crate::ai_service::AIService; +use crate::api::auth::AuthContext; +use crate::types::ModelType; + +pub async fn realtime_ws(auth: AuthContext, ws: WebSocketUpgrade) -> Response { + ws.on_upgrade(move |socket| handle_socket(auth, socket)) +} + +async fn handle_socket(auth: AuthContext, mut socket: WebSocket) { + // Capability check happens after upgrade — there's no clean way to + // reject a 403 mid-handshake under axum's `on_upgrade`, so we emit + // an `error` envelope and close. + if let Err(e) = check_capability(&auth.capabilities, &AI_TRANSCRIBE_CAPABILITY) { + let _ = send_error(&mut socket, "forbidden", &e).await; + let _ = socket.close().await; + return; + } + + let mut session_state: Option = None; + + while let Some(msg) = socket.recv().await { + let msg = match msg { + Ok(m) => m, + Err(_) => break, + }; + let text = match msg { + Message::Text(t) => t.to_string(), + Message::Binary(_) => { + let _ = send_error( + &mut socket, + "invalid_request", + "Binary frames not supported", + ) + .await; + continue; + } + Message::Close(_) => break, + // Ping/pong handled by axum automatically. + _ => continue, + }; + let parsed: Value = match serde_json::from_str(&text) { + Ok(v) => v, + Err(e) => { + let _ = + send_error(&mut socket, "invalid_request", &format!("JSON parse: {e}")).await; + continue; + } + }; + let event_type = parsed + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + match event_type.as_str() { + "transcription_session.update" => { + let model_str = parsed + .get("session") + .and_then(|s| s.get("model")) + .and_then(|m| m.as_str()) + .unwrap_or("default"); + let model_id = match resolve_model(model_str, ModelType::Transcription).await { + Ok(id) => id, + Err(e) => { + let _ = send_error(&mut socket, "invalid_request", &e.error.message).await; + continue; + } + }; + let service = match AIService::global_instance().await { + Ok(s) => s, + Err(e) => { + let _ = send_error(&mut socket, "server_error", &e.to_string()).await; + continue; + } + }; + let stream_id = match service + .open_transcription_stream(model_id.clone(), None, auth.auth_token.clone()) + .await + { + Ok(id) => id, + Err(e) => { + let _ = send_error(&mut socket, "server_error", &e.to_string()).await; + continue; + } + }; + let created = json!({ + "type": "transcription_session.created", + "session_id": stream_id, + "model": model_id, + }); + let _ = socket.send(Message::Text(created.to_string().into())).await; + session_state = Some(SessionState { stream_id }); + } + "input_audio_buffer.append" => { + let Some(ref state) = session_state else { + let _ = send_error( + &mut socket, + "invalid_request", + "Send `transcription_session.update` before appending audio.", + ) + .await; + continue; + }; + let audio_b64 = match parsed.get("audio").and_then(|v| v.as_str()) { + Some(s) => s, + None => { + let _ = send_error( + &mut socket, + "invalid_request", + "Missing required field: audio", + ) + .await; + continue; + } + }; + let bytes = match base64::prelude::BASE64_STANDARD.decode(audio_b64) { + Ok(b) => b, + Err(e) => { + let _ = send_error(&mut socket, "invalid_request", &format!("base64: {e}")) + .await; + continue; + } + }; + let samples = bytes_to_f32_le(&bytes); + + let service = AIService::global_instance().await.ok(); + if let Some(service) = service { + let rx = service + .feed_transcription_stream_with_broadcast( + &state.stream_id, + samples, + &auth.auth_token, + ) + .await; + if let Ok(mut rx) = rx { + // Drain any whisper deltas that became + // available since the last append. We do this + // synchronously on each append rather than + // running a background forwarder — keeps the + // socket lifetime obvious and avoids races at + // close time. + while let Ok(text) = rx.try_recv() { + let delta = json!({ + "type": "conversation.item.input_audio_transcription.delta", + "delta": text, + }); + let _ = socket.send(Message::Text(delta.to_string().into())).await; + } + } + } + } + "input_audio_buffer.commit" => { + // No-op for now — the whisper engine flushes on its own + // VAD detection. Forward the no-op as `completed` so + // OpenAI clients that wait for it can proceed; if + // there's nothing to flush, the `completed` carries an + // empty transcript. + let completed = json!({ + "type": "conversation.item.input_audio_transcription.completed", + "transcript": "", + }); + let _ = socket + .send(Message::Text(completed.to_string().into())) + .await; + } + "session.close" => { + break; + } + _ => { + let _ = send_error( + &mut socket, + "invalid_request", + &format!("Unsupported event type: {event_type}"), + ) + .await; + } + } + } + + if let Some(state) = session_state { + if let Ok(service) = AIService::global_instance().await { + let _ = service + .close_transcription_stream(&state.stream_id, &auth.auth_token) + .await; + } + } + let _ = socket.close().await; +} + +struct SessionState { + stream_id: String, +} + +async fn send_error(socket: &mut WebSocket, code: &str, message: &str) -> Result<(), ()> { + let body = json!({ + "type": "error", + "error": { + "message": message, + "code": code, + } + }); + socket + .send(Message::Text(body.to_string().into())) + .await + .map_err(|_| ()) +} + +/// Reinterpret a raw byte slice as little-endian `f32` samples. +fn bytes_to_f32_le(bytes: &[u8]) -> Vec { + let mut out = Vec::with_capacity(bytes.len() / 4); + for chunk in bytes.chunks_exact(4) { + out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])); + } + out +} diff --git a/rust-executor/src/api/openai_compat/router.rs b/rust-executor/src/api/openai_compat/router.rs new file mode 100644 index 000000000..fe4d91a21 --- /dev/null +++ b/rust-executor/src/api/openai_compat/router.rs @@ -0,0 +1,35 @@ +//! axum router for the OpenAI-compatible surface. +//! +//! Mount under `/v1` from the top-level [`api_router`](super::super::api_router). +//! All endpoints share the existing [`AppState`](crate::api::auth::AppState) +//! so JWT extraction, capabilities, and admin credentials work the same +//! way as the native WS-RPC surface. + +use axum::{ + extract::DefaultBodyLimit, + routing::{get, post}, + Router, +}; + +use super::{audio, chat, embeddings, models, realtime}; +use crate::api::auth::AppState; + +/// Build the `/v1` router. Mounted by [`super::super::api_router`]. +/// +/// The router is intentionally minimal — every endpoint delegates to a +/// dedicated handler module so reviewers can audit one phase at a time. +/// Body limit is bumped on audio routes to accommodate file uploads +/// (defaults to 2 MiB in axum, which is small for whisper inputs). +pub fn router() -> Router { + Router::new() + .route("/models", get(models::list_models)) + .route("/chat/completions", post(chat::chat_completions)) + .route("/completions", post(chat::completions)) + .route("/embeddings", post(embeddings::embeddings)) + .route( + "/audio/transcriptions", + post(audio::transcriptions).layer(DefaultBodyLimit::max(50 * 1024 * 1024)), + ) + .route("/audio/speech", post(audio::speech)) + .route("/realtime", get(realtime::realtime_ws)) +} diff --git a/rust-executor/src/api/openai_compat/tts_passthrough.rs b/rust-executor/src/api/openai_compat/tts_passthrough.rs new file mode 100644 index 000000000..c996d9f57 --- /dev/null +++ b/rust-executor/src/api/openai_compat/tts_passthrough.rs @@ -0,0 +1,126 @@ +//! Remote-passthrough TTS client. +//! +//! The proposal allows three TTS backends behind a common +//! `AIService::synthesize` trait: local Kokoro, local Piper, or a +//! passthrough to an OpenAI-compatible upstream. This module implements +//! the passthrough path; the local backends are scoped to a follow-up +//! (they add `ort` + voice model assets, both of which deserve their own +//! review). +//! +//! Configuration: we look for a TTS-typed model in the existing model +//! registry whose `api` field carries an `OPEN_AI` upstream. When +//! `model_type == Llm` we currently misuse — see TODO below — the +//! configured remote LLM API key for TTS; once `ModelType::TextToSpeech` +//! lands, this module reads from the proper TTS-typed registration. +//! +//! No new dependencies: we reuse the existing `reqwest` client. + +use crate::api::auth::AuthContext; +use crate::db::Ad4mDb; +use crate::types::{ModelApiType, ModelType}; + +use super::types::SpeechRequest; + +#[derive(Debug)] +pub enum TtsPassthroughError { + /// No TTS-capable upstream is configured on this executor. + NotConfigured, + Upstream(String), +} + +/// Forward a `/v1/audio/speech` request to a configured upstream and +/// return the raw audio bytes. Returns [`TtsPassthroughError::NotConfigured`] +/// when no upstream is registered. +pub async fn synthesize( + _auth: &AuthContext, + req: &SpeechRequest, +) -> Result, TtsPassthroughError> { + let upstream = find_tts_upstream(&req.model)?; + + let body = serde_json::json!({ + "model": upstream.upstream_model, + "input": req.input, + "voice": req.voice.clone().unwrap_or_else(|| "alloy".to_string()), + "response_format": req.response_format.clone().unwrap_or_else(|| "mp3".to_string()), + "speed": req.speed.unwrap_or(1.0), + }); + + let url = format!( + "{}/audio/speech", + upstream + .base_url + .trim_end_matches('/') + .trim_end_matches("/v1"), + ); + + let client = reqwest::Client::new(); + let resp = client + .post(format!("{url}/v1/audio/speech")) + .bearer_auth(&upstream.api_key) + .json(&body) + .send() + .await + .map_err(|e| TtsPassthroughError::Upstream(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + return Err(TtsPassthroughError::Upstream(format!( + "Upstream returned {status}: {text}" + ))); + } + + let bytes = resp + .bytes() + .await + .map_err(|e| TtsPassthroughError::Upstream(e.to_string()))?; + Ok(bytes.to_vec()) +} + +struct TtsUpstream { + base_url: String, + api_key: String, + upstream_model: String, +} + +/// Look up a TTS-capable upstream in the model registry. +/// +/// We don't have `ModelType::TextToSpeech` yet (that lands with the +/// local Kokoro PR), so the resolution rule is: any LLM-typed `Model` +/// whose name matches `requested_model` AND has an `OPEN_AI` API +/// upstream is treated as a TTS-capable forwarder. This makes the +/// passthrough usable today without a schema migration; once the +/// `ModelType::TextToSpeech` enum lands, swap the filter to that type +/// and remove this comment. +fn find_tts_upstream(requested_model: &str) -> Result { + let models = Ad4mDb::with_global_instance(|db| db.get_models()) + .map_err(|_| TtsPassthroughError::NotConfigured)?; + + // Prefer an exact name match (callers send the upstream model + // string verbatim); fall back to the configured "default" remote + // LLM if none matches. + let candidate = models + .iter() + .find(|m| m.model_type == ModelType::Llm && m.api.is_some() && m.name == requested_model) + .or_else(|| { + models.iter().find(|m| { + m.model_type == ModelType::Llm + && m.api + .as_ref() + .map(|a| matches!(a.api_type, ModelApiType::OpenAi)) + .unwrap_or(false) + }) + }) + .ok_or(TtsPassthroughError::NotConfigured)?; + + let api = candidate + .api + .as_ref() + .ok_or(TtsPassthroughError::NotConfigured)?; + + Ok(TtsUpstream { + base_url: api.base_url.to_string(), + api_key: api.api_key.clone(), + upstream_model: requested_model.to_string(), + }) +} diff --git a/rust-executor/src/api/openai_compat/types.rs b/rust-executor/src/api/openai_compat/types.rs new file mode 100644 index 000000000..69172611e --- /dev/null +++ b/rust-executor/src/api/openai_compat/types.rs @@ -0,0 +1,317 @@ +//! OpenAI request/response schemas. +//! +//! All field names are spelled exactly as the OpenAI HTTP API expects — +//! that's the whole point of this module. Internal AD4M-specific names +//! go behind serde aliases or under the `ad4m` extension namespace. + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +// --------------------------------------------------------------------------- +// Common +// --------------------------------------------------------------------------- + +/// Standard OpenAI usage object. All token counts are best-effort: the +/// local Kalosm backend reports `chars/4` estimates today; remote upstreams +/// return exact counts which we forward when available. +#[derive(Debug, Clone, Serialize)] +pub struct Usage { + pub prompt_tokens: u64, + pub completion_tokens: u64, + pub total_tokens: u64, +} + +// --------------------------------------------------------------------------- +// /v1/models +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize)] +pub struct ModelListResponse { + pub object: &'static str, // "list" + pub data: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ModelInfo { + pub id: String, + pub object: &'static str, // "model" + pub created: i64, // unix epoch seconds (we report 0 for static models) + pub owned_by: &'static str, // "ad4m" + /// AD4M-specific extension fields. Tools that don't know about them + /// ignore them; the proposal's spec carries `ad4m.model_type` here. + #[serde(rename = "ad4m")] + pub extensions: ModelExtensions, +} + +#[derive(Debug, Serialize)] +pub struct ModelExtensions { + pub model_type: String, // "llm" | "embedding" | "transcription" | "tts" + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub backend: Option, // "local" | "remote" | "passthrough" +} + +// --------------------------------------------------------------------------- +// /v1/chat/completions +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Role { + System, + User, + Assistant, + /// We don't process tool/function messages today — accepted but + /// ignored in the prompt assembly so the request doesn't reject. + Tool, + Function, + Developer, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ChatMessage { + pub role: Role, + pub content: ChatMessageContent, + #[serde(default)] + pub name: Option, +} + +/// OpenAI accepts either a string or an array of content parts. For now we +/// flatten parts to their text and ignore image/audio inputs — those are +/// out of scope for the first pass. +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +pub enum ChatMessageContent { + Text(String), + Parts(Vec), +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentPart { + Text { + text: String, + }, + /// Image input — silently dropped by the prompt assembler today. We + /// keep parsing so clients sending mixed content don't 400. + ImageUrl { + image_url: Value, + }, + /// Same — accepted, not consumed. + InputAudio { + input_audio: Value, + }, +} + +impl ChatMessageContent { + pub fn flatten_to_text(&self) -> String { + match self { + ChatMessageContent::Text(s) => s.clone(), + ChatMessageContent::Parts(parts) => parts + .iter() + .filter_map(|p| match p { + ContentPart::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("\n"), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, + #[serde(default)] + pub stream: bool, + #[serde(default)] + pub temperature: Option, + #[serde(default)] + pub max_tokens: Option, + #[serde(default)] + pub top_p: Option, + #[serde(default)] + pub stop: Option, + #[serde(default)] + pub seed: Option, + #[serde(default)] + pub response_format: Option, + #[serde(default)] + pub user: Option, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: &'static str, // "chat.completion" + pub created: i64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Debug, Serialize)] +pub struct ChatChoice { + pub index: u32, + pub message: ChatResponseMessage, + pub finish_reason: &'static str, // "stop" | "length" | "tool_calls" +} + +#[derive(Debug, Serialize)] +pub struct ChatResponseMessage { + pub role: &'static str, // "assistant" + pub content: String, +} + +// --------------------------------------------------------------------------- +// Streaming chunks (SSE) +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize)] +pub struct ChatCompletionChunk { + pub id: String, + pub object: &'static str, // "chat.completion.chunk" + pub created: i64, + pub model: String, + pub choices: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ChatChunkChoice { + pub index: u32, + pub delta: ChatChunkDelta, + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option<&'static str>, +} + +#[derive(Debug, Default, Serialize)] +pub struct ChatChunkDelta { + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option<&'static str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, +} + +// --------------------------------------------------------------------------- +// /v1/completions (legacy) +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +pub struct CompletionRequest { + pub model: String, + pub prompt: PromptInput, + #[serde(default)] + pub stream: bool, + #[serde(default)] + pub temperature: Option, + #[serde(default)] + pub max_tokens: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum PromptInput { + One(String), + Many(Vec), +} + +impl PromptInput { + pub fn first(&self) -> String { + match self { + PromptInput::One(s) => s.clone(), + PromptInput::Many(v) => v.first().cloned().unwrap_or_default(), + } + } +} + +#[derive(Debug, Serialize)] +pub struct CompletionResponse { + pub id: String, + pub object: &'static str, // "text_completion" + pub created: i64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Debug, Serialize)] +pub struct CompletionChoice { + pub index: u32, + pub text: String, + pub finish_reason: &'static str, +} + +// --------------------------------------------------------------------------- +// /v1/embeddings +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +pub struct EmbeddingRequest { + pub model: String, + pub input: EmbeddingInput, + #[serde(default)] + pub user: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum EmbeddingInput { + One(String), + Many(Vec), +} + +impl EmbeddingInput { + pub fn into_vec(self) -> Vec { + match self { + EmbeddingInput::One(s) => vec![s], + EmbeddingInput::Many(v) => v, + } + } +} + +#[derive(Debug, Serialize)] +pub struct EmbeddingResponse { + pub object: &'static str, // "list" + pub data: Vec, + pub model: String, + pub usage: EmbeddingUsage, +} + +#[derive(Debug, Serialize)] +pub struct EmbeddingItem { + pub object: &'static str, // "embedding" + pub index: usize, + pub embedding: Vec, +} + +#[derive(Debug, Serialize)] +pub struct EmbeddingUsage { + pub prompt_tokens: u64, + pub total_tokens: u64, +} + +// --------------------------------------------------------------------------- +// /v1/audio/transcriptions +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize)] +pub struct TranscriptionResponse { + pub text: String, +} + +// --------------------------------------------------------------------------- +// /v1/audio/speech (TTS) +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +pub struct SpeechRequest { + pub model: String, + pub input: String, + #[serde(default)] + pub voice: Option, + #[serde(default)] + pub response_format: Option, + #[serde(default)] + pub speed: Option, +}