mirror of
https://github.com/nesquena/hermes-webui.git
synced 2026-05-25 03:00:23 +00:00
fix(config): custom named provider API key resolution in WebUI
- add robust custom provider credential/base_url resolver - apply fallback in streaming and routes agent init/self-heal paths - support slug normalization and config fallbacks for custom:* providers
This commit is contained in:
@@ -1596,6 +1596,102 @@ def resolve_model_provider(model_id: str) -> tuple:
|
||||
return model_id, config_provider, config_base_url
|
||||
|
||||
|
||||
def resolve_custom_provider_connection(provider_id: str) -> tuple[str | None, str | None]:
|
||||
"""Return (api_key, base_url) for a named ``custom:*`` provider.
|
||||
|
||||
Supports ``custom_providers[].api_key`` as either a literal key or
|
||||
``${ENV_VAR}``, and ``custom_providers[].key_env`` as an env-var hint.
|
||||
Returns ``(None, None)`` when no named custom provider matches.
|
||||
"""
|
||||
pid = str(provider_id or "").strip().lower()
|
||||
if not pid.startswith("custom:"):
|
||||
return None, None
|
||||
|
||||
def _slugify(value: str) -> str:
|
||||
s = str(value or "").strip().lower().replace("_", "-").replace(" ", "-")
|
||||
while "--" in s:
|
||||
s = s.replace("--", "-")
|
||||
return s.strip("-")
|
||||
|
||||
slug = _slugify(pid.split(":", 1)[1].strip())
|
||||
if not slug:
|
||||
return None, None
|
||||
|
||||
# Read the live config snapshot to avoid stale module-level cache edge
|
||||
# cases after profile switches or runtime config edits.
|
||||
cfg_data = get_config()
|
||||
|
||||
def _resolve_key(raw_api_key, raw_key_env) -> str | None:
|
||||
api_key = None
|
||||
if raw_api_key is not None:
|
||||
key_text = str(raw_api_key).strip()
|
||||
if key_text.startswith("${") and key_text.endswith("}") and len(key_text) > 3:
|
||||
api_key = os.getenv(key_text[2:-1], "").strip() or None
|
||||
elif key_text:
|
||||
api_key = key_text
|
||||
if not api_key:
|
||||
key_env = str(raw_key_env or "").strip()
|
||||
if key_env:
|
||||
api_key = os.getenv(key_env, "").strip() or None
|
||||
return api_key
|
||||
|
||||
custom_providers = cfg_data.get("custom_providers", [])
|
||||
if not isinstance(custom_providers, list):
|
||||
custom_providers = []
|
||||
|
||||
for entry in custom_providers:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
name = str(entry.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
entry_slug = _slugify(name)
|
||||
if entry_slug != slug:
|
||||
continue
|
||||
|
||||
base_url = str(entry.get("base_url") or "").strip() or None
|
||||
api_key = _resolve_key(entry.get("api_key"), entry.get("key_env"))
|
||||
return api_key, base_url
|
||||
|
||||
# If exactly one custom provider is configured, use it as a pragmatic
|
||||
# fallback for mismatched slugs (e.g. punctuation differences).
|
||||
if len(custom_providers) == 1 and isinstance(custom_providers[0], dict):
|
||||
entry = custom_providers[0]
|
||||
return (
|
||||
_resolve_key(entry.get("api_key"), entry.get("key_env")),
|
||||
str(entry.get("base_url") or "").strip() or None,
|
||||
)
|
||||
|
||||
# Fallbacks for setups that don't use custom_providers names directly.
|
||||
providers_cfg = cfg_data.get("providers", {})
|
||||
provider_specific = providers_cfg.get(pid, {}) if isinstance(providers_cfg, dict) else {}
|
||||
provider_custom = providers_cfg.get("custom", {}) if isinstance(providers_cfg, dict) else {}
|
||||
|
||||
model_cfg = cfg_data.get("model", {})
|
||||
model_provider = str(model_cfg.get("provider") or "").strip().lower() if isinstance(model_cfg, dict) else ""
|
||||
|
||||
fallback_base = None
|
||||
for candidate in (provider_specific, provider_custom, model_cfg):
|
||||
if isinstance(candidate, dict):
|
||||
_base = str(candidate.get("base_url") or "").strip()
|
||||
if _base:
|
||||
fallback_base = _base
|
||||
break
|
||||
|
||||
fallback_key = None
|
||||
if isinstance(provider_specific, dict):
|
||||
fallback_key = _resolve_key(provider_specific.get("api_key"), provider_specific.get("key_env"))
|
||||
if not fallback_key and isinstance(provider_custom, dict):
|
||||
fallback_key = _resolve_key(provider_custom.get("api_key"), provider_custom.get("key_env"))
|
||||
if not fallback_key and isinstance(model_cfg, dict) and model_provider in {"custom", pid, slug}:
|
||||
fallback_key = _resolve_key(model_cfg.get("api_key"), model_cfg.get("key_env"))
|
||||
|
||||
if fallback_key or fallback_base:
|
||||
return fallback_key, fallback_base or None
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def model_with_provider_context(model_id: str, model_provider: str | None = None) -> str:
|
||||
"""Return the model string to pass to ``resolve_model_provider()``.
|
||||
|
||||
|
||||
+24
-1
@@ -6639,7 +6639,10 @@ def _handle_chat_sync(handler, body):
|
||||
from run_agent import AIAgent
|
||||
|
||||
with CHAT_LOCK:
|
||||
from api.config import resolve_model_provider
|
||||
from api.config import (
|
||||
resolve_model_provider,
|
||||
resolve_custom_provider_connection,
|
||||
)
|
||||
|
||||
_model, _provider, _base_url = resolve_model_provider(
|
||||
model_with_provider_context(s.model, getattr(s, "model_provider", None))
|
||||
@@ -6665,6 +6668,12 @@ def _handle_chat_sync(handler, body):
|
||||
f"[webui] WARNING: resolve_runtime_provider failed: {_e}",
|
||||
flush=True,
|
||||
)
|
||||
if isinstance(_provider, str) and _provider.startswith("custom:"):
|
||||
_cp_key, _cp_base = resolve_custom_provider_connection(_provider)
|
||||
if not _api_key and _cp_key:
|
||||
_api_key = _cp_key
|
||||
if not _base_url and _cp_base:
|
||||
_base_url = _cp_base
|
||||
agent = AIAgent(
|
||||
model=_model,
|
||||
provider=_provider,
|
||||
@@ -7427,6 +7436,13 @@ def _handle_session_compress(handler, body):
|
||||
except Exception as _e:
|
||||
logger.warning("resolve_runtime_provider failed for compression: %s", _e)
|
||||
|
||||
if isinstance(resolved_provider, str) and resolved_provider.startswith("custom:"):
|
||||
_cp_key, _cp_base = _cfg.resolve_custom_provider_connection(resolved_provider)
|
||||
if not resolved_api_key and _cp_key:
|
||||
resolved_api_key = _cp_key
|
||||
if not resolved_base_url and _cp_base:
|
||||
resolved_base_url = _cp_base
|
||||
|
||||
if not resolved_api_key:
|
||||
return bad(handler, "No provider configured -- cannot compress.")
|
||||
|
||||
@@ -8041,6 +8057,13 @@ def _handle_handoff_summary(handler, body):
|
||||
except Exception as _e:
|
||||
logger.warning("resolve_runtime_provider failed for handoff summary: %s", _e)
|
||||
|
||||
if isinstance(resolved_provider, str) and resolved_provider.startswith("custom:"):
|
||||
_cp_key, _cp_base = _cfg.resolve_custom_provider_connection(resolved_provider)
|
||||
if not resolved_api_key and _cp_key:
|
||||
resolved_api_key = _cp_key
|
||||
if not resolved_base_url and _cp_base:
|
||||
resolved_base_url = _cp_base
|
||||
|
||||
if not resolved_api_key:
|
||||
summary_text = _fallback_handoff_summary(msgs)
|
||||
try:
|
||||
|
||||
@@ -26,6 +26,7 @@ from api.config import (
|
||||
_get_session_agent_lock, _set_thread_env, _clear_thread_env,
|
||||
SESSION_AGENT_LOCKS, SESSION_AGENT_LOCKS_LOCK,
|
||||
resolve_model_provider,
|
||||
resolve_custom_provider_connection,
|
||||
model_with_provider_context,
|
||||
)
|
||||
from api.helpers import redact_session_data, _redact_text
|
||||
@@ -2266,6 +2267,16 @@ def _run_agent_streaming(
|
||||
except Exception as _e:
|
||||
print(f"[webui] WARNING: resolve_runtime_provider failed: {_e}", flush=True)
|
||||
|
||||
# Named custom providers (custom:slug) may not be resolvable by
|
||||
# hermes_cli.runtime_provider directly. Fall back to config.yaml
|
||||
# custom_providers[] so WebUI can pass explicit creds/base_url.
|
||||
if isinstance(resolved_provider, str) and resolved_provider.startswith("custom:"):
|
||||
_cp_key, _cp_base = resolve_custom_provider_connection(resolved_provider)
|
||||
if not resolved_api_key and _cp_key:
|
||||
resolved_api_key = _cp_key
|
||||
if not resolved_base_url and _cp_base:
|
||||
resolved_base_url = _cp_base
|
||||
|
||||
# Read per-profile config at call time (not module-level snapshot)
|
||||
from api.config import get_config as _get_config
|
||||
_cfg = _get_config()
|
||||
@@ -2725,6 +2736,12 @@ def _run_agent_streaming(
|
||||
resolved_provider = _heal_rt.get('provider')
|
||||
if not resolved_base_url:
|
||||
resolved_base_url = _heal_rt.get('base_url')
|
||||
if isinstance(resolved_provider, str) and resolved_provider.startswith('custom:'):
|
||||
_cp_key, _cp_base = resolve_custom_provider_connection(resolved_provider)
|
||||
if not resolved_api_key and _cp_key:
|
||||
resolved_api_key = _cp_key
|
||||
if not resolved_base_url and _cp_base:
|
||||
resolved_base_url = _cp_base
|
||||
# Rebuild agent kwargs and create a fresh agent
|
||||
_agent_kwargs['api_key'] = resolved_api_key
|
||||
_agent_kwargs['base_url'] = resolved_base_url
|
||||
@@ -3284,6 +3301,12 @@ def _run_agent_streaming(
|
||||
resolved_provider = _heal_rt.get('provider')
|
||||
if not resolved_base_url:
|
||||
resolved_base_url = _heal_rt.get('base_url')
|
||||
if isinstance(resolved_provider, str) and resolved_provider.startswith('custom:'):
|
||||
_cp_key, _cp_base = resolve_custom_provider_connection(resolved_provider)
|
||||
if not resolved_api_key and _cp_key:
|
||||
resolved_api_key = _cp_key
|
||||
if not resolved_base_url and _cp_base:
|
||||
resolved_base_url = _cp_base
|
||||
# Build a fresh agent with the new credentials
|
||||
_heal_kwargs = dict(_agent_kwargs) if '_agent_kwargs' in dir() else {}
|
||||
_heal_kwargs['api_key'] = resolved_api_key
|
||||
|
||||
Reference in New Issue
Block a user