diff --git a/.env.example b/.env.example index 77ca0f3a3..a56452709 100644 --- a/.env.example +++ b/.env.example @@ -39,6 +39,9 @@ # OPENROUTER_API_KEY=sk-or-... # OPENROUTER_MODEL=anthropic/claude-sonnet-4-20250514 +# OPENROUTER_BASE_URL=https://openrouter.ai/api/v1 # Override for OpenRouter-compatible proxies + +# GEMINI_BASE_URL=https://generativelanguage.googleapis.com/v1beta # Override for Gemini-compatible proxies # MINIMAX_API_KEY=... # MINIMAX_MODEL=MiniMax-M2.7 diff --git a/README.md b/README.md index c4ec2c1e0..dbcaa8e7b 100644 --- a/README.md +++ b/README.md @@ -1382,7 +1382,9 @@ Create `~/.agentmemory/.env`: # ANTHROPIC_API_KEY=sk-ant-... # ANTHROPIC_BASE_URL=... # Optional: Anthropic-compatible proxy / Azure # GEMINI_API_KEY=... +# GEMINI_BASE_URL=https://generativelanguage.googleapis.com/v1beta # Optional: Gemini API / proxy base URL # OPENROUTER_API_KEY=... +# OPENROUTER_BASE_URL=https://openrouter.ai/api/v1 # Optional: OpenRouter API / proxy base URL # MINIMAX_API_KEY=... # OPENAI_API_KEY=*** # NOTE: this same key auto-activates BOTH the # # OpenAI LLM provider (here) AND the OpenAI @@ -1419,6 +1421,8 @@ Create `~/.agentmemory/.env`: # OPENAI_BASE_URL=https://api.openai.com # Override for Azure / vLLM / LM Studio / proxies # OPENAI_EMBEDDING_MODEL=text-embedding-3-small # OPENAI_EMBEDDING_DIMENSIONS=1536 # Required when the model is not in the known-models table +# GEMINI_BASE_URL=https://generativelanguage.googleapis.com/v1beta # Also used by Gemini embeddings +# OPENROUTER_BASE_URL=https://openrouter.ai/api/v1 # Also used by OpenRouter embeddings # Outbound LLM / embedding timeout # AGENTMEMORY_LLM_TIMEOUT_MS=60000 # Default: 60 000 ms (60 s). Applies to every diff --git a/src/config.ts b/src/config.ts index f68da2e31..d8deed2ad 100644 --- a/src/config.ts +++ b/src/config.ts @@ -90,6 +90,7 @@ function detectProvider(env: Record): ProviderConfig { provider: "gemini", model: env["GEMINI_MODEL"] || "gemini-2.5-flash", maxTokens, + baseURL: env["GEMINI_BASE_URL"], }; } if (hasRealValue(env["OPENROUTER_API_KEY"])) { @@ -120,6 +121,7 @@ function detectProvider(env: Record): ProviderConfig { provider: "openrouter", model, maxTokens, + baseURL: env["OPENROUTER_BASE_URL"], }; } diff --git a/src/providers/embedding/gemini.ts b/src/providers/embedding/gemini.ts index 140164693..1eb1a2030 100644 --- a/src/providers/embedding/gemini.ts +++ b/src/providers/embedding/gemini.ts @@ -4,16 +4,20 @@ import { fetchWithTimeout } from "../_fetch.js"; const BATCH_LIMIT = 100; const MODEL = "models/gemini-embedding-001"; -const API_BASE = `https://generativelanguage.googleapis.com/v1beta/${MODEL}:batchEmbedContents`; +const DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"; export class GeminiEmbeddingProvider implements EmbeddingProvider { readonly name = "gemini"; readonly dimensions = 768; private apiKey: string; + private endpoint: string; - constructor(apiKey?: string) { + constructor(apiKey?: string, baseURL?: string) { this.apiKey = apiKey || getEnvVar("GEMINI_API_KEY") || ""; if (!this.apiKey) throw new Error("GEMINI_API_KEY is required"); + const baseUrl = (baseURL || getEnvVar("GEMINI_BASE_URL") || DEFAULT_BASE_URL) + .replace(/\/+$/, ""); + this.endpoint = `${baseUrl}/${MODEL}:batchEmbedContents`; } async embed(text: string): Promise { @@ -26,7 +30,7 @@ export class GeminiEmbeddingProvider implements EmbeddingProvider { for (let i = 0; i < texts.length; i += BATCH_LIMIT) { const chunk = texts.slice(i, i + BATCH_LIMIT); - const response = await fetchWithTimeout(`${API_BASE}?key=${this.apiKey}`, { + const response = await fetchWithTimeout(`${this.endpoint}?key=${this.apiKey}`, { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ diff --git a/src/providers/embedding/openrouter.ts b/src/providers/embedding/openrouter.ts index 46999e559..10195d70f 100644 --- a/src/providers/embedding/openrouter.ts +++ b/src/providers/embedding/openrouter.ts @@ -2,20 +2,24 @@ import type { EmbeddingProvider } from "../../types.js"; import { getEnvVar } from "../../config.js"; import { fetchWithTimeout } from "../_fetch.js"; -const API_URL = "https://openrouter.ai/api/v1/embeddings"; +const DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"; export class OpenRouterEmbeddingProvider implements EmbeddingProvider { readonly name = "openrouter"; readonly dimensions = 1536; private apiKey: string; private model: string; + private endpoint: string; - constructor(apiKey?: string) { + constructor(apiKey?: string, baseURL?: string) { this.apiKey = apiKey || getEnvVar("OPENROUTER_API_KEY") || ""; if (!this.apiKey) throw new Error("OPENROUTER_API_KEY is required"); this.model = getEnvVar("OPENROUTER_EMBEDDING_MODEL") || "openai/text-embedding-3-small"; + const baseUrl = (baseURL || getEnvVar("OPENROUTER_BASE_URL") || DEFAULT_BASE_URL) + .replace(/\/+$/, ""); + this.endpoint = `${baseUrl}/embeddings`; } async embed(text: string): Promise { @@ -24,7 +28,7 @@ export class OpenRouterEmbeddingProvider implements EmbeddingProvider { } async embedBatch(texts: string[]): Promise { - const response = await fetchWithTimeout(API_URL, { + const response = await fetchWithTimeout(this.endpoint, { method: "POST", headers: { Authorization: `Bearer ${this.apiKey}`, diff --git a/src/providers/gemini.ts b/src/providers/gemini.ts new file mode 100644 index 000000000..06c1ec607 --- /dev/null +++ b/src/providers/gemini.ts @@ -0,0 +1,73 @@ +import type { MemoryProvider } from "../types.js"; +import { getEnvVar } from "../config.js"; +import { fetchWithTimeout } from "./_fetch.js"; + +const DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"; + +export class GeminiProvider implements MemoryProvider { + name = "gemini"; + private apiKey: string; + private model: string; + private maxTokens: number; + private endpoint: string; + + constructor( + apiKey: string, + model: string, + maxTokens: number, + baseURL?: string, + ) { + this.apiKey = apiKey; + this.model = model; + this.maxTokens = maxTokens; + const baseUrl = (baseURL || getEnvVar("GEMINI_BASE_URL") || DEFAULT_BASE_URL) + .replace(/\/+$/, ""); + this.endpoint = `${baseUrl}/openai/chat/completions`; + } + + async compress(systemPrompt: string, userPrompt: string): Promise { + return this.call(systemPrompt, userPrompt); + } + + async summarize(systemPrompt: string, userPrompt: string): Promise { + return this.call(systemPrompt, userPrompt); + } + + private async call( + systemPrompt: string, + userPrompt: string, + ): Promise { + const response = await fetchWithTimeout(this.endpoint, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiKey}`, + }, + body: JSON.stringify({ + model: this.model, + max_tokens: this.maxTokens, + messages: [ + { role: "system", content: systemPrompt }, + { role: "user", content: userPrompt }, + ], + }), + }); + + if (!response.ok) { + const text = await response.text(); + throw new Error(`Gemini API error (${response.status}): ${text}`); + } + + const data = (await response.json()) as Record; + const choices = data.choices as + | Array<{ message: { content: string } }> + | undefined; + const content = choices?.[0]?.message?.content; + if (!content) { + throw new Error( + `Gemini returned unexpected response: ${JSON.stringify(data).slice(0, 200)}`, + ); + } + return content; + } +} diff --git a/src/providers/index.ts b/src/providers/index.ts index 0ec3feba0..f0e0d023b 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -9,6 +9,7 @@ import { MinimaxProvider } from "./minimax.js"; import { NoopProvider } from "./noop.js"; import { OpenAIProvider } from "./openai.js"; import { OpenRouterProvider } from "./openrouter.js"; +import { GeminiProvider } from "./gemini.js"; import { ResilientProvider } from "./resilient.js"; import { FallbackChainProvider } from "./fallback-chain.js"; import { getEnvVar } from "../config.js"; @@ -115,11 +116,11 @@ function createBaseProvider(config: ProviderConfig): MemoryProvider { "GEMINI_API_KEY (or GOOGLE_API_KEY) is required for the gemini provider", ); } - return new OpenRouterProvider( + return new GeminiProvider( geminiKey, config.model, config.maxTokens, - "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions", + config.baseURL, ); } case "openrouter": @@ -127,7 +128,7 @@ function createBaseProvider(config: ProviderConfig): MemoryProvider { requireEnvVar("OPENROUTER_API_KEY"), config.model, config.maxTokens, - "https://openrouter.ai/api/v1/chat/completions", + config.baseURL, ); case "openai": { const openaiKey = getEnvVar("OPENAI_API_KEY"); diff --git a/src/providers/openrouter.ts b/src/providers/openrouter.ts index 5c47bb0a8..8475e2a27 100644 --- a/src/providers/openrouter.ts +++ b/src/providers/openrouter.ts @@ -1,24 +1,29 @@ import type { MemoryProvider } from "../types.js"; +import { getEnvVar } from "../config.js"; import { fetchWithTimeout } from "./_fetch.js"; +const DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"; + export class OpenRouterProvider implements MemoryProvider { name: string; private apiKey: string; private model: string; private maxTokens: number; - private baseUrl: string; + private endpoint: string; constructor( apiKey: string, model: string, maxTokens: number, - baseUrl: string, + baseURL?: string, ) { this.apiKey = apiKey; this.model = model; this.maxTokens = maxTokens; - this.baseUrl = baseUrl; - this.name = baseUrl.includes("openrouter") ? "openrouter" : "gemini"; + const baseUrl = (baseURL || getEnvVar("OPENROUTER_BASE_URL") || DEFAULT_BASE_URL) + .replace(/\/+$/, ""); + this.endpoint = `${baseUrl}/chat/completions`; + this.name = "openrouter"; } async compress(systemPrompt: string, userPrompt: string): Promise { @@ -33,14 +38,12 @@ export class OpenRouterProvider implements MemoryProvider { systemPrompt: string, userPrompt: string, ): Promise { - const response = await fetchWithTimeout(this.baseUrl, { + const response = await fetchWithTimeout(this.endpoint, { method: "POST", headers: { "Content-Type": "application/json", Authorization: `Bearer ${this.apiKey}`, - ...(this.baseUrl.includes("openrouter") - ? { "HTTP-Referer": "https://github.com/rohitg00/agentmemory" } - : {}), + "HTTP-Referer": "https://github.com/rohitg00/agentmemory", }, body: JSON.stringify({ model: this.model, diff --git a/test/fallback-model-resolution.test.ts b/test/fallback-model-resolution.test.ts index 91a821161..605cd0fa3 100644 --- a/test/fallback-model-resolution.test.ts +++ b/test/fallback-model-resolution.test.ts @@ -24,11 +24,23 @@ vi.mock("../src/providers/openai.js", () => ({ vi.mock("../src/providers/openrouter.js", () => ({ OpenRouterProvider: class { name = "openrouter"; - constructor(_key: string, model: string, _max: number, url?: string) { - captured.push({ - provider: url?.includes("googleapis") ? "gemini" : "openrouter", - model, - }); + constructor(_key: string, model: string) { + captured.push({ provider: "openrouter", model }); + } + async compress() { + return ""; + } + async summarize() { + return ""; + } + }, +})); + +vi.mock("../src/providers/gemini.js", () => ({ + GeminiProvider: class { + name = "gemini"; + constructor(_key: string, model: string) { + captured.push({ provider: "gemini", model }); } async compress() { return ""; diff --git a/test/fetch-timeout.test.ts b/test/fetch-timeout.test.ts index 5b2cd7c9c..e2ff34625 100644 --- a/test/fetch-timeout.test.ts +++ b/test/fetch-timeout.test.ts @@ -2,6 +2,7 @@ import { describe, it, expect, vi, afterEach, beforeEach } from "vitest"; import { fetchWithTimeout } from "../src/providers/_fetch.js"; import { MinimaxProvider } from "../src/providers/minimax.js"; import { OpenRouterProvider } from "../src/providers/openrouter.js"; +import { GeminiProvider } from "../src/providers/gemini.js"; import { OpenAIProvider } from "../src/providers/openai.js"; import { GeminiEmbeddingProvider } from "../src/providers/embedding/gemini.js"; import { OpenAIEmbeddingProvider } from "../src/providers/embedding/openai.js"; @@ -96,7 +97,7 @@ describe("Provider hang regression — MinimaxProvider", () => { }); }); -describe("Provider hang regression — OpenRouterProvider (covers Gemini LLM path)", () => { +describe("Provider hang regression — OpenRouterProvider", () => { beforeEach(() => { vi.spyOn(globalThis, "fetch").mockImplementation(hangingFetch as typeof fetch); process.env["AGENTMEMORY_LLM_TIMEOUT_MS"] = "50"; @@ -111,12 +112,28 @@ describe("Provider hang regression — OpenRouterProvider (covers Gemini LLM pat "test-key", "gemini-2.5-flash", 1024, - "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions", + "https://openrouter.ai/api/v1", ); await expect(provider.compress("system", "user")).rejects.toThrow(); }); }); +describe("Provider hang regression — GeminiProvider", () => { + beforeEach(() => { + vi.spyOn(globalThis, "fetch").mockImplementation(hangingFetch as typeof fetch); + process.env["AGENTMEMORY_LLM_TIMEOUT_MS"] = "50"; + }); + afterEach(() => { + vi.restoreAllMocks(); + delete process.env["AGENTMEMORY_LLM_TIMEOUT_MS"]; + }); + + it("compress() aborts after timeout when upstream hangs", async () => { + const provider = new GeminiProvider("test-key", "gemini-2.5-flash", 1024); + await expect(provider.compress("system", "user")).rejects.toThrow(); + }); +}); + describe("Provider hang regression — GeminiEmbeddingProvider", () => { beforeEach(() => { vi.spyOn(globalThis, "fetch").mockImplementation(hangingFetch as typeof fetch); @@ -343,4 +360,3 @@ describe("OpenAIProvider thinking-model fallback (#627)", () => { expect(out).toBe("real content"); }); }); - diff --git a/test/gemini-provider.test.ts b/test/gemini-provider.test.ts new file mode 100644 index 000000000..1a28db608 --- /dev/null +++ b/test/gemini-provider.test.ts @@ -0,0 +1,69 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { GeminiProvider } from "../src/providers/gemini.js"; +import { GeminiEmbeddingProvider } from "../src/providers/embedding/gemini.js"; + +const originalBaseUrl = process.env["GEMINI_BASE_URL"]; + +function chatResponse(): Response { + return new Response( + JSON.stringify({ choices: [{ message: { content: "result" } }] }), + { status: 200 }, + ); +} + +function embeddingResponse(): Response { + return new Response( + JSON.stringify({ embeddings: [{ values: [0.1, 0.2] }] }), + { status: 200 }, + ); +} + +describe("GeminiProvider base URL", () => { + beforeEach(() => { + delete process.env["GEMINI_BASE_URL"]; + vi.spyOn(globalThis, "fetch").mockResolvedValue(chatResponse()); + }); + + afterEach(() => { + if (originalBaseUrl === undefined) { + delete process.env["GEMINI_BASE_URL"]; + } else { + process.env["GEMINI_BASE_URL"] = originalBaseUrl; + } + vi.restoreAllMocks(); + }); + + it("uses the default base URL for chat and embeddings", async () => { + await new GeminiProvider("key", "model", 100).compress("system", "user"); + expect(globalThis.fetch).toHaveBeenLastCalledWith( + "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions", + expect.objectContaining({ + headers: expect.not.objectContaining({ "HTTP-Referer": expect.anything() }), + }), + ); + + vi.mocked(globalThis.fetch).mockResolvedValueOnce(embeddingResponse()); + await new GeminiEmbeddingProvider("key").embed("text"); + expect(globalThis.fetch).toHaveBeenLastCalledWith( + "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:batchEmbedContents?key=key", + expect.any(Object), + ); + }); + + it("uses GEMINI_BASE_URL for chat and embeddings", async () => { + process.env["GEMINI_BASE_URL"] = "https://gemini-proxy.example/v1beta/"; + + await new GeminiProvider("key", "model", 100).compress("system", "user"); + expect(globalThis.fetch).toHaveBeenLastCalledWith( + "https://gemini-proxy.example/v1beta/openai/chat/completions", + expect.any(Object), + ); + + vi.mocked(globalThis.fetch).mockResolvedValueOnce(embeddingResponse()); + await new GeminiEmbeddingProvider("key").embed("text"); + expect(globalThis.fetch).toHaveBeenLastCalledWith( + "https://gemini-proxy.example/v1beta/models/gemini-embedding-001:batchEmbedContents?key=key", + expect.any(Object), + ); + }); +}); diff --git a/test/openrouter-provider.test.ts b/test/openrouter-provider.test.ts new file mode 100644 index 000000000..1d0e9c1ea --- /dev/null +++ b/test/openrouter-provider.test.ts @@ -0,0 +1,70 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { OpenRouterProvider } from "../src/providers/openrouter.js"; +import { OpenRouterEmbeddingProvider } from "../src/providers/embedding/openrouter.js"; + +const originalBaseUrl = process.env["OPENROUTER_BASE_URL"]; + +function chatResponse(): Response { + return new Response( + JSON.stringify({ choices: [{ message: { content: "result" } }] }), + { status: 200 }, + ); +} + +function embeddingResponse(): Response { + return new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2] }] }), { + status: 200, + }); +} + +describe("OpenRouterProvider base URL", () => { + beforeEach(() => { + delete process.env["OPENROUTER_BASE_URL"]; + vi.spyOn(globalThis, "fetch").mockResolvedValue(chatResponse()); + }); + + afterEach(() => { + if (originalBaseUrl === undefined) { + delete process.env["OPENROUTER_BASE_URL"]; + } else { + process.env["OPENROUTER_BASE_URL"] = originalBaseUrl; + } + vi.restoreAllMocks(); + }); + + it("uses the default base URL for chat and embeddings", async () => { + await new OpenRouterProvider("key", "model", 100).compress("system", "user"); + expect(globalThis.fetch).toHaveBeenLastCalledWith( + "https://openrouter.ai/api/v1/chat/completions", + expect.objectContaining({ + headers: expect.objectContaining({ + "HTTP-Referer": "https://github.com/rohitg00/agentmemory", + }), + }), + ); + + vi.mocked(globalThis.fetch).mockResolvedValueOnce(embeddingResponse()); + await new OpenRouterEmbeddingProvider("key").embed("text"); + expect(globalThis.fetch).toHaveBeenLastCalledWith( + "https://openrouter.ai/api/v1/embeddings", + expect.any(Object), + ); + }); + + it("uses OPENROUTER_BASE_URL for chat and embeddings", async () => { + process.env["OPENROUTER_BASE_URL"] = "https://openrouter-proxy.example/v1/"; + + await new OpenRouterProvider("key", "model", 100).compress("system", "user"); + expect(globalThis.fetch).toHaveBeenLastCalledWith( + "https://openrouter-proxy.example/v1/chat/completions", + expect.any(Object), + ); + + vi.mocked(globalThis.fetch).mockResolvedValueOnce(embeddingResponse()); + await new OpenRouterEmbeddingProvider("key").embed("text"); + expect(globalThis.fetch).toHaveBeenLastCalledWith( + "https://openrouter-proxy.example/v1/embeddings", + expect.any(Object), + ); + }); +});