diff --git a/.env.template b/.env.template index 7b94da1f..44120db0 100644 --- a/.env.template +++ b/.env.template @@ -352,6 +352,11 @@ # Set base URL to enable (default: http://localhost:11434/v1) # OLLAMA_BASE_URL=http://localhost:11434/v1 +# Alibaba Cloud Bailian (百炼) +# BAILIAN_API_KEY=... +# BAILIAN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 +# BAILIAN_MODELS=text-embedding-v3,text-embedding-v4 + # vLLM (OpenAI-compatible server) # VLLM_API_KEY is optional; set it only if vllm serve was started with --api-key. # VLLM_API_KEY=token-abc123 diff --git a/.gitignore b/.gitignore index fbe85344..4fbc41a2 100644 --- a/.gitignore +++ b/.gitignore @@ -26,8 +26,14 @@ # Local git worktrees /.worktrees/ +# Local pre-commit hook (not shared) +/.githooks/ + # Others /*.bck.yml /repomix-output.* /coverage.out /.claude/ + +# Superpower design docs and plans (never commit) +/docs/superpowers/ diff --git a/CLAUDE.md b/CLAUDE.md index 6e9f6fdf..ac963e30 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -128,5 +128,5 @@ Full reference: `.env.template` and `config/config.yaml` - **Resilience:** Configured via `config/config.yaml` - global `resilience.retry.*` and `resilience.circuit_breaker.*` defaults with optional per-provider overrides under `providers..resilience.retry.*` and `providers..resilience.circuit_breaker.*`. Retry defaults: `max_retries` (3), `initial_backoff` (1s), `max_backoff` (30s), `backoff_factor` (2.0), `jitter_factor` (0.1). Circuit breaker defaults: `failure_threshold` (5), `success_threshold` (2), `timeout` (30s) - **Metrics:** `METRICS_ENABLED` (false), `METRICS_ENDPOINT` (/metrics) - **Guardrails:** Configured via `config/config.yaml` only (except `GUARDRAILS_ENABLED` env var) -- **Providers:** `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `GEMINI_API_KEY`, `USE_GOOGLE_GEMINI_NATIVE_API` (true by default; false uses Gemini's OpenAI-compatible chat API), `XAI_API_KEY`, `GROQ_API_KEY`, `OPENROUTER_API_KEY`, `ZAI_API_KEY`, `ZAI_BASE_URL` (optional Z.ai endpoint override), `MINIMAX_API_KEY`, `MINIMAX_BASE_URL` (optional MiniMax endpoint override), `XIAOMI_API_KEY`, `XIAOMI_BASE_URL` (optional Xiaomi MiMo endpoint override), `AZURE_API_KEY`, `AZURE_BASE_URL` (Azure OpenAI deployment base URL), `AZURE_API_VERSION` (optional Azure API version), `ORACLE_API_KEY` (Oracle API key), `ORACLE_BASE_URL` (Oracle OpenAI-compatible base URL), `[_SUFFIX]_MODELS` (comma-separated configured model list for any provider type), `OLLAMA_BASE_URL`, `VLLM_BASE_URL`, `VLLM_API_KEY` (optional upstream vLLM bearer token) +- **Providers:** `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `GEMINI_API_KEY`, `USE_GOOGLE_GEMINI_NATIVE_API` (true by default; false uses Gemini's OpenAI-compatible chat API), `XAI_API_KEY`, `GROQ_API_KEY`, `OPENROUTER_API_KEY`, `ZAI_API_KEY`, `ZAI_BASE_URL` (optional Z.ai endpoint override), `MINIMAX_API_KEY`, `MINIMAX_BASE_URL` (optional MiniMax endpoint override), `XIAOMI_API_KEY`, `XIAOMI_BASE_URL` (optional Xiaomi MiMo endpoint override), `BAILIAN_API_KEY`, `BAILIAN_BASE_URL` (optional Bailian base URL for region switching; default `https://dashscope.aliyuncs.com/compatible-mode/v1`), `AZURE_API_KEY`, `AZURE_BASE_URL` (Azure OpenAI deployment base URL), `AZURE_API_VERSION` (optional Azure API version), `ORACLE_API_KEY` (Oracle API key), `ORACLE_BASE_URL` (Oracle OpenAI-compatible base URL), `[_SUFFIX]_MODELS` (comma-separated configured model list for any provider type), `OLLAMA_BASE_URL`, `VLLM_BASE_URL`, `VLLM_API_KEY` (optional upstream vLLM bearer token) - **Provider model metadata:** `providers..models` accepts either model IDs (strings) or `{id, metadata}` objects. When `metadata` is supplied (`display_name`, `context_window`, `max_output_tokens`, `modes`, `capabilities`, `pricing`, …) it is merged onto the remote ai-model-list entry during enrichment, with operator values winning per-field. Primary use case: advertising context windows, capabilities, and pricing for local models (Ollama) and other custom endpoints whose IDs are not in the upstream registry. diff --git a/README.md b/README.md index f876f6c6..1b5ee1ff 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,7 @@ Example model identifiers are illustrative and subject to change; consult provid | OpenRouter | `OPENROUTER_API_KEY` | `google/gemini-2.5-flash` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | Z.ai | `ZAI_API_KEY` (`ZAI_BASE_URL` optional) | `glm-5.1` | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | xAI (Grok) | `XAI_API_KEY` | `grok-4` | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| Alibaba Cloud Bailian | `BAILIAN_API_KEY` (`BAILIAN_BASE_URL` optional) | `qwen3-max` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | MiniMax | `MINIMAX_API_KEY` (`MINIMAX_BASE_URL` optional) | `MiniMax-M3` | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | Xiaomi MiMo | `XIAOMI_API_KEY` (`XIAOMI_BASE_URL` optional) | `mimo-v2.5-pro` | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | | Azure OpenAI | `AZURE_API_KEY` + `AZURE_BASE_URL` (`AZURE_API_VERSION` optional) | `gpt-5` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/cmd/gomodel/docs/docs.go b/cmd/gomodel/docs/docs.go index 219f6d19..ea6f1467 100644 --- a/cmd/gomodel/docs/docs.go +++ b/cmd/gomodel/docs/docs.go @@ -7009,7 +7009,7 @@ var SwaggerInfo = &swag.Spec{ BasePath: "/", Schemes: []string{"http"}, Title: "GoModel API", - Description: "AI gateway routing requests to multiple LLM providers (OpenAI, Anthropic, Gemini, Groq, OpenRouter, DeepSeek, Z.ai, xAI, MiniMax, Xiaomi MiMo, Oracle, Ollama). Drop-in OpenAI-compatible API.", + Description: "AI gateway routing requests to multiple LLM providers (OpenAI, Anthropic, Gemini, Groq, OpenRouter, DeepSeek, Z.ai, xAI, MiniMax, Xiaomi MiMo, Bailian, Oracle, Ollama). Drop-in OpenAI-compatible API.", InfoInstanceName: "swagger", SwaggerTemplate: docTemplate, LeftDelim: "{{", diff --git a/cmd/gomodel/main.go b/cmd/gomodel/main.go index d40b6c53..f65836d3 100644 --- a/cmd/gomodel/main.go +++ b/cmd/gomodel/main.go @@ -18,6 +18,7 @@ import ( "gomodel/internal/providers" "gomodel/internal/providers/anthropic" "gomodel/internal/providers/azure" + "gomodel/internal/providers/bailian" "gomodel/internal/providers/bedrock" "gomodel/internal/providers/deepseek" "gomodel/internal/providers/gemini" @@ -76,7 +77,7 @@ func startApplication(application lifecycleApp, addr string) error { // @title GoModel API // @version 1.0 -// @description AI gateway routing requests to multiple LLM providers (OpenAI, Anthropic, Gemini, Groq, OpenRouter, DeepSeek, Z.ai, xAI, MiniMax, Xiaomi MiMo, Oracle, Ollama). Drop-in OpenAI-compatible API. +// @description AI gateway routing requests to multiple LLM providers (OpenAI, Anthropic, Gemini, Groq, OpenRouter, DeepSeek, Z.ai, xAI, MiniMax, Xiaomi MiMo, Oracle, Ollama, Bailian). Drop-in OpenAI-compatible API. // @BasePath / // @schemes http // @securityDefinitions.apikey BearerAuth @@ -120,6 +121,7 @@ func main() { factory.Add(openai.Registration) factory.Add(openrouter.Registration) factory.Add(azure.Registration) + factory.Add(bailian.Registration) factory.Add(oracle.Registration) factory.Add(anthropic.Registration) factory.Add(bedrock.Registration) diff --git a/config/config.example.yaml b/config/config.example.yaml index 77257c28..a07c5a83 100644 --- a/config/config.example.yaml +++ b/config/config.example.yaml @@ -13,7 +13,7 @@ server: enable_passthrough_routes: true # expose /p/{provider}/{endpoint} passthrough routes allow_passthrough_v1_alias: true # allow /p/{provider}/v1/... while keeping /p/{provider}/... canonical user_path_header: "X-GoModel-User-Path" # env: USER_PATH_HEADER; inbound header used for user_path scoping - enabled_passthrough_providers: ["openai", "anthropic", "openrouter", "zai", "vllm", "deepseek"] # providers enabled on /p/{provider}/... + enabled_passthrough_providers: ["openai", "anthropic", "openrouter", "zai", "vllm", "deepseek", "bailian"] # providers enabled on /p/{provider}/... models: enabled_by_default: true # env: MODELS_ENABLED_BY_DEFAULT; when false, models stay unavailable until an override allows one or more user paths @@ -199,6 +199,15 @@ providers: type: anthropic api_key: "sk-ant-..." + bailian: + type: bailian + api_key: "${BAILIAN_API_KEY}" + # base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1" + # Alternative regions (replace {workspace-id} with your workspace): + # Singapore: "https://{workspace-id}.ap-southeast-1.maas.aliyuncs.com/compatible-mode/v1" + # Frankfurt: "https://{workspace-id}.eu-central-1.maas.aliyuncs.com/compatible-mode/v1" + # Hong Kong: "https://{workspace-id}.cn-hongkong.maas.aliyuncs.com/compatible-mode/v1" + gemini: type: gemini api_key: "..." diff --git a/config/config_test.go b/config/config_test.go index 85305ccd..695e00cc 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1125,7 +1125,7 @@ func TestLoad_ConfigExample_UsesNestedModelCacheSettings(t *testing.T) { t.Fatalf("expected Cache.Model.Redis to be nil in example config, got %+v", result.Config.Cache.Model.Redis) } gotProviders := result.Config.Server.EnabledPassthroughProviders - wantProviders := []string{"openai", "anthropic", "openrouter", "zai", "vllm", "deepseek"} + wantProviders := []string{"openai", "anthropic", "openrouter", "zai", "vllm", "deepseek", "bailian"} if !reflect.DeepEqual(gotProviders, wantProviders) { t.Fatalf("Server.EnabledPassthroughProviders = %v, want %v", gotProviders, wantProviders) } diff --git a/docs/docs.json b/docs/docs.json index 040b29f0..2cd4e394 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -104,6 +104,7 @@ "providers/anthropic", "providers/gemini", "providers/deepseek", +"providers/bailian", "providers/xiaomi", "providers/vllm", "providers/multiple-ollama", diff --git a/docs/openapi.json b/docs/openapi.json index 9ae9a063..caa04352 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1,7 +1,7 @@ { "openapi": "3.0.0", "info": { - "description": "AI gateway routing requests to multiple LLM providers (OpenAI, Anthropic, Gemini, Groq, OpenRouter, DeepSeek, Z.ai, xAI, MiniMax, Oracle, Ollama). Drop-in OpenAI-compatible API.", + "description": "AI gateway routing requests to multiple LLM providers (OpenAI, Anthropic, Gemini, Groq, OpenRouter, DeepSeek, Z.ai, xAI, MiniMax, Xiaomi MiMo, Bailian, Oracle, Ollama). Drop-in OpenAI-compatible API.", "title": "GoModel API", "contact": {}, "version": "1.0" diff --git a/docs/providers/bailian.mdx b/docs/providers/bailian.mdx new file mode 100644 index 00000000..0e1ef229 --- /dev/null +++ b/docs/providers/bailian.mdx @@ -0,0 +1,104 @@ +--- +title: "Alibaba Cloud Bailian" +description: "Configure Alibaba Cloud Bailian (百炼 / DashScope) in GoModel, including the max_tokens compatibility shim for models like Qwen." +icon: "cloud" +--- + +Bailian (百炼) is Alibaba Cloud's model-as-a-service platform for the Qwen +family of models. GoModel routes to Bailian through its OpenAI-compatible +endpoint (`/compatible-mode/v1`). + +Because Bailian deprecated `max_tokens` in April 2026 in favor of +`max_completion_tokens`, GoModel automatically maps the standard +`max_tokens` field to `max_completion_tokens` for every request — no +client change required. + +## Configure + +```bash +BAILIAN_API_KEY=... +``` + +Or in `config.yaml`: + +```yaml +providers: + bailian: + type: bailian + api_key: "${BAILIAN_API_KEY}" + # base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1" +``` + +## Base URLs + +Bailian's OpenAI-compatible API is available in multiple regions. Set +`BAILIAN_BASE_URL` to switch: + +| Region | URL | +| ------ | --- | +| Beijing (default) | `https://dashscope.aliyuncs.com/compatible-mode/v1` | +| Singapore | `https://{workspace-id}.ap-southeast-1.maas.aliyuncs.com/compatible-mode/v1` | +| Frankfurt | `https://{workspace-id}.eu-central-1.maas.aliyuncs.com/compatible-mode/v1` | +| Hong Kong | `https://{workspace-id}.cn-hongkong.maas.aliyuncs.com/compatible-mode/v1` | + +## Model IDs + +Common Qwen model identifiers — check the [Bailian model +list](https://www.alibabacloud.com/help/zh/model-studio/model-list) for the +current catalog: + +| Model | Example ID | +| ----- | ---------- | +| Qwen 3.7 Max | `qwen3.7-max` | +| Qwen 3.7 Plus | `qwen3.7-plus` | +| Qwen 3.6 Flash | `qwen3.6-flash` | +| Qwen 3 Max | `qwen3-max` | +| Qwen 3 Plus | `qwen3-plus` | +| Qwen 3 Flash | `qwen3-flash` | +| Qwen 3 Coder Plus | `qwen3-coder-plus` | +| Text Embedding | `text-embedding-v3` | + +## `max_tokens` compatibility + +Bailian deprecated `max_tokens` on 2026-04-20 (effective 2026-05-30). +All Bailian models now require `max_completion_tokens` instead. + +GoModel transparently maps the standard `max_tokens` parameter to +`max_completion_tokens` for every bailian model — send `max_tokens` as you +normally would, and GoModel rewrites it before forwarding to Bailian. + +```bash +# max_tokens=4096 is automatically sent as max_completion_tokens=4096 +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "qwen3-max", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 4096 + }' +``` + +## Supported features + +| Feature | Supported | +| ------- | :-------: | +| Chat completions | ✅ | +| Streaming chat | ✅ | +| Responses (`/v1/responses`) | ✅ (translated to chat) | +| Embeddings | ✅ (configure model IDs via `BAILIAN_MODELS`) | +| Files (`/v1/files`) | ✅ | +| Batches (`/v1/batches`) | ✅ | +| Passthrough (`/p/bailian/...`) | ✅ | + + + Embedding models (`text-embedding-v3`, `text-embedding-v4`) are served by + the compatible-mode API but are **not** auto-discovered from the upstream + `/v1/models` endpoint. Set `BAILIAN_MODELS=text-embedding-v3,text-embedding-v4` + or use `CONFIGURED_PROVIDER_MODELS_MODE=allowlist` to make them available. + + +## References + +- [Bailian documentation](https://www.alibabacloud.com/help/zh/model-studio/) +- [OpenAI-compatible API reference](https://www.alibabacloud.com/help/zh/model-studio/compatibility-with-openai-responses-api) +- [Qwen model list](https://www.alibabacloud.com/help/zh/model-studio/model-list) diff --git a/docs/providers/overview.mdx b/docs/providers/overview.mdx index d548b61c..7d8e8525 100644 --- a/docs/providers/overview.mdx +++ b/docs/providers/overview.mdx @@ -25,6 +25,7 @@ quirks. | Z.ai | `ZAI_API_KEY` (`ZAI_BASE_URL` optional) | — | | xAI (Grok) | `XAI_API_KEY` | — | | MiniMax | `MINIMAX_API_KEY` (`MINIMAX_BASE_URL` optional) | — | +| Alibaba Cloud Bailian | `BAILIAN_API_KEY` (`BAILIAN_BASE_URL` optional) | [Alibaba Cloud Bailian](/providers/bailian) | | Xiaomi MiMo | `XIAOMI_API_KEY` (`XIAOMI_BASE_URL` optional) | [Xiaomi MiMo](/providers/xiaomi) | | Azure OpenAI | `AZURE_API_KEY` + `AZURE_BASE_URL` (`AZURE_API_VERSION` optional) | [Azure OpenAI](/providers/azure) | | Amazon Bedrock | `BEDROCK_BASE_URL` (region or endpoint) + AWS credentials | [Amazon Bedrock](/providers/bedrock) | diff --git a/internal/providers/bailian/bailian.go b/internal/providers/bailian/bailian.go new file mode 100644 index 00000000..0aa0c5f3 --- /dev/null +++ b/internal/providers/bailian/bailian.go @@ -0,0 +1,302 @@ +// Package bailian provides the Alibaba Cloud Bailian (百炼 / DashScope) provider. +// Bailian is Alibaba Cloud's model-as-a-service platform for the Qwen family +// of models. It exposes an OpenAI-compatible API through the +// /compatible-mode/v1 endpoint. +package bailian + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + + "gomodel/internal/core" + "gomodel/internal/llmclient" + "gomodel/internal/providers" + "gomodel/internal/providers/openai" +) + +const defaultBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1" + +// Registration provides factory registration for the Bailian provider. +var Registration = providers.Registration{ + Type: "bailian", + New: New, + Discovery: providers.DiscoveryConfig{ + DefaultBaseURL: defaultBaseURL, + }, +} + +// Provider implements the core.Provider interface for Alibaba Cloud Bailian. +// It wraps openai.CompatibleProvider and maps max_tokens to +// max_completion_tokens for every request (Bailian deprecated max_tokens +// in April 2026). +type Provider struct { + compatible *openai.CompatibleProvider +} + +// New creates a new Bailian provider from a resolved ProviderConfig. +func New(cfg providers.ProviderConfig, opts providers.ProviderOptions) core.Provider { + baseURL := providers.ResolveBaseURL(cfg.BaseURL, defaultBaseURL) + return &Provider{ + compatible: openai.NewCompatibleProvider(cfg.APIKey, opts, openai.CompatibleProviderConfig{ + ProviderName: "bailian", + BaseURL: baseURL, + SetHeaders: setHeaders, + }), + } +} + +// NewWithHTTPClient creates a new Bailian provider with a custom HTTP client. +// If httpClient is nil, http.DefaultClient is used. +func NewWithHTTPClient(apiKey string, httpClient *http.Client, hooks llmclient.Hooks) *Provider { + return &Provider{ + compatible: openai.NewCompatibleProviderWithHTTPClient(apiKey, httpClient, hooks, openai.CompatibleProviderConfig{ + ProviderName: "bailian", + BaseURL: defaultBaseURL, + SetHeaders: setHeaders, + }), + } +} + +func setHeaders(req *http.Request, apiKey string) { + providers.SetAuthHeaders(req, apiKey, providers.AuthHeaderConfig{AuthScheme: "Bearer "}) +} + +// SetBaseURL configures a custom base URL for the provider. +func (p *Provider) SetBaseURL(url string) { + p.compatible.SetBaseURL(url) +} + +// ChatCompletion maps max_tokens to max_completion_tokens for Bailian models. +// Bailian deprecated max_tokens in April 2026; all models now require max_completion_tokens. +func (p *Provider) ChatCompletion(ctx context.Context, req *core.ChatRequest) (*core.ChatResponse, error) { + return p.compatible.ChatCompletion(ctx, adaptBailianRequest(req)) +} + +// StreamChatCompletion maps max_tokens to max_completion_tokens for streaming requests. +func (p *Provider) StreamChatCompletion(ctx context.Context, req *core.ChatRequest) (io.ReadCloser, error) { + return p.compatible.StreamChatCompletion(ctx, adaptBailianRequest(req)) +} + +// ListModels returns the list of available models from Bailian. +func (p *Provider) ListModels(ctx context.Context) (*core.ModelsResponse, error) { + return p.compatible.ListModels(ctx) +} + +// Responses sends a Responses API request translated through chat completions. +func (p *Provider) Responses(ctx context.Context, req *core.ResponsesRequest) (*core.ResponsesResponse, error) { + return providers.ResponsesViaChat(ctx, p, req) +} + +// StreamResponses streams a Responses API request translated through chat completions. +func (p *Provider) StreamResponses(ctx context.Context, req *core.ResponsesRequest) (io.ReadCloser, error) { + return providers.StreamResponsesViaChat(ctx, p, req, "bailian") +} + +// Embeddings sends an embedding request to Bailian's compatible-mode API. +// Embedding models (text-embedding-v3, text-embedding-v4) must be configured +// via BAILIAN_MODELS as they are not auto-discovered from the upstream +// /v1/models endpoint. +func (p *Provider) Embeddings(ctx context.Context, req *core.EmbeddingRequest) (*core.EmbeddingResponse, error) { + return p.compatible.Embeddings(ctx, req) +} + +// Passthrough routes an opaque provider-native request to Bailian. +// It also adapts max_tokens -> max_completion_tokens in the raw body, +// mirroring the adaptation done in ChatCompletion/StreamChatCompletion. +func (p *Provider) Passthrough(ctx context.Context, req *core.PassthroughRequest) (*core.PassthroughResponse, error) { + adapted, err := adaptPassthroughBody(req.Body) + if err != nil { + slog.Warn("bailian: passthrough body adaptation failed, forwarding original body", + "error", err) + // Read the original body back so we can still forward it + req.Body, err = rewindBody(req.Body, nil) + if err != nil { + return nil, err + } + return p.compatible.Passthrough(ctx, req) + } + if adapted != nil { + req.Body = adapted + } + return p.compatible.Passthrough(ctx, req) +} + +// CreateBatch creates a native Bailian batch job. +func (p *Provider) CreateBatch(ctx context.Context, req *core.BatchRequest) (*core.BatchResponse, error) { + return p.compatible.CreateBatch(ctx, req) +} + +// GetBatch retrieves a Bailian batch job by ID. +func (p *Provider) GetBatch(ctx context.Context, id string) (*core.BatchResponse, error) { + return p.compatible.GetBatch(ctx, id) +} + +// ListBatches lists Bailian batch jobs with pagination. +func (p *Provider) ListBatches(ctx context.Context, limit int, after string) (*core.BatchListResponse, error) { + return p.compatible.ListBatches(ctx, limit, after) +} + +// CancelBatch cancels a pending Bailian batch job. +func (p *Provider) CancelBatch(ctx context.Context, id string) (*core.BatchResponse, error) { + return p.compatible.CancelBatch(ctx, id) +} + +// GetBatchResults fetches Bailian batch results via the output file API. +func (p *Provider) GetBatchResults(ctx context.Context, id string) (*core.BatchResultsResponse, error) { + return p.compatible.GetBatchResults(ctx, id) +} + +// CreateFile uploads a file through Bailian's OpenAI-compatible /files API. +func (p *Provider) CreateFile(ctx context.Context, req *core.FileCreateRequest) (*core.FileObject, error) { + resp, err := p.compatible.CreateFile(ctx, req) + if err != nil { + return nil, err + } + resp.Provider = "bailian" + return resp, nil +} + +// ListFiles lists files through Bailian's OpenAI-compatible /files API. +func (p *Provider) ListFiles(ctx context.Context, purpose string, limit int, after string) (*core.FileListResponse, error) { + resp, err := p.compatible.ListFiles(ctx, purpose, limit, after) + if err != nil { + return nil, err + } + for i := range resp.Data { + resp.Data[i].Provider = "bailian" + } + return resp, nil +} + +// GetFile retrieves a file object through Bailian's OpenAI-compatible /files API. +func (p *Provider) GetFile(ctx context.Context, id string) (*core.FileObject, error) { + resp, err := p.compatible.GetFile(ctx, id) + if err != nil { + return nil, err + } + resp.Provider = "bailian" + return resp, nil +} + +// DeleteFile deletes a file through Bailian's OpenAI-compatible /files API. +func (p *Provider) DeleteFile(ctx context.Context, id string) (*core.FileDeleteResponse, error) { + return p.compatible.DeleteFile(ctx, id) +} + +// GetFileContent fetches raw file bytes through Bailian's /files/{id}/content API. +func (p *Provider) GetFileContent(ctx context.Context, id string) (*core.FileContentResponse, error) { + return p.compatible.GetFileContent(ctx, id) +} + +// adaptBailianRequest maps max_tokens -> max_completion_tokens in the request. +// Bailian deprecated max_tokens in April 2026. +// It moves MaxTokens into ExtraFields as max_completion_tokens so that when +// ChatRequest is serialized, max_tokens is omitted and max_completion_tokens +// appears instead. +// +// If the user already set max_completion_tokens explicitly (Bailian-native +// parameter), its value is preserved and max_tokens is used as a fallback +// only — the explicit value takes precedence. +// +// If either operation fails, the original request is returned unmodified +// and a warning is logged so operators can diagnose the issue. +func adaptBailianRequest(req *core.ChatRequest) *core.ChatRequest { + if req == nil || req.MaxTokens == nil { + return req + } + // If the caller already set max_completion_tokens explicitly, respect it. + if existing := req.ExtraFields.Lookup("max_completion_tokens"); existing != nil { + cloned := *req + cloned.MaxTokens = nil + return &cloned + } + maxTokensJSON, err := json.Marshal(*req.MaxTokens) + if err != nil { + slog.Warn("bailian: failed to marshal MaxTokens for adaptation, forwarding original request", + "error", err) + return req + } + extra, err := core.MergeUnknownJSONFields(req.ExtraFields, map[string]json.RawMessage{ + "max_completion_tokens": maxTokensJSON, + }) + if err != nil { + slog.Warn("bailian: failed to merge ExtraFields for adaptation, forwarding original request", + "error", err) + return req + } + cloned := *req + cloned.ExtraFields = extra + cloned.MaxTokens = nil + return &cloned +} + +// adaptPassthroughBody adapts max_tokens -> max_completion_tokens in a raw +// passthrough request body. It reads the body, parses it as JSON, swaps the +// field if needed, and returns a new io.ReadCloser with the adapted body. +// If no adaptation is needed, it returns nil (the caller should rewind the +// original body). If parsing fails, it returns the error. +func adaptPassthroughBody(body io.ReadCloser) (io.ReadCloser, error) { + if body == nil { + return nil, nil + } + raw, err := io.ReadAll(body) + body.Close() + if err != nil { + return nil, err + } + + var obj map[string]json.RawMessage + if err := json.Unmarshal(raw, &obj); err != nil { + // Not valid JSON — can't adapt, rewind original bytes. + return io.NopCloser(bytes.NewReader(raw)), nil + } + + // No max_tokens present — no adaptation needed, rewind original. + if _, hasMaxTokens := obj["max_tokens"]; !hasMaxTokens { + return nil, nil + } + + // Caller already set max_completion_tokens — just remove max_tokens. + if _, hasMCT := obj["max_completion_tokens"]; hasMCT { + delete(obj, "max_tokens") + adapted, err := json.Marshal(obj) + if err != nil { + return nil, err + } + return io.NopCloser(bytes.NewReader(adapted)), nil + } + + // Swap: max_tokens -> max_completion_tokens, remove max_tokens. + obj["max_completion_tokens"] = obj["max_tokens"] + delete(obj, "max_tokens") + adapted, err := json.Marshal(obj) + if err != nil { + return nil, err + } + return io.NopCloser(bytes.NewReader(adapted)), nil +} + +// rewindBody reads an io.ReadCloser into memory and returns a new ReadCloser +// over the same bytes, allowing the body to be read again. +// If fallback is non-nil, it is used when the original body is nil or cannot be read. +func rewindBody(body io.ReadCloser, fallback []byte) (io.ReadCloser, error) { + if body == nil { + if fallback != nil { + return io.NopCloser(bytes.NewReader(fallback)), nil + } + return nil, nil + } + raw, err := io.ReadAll(body) + body.Close() + if err != nil { + if fallback != nil { + return io.NopCloser(bytes.NewReader(fallback)), nil + } + return nil, err + } + return io.NopCloser(bytes.NewReader(raw)), nil +} diff --git a/internal/providers/bailian/bailian_test.go b/internal/providers/bailian/bailian_test.go new file mode 100644 index 00000000..42822625 --- /dev/null +++ b/internal/providers/bailian/bailian_test.go @@ -0,0 +1,952 @@ +package bailian + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "gomodel/internal/core" + "gomodel/internal/llmclient" + "gomodel/internal/providers" +) + +func TestChatCompletion_SendsBearerAuthAndCorrectPath(t *testing.T) { + var gotPath string + var gotAuth string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "id":"chatcmpl-bailian", + "created":1677652288, + "model":"qwen3-max", + "choices":[{"index":0,"message":{"role":"assistant","content":"hello"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":5,"completion_tokens":10,"total_tokens":15} + }`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("bailian-key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + resp, err := provider.ChatCompletion(context.Background(), &core.ChatRequest{ + Model: "qwen3-max", + Messages: []core.Message{ + {Role: "user", Content: "hi"}, + }, + }) + if err != nil { + t.Fatalf("ChatCompletion() error = %v", err) + } + if resp.Model != "qwen3-max" { + t.Fatalf("resp.Model = %q, want qwen3-max", resp.Model) + } + if gotPath != "/chat/completions" { + t.Fatalf("path = %q, want /chat/completions", gotPath) + } + if gotAuth != "Bearer bailian-key" { + t.Fatalf("authorization = %q, want Bearer bailian-key", gotAuth) + } +} + +func TestChatCompletion_MaxTokensMapping(t *testing.T) { + var gotBody []byte + var readErr error + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotBody, readErr = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "id":"chatcmpl-bailian", + "created":1677652288, + "model":"qwen3-max", + "choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":5,"completion_tokens":10,"total_tokens":15} + }`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + maxTokens := 4096 + _, err := provider.ChatCompletion(context.Background(), &core.ChatRequest{ + Model: "qwen3-max", + Messages: []core.Message{{Role: "user", Content: "hi"}}, + MaxTokens: &maxTokens, + }) + if err != nil { + t.Fatalf("ChatCompletion() error = %v", err) + } + if readErr != nil { + t.Fatalf("reading request body: %v", readErr) + } + + var sentBody map[string]any + if err := json.Unmarshal(gotBody, &sentBody); err != nil { + t.Fatalf("unmarshal sent body: %v", err) + } + if _, exists := sentBody["max_tokens"]; exists { + t.Fatal("sent body should NOT contain max_tokens (Bailian deprecated it)") + } + mct, exists := sentBody["max_completion_tokens"] + if !exists { + t.Fatal("sent body should contain max_completion_tokens") + } + if mct.(float64) != 4096 { + t.Fatalf("max_completion_tokens = %v, want 4096", mct) + } +} + +func TestChatCompletion_NoMaxTokensMapping(t *testing.T) { + var gotBody []byte + var readErr error + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotBody, readErr = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "id":"chatcmpl-bailian", + "created":1677652288, + "model":"qwen3-max", + "choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":5,"completion_tokens":10,"total_tokens":15} + }`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + _, err := provider.ChatCompletion(context.Background(), &core.ChatRequest{ + Model: "qwen3-max", + Messages: []core.Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("ChatCompletion() error = %v", err) + } + if readErr != nil { + t.Fatalf("reading request body: %v", readErr) + } + + var sentBody map[string]any + if err := json.Unmarshal(gotBody, &sentBody); err != nil { + t.Fatalf("unmarshal sent body: %v", err) + } + if _, exists := sentBody["max_completion_tokens"]; exists { + t.Fatal("sent body should NOT contain max_completion_tokens when request had no max_tokens") + } +} + +func TestStreamChatCompletion_MaxTokensMapping(t *testing.T) { + var gotBody []byte + var readErr error + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotBody, readErr = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + maxTokens := 2048 + _, err := provider.StreamChatCompletion(context.Background(), &core.ChatRequest{ + Model: "qwen3-flash", + Messages: []core.Message{{Role: "user", Content: "hi"}}, + MaxTokens: &maxTokens, + }) + if err != nil { + t.Fatalf("StreamChatCompletion() error = %v", err) + } + if readErr != nil { + t.Fatalf("reading request body: %v", readErr) + } + + var sentBody map[string]any + if err := json.Unmarshal(gotBody, &sentBody); err != nil { + t.Fatalf("unmarshal sent body: %v", err) + } + if _, exists := sentBody["max_tokens"]; exists { + t.Fatal("sent body should NOT contain max_tokens for streaming either") + } + mct, exists := sentBody["max_completion_tokens"] + if !exists { + t.Fatal("sent body should contain max_completion_tokens for streaming") + } + if mct.(float64) != 2048 { + t.Fatalf("max_completion_tokens = %v, want 2048", mct) + } +} + +func TestStreamChatCompletion_ReturnsSSE(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data: {\"id\":\"x\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hi\"},\"finish_reason\":null}]}\n\n")) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + stream, err := provider.StreamChatCompletion(context.Background(), &core.ChatRequest{ + Model: "qwen3-flash", + Messages: []core.Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("StreamChatCompletion() error = %v", err) + } + defer stream.Close() + + body, err := io.ReadAll(stream) + if err != nil { + t.Fatalf("failed to read stream: %v", err) + } + if !strings.Contains(string(body), "[DONE]") { + t.Fatalf("stream should contain [DONE], got: %s", string(body)) + } +} + +func TestListModels_ReturnsModels(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "object":"list", + "data":[ + {"id":"qwen3-max","object":"model","owned_by":"alibaba"}, + {"id":"qwen3-plus","object":"model","owned_by":"alibaba"}, + {"id":"qwen3-flash","object":"model","owned_by":"alibaba"} + ] + }`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + resp, err := provider.ListModels(context.Background()) + if err != nil { + t.Fatalf("ListModels() error = %v", err) + } + if len(resp.Data) != 3 { + t.Fatalf("got %d models, want 3", len(resp.Data)) + } +} + +func TestEmbeddings_SendsRequest(t *testing.T) { + var gotPath string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "object":"list", + "data":[{"object":"embedding","embedding":[0.1,0.2],"index":0}], + "model":"text-embedding-v3", + "usage":{"prompt_tokens":2,"total_tokens":2} + }`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + _, err := provider.Embeddings(context.Background(), &core.EmbeddingRequest{ + Model: "text-embedding-v3", + Input: "test", + }) + if err != nil { + t.Fatalf("Embeddings() error = %v", err) + } + if gotPath != "/embeddings" { + t.Fatalf("path = %q, want /embeddings", gotPath) + } +} + +func TestResponsesViaChat_DelegatesToChat(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "id":"chatcmpl-bailian", + "created":1677652288, + "model":"qwen3-max", + "choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":5,"completion_tokens":5,"total_tokens":10} + }`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + resp, err := provider.Responses(context.Background(), &core.ResponsesRequest{ + Model: "qwen3-max", + Input: "hello", + }) + if err != nil { + t.Fatalf("Responses() error = %v", err) + } + if resp.Status != "completed" { + t.Fatalf("status = %q, want completed", resp.Status) + } +} + +func TestDefaultBaseURL(t *testing.T) { + provider := NewWithHTTPClient("key", nil, llmclient.Hooks{}) + if provider == nil { + t.Fatal("expected non-nil provider") + } + // Verify the registration exposes the correct default base URL + if Registration.Discovery.DefaultBaseURL != defaultBaseURL { + t.Fatalf("Registration.DefaultBaseURL = %q, want %q", + Registration.Discovery.DefaultBaseURL, defaultBaseURL) + } + // Verify the provider actually uses the default base URL by checking + // that a ChatCompletion request hits the correct host. + var gotHost string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHost = r.Host + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "id":"chatcmpl-1", + "created":1, + "model":"qwen3-max", + "choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2} + }`)) + })) + defer server.Close() + + // Override the base URL to our test server to capture the request, + // but first verify the provider's default is the expected DashScope URL. + provider.SetBaseURL(server.URL) + + _, err := provider.ChatCompletion(context.Background(), &core.ChatRequest{ + Model: "qwen3-max", + Messages: []core.Message{{Role: "user", Content: "hi"}}, + }) + if err != nil { + t.Fatalf("ChatCompletion() error = %v", err) + } + // The request should have reached our test server, confirming SetBaseURL works. + if gotHost == "" { + t.Fatal("expected a non-empty host in the request") + } +} + +func TestProvider_ExposesBatchAndFileInterfaces(t *testing.T) { + provider := NewWithHTTPClient("key", nil, llmclient.Hooks{}) + if _, ok := any(provider).(core.NativeBatchProvider); !ok { + t.Fatal("bailian should implement native batch") + } + if _, ok := any(provider).(core.NativeFileProvider); !ok { + t.Fatal("bailian should implement native file") + } +} + +func TestRegistration_TypeAndBaseURL(t *testing.T) { + if Registration.Type != "bailian" { + t.Fatalf("Registration.Type = %q, want bailian", Registration.Type) + } + if Registration.Discovery.DefaultBaseURL != defaultBaseURL { + t.Fatalf("DefaultBaseURL mismatch") + } +} + +func TestPassthrough_Delegates(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + resp, err := provider.Passthrough(context.Background(), &core.PassthroughRequest{ + Method: http.MethodPost, + Endpoint: "/chat/completions", + Body: io.NopCloser(strings.NewReader(`{}`)), + }) + if err != nil { + t.Fatalf("Passthrough() error = %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("StatusCode = %d, want 200", resp.StatusCode) + } +} + +func TestAdaptBailianRequest_Nil(t *testing.T) { + if r := adaptBailianRequest(nil); r != nil { + t.Fatal("expected nil") + } +} + +func TestAdaptBailianRequest_NoMaxTokens(t *testing.T) { + req := &core.ChatRequest{Model: "qwen3-max"} + r := adaptBailianRequest(req) + if r.MaxTokens != nil { + t.Fatal("should not set max_completion_tokens when request had none") + } + + +} +func TestNew_UsesRegistrationAndDefaultBaseURL(t *testing.T) { + provider := New(providers.ProviderConfig{ + APIKey: "reg-key", + }, providers.ProviderOptions{}) + if provider == nil { + t.Fatal("New() returned nil") + } + // Verify the provider constructed via registration uses the expected base URL + if Registration.Discovery.DefaultBaseURL != defaultBaseURL { + t.Fatalf("Registration.DefaultBaseURL = %q, want %q", + Registration.Discovery.DefaultBaseURL, defaultBaseURL) + } +} + +func TestStreamResponses_DelegatesToChat(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data: {\"id\":\"x\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n")) + _, _ = w.Write([]byte("data: {\"id\":\"x\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n")) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + stream, err := provider.StreamResponses(context.Background(), &core.ResponsesRequest{ + Model: "qwen3-max", + Input: "hello", + }) + if err != nil { + t.Fatalf("StreamResponses() error = %v", err) + } + defer stream.Close() + + body, err := io.ReadAll(stream) + if err != nil { + t.Fatalf("failed to read stream: %v", err) + } + if len(body) == 0 { + t.Fatal("expected non-empty stream body") + } +} + +func TestAdaptBailianRequest_PreservesOtherFields(t *testing.T) { + maxTokens := 100 + req := &core.ChatRequest{ + Model: "qwen3-max", + Messages: []core.Message{{Role: "user", Content: "hi"}}, + MaxTokens: &maxTokens, + } + r := adaptBailianRequest(req) + if r == req { + t.Fatal("should return a clone, not the original") + } + if r.Model != "qwen3-max" { + t.Fatalf("model = %q", r.Model) + } + if r.MaxTokens != nil { + t.Fatal("MaxTokens should be nil in the clone") + } +} + + +func TestAdaptBailianRequest_RespectsExistingMaxCompletionTokens(t *testing.T) { + extra := core.UnknownJSONFieldsFromMap(map[string]json.RawMessage{ + "max_completion_tokens": json.RawMessage(`200`), + }) + maxTokens := 100 + req := &core.ChatRequest{ + Model: "qwen3-max", + Messages: []core.Message{{Role: "user", Content: "hi"}}, + MaxTokens: &maxTokens, + ExtraFields: extra, + } + r := adaptBailianRequest(req) + if r == req { + t.Fatal("should return a clone") + } + if r.MaxTokens != nil { + t.Fatal("MaxTokens should be nil") + } + + body, err := json.Marshal(r) + if err != nil { + t.Fatalf("failed to marshal adapted request: %v", err) + } + var raw map[string]json.RawMessage + if err := json.Unmarshal(body, &raw); err != nil { + t.Fatalf("failed to unmarshal body: %v", err) + } + if _, exists := raw["max_completion_tokens"]; !exists { + t.Fatal("max_completion_tokens should exist") + } + var mct int + if err := json.Unmarshal(raw["max_completion_tokens"], &mct); err != nil { + t.Fatalf("failed to unmarshal max_completion_tokens: %v", err) + } + if mct != 200 { + t.Fatalf("max_completion_tokens = %d, want 200", mct) + } + if _, exists := raw["max_tokens"]; exists { + t.Fatal("max_tokens should NOT exist in output") + } +} + +func TestChatCompletion_UpstreamError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"message":"bad request","type":"invalid_request_error"}}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + _, err := provider.ChatCompletion(context.Background(), &core.ChatRequest{ + Model: "qwen3-max", + Messages: []core.Message{{Role: "user", Content: "hi"}}, + }) + if err == nil { + t.Fatal("expected error from upstream 400") + } +} + +func TestChatCompletion_TransportFailure(t *testing.T) { + // Use a stub RoundTripper that always returns an error + errTransport := errors.New("simulated transport failure") + provider := NewWithHTTPClient("key", &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return nil, errTransport + }), + }, llmclient.Hooks{}) + + _, err := provider.ChatCompletion(context.Background(), &core.ChatRequest{ + Model: "qwen3-max", + Messages: []core.Message{{Role: "user", Content: "hi"}}, + }) + if err == nil { + t.Fatal("expected transport error") + } +} + +func TestStreamChatCompletion_UpstreamError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":{"message":"unauthorized","type":"authentication_error"}}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + _, err := provider.StreamChatCompletion(context.Background(), &core.ChatRequest{ + Model: "qwen3-max", + Messages: []core.Message{{Role: "user", Content: "hi"}}, + }) + if err == nil { + t.Fatal("expected error from upstream 401") + } +} + +func TestResponses_UpstreamError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"message":"bad request","type":"invalid_request_error"}}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + _, err := provider.Responses(context.Background(), &core.ResponsesRequest{ + Model: "qwen3-max", + Input: "hello", + }) + if err == nil { + t.Fatal("expected error from upstream 400") + } +} + +func TestEmbeddings_UpstreamError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"message":"bad request","type":"invalid_request_error"}}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + _, err := provider.Embeddings(context.Background(), &core.EmbeddingRequest{ + Model: "text-embedding-v3", + Input: "test", + }) + if err == nil { + t.Fatal("expected error from upstream 400") + } +} + +func TestPassthrough_UpstreamError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"internal"}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + resp, err := provider.Passthrough(context.Background(), &core.PassthroughRequest{ + Method: http.MethodPost, + Endpoint: "/chat/completions", + Body: io.NopCloser(strings.NewReader(`{}`)), + }) + if err != nil { + t.Fatalf("Passthrough() should not return error on non-2xx: %v", err) + } + if resp.StatusCode != http.StatusInternalServerError { + t.Fatalf("StatusCode = %d, want 500", resp.StatusCode) + } +} + +func TestListModels_UpstreamError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"internal"}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + _, err := provider.ListModels(context.Background()) + if err == nil { + t.Fatal("expected error from upstream 500") + } +} + +func TestCreateBatch_Delegates(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"batch-bailian-1","object":"batch","status":"validating"}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + resp, err := provider.CreateBatch(context.Background(), &core.BatchRequest{ + InputFileID: "file-1", + Endpoint: "/v1/chat/completions", + }) + if err != nil { + t.Fatalf("CreateBatch() error = %v", err) + } + if resp.ID != "batch-bailian-1" { + t.Fatalf("batch id = %q", resp.ID) + } +} + +func TestGetBatch_Delegates(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"batch-bailian-1","object":"batch","status":"completed"}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + resp, err := provider.GetBatch(context.Background(), "batch-bailian-1") + if err != nil { + t.Fatalf("GetBatch() error = %v", err) + } + if resp.ID != "batch-bailian-1" { + t.Fatalf("batch id = %q", resp.ID) + } +} + +func TestListBatches_Delegates(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"object":"list","data":[]}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + _, err := provider.ListBatches(context.Background(), 10, "") + if err != nil { + t.Fatalf("ListBatches() error = %v", err) + } +} + +func TestCancelBatch_Delegates(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"batch-bailian-1","object":"batch","status":"cancelling"}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + resp, err := provider.CancelBatch(context.Background(), "batch-bailian-1") + if err != nil { + t.Fatalf("CancelBatch() error = %v", err) + } + if resp.ID != "batch-bailian-1" { + t.Fatalf("batch id = %q", resp.ID) + } +} + +func TestCreateFile_Delegates(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"file-1","object":"file","purpose":"batch","bytes":100}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + resp, err := provider.CreateFile(context.Background(), &core.FileCreateRequest{ + Content: []byte("data"), + Purpose: "batch", + }) + if err != nil { + t.Fatalf("CreateFile() error = %v", err) + } + if resp.Provider != "bailian" { + t.Fatalf("provider = %q, want bailian", resp.Provider) + } +} + +func TestDeleteFile_Delegates(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"file-1","object":"file","deleted":true}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + resp, err := provider.DeleteFile(context.Background(), "file-1") + if err != nil { + t.Fatalf("DeleteFile() error = %v", err) + } + if !resp.Deleted { + t.Fatal("expected deleted=true") + } +} + +func TestListFiles_Delegates(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"object":"list","data":[]}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + _, err := provider.ListFiles(context.Background(), "batch", 10, "") + if err != nil { + t.Fatalf("ListFiles() error = %v", err) + } +} + +func TestGetFile_Delegates(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"file-1","object":"file","purpose":"batch"}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + resp, err := provider.GetFile(context.Background(), "file-1") + if err != nil { + t.Fatalf("GetFile() error = %v", err) + } + if resp.Provider != "bailian" { + t.Fatalf("provider = %q, want bailian", resp.Provider) + } +} + +func TestGetFileContent_Delegates(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"text":"content"}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + _, err := provider.GetFileContent(context.Background(), "file-1") + if err != nil { + t.Fatalf("GetFileContent() error = %v", err) + } +} + +func TestGetBatchResults_Delegates(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"batch-1","output_file_id":"file-out-1"}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + _, err := provider.GetBatchResults(context.Background(), "batch-1") + if err != nil { + t.Fatalf("GetBatchResults() error = %v", err) + } +} + +// roundTripperFunc adapts a function to the http.RoundTripper interface. +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +func TestPassthrough_MaxTokensMapping(t *testing.T) { + var gotBody []byte + var readErr error + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotBody, readErr = io.ReadAll(r.Body) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"chatcmpl-bailian","model":"qwen3-max"}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + _, err := provider.Passthrough(context.Background(), &core.PassthroughRequest{ + Method: http.MethodPost, + Endpoint: "/chat/completions", + Body: io.NopCloser(strings.NewReader(`{"model":"qwen3-max","messages":[{"role":"user","content":"hi"}],"max_tokens":4096}`)), + }) + if err != nil { + t.Fatalf("Passthrough() error = %v", err) + } + if readErr != nil { + t.Fatalf("reading request body: %v", readErr) + } + + var sentBody map[string]any + if err := json.Unmarshal(gotBody, &sentBody); err != nil { + t.Fatalf("unmarshal sent body: %v", err) + } + if _, exists := sentBody["max_tokens"]; exists { + t.Fatal("passthrough body should NOT contain max_tokens") + } + mct, exists := sentBody["max_completion_tokens"] + if !exists { + t.Fatal("passthrough body should contain max_completion_tokens") + } + if mct.(float64) != 4096 { + t.Fatalf("max_completion_tokens = %v, want 4096", mct) + } +} + +func TestPassthrough_PreservesExistingMaxCompletionTokens(t *testing.T) { + var gotBody []byte + var readErr error + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotBody, readErr = io.ReadAll(r.Body) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"chatcmpl-bailian","model":"qwen3-max"}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + _, err := provider.Passthrough(context.Background(), &core.PassthroughRequest{ + Method: http.MethodPost, + Endpoint: "/chat/completions", + Body: io.NopCloser(strings.NewReader(`{"model":"qwen3-max","messages":[{"role":"user","content":"hi"}],"max_tokens":100,"max_completion_tokens":200}`)), + }) + if err != nil { + t.Fatalf("Passthrough() error = %v", err) + } + if readErr != nil { + t.Fatalf("reading request body: %v", readErr) + } + + var sentBody map[string]any + if err := json.Unmarshal(gotBody, &sentBody); err != nil { + t.Fatalf("unmarshal sent body: %v", err) + } + if _, exists := sentBody["max_tokens"]; exists { + t.Fatal("passthrough body should NOT contain max_tokens when max_completion_tokens already set") + } + mct, exists := sentBody["max_completion_tokens"] + if !exists { + t.Fatal("passthrough body should contain max_completion_tokens") + } + if mct.(float64) != 200 { + t.Fatalf("max_completion_tokens = %v, want 200 (explicit value should win)", mct) + } +} + +func TestPassthrough_NoMaxTokens(t *testing.T) { + var gotBody []byte + var readErr error + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotBody, readErr = io.ReadAll(r.Body) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"chatcmpl-bailian","model":"qwen3-max"}`)) + })) + defer server.Close() + + provider := NewWithHTTPClient("key", server.Client(), llmclient.Hooks{}) + provider.SetBaseURL(server.URL) + + _, err := provider.Passthrough(context.Background(), &core.PassthroughRequest{ + Method: http.MethodPost, + Endpoint: "/chat/completions", + Body: io.NopCloser(strings.NewReader(`{"model":"qwen3-max","messages":[{"role":"user","content":"hi"}]}`)), + }) + if err != nil { + t.Fatalf("Passthrough() error = %v", err) + } + if readErr != nil { + t.Fatalf("reading request body: %v", readErr) + } + + var sentBody map[string]any + if err := json.Unmarshal(gotBody, &sentBody); err != nil { + t.Fatalf("unmarshal sent body: %v", err) + } + if _, exists := sentBody["max_completion_tokens"]; exists { + t.Fatal("passthrough body should NOT contain max_completion_tokens when request had no max_tokens") + } +}