diff --git a/api/routes.py b/api/routes.py index 6a538b20..3b548304 100644 --- a/api/routes.py +++ b/api/routes.py @@ -335,17 +335,20 @@ def _clear_stale_stream_state(session) -> bool: stream_alive = stream_id in STREAMS if stream_alive: return False - session.active_stream_id = None - if hasattr(session, "pending_user_message"): - session.pending_user_message = None - if hasattr(session, "pending_attachments"): - session.pending_attachments = [] - if hasattr(session, "pending_started_at"): - session.pending_started_at = None - try: - session.save() - except Exception: - pass + with _get_session_agent_lock(session.session_id): + if getattr(session, "active_stream_id", None) != stream_id: + return False + session.active_stream_id = None + if hasattr(session, "pending_user_message"): + session.pending_user_message = None + if hasattr(session, "pending_attachments"): + session.pending_attachments = [] + if hasattr(session, "pending_started_at"): + session.pending_started_at = None + try: + session.save() + except Exception: + pass return True # ── CSRF: validate Origin/Referer on POST ──────────────────────────────────── diff --git a/tests/test_stale_stream_cleanup.py b/tests/test_stale_stream_cleanup.py index fe117d01..5f294789 100644 --- a/tests/test_stale_stream_cleanup.py +++ b/tests/test_stale_stream_cleanup.py @@ -1,11 +1,48 @@ +import queue +import threading from pathlib import Path +import api.config as config +import api.routes as routes + REPO = Path(__file__).resolve().parents[1] ROUTES_SRC = (REPO / "api" / "routes.py").read_text(encoding="utf-8") SESSIONS_SRC = (REPO / "static" / "sessions.js").read_text(encoding="utf-8") SW_SRC = (REPO / "static" / "sw.js").read_text(encoding="utf-8") +class _GateLock: + def __init__(self): + self._lock = threading.Lock() + self.lookup_finished = threading.Event() + self.writer_finished = threading.Event() + + def __enter__(self): + self._lock.acquire() + return self + + def __exit__(self, exc_type, exc, tb): + self._lock.release() + if not self.lookup_finished.is_set(): + self.lookup_finished.set() + assert self.writer_finished.wait(2), "writer did not finish race setup" + return False + + +class _FakeSession: + session_id = "issue1533-session" + + def __init__(self): + self.active_stream_id = "stale-stream" + self.pending_user_message = "old prompt" + self.pending_attachments = ["old.txt"] + self.pending_started_at = 123 + self.saved_stream_ids = [] + + def save(self): + self.saved_stream_ids.append(self.active_stream_id) + + def test_stale_stream_cleanup_helper_exists(): assert "def _clear_stale_stream_state(session)" in ROUTES_SRC assert "stream_id in STREAMS" in ROUTES_SRC @@ -30,6 +67,53 @@ def test_chat_start_clears_stale_pending_state_not_only_active_id(): assert stale_comment_pos < cleanup_pos < stream_id_pos +def test_stale_stream_cleanup_does_not_clobber_concurrent_chat_start(monkeypatch): + """Regression for #1533: stale cleanup must not erase a new stream id. + + The gate lock pauses the cleanup thread after it has decided that the old + stream id is stale, then lets a chat_start-like writer register and persist + a new active_stream_id for the same session. + """ + config.STREAMS.clear() + config.SESSION_AGENT_LOCKS.clear() + gate_lock = _GateLock() + session = _FakeSession() + new_stream_id = "new-stream" + result = {} + + monkeypatch.setattr(routes, "STREAMS_LOCK", gate_lock) + + def cleanup_stale_stream(): + result["cleared"] = routes._clear_stale_stream_state(session) + + def start_new_stream(): + assert gate_lock.lookup_finished.wait(2), "cleanup did not reach race point" + with routes.STREAMS_LOCK: + routes.STREAMS[new_stream_id] = queue.Queue() + with routes._get_session_agent_lock(session.session_id): + session.active_stream_id = new_stream_id + session.pending_user_message = "new prompt" + session.pending_attachments = ["new.txt"] + session.pending_started_at = 456 + session.save() + gate_lock.writer_finished.set() + + cleanup_thread = threading.Thread(target=cleanup_stale_stream) + writer_thread = threading.Thread(target=start_new_stream) + cleanup_thread.start() + writer_thread.start() + cleanup_thread.join(2) + writer_thread.join(2) + + assert not cleanup_thread.is_alive() + assert not writer_thread.is_alive() + assert result["cleared"] is False + assert session.active_stream_id == new_stream_id + assert session.pending_user_message == "new prompt" + assert session.pending_attachments == ["new.txt"] + assert session.pending_started_at == 456 + + def test_frontend_drops_inflight_cache_when_server_session_is_idle(): marker = "If the server says the session is idle, discard any browser-side inflight" marker_pos = SESSIONS_SRC.index(marker)