mirror of
https://github.com/nesquena/hermes-webui.git
synced 2026-05-24 18:50:15 +00:00
fix: resolve local models from configured base url
This commit is contained in:
committed by
Hermes Bot
parent
3964339a58
commit
e4d2704ce8
+74
-6
@@ -1799,8 +1799,61 @@ def get_available_models() -> dict:
|
||||
if _pid_key in _PROVIDER_MODELS or _pid_key in cfg.get("providers", {}):
|
||||
detected_providers.add(_pid_key)
|
||||
|
||||
def _normalize_base_url_for_match(value: object) -> str:
|
||||
url = str(value or "").strip().rstrip("/")
|
||||
if not url:
|
||||
return ""
|
||||
parsed_url = urlparse(url if "://" in url else f"http://{url}")
|
||||
scheme = (parsed_url.scheme or "http").lower()
|
||||
netloc = (parsed_url.netloc or parsed_url.path).lower().rstrip("/")
|
||||
path = parsed_url.path.rstrip("/")
|
||||
if not parsed_url.netloc:
|
||||
path = ""
|
||||
return f"{scheme}://{netloc}{path}"
|
||||
|
||||
def _configured_provider_for_base_url(base_url: object) -> str:
|
||||
target = _normalize_base_url_for_match(base_url)
|
||||
if not target:
|
||||
return ""
|
||||
|
||||
if isinstance(model_cfg, dict):
|
||||
model_base_url = _normalize_base_url_for_match(model_cfg.get("base_url"))
|
||||
if model_base_url == target:
|
||||
provider_hint = _resolve_provider_alias(model_cfg.get("provider"))
|
||||
if provider_hint:
|
||||
return str(provider_hint).strip().lower()
|
||||
|
||||
providers_cfg = cfg.get("providers", {})
|
||||
if isinstance(providers_cfg, dict):
|
||||
for provider_key, provider_cfg in providers_cfg.items():
|
||||
if not isinstance(provider_cfg, dict):
|
||||
continue
|
||||
provider_base_url = _normalize_base_url_for_match(
|
||||
provider_cfg.get("base_url")
|
||||
)
|
||||
if provider_base_url == target:
|
||||
provider_hint = _resolve_provider_alias(provider_key)
|
||||
if provider_hint:
|
||||
return str(provider_hint).strip().lower()
|
||||
|
||||
custom_providers_cfg = cfg.get("custom_providers", [])
|
||||
if isinstance(custom_providers_cfg, list):
|
||||
for entry in custom_providers_cfg:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
entry_base_url = _normalize_base_url_for_match(entry.get("base_url"))
|
||||
if entry_base_url != target:
|
||||
continue
|
||||
entry_name = str(entry.get("name") or "").strip()
|
||||
if entry_name:
|
||||
return "custom:" + entry_name.lower().replace(" ", "-")
|
||||
return "custom"
|
||||
|
||||
return ""
|
||||
|
||||
# 4. Fetch models from custom endpoint if base_url is configured
|
||||
auto_detected_models = []
|
||||
auto_detected_models_by_provider: dict[str, list[dict]] = {}
|
||||
if cfg_base_url:
|
||||
try:
|
||||
import ipaddress
|
||||
@@ -1812,11 +1865,13 @@ def get_available_models() -> dict:
|
||||
else:
|
||||
endpoint_url = base_url.rstrip("/") + "/v1/models"
|
||||
|
||||
provider = "custom"
|
||||
configured_provider = _configured_provider_for_base_url(base_url)
|
||||
provider = configured_provider or "custom"
|
||||
provider_from_config = bool(configured_provider)
|
||||
parsed = urlparse(base_url if "://" in base_url else f"http://{base_url}")
|
||||
host = (parsed.netloc or parsed.path).lower()
|
||||
|
||||
if parsed.hostname:
|
||||
if parsed.hostname and not provider_from_config:
|
||||
try:
|
||||
addr = ipaddress.ip_address(parsed.hostname)
|
||||
if addr.is_private or addr.is_loopback or addr.is_link_local:
|
||||
@@ -1939,8 +1994,11 @@ def get_available_models() -> dict:
|
||||
model_name = model.get("name", "") or model.get("model", "") or model_id
|
||||
if model_id and model_name:
|
||||
label = _format_ollama_label(model_id) if provider in ("ollama", "ollama-cloud") else model_name
|
||||
auto_detected_models.append({"id": model_id, "label": label})
|
||||
detected_providers.add(provider.lower())
|
||||
auto_model = {"id": model_id, "label": label}
|
||||
auto_detected_models.append(auto_model)
|
||||
provider_key = provider.lower()
|
||||
auto_detected_models_by_provider.setdefault(provider_key, []).append(auto_model)
|
||||
detected_providers.add(provider_key)
|
||||
except Exception:
|
||||
logger.debug("Custom endpoint unreachable or misconfigured for provider: %s", provider)
|
||||
|
||||
@@ -2053,6 +2111,9 @@ def get_available_models() -> dict:
|
||||
)
|
||||
elif pid in _PROVIDER_MODELS or pid in cfg.get("providers", {}):
|
||||
raw_models = copy.deepcopy(_PROVIDER_MODELS.get(pid, []))
|
||||
detected_models = auto_detected_models_by_provider.get(pid, [])
|
||||
if detected_models and not raw_models:
|
||||
raw_models = copy.deepcopy(detected_models)
|
||||
|
||||
provider_cfg = cfg.get("providers", {}).get(pid, {})
|
||||
if isinstance(provider_cfg, dict) and "models" in provider_cfg:
|
||||
@@ -2070,7 +2131,14 @@ def get_available_models() -> dict:
|
||||
}
|
||||
)
|
||||
else:
|
||||
if auto_detected_models:
|
||||
detected_models = auto_detected_models_by_provider.get(pid)
|
||||
if detected_models:
|
||||
models_for_group = copy.deepcopy(detected_models)
|
||||
elif auto_detected_models:
|
||||
models_for_group = copy.deepcopy(auto_detected_models)
|
||||
else:
|
||||
models_for_group = []
|
||||
if models_for_group:
|
||||
# Per-group deep copy so subsequent mutation by
|
||||
# _deduplicate_model_ids() (which prefixes ids with
|
||||
# @provider_id:) does not bleed into other groups
|
||||
@@ -2084,7 +2152,7 @@ def get_available_models() -> dict:
|
||||
{
|
||||
"provider": provider_name,
|
||||
"provider_id": pid,
|
||||
"models": copy.deepcopy(auto_detected_models),
|
||||
"models": models_for_group,
|
||||
}
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
"""Regression tests for #1527/#1530 LM Studio base_url ownership.
|
||||
|
||||
When a local OpenAI-compatible endpoint is configured as LM Studio, model
|
||||
discovery must trust the configured provider before guessing from the URL host.
|
||||
LAN IPs, Tailscale names, and reverse proxies do not contain "lmstudio" in the
|
||||
hostname, but the config block already says which provider owns that base_url.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
import urllib.request
|
||||
|
||||
import pytest
|
||||
|
||||
import api.config as config
|
||||
import api.profiles as profiles
|
||||
|
||||
|
||||
_API_KEY_ENV_VARS = (
|
||||
"ANTHROPIC_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"GOOGLE_API_KEY",
|
||||
"GEMINI_API_KEY",
|
||||
"GLM_API_KEY",
|
||||
"KIMI_API_KEY",
|
||||
"DEEPSEEK_API_KEY",
|
||||
"OPENCODE_ZEN_API_KEY",
|
||||
"OPENCODE_GO_API_KEY",
|
||||
"MINIMAX_API_KEY",
|
||||
"MINIMAX_CN_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"MISTRAL_API_KEY",
|
||||
"LM_API_KEY",
|
||||
"LMSTUDIO_API_KEY",
|
||||
"OLLAMA_API_KEY",
|
||||
"LOCAL_API_KEY",
|
||||
"API_KEY",
|
||||
)
|
||||
|
||||
|
||||
class _ModelsResponse:
|
||||
def __init__(self, model_ids: list[str]):
|
||||
self._model_ids = model_ids
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_args):
|
||||
return None
|
||||
|
||||
def read(self) -> bytes:
|
||||
return json.dumps({"data": [{"id": mid} for mid in self._model_ids]}).encode()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_config(monkeypatch, tmp_path):
|
||||
old_cfg = dict(config.cfg)
|
||||
old_mtime = config._cfg_mtime
|
||||
monkeypatch.setattr(profiles, "get_active_hermes_home", lambda: tmp_path)
|
||||
for var in _API_KEY_ENV_VARS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
config.invalidate_models_cache()
|
||||
yield
|
||||
config.cfg.clear()
|
||||
config.cfg.update(old_cfg)
|
||||
config._cfg_mtime = old_mtime
|
||||
config.invalidate_models_cache()
|
||||
|
||||
|
||||
def _write_config(tmp_path, monkeypatch, text: str) -> None:
|
||||
cfgfile = tmp_path / "config.yaml"
|
||||
cfgfile.write_text(text, encoding="utf-8")
|
||||
monkeypatch.setattr(config, "_get_config_path", lambda: cfgfile)
|
||||
config.reload_config()
|
||||
config.invalidate_models_cache()
|
||||
|
||||
|
||||
def _mock_model_discovery(monkeypatch, model_ids: list[str], resolved_ip: str) -> None:
|
||||
monkeypatch.setattr(
|
||||
urllib.request,
|
||||
"urlopen",
|
||||
lambda *_args, **_kwargs: _ModelsResponse(model_ids),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
socket,
|
||||
"getaddrinfo",
|
||||
lambda *_args, **_kwargs: [
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 6, "", (resolved_ip, 0))
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _groups_by_id() -> dict[str, dict]:
|
||||
return {
|
||||
group["provider_id"]: group
|
||||
for group in config.get_available_models()["groups"]
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("base_url", "resolved_ip"),
|
||||
[
|
||||
("http://192.168.1.22:1234/v1", "192.168.1.22"),
|
||||
("http://my-mac.tailnet.example:1234/v1", "192.168.1.22"),
|
||||
("https://lm.internal.example.com/v1", "192.168.1.22"),
|
||||
],
|
||||
)
|
||||
def test_lmstudio_configured_base_url_keeps_discovered_models(
|
||||
tmp_path,
|
||||
monkeypatch,
|
||||
base_url: str,
|
||||
resolved_ip: str,
|
||||
):
|
||||
_write_config(
|
||||
tmp_path,
|
||||
monkeypatch,
|
||||
f"""
|
||||
model:
|
||||
provider: lmstudio
|
||||
default: qwen3.6-35b-a3b@q6_k
|
||||
base_url: {base_url}
|
||||
providers:
|
||||
lmstudio:
|
||||
api_key: local-key
|
||||
""",
|
||||
)
|
||||
_mock_model_discovery(
|
||||
monkeypatch,
|
||||
["qwen3.6-35b-a3b@q6_k", "second-lmstudio-model"],
|
||||
resolved_ip,
|
||||
)
|
||||
|
||||
groups = _groups_by_id()
|
||||
assert "custom" not in groups
|
||||
assert "lmstudio" in groups
|
||||
model_ids = {model["id"] for model in groups["lmstudio"]["models"]}
|
||||
assert {"qwen3.6-35b-a3b@q6_k", "second-lmstudio-model"} <= model_ids
|
||||
|
||||
|
||||
def test_custom_configured_base_url_is_not_reclassified_as_ollama(tmp_path, monkeypatch):
|
||||
_write_config(
|
||||
tmp_path,
|
||||
monkeypatch,
|
||||
"""
|
||||
model:
|
||||
provider: custom
|
||||
default: custom-model
|
||||
base_url: http://localhost:4000/v1
|
||||
providers:
|
||||
custom:
|
||||
api_key: local-key
|
||||
""",
|
||||
)
|
||||
_mock_model_discovery(monkeypatch, ["custom-model", "custom-extra"], "127.0.0.1")
|
||||
|
||||
groups = _groups_by_id()
|
||||
assert "ollama" not in groups
|
||||
assert "custom" in groups
|
||||
model_ids = {model["id"] for model in groups["custom"]["models"]}
|
||||
assert {"custom-model", "custom-extra"} <= model_ids
|
||||
|
||||
|
||||
def test_lmstudio_session_model_resolves_to_configured_base_url(tmp_path, monkeypatch):
|
||||
_write_config(
|
||||
tmp_path,
|
||||
monkeypatch,
|
||||
"""
|
||||
model:
|
||||
provider: lmstudio
|
||||
default: qwen3.6-35b-a3b@q6_k
|
||||
base_url: http://192.168.1.22:1234/v1
|
||||
providers:
|
||||
lmstudio:
|
||||
api_key: local-key
|
||||
""",
|
||||
)
|
||||
|
||||
model, provider, base_url = config.resolve_model_provider(
|
||||
"qwen3.6-35b-a3b@q6_k"
|
||||
)
|
||||
|
||||
assert model == "qwen3.6-35b-a3b@q6_k"
|
||||
assert provider == "lmstudio"
|
||||
assert base_url == "http://192.168.1.22:1234/v1"
|
||||
Reference in New Issue
Block a user