fix: resolve local models from configured base url

This commit is contained in:
Dutch AI Agency
2026-05-03 18:35:15 +01:00
committed by Hermes Bot
parent 3964339a58
commit e4d2704ce8
2 changed files with 261 additions and 6 deletions
+74 -6
View File
@@ -1799,8 +1799,61 @@ def get_available_models() -> dict:
if _pid_key in _PROVIDER_MODELS or _pid_key in cfg.get("providers", {}):
detected_providers.add(_pid_key)
def _normalize_base_url_for_match(value: object) -> str:
url = str(value or "").strip().rstrip("/")
if not url:
return ""
parsed_url = urlparse(url if "://" in url else f"http://{url}")
scheme = (parsed_url.scheme or "http").lower()
netloc = (parsed_url.netloc or parsed_url.path).lower().rstrip("/")
path = parsed_url.path.rstrip("/")
if not parsed_url.netloc:
path = ""
return f"{scheme}://{netloc}{path}"
def _configured_provider_for_base_url(base_url: object) -> str:
target = _normalize_base_url_for_match(base_url)
if not target:
return ""
if isinstance(model_cfg, dict):
model_base_url = _normalize_base_url_for_match(model_cfg.get("base_url"))
if model_base_url == target:
provider_hint = _resolve_provider_alias(model_cfg.get("provider"))
if provider_hint:
return str(provider_hint).strip().lower()
providers_cfg = cfg.get("providers", {})
if isinstance(providers_cfg, dict):
for provider_key, provider_cfg in providers_cfg.items():
if not isinstance(provider_cfg, dict):
continue
provider_base_url = _normalize_base_url_for_match(
provider_cfg.get("base_url")
)
if provider_base_url == target:
provider_hint = _resolve_provider_alias(provider_key)
if provider_hint:
return str(provider_hint).strip().lower()
custom_providers_cfg = cfg.get("custom_providers", [])
if isinstance(custom_providers_cfg, list):
for entry in custom_providers_cfg:
if not isinstance(entry, dict):
continue
entry_base_url = _normalize_base_url_for_match(entry.get("base_url"))
if entry_base_url != target:
continue
entry_name = str(entry.get("name") or "").strip()
if entry_name:
return "custom:" + entry_name.lower().replace(" ", "-")
return "custom"
return ""
# 4. Fetch models from custom endpoint if base_url is configured
auto_detected_models = []
auto_detected_models_by_provider: dict[str, list[dict]] = {}
if cfg_base_url:
try:
import ipaddress
@@ -1812,11 +1865,13 @@ def get_available_models() -> dict:
else:
endpoint_url = base_url.rstrip("/") + "/v1/models"
provider = "custom"
configured_provider = _configured_provider_for_base_url(base_url)
provider = configured_provider or "custom"
provider_from_config = bool(configured_provider)
parsed = urlparse(base_url if "://" in base_url else f"http://{base_url}")
host = (parsed.netloc or parsed.path).lower()
if parsed.hostname:
if parsed.hostname and not provider_from_config:
try:
addr = ipaddress.ip_address(parsed.hostname)
if addr.is_private or addr.is_loopback or addr.is_link_local:
@@ -1939,8 +1994,11 @@ def get_available_models() -> dict:
model_name = model.get("name", "") or model.get("model", "") or model_id
if model_id and model_name:
label = _format_ollama_label(model_id) if provider in ("ollama", "ollama-cloud") else model_name
auto_detected_models.append({"id": model_id, "label": label})
detected_providers.add(provider.lower())
auto_model = {"id": model_id, "label": label}
auto_detected_models.append(auto_model)
provider_key = provider.lower()
auto_detected_models_by_provider.setdefault(provider_key, []).append(auto_model)
detected_providers.add(provider_key)
except Exception:
logger.debug("Custom endpoint unreachable or misconfigured for provider: %s", provider)
@@ -2053,6 +2111,9 @@ def get_available_models() -> dict:
)
elif pid in _PROVIDER_MODELS or pid in cfg.get("providers", {}):
raw_models = copy.deepcopy(_PROVIDER_MODELS.get(pid, []))
detected_models = auto_detected_models_by_provider.get(pid, [])
if detected_models and not raw_models:
raw_models = copy.deepcopy(detected_models)
provider_cfg = cfg.get("providers", {}).get(pid, {})
if isinstance(provider_cfg, dict) and "models" in provider_cfg:
@@ -2070,7 +2131,14 @@ def get_available_models() -> dict:
}
)
else:
if auto_detected_models:
detected_models = auto_detected_models_by_provider.get(pid)
if detected_models:
models_for_group = copy.deepcopy(detected_models)
elif auto_detected_models:
models_for_group = copy.deepcopy(auto_detected_models)
else:
models_for_group = []
if models_for_group:
# Per-group deep copy so subsequent mutation by
# _deduplicate_model_ids() (which prefixes ids with
# @provider_id:) does not bleed into other groups
@@ -2084,7 +2152,7 @@ def get_available_models() -> dict:
{
"provider": provider_name,
"provider_id": pid,
"models": copy.deepcopy(auto_detected_models),
"models": models_for_group,
}
)
else:
@@ -0,0 +1,187 @@
"""Regression tests for #1527/#1530 LM Studio base_url ownership.
When a local OpenAI-compatible endpoint is configured as LM Studio, model
discovery must trust the configured provider before guessing from the URL host.
LAN IPs, Tailscale names, and reverse proxies do not contain "lmstudio" in the
hostname, but the config block already says which provider owns that base_url.
"""
from __future__ import annotations
import json
import socket
import urllib.request
import pytest
import api.config as config
import api.profiles as profiles
_API_KEY_ENV_VARS = (
"ANTHROPIC_API_KEY",
"OPENAI_API_KEY",
"OPENROUTER_API_KEY",
"GOOGLE_API_KEY",
"GEMINI_API_KEY",
"GLM_API_KEY",
"KIMI_API_KEY",
"DEEPSEEK_API_KEY",
"OPENCODE_ZEN_API_KEY",
"OPENCODE_GO_API_KEY",
"MINIMAX_API_KEY",
"MINIMAX_CN_API_KEY",
"XAI_API_KEY",
"MISTRAL_API_KEY",
"LM_API_KEY",
"LMSTUDIO_API_KEY",
"OLLAMA_API_KEY",
"LOCAL_API_KEY",
"API_KEY",
)
class _ModelsResponse:
def __init__(self, model_ids: list[str]):
self._model_ids = model_ids
def __enter__(self):
return self
def __exit__(self, *_args):
return None
def read(self) -> bytes:
return json.dumps({"data": [{"id": mid} for mid in self._model_ids]}).encode()
@pytest.fixture(autouse=True)
def _isolate_config(monkeypatch, tmp_path):
old_cfg = dict(config.cfg)
old_mtime = config._cfg_mtime
monkeypatch.setattr(profiles, "get_active_hermes_home", lambda: tmp_path)
for var in _API_KEY_ENV_VARS:
monkeypatch.delenv(var, raising=False)
config.invalidate_models_cache()
yield
config.cfg.clear()
config.cfg.update(old_cfg)
config._cfg_mtime = old_mtime
config.invalidate_models_cache()
def _write_config(tmp_path, monkeypatch, text: str) -> None:
cfgfile = tmp_path / "config.yaml"
cfgfile.write_text(text, encoding="utf-8")
monkeypatch.setattr(config, "_get_config_path", lambda: cfgfile)
config.reload_config()
config.invalidate_models_cache()
def _mock_model_discovery(monkeypatch, model_ids: list[str], resolved_ip: str) -> None:
monkeypatch.setattr(
urllib.request,
"urlopen",
lambda *_args, **_kwargs: _ModelsResponse(model_ids),
)
monkeypatch.setattr(
socket,
"getaddrinfo",
lambda *_args, **_kwargs: [
(socket.AF_INET, socket.SOCK_STREAM, 6, "", (resolved_ip, 0))
],
)
def _groups_by_id() -> dict[str, dict]:
return {
group["provider_id"]: group
for group in config.get_available_models()["groups"]
}
@pytest.mark.parametrize(
("base_url", "resolved_ip"),
[
("http://192.168.1.22:1234/v1", "192.168.1.22"),
("http://my-mac.tailnet.example:1234/v1", "192.168.1.22"),
("https://lm.internal.example.com/v1", "192.168.1.22"),
],
)
def test_lmstudio_configured_base_url_keeps_discovered_models(
tmp_path,
monkeypatch,
base_url: str,
resolved_ip: str,
):
_write_config(
tmp_path,
monkeypatch,
f"""
model:
provider: lmstudio
default: qwen3.6-35b-a3b@q6_k
base_url: {base_url}
providers:
lmstudio:
api_key: local-key
""",
)
_mock_model_discovery(
monkeypatch,
["qwen3.6-35b-a3b@q6_k", "second-lmstudio-model"],
resolved_ip,
)
groups = _groups_by_id()
assert "custom" not in groups
assert "lmstudio" in groups
model_ids = {model["id"] for model in groups["lmstudio"]["models"]}
assert {"qwen3.6-35b-a3b@q6_k", "second-lmstudio-model"} <= model_ids
def test_custom_configured_base_url_is_not_reclassified_as_ollama(tmp_path, monkeypatch):
_write_config(
tmp_path,
monkeypatch,
"""
model:
provider: custom
default: custom-model
base_url: http://localhost:4000/v1
providers:
custom:
api_key: local-key
""",
)
_mock_model_discovery(monkeypatch, ["custom-model", "custom-extra"], "127.0.0.1")
groups = _groups_by_id()
assert "ollama" not in groups
assert "custom" in groups
model_ids = {model["id"] for model in groups["custom"]["models"]}
assert {"custom-model", "custom-extra"} <= model_ids
def test_lmstudio_session_model_resolves_to_configured_base_url(tmp_path, monkeypatch):
_write_config(
tmp_path,
monkeypatch,
"""
model:
provider: lmstudio
default: qwen3.6-35b-a3b@q6_k
base_url: http://192.168.1.22:1234/v1
providers:
lmstudio:
api_key: local-key
""",
)
model, provider, base_url = config.resolve_model_provider(
"qwen3.6-35b-a3b@q6_k"
)
assert model == "qwen3.6-35b-a3b@q6_k"
assert provider == "lmstudio"
assert base_url == "http://192.168.1.22:1234/v1"