mirror of
https://github.com/nesquena/hermes-webui.git
synced 2026-05-26 19:50:15 +00:00
fix(profile): preserve context when starting chats
This commit is contained in:
+40
-4
@@ -184,6 +184,28 @@ else:
|
||||
_cfg_cache = {}
|
||||
_cfg_lock = threading.Lock()
|
||||
_cfg_mtime: float = 0.0 # last known mtime of config.yaml; 0 = never loaded
|
||||
_cfg_path: Path | None = None # active config.yaml path for the disk-loaded cache
|
||||
_cfg_fingerprint: str | None = None # serialized snapshot from the last disk load
|
||||
|
||||
|
||||
def _fingerprint_config(data: dict) -> str:
|
||||
"""Return a stable fingerprint for config dictionaries.
|
||||
|
||||
A few tests and legacy call sites still mutate ``cfg`` directly for
|
||||
in-memory overrides. Path-aware reloads should not immediately discard
|
||||
those overrides just because the active profile path differs from the last
|
||||
disk load, but an unchanged disk-loaded cache must still reload on profile
|
||||
switches.
|
||||
"""
|
||||
try:
|
||||
return json.dumps(data, sort_keys=True, separators=(",", ":"), default=str)
|
||||
except Exception:
|
||||
return repr(data)
|
||||
|
||||
|
||||
def _cfg_has_in_memory_overrides() -> bool:
|
||||
"""True when cfg was changed after the last successful reload_config()."""
|
||||
return _cfg_fingerprint is not None and _fingerprint_config(_cfg_cache) != _cfg_fingerprint
|
||||
|
||||
|
||||
def _get_config_path() -> Path:
|
||||
@@ -205,7 +227,13 @@ _DEFAULT_WEBUI_SESSION_SAVE_MODE = "deferred"
|
||||
|
||||
def get_config() -> dict:
|
||||
"""Return the cached config dict, loading from disk if needed."""
|
||||
if not _cfg_cache:
|
||||
config_path = _get_config_path()
|
||||
try:
|
||||
current_mtime = config_path.stat().st_mtime
|
||||
except OSError:
|
||||
current_mtime = 0.0
|
||||
cache_stale = current_mtime != _cfg_mtime or _cfg_path != config_path
|
||||
if not _cfg_cache or (cache_stale and not _cfg_has_in_memory_overrides()):
|
||||
reload_config()
|
||||
return _cfg_cache
|
||||
|
||||
@@ -234,13 +262,15 @@ def get_webui_session_save_mode(config_data: dict | None = None) -> str:
|
||||
|
||||
def reload_config() -> None:
|
||||
"""Reload config.yaml from the active profile's directory."""
|
||||
global _cfg_mtime
|
||||
global _cfg_mtime, _cfg_path, _cfg_fingerprint
|
||||
with _cfg_lock:
|
||||
_cfg_cache.clear()
|
||||
config_path = _get_config_path()
|
||||
# Remember the old mtime so we can tell whether config actually changed
|
||||
# vs. first-ever load (mtime == 0.0, e.g. server start or profile switch).
|
||||
_old_cfg_mtime = _cfg_mtime
|
||||
_cfg_path = config_path
|
||||
_cfg_mtime = 0.0
|
||||
try:
|
||||
import yaml as _yaml
|
||||
|
||||
@@ -254,6 +284,7 @@ def reload_config() -> None:
|
||||
_cfg_mtime = 0.0
|
||||
except Exception:
|
||||
logger.debug("Failed to load yaml config from %s", config_path)
|
||||
_cfg_fingerprint = _fingerprint_config(_cfg_cache)
|
||||
# Bust the models cache so the next request sees fresh config values.
|
||||
# Only delete the disk cache when config has actually changed -- not on
|
||||
# first-ever load (when _old_cfg_mtime == 0.0, i.e. server start or
|
||||
@@ -2083,10 +2114,15 @@ def get_available_models() -> dict:
|
||||
# Config mtime check — must come before any config reads.
|
||||
# (Test #585 verifies _current_mtime appears before active_provider = None)
|
||||
try:
|
||||
_current_mtime = Path(_get_config_path()).stat().st_mtime
|
||||
_current_path = _get_config_path()
|
||||
_current_mtime = _current_path.stat().st_mtime
|
||||
except OSError:
|
||||
_current_path = _get_config_path()
|
||||
_current_mtime = 0.0
|
||||
if _current_mtime != _cfg_mtime:
|
||||
if (
|
||||
(_current_mtime != _cfg_mtime or _current_path != _cfg_path)
|
||||
and not _cfg_has_in_memory_overrides()
|
||||
):
|
||||
reload_config()
|
||||
# ── COLD PATH helper ─────────────────────────────────────────────────────
|
||||
# Extracted so it runs inside _available_models_cache_lock (RLock) to
|
||||
|
||||
@@ -5940,6 +5940,27 @@ def _handle_chat_start(handler, body):
|
||||
s = get_session(body["session_id"])
|
||||
except KeyError:
|
||||
return bad(handler, "Session not found", 404)
|
||||
requested_profile = str(body.get("profile") or "").strip()
|
||||
if requested_profile:
|
||||
try:
|
||||
from api.profiles import _PROFILE_ID_RE
|
||||
|
||||
if requested_profile != "default" and not _PROFILE_ID_RE.fullmatch(requested_profile):
|
||||
return bad(handler, "invalid profile", 400)
|
||||
except ImportError:
|
||||
requested_profile = ""
|
||||
if requested_profile and not _profiles_match(getattr(s, "profile", None), requested_profile):
|
||||
has_persisted_turns = bool(
|
||||
getattr(s, "messages", None)
|
||||
or getattr(s, "context_messages", None)
|
||||
or getattr(s, "pending_user_message", None)
|
||||
)
|
||||
if not has_persisted_turns:
|
||||
# Empty sessions are placeholders. If the user switches profiles
|
||||
# before sending the first turn, run the placeholder under the
|
||||
# currently-selected profile instead of the stale one stamped at
|
||||
# creation time.
|
||||
s.profile = requested_profile
|
||||
msg = str(body.get("message", "")).strip()
|
||||
if not msg:
|
||||
return bad(handler, "message is required")
|
||||
|
||||
@@ -225,6 +225,7 @@ async function send(){
|
||||
session_id:activeSid,message:msgText,
|
||||
model:S.session.model||$('modelSelect').value,workspace:S.session.workspace,
|
||||
model_provider:S.session.model_provider||null,
|
||||
profile:S.activeProfile||S.session.profile||'default',
|
||||
attachments:uploaded.length?uploaded:undefined
|
||||
})});
|
||||
if(startData.effective_model && S.session){
|
||||
|
||||
@@ -389,3 +389,226 @@ def test_regression_switch_profile_returns_target_model():
|
||||
profiles._DEFAULT_HERMES_HOME = orig
|
||||
profiles._active_profile = orig_act
|
||||
profiles._tls.profile = None
|
||||
|
||||
|
||||
def test_get_config_reloads_when_request_profile_changes(tmp_path, monkeypatch):
|
||||
"""get_config() must follow the per-request profile, not stale global cache."""
|
||||
monkeypatch.delenv("HERMES_CONFIG_PATH", raising=False)
|
||||
import api.config as config
|
||||
import api.profiles as profiles
|
||||
|
||||
default_home = tmp_path / ".hermes"
|
||||
work_home = default_home / "profiles" / "work"
|
||||
work_home.mkdir(parents=True)
|
||||
default_home.mkdir(exist_ok=True)
|
||||
(default_home / "config.yaml").write_text(
|
||||
"model:\n provider: openai-codex\n default: gpt-5.5\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(work_home / "config.yaml").write_text(
|
||||
"model:\n provider: openrouter\n default: google/gemini-3-flash-preview\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
same_mtime = 1_700_000_000
|
||||
os.utime(default_home / "config.yaml", (same_mtime, same_mtime))
|
||||
os.utime(work_home / "config.yaml", (same_mtime, same_mtime))
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"_get_config_path",
|
||||
lambda: profiles.get_active_hermes_home() / "config.yaml",
|
||||
)
|
||||
|
||||
orig_default_home = profiles._DEFAULT_HERMES_HOME
|
||||
orig_active = profiles._active_profile
|
||||
orig_cache = dict(config._cfg_cache)
|
||||
orig_mtime = config._cfg_mtime
|
||||
orig_path = getattr(config, "_cfg_path", None)
|
||||
orig_fingerprint = getattr(config, "_cfg_fingerprint", None)
|
||||
profiles._tls.profile = None
|
||||
try:
|
||||
profiles._DEFAULT_HERMES_HOME = default_home
|
||||
profiles._active_profile = "default"
|
||||
config._cfg_cache.clear()
|
||||
config._cfg_mtime = 0.0
|
||||
if hasattr(config, "_cfg_path"):
|
||||
config._cfg_path = None
|
||||
if hasattr(config, "_cfg_fingerprint"):
|
||||
config._cfg_fingerprint = None
|
||||
|
||||
assert config.get_config()["model"]["provider"] == "openai-codex"
|
||||
profiles.set_request_profile("work")
|
||||
assert config._get_config_path() == work_home / "config.yaml"
|
||||
assert config.get_config()["model"]["provider"] == "openrouter"
|
||||
finally:
|
||||
profiles.clear_request_profile()
|
||||
profiles._DEFAULT_HERMES_HOME = orig_default_home
|
||||
profiles._active_profile = orig_active
|
||||
config._cfg_cache.clear()
|
||||
config._cfg_cache.update(orig_cache)
|
||||
config._cfg_mtime = orig_mtime
|
||||
if hasattr(config, "_cfg_path"):
|
||||
config._cfg_path = orig_path
|
||||
if hasattr(config, "_cfg_fingerprint"):
|
||||
config._cfg_fingerprint = orig_fingerprint
|
||||
|
||||
|
||||
def test_chat_start_retags_empty_session_to_request_profile(monkeypatch, tmp_path):
|
||||
"""An empty session created under profile A can be sent under profile B after a switch."""
|
||||
import api.routes as routes
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self):
|
||||
self.session_id = "sid-profile-switch"
|
||||
self.profile = "default"
|
||||
self.workspace = str(tmp_path)
|
||||
self.model = "google/gemini-3-flash-preview"
|
||||
self.model_provider = "openrouter"
|
||||
self.messages = []
|
||||
self.context_messages = []
|
||||
self.tool_calls = []
|
||||
self.active_stream_id = None
|
||||
self.pending_user_message = None
|
||||
self.pending_attachments = []
|
||||
self.pending_started_at = None
|
||||
self.saved = False
|
||||
|
||||
def save(self):
|
||||
self.saved = True
|
||||
|
||||
fake = FakeSession()
|
||||
monkeypatch.setattr(routes, "get_session", lambda sid: fake)
|
||||
monkeypatch.setattr(routes, "resolve_trusted_workspace", lambda path: tmp_path)
|
||||
monkeypatch.setattr(
|
||||
routes,
|
||||
"_resolve_compatible_session_model_state",
|
||||
lambda model, provider: (model, provider, False),
|
||||
)
|
||||
monkeypatch.setattr(routes, "set_last_workspace", lambda workspace: None)
|
||||
monkeypatch.setattr(routes, "create_stream_channel", lambda: object())
|
||||
|
||||
started_threads = []
|
||||
|
||||
class FakeThread:
|
||||
def __init__(self, *args, **kwargs):
|
||||
started_threads.append((args, kwargs))
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(routes.threading, "Thread", FakeThread)
|
||||
|
||||
payloads = []
|
||||
|
||||
class Handler:
|
||||
pass
|
||||
|
||||
def fake_j(handler, payload, status=200, **kwargs):
|
||||
payloads.append((status, payload))
|
||||
return payload
|
||||
|
||||
monkeypatch.setattr(routes, "j", fake_j)
|
||||
|
||||
body = {
|
||||
"session_id": fake.session_id,
|
||||
"message": "hello",
|
||||
"workspace": str(tmp_path),
|
||||
"model": fake.model,
|
||||
"model_provider": fake.model_provider,
|
||||
"profile": "work",
|
||||
}
|
||||
routes._handle_chat_start(Handler(), body)
|
||||
|
||||
assert fake.profile == "work"
|
||||
assert fake.saved is True
|
||||
assert started_threads, "chat_start should launch the stream after retagging"
|
||||
assert payloads and payloads[-1][0] == 200
|
||||
|
||||
|
||||
def test_chat_start_does_not_retag_non_empty_session(monkeypatch, tmp_path):
|
||||
"""Profile retagging is limited to empty placeholder sessions."""
|
||||
import api.routes as routes
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self):
|
||||
self.session_id = "sid-profile-switch-non-empty"
|
||||
self.profile = "default"
|
||||
self.workspace = str(tmp_path)
|
||||
self.model = "google/gemini-3-flash-preview"
|
||||
self.model_provider = "openrouter"
|
||||
self.messages = [{"role": "user", "content": "previous turn"}]
|
||||
self.context_messages = []
|
||||
self.tool_calls = []
|
||||
self.active_stream_id = None
|
||||
self.pending_user_message = None
|
||||
self.pending_attachments = []
|
||||
self.pending_started_at = None
|
||||
self.saved = False
|
||||
|
||||
def save(self):
|
||||
self.saved = True
|
||||
|
||||
fake = FakeSession()
|
||||
monkeypatch.setattr(routes, "get_session", lambda sid: fake)
|
||||
monkeypatch.setattr(routes, "resolve_trusted_workspace", lambda path: tmp_path)
|
||||
monkeypatch.setattr(
|
||||
routes,
|
||||
"_resolve_compatible_session_model_state",
|
||||
lambda model, provider: (model, provider, False),
|
||||
)
|
||||
monkeypatch.setattr(routes, "set_last_workspace", lambda workspace: None)
|
||||
monkeypatch.setattr(routes, "create_stream_channel", lambda: object())
|
||||
|
||||
class FakeThread:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(routes.threading, "Thread", FakeThread)
|
||||
monkeypatch.setattr(routes, "j", lambda handler, payload, status=200, **kwargs: payload)
|
||||
|
||||
routes._handle_chat_start(
|
||||
object(),
|
||||
{
|
||||
"session_id": fake.session_id,
|
||||
"message": "hello",
|
||||
"workspace": str(tmp_path),
|
||||
"model": fake.model,
|
||||
"model_provider": fake.model_provider,
|
||||
"profile": "work",
|
||||
},
|
||||
)
|
||||
|
||||
assert fake.profile == "default"
|
||||
assert fake.saved is True
|
||||
|
||||
|
||||
def test_chat_start_rejects_invalid_request_profile(monkeypatch):
|
||||
"""chat_start validates the optional profile payload before retagging."""
|
||||
import api.routes as routes
|
||||
|
||||
class FakeSession:
|
||||
profile = "default"
|
||||
|
||||
monkeypatch.setattr(routes, "get_session", lambda sid: FakeSession())
|
||||
errors = []
|
||||
|
||||
def fake_bad(handler, message, status=400):
|
||||
errors.append((message, status))
|
||||
return {"error": message}
|
||||
|
||||
monkeypatch.setattr(routes, "bad", fake_bad)
|
||||
|
||||
result = routes._handle_chat_start(
|
||||
object(),
|
||||
{
|
||||
"session_id": "sid-invalid-profile",
|
||||
"message": "hello",
|
||||
"profile": "../etc",
|
||||
},
|
||||
)
|
||||
|
||||
assert result == {"error": "invalid profile"}
|
||||
assert errors == [("invalid profile", 400)]
|
||||
|
||||
Reference in New Issue
Block a user