From e4d2704ce89caa4e7bd2a706827e49f4ad7c2e2b Mon Sep 17 00:00:00 2001 From: Dutch AI Agency Date: Sun, 3 May 2026 18:35:15 +0100 Subject: [PATCH] fix: resolve local models from configured base url --- api/config.py | 80 +++++++- ...ue1527_lmstudio_base_url_classification.py | 187 ++++++++++++++++++ 2 files changed, 261 insertions(+), 6 deletions(-) create mode 100644 tests/test_issue1527_lmstudio_base_url_classification.py diff --git a/api/config.py b/api/config.py index f71bc812..f6bd774f 100644 --- a/api/config.py +++ b/api/config.py @@ -1799,8 +1799,61 @@ def get_available_models() -> dict: if _pid_key in _PROVIDER_MODELS or _pid_key in cfg.get("providers", {}): detected_providers.add(_pid_key) + def _normalize_base_url_for_match(value: object) -> str: + url = str(value or "").strip().rstrip("/") + if not url: + return "" + parsed_url = urlparse(url if "://" in url else f"http://{url}") + scheme = (parsed_url.scheme or "http").lower() + netloc = (parsed_url.netloc or parsed_url.path).lower().rstrip("/") + path = parsed_url.path.rstrip("/") + if not parsed_url.netloc: + path = "" + return f"{scheme}://{netloc}{path}" + + def _configured_provider_for_base_url(base_url: object) -> str: + target = _normalize_base_url_for_match(base_url) + if not target: + return "" + + if isinstance(model_cfg, dict): + model_base_url = _normalize_base_url_for_match(model_cfg.get("base_url")) + if model_base_url == target: + provider_hint = _resolve_provider_alias(model_cfg.get("provider")) + if provider_hint: + return str(provider_hint).strip().lower() + + providers_cfg = cfg.get("providers", {}) + if isinstance(providers_cfg, dict): + for provider_key, provider_cfg in providers_cfg.items(): + if not isinstance(provider_cfg, dict): + continue + provider_base_url = _normalize_base_url_for_match( + provider_cfg.get("base_url") + ) + if provider_base_url == target: + provider_hint = _resolve_provider_alias(provider_key) + if provider_hint: + return str(provider_hint).strip().lower() + + custom_providers_cfg = cfg.get("custom_providers", []) + if isinstance(custom_providers_cfg, list): + for entry in custom_providers_cfg: + if not isinstance(entry, dict): + continue + entry_base_url = _normalize_base_url_for_match(entry.get("base_url")) + if entry_base_url != target: + continue + entry_name = str(entry.get("name") or "").strip() + if entry_name: + return "custom:" + entry_name.lower().replace(" ", "-") + return "custom" + + return "" + # 4. Fetch models from custom endpoint if base_url is configured auto_detected_models = [] + auto_detected_models_by_provider: dict[str, list[dict]] = {} if cfg_base_url: try: import ipaddress @@ -1812,11 +1865,13 @@ def get_available_models() -> dict: else: endpoint_url = base_url.rstrip("/") + "/v1/models" - provider = "custom" + configured_provider = _configured_provider_for_base_url(base_url) + provider = configured_provider or "custom" + provider_from_config = bool(configured_provider) parsed = urlparse(base_url if "://" in base_url else f"http://{base_url}") host = (parsed.netloc or parsed.path).lower() - if parsed.hostname: + if parsed.hostname and not provider_from_config: try: addr = ipaddress.ip_address(parsed.hostname) if addr.is_private or addr.is_loopback or addr.is_link_local: @@ -1939,8 +1994,11 @@ def get_available_models() -> dict: model_name = model.get("name", "") or model.get("model", "") or model_id if model_id and model_name: label = _format_ollama_label(model_id) if provider in ("ollama", "ollama-cloud") else model_name - auto_detected_models.append({"id": model_id, "label": label}) - detected_providers.add(provider.lower()) + auto_model = {"id": model_id, "label": label} + auto_detected_models.append(auto_model) + provider_key = provider.lower() + auto_detected_models_by_provider.setdefault(provider_key, []).append(auto_model) + detected_providers.add(provider_key) except Exception: logger.debug("Custom endpoint unreachable or misconfigured for provider: %s", provider) @@ -2053,6 +2111,9 @@ def get_available_models() -> dict: ) elif pid in _PROVIDER_MODELS or pid in cfg.get("providers", {}): raw_models = copy.deepcopy(_PROVIDER_MODELS.get(pid, [])) + detected_models = auto_detected_models_by_provider.get(pid, []) + if detected_models and not raw_models: + raw_models = copy.deepcopy(detected_models) provider_cfg = cfg.get("providers", {}).get(pid, {}) if isinstance(provider_cfg, dict) and "models" in provider_cfg: @@ -2070,7 +2131,14 @@ def get_available_models() -> dict: } ) else: - if auto_detected_models: + detected_models = auto_detected_models_by_provider.get(pid) + if detected_models: + models_for_group = copy.deepcopy(detected_models) + elif auto_detected_models: + models_for_group = copy.deepcopy(auto_detected_models) + else: + models_for_group = [] + if models_for_group: # Per-group deep copy so subsequent mutation by # _deduplicate_model_ids() (which prefixes ids with # @provider_id:) does not bleed into other groups @@ -2084,7 +2152,7 @@ def get_available_models() -> dict: { "provider": provider_name, "provider_id": pid, - "models": copy.deepcopy(auto_detected_models), + "models": models_for_group, } ) else: diff --git a/tests/test_issue1527_lmstudio_base_url_classification.py b/tests/test_issue1527_lmstudio_base_url_classification.py new file mode 100644 index 00000000..b42f7046 --- /dev/null +++ b/tests/test_issue1527_lmstudio_base_url_classification.py @@ -0,0 +1,187 @@ +"""Regression tests for #1527/#1530 LM Studio base_url ownership. + +When a local OpenAI-compatible endpoint is configured as LM Studio, model +discovery must trust the configured provider before guessing from the URL host. +LAN IPs, Tailscale names, and reverse proxies do not contain "lmstudio" in the +hostname, but the config block already says which provider owns that base_url. +""" + +from __future__ import annotations + +import json +import socket +import urllib.request + +import pytest + +import api.config as config +import api.profiles as profiles + + +_API_KEY_ENV_VARS = ( + "ANTHROPIC_API_KEY", + "OPENAI_API_KEY", + "OPENROUTER_API_KEY", + "GOOGLE_API_KEY", + "GEMINI_API_KEY", + "GLM_API_KEY", + "KIMI_API_KEY", + "DEEPSEEK_API_KEY", + "OPENCODE_ZEN_API_KEY", + "OPENCODE_GO_API_KEY", + "MINIMAX_API_KEY", + "MINIMAX_CN_API_KEY", + "XAI_API_KEY", + "MISTRAL_API_KEY", + "LM_API_KEY", + "LMSTUDIO_API_KEY", + "OLLAMA_API_KEY", + "LOCAL_API_KEY", + "API_KEY", +) + + +class _ModelsResponse: + def __init__(self, model_ids: list[str]): + self._model_ids = model_ids + + def __enter__(self): + return self + + def __exit__(self, *_args): + return None + + def read(self) -> bytes: + return json.dumps({"data": [{"id": mid} for mid in self._model_ids]}).encode() + + +@pytest.fixture(autouse=True) +def _isolate_config(monkeypatch, tmp_path): + old_cfg = dict(config.cfg) + old_mtime = config._cfg_mtime + monkeypatch.setattr(profiles, "get_active_hermes_home", lambda: tmp_path) + for var in _API_KEY_ENV_VARS: + monkeypatch.delenv(var, raising=False) + config.invalidate_models_cache() + yield + config.cfg.clear() + config.cfg.update(old_cfg) + config._cfg_mtime = old_mtime + config.invalidate_models_cache() + + +def _write_config(tmp_path, monkeypatch, text: str) -> None: + cfgfile = tmp_path / "config.yaml" + cfgfile.write_text(text, encoding="utf-8") + monkeypatch.setattr(config, "_get_config_path", lambda: cfgfile) + config.reload_config() + config.invalidate_models_cache() + + +def _mock_model_discovery(monkeypatch, model_ids: list[str], resolved_ip: str) -> None: + monkeypatch.setattr( + urllib.request, + "urlopen", + lambda *_args, **_kwargs: _ModelsResponse(model_ids), + ) + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda *_args, **_kwargs: [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", (resolved_ip, 0)) + ], + ) + + +def _groups_by_id() -> dict[str, dict]: + return { + group["provider_id"]: group + for group in config.get_available_models()["groups"] + } + + +@pytest.mark.parametrize( + ("base_url", "resolved_ip"), + [ + ("http://192.168.1.22:1234/v1", "192.168.1.22"), + ("http://my-mac.tailnet.example:1234/v1", "192.168.1.22"), + ("https://lm.internal.example.com/v1", "192.168.1.22"), + ], +) +def test_lmstudio_configured_base_url_keeps_discovered_models( + tmp_path, + monkeypatch, + base_url: str, + resolved_ip: str, +): + _write_config( + tmp_path, + monkeypatch, + f""" +model: + provider: lmstudio + default: qwen3.6-35b-a3b@q6_k + base_url: {base_url} +providers: + lmstudio: + api_key: local-key +""", + ) + _mock_model_discovery( + monkeypatch, + ["qwen3.6-35b-a3b@q6_k", "second-lmstudio-model"], + resolved_ip, + ) + + groups = _groups_by_id() + assert "custom" not in groups + assert "lmstudio" in groups + model_ids = {model["id"] for model in groups["lmstudio"]["models"]} + assert {"qwen3.6-35b-a3b@q6_k", "second-lmstudio-model"} <= model_ids + + +def test_custom_configured_base_url_is_not_reclassified_as_ollama(tmp_path, monkeypatch): + _write_config( + tmp_path, + monkeypatch, + """ +model: + provider: custom + default: custom-model + base_url: http://localhost:4000/v1 +providers: + custom: + api_key: local-key +""", + ) + _mock_model_discovery(monkeypatch, ["custom-model", "custom-extra"], "127.0.0.1") + + groups = _groups_by_id() + assert "ollama" not in groups + assert "custom" in groups + model_ids = {model["id"] for model in groups["custom"]["models"]} + assert {"custom-model", "custom-extra"} <= model_ids + + +def test_lmstudio_session_model_resolves_to_configured_base_url(tmp_path, monkeypatch): + _write_config( + tmp_path, + monkeypatch, + """ +model: + provider: lmstudio + default: qwen3.6-35b-a3b@q6_k + base_url: http://192.168.1.22:1234/v1 +providers: + lmstudio: + api_key: local-key +""", + ) + + model, provider, base_url = config.resolve_model_provider( + "qwen3.6-35b-a3b@q6_k" + ) + + assert model == "qwen3.6-35b-a3b@q6_k" + assert provider == "lmstudio" + assert base_url == "http://192.168.1.22:1234/v1"