diff --git a/src/commands/init.ts b/src/commands/init.ts index 9c5fc46..efbec11 100644 --- a/src/commands/init.ts +++ b/src/commands/init.ts @@ -8,6 +8,10 @@ import { installPrepareCommitMsgHook } from '../git/hook.js'; const CUSTOM_KEY = '__custom__'; +export function normalizeBaseUrl(value: string): string { + return value.replace(/\/+$/, ''); +} + export function buildApiKeyPrompt(existingKey: string, apiKeyEnv: string) { return { message: `Enter your API key (will be stored in config), or leave blank to use ${pc.cyan(`$${apiKeyEnv}`)} env var:`, @@ -78,7 +82,7 @@ export async function initCommand(options: { installHook?: boolean } = {}): Prom outro('Setup cancelled.'); return; } - baseUrl = urlResult; + baseUrl = normalizeBaseUrl(urlResult); apiKeyEnv = 'CUSTOM_API_KEY'; needsApiKey = true; } else { diff --git a/tests/init-base-url.test.mjs b/tests/init-base-url.test.mjs new file mode 100644 index 0000000..1c2f4f2 --- /dev/null +++ b/tests/init-base-url.test.mjs @@ -0,0 +1,56 @@ +import assert from 'node:assert/strict'; +import test from 'node:test'; + +import { initCommand, normalizeBaseUrl } from '../dist/commands/init.js'; + +test('handles single trailing slash', () => { + assert.equal(normalizeBaseUrl('https://api.example.com/v1/'), 'https://api.example.com/v1'); +}); + +test('trims multiple trailing slashes from custom base URLs', () => { + assert.equal(normalizeBaseUrl('https://api.example.com/v1///'), 'https://api.example.com/v1'); +}); + +test('preserves custom base URLs without trailing slashes', () => { + assert.equal(normalizeBaseUrl('https://api.example.com/v1'), 'https://api.example.com/v1'); +}); + +test('initCommand saves normalized custom provider base URL', async () => { + const savedConfigs = []; + const selectedValues = ['__custom__', 'custom-model']; + const textValues = ['https://api.example.com/v1/', 'test-api-key', '50']; + const spinnerStub = { + start() {}, + stop() {}, + }; + + await initCommand( + {}, + { + intro() {}, + outro() {}, + select: async () => selectedValues.shift(), + text: async () => textValues.shift(), + confirm: async () => false, + spinner: () => spinnerStub, + isCancel: () => false, + configExists: () => false, + fetchModels: async (provider, baseUrl) => { + assert.equal(provider, '__custom__'); + assert.equal(baseUrl, 'https://api.example.com/v1'); + return ['custom-model']; + }, + saveConfig: async (config) => { + savedConfigs.push(config); + }, + testConnection: async (config) => { + assert.equal(config.baseUrl, 'https://api.example.com/v1'); + return config.model; + }, + }, + ); + + assert.equal(savedConfigs.length, 1); + assert.equal(savedConfigs[0].provider, '__custom__'); + assert.equal(savedConfigs[0].baseUrl, 'https://api.example.com/v1'); +});