diff --git a/api/streaming.py b/api/streaming.py index 8d5865d8..2e54193a 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -35,6 +35,7 @@ from api.config import ( from api.helpers import redact_session_data, _redact_text from api.compression_anchor import visible_messages_for_anchor from api.metering import meter +from api.turn_journal import append_turn_journal_event_for_stream # Global lock for os.environ writes. Per-session locks (_agent_lock) prevent # concurrent runs of the SAME session, but two DIFFERENT sessions can still @@ -2016,6 +2017,15 @@ def _run_agent_streaming( provider=model_provider, ephemeral=bool(ephemeral), ) + if not ephemeral: + try: + append_turn_journal_event_for_stream( + session_id, + stream_id, + {"event": "worker_started", "created_at": time.time()}, + ) + except Exception: + logger.debug("Failed to append worker_started turn journal event", exc_info=True) s = None _rt = {} old_cwd = None @@ -3512,7 +3522,44 @@ def _run_agent_streaming( # Older hermes-agent builds may not expose this helper. # Better to leave context_length=0 than crash the save. pass + if not ephemeral and s.messages: + _latest_assistant_idx = next( + (idx for idx in range(len(s.messages) - 1, -1, -1) + if isinstance(s.messages[idx], dict) and s.messages[idx].get('role') == 'assistant'), + None, + ) + if _latest_assistant_idx is not None: + _latest_assistant = s.messages[_latest_assistant_idx] + try: + append_turn_journal_event_for_stream( + s.session_id, + stream_id, + { + "event": "assistant_started", + "created_at": float(_latest_assistant.get('timestamp') or time.time()), + "assistant_message_index": _latest_assistant_idx, + }, + ) + except Exception: + logger.debug("Failed to append assistant_started turn journal event", exc_info=True) s.save() + if not ephemeral: + try: + append_turn_journal_event_for_stream( + s.session_id, + stream_id, + { + "event": "completed", + "created_at": time.time(), + "assistant_message_index": next( + (idx for idx in range(len(s.messages) - 1, -1, -1) + if isinstance(s.messages[idx], dict) and s.messages[idx].get('role') == 'assistant'), + None, + ), + }, + ) + except Exception: + logger.debug("Failed to append completed turn journal event", exc_info=True) # Sync to state.db for /insights (opt-in setting) try: from api.config import load_settings as _load_settings @@ -3882,6 +3929,19 @@ def _run_agent_streaming( s.save() except Exception: pass + if not ephemeral: + try: + append_turn_journal_event_for_stream( + s.session_id, + stream_id, + { + "event": "interrupted", + "created_at": time.time(), + "reason": _exc_type, + }, + ) + except Exception: + logger.debug("Failed to append interrupted turn journal event", exc_info=True) put('apperror', _error_payload) finally: # Stop the periodic checkpoint thread before the final recovery path. diff --git a/api/turn_journal.py b/api/turn_journal.py index 3f116d23..f25268a4 100644 --- a/api/turn_journal.py +++ b/api/turn_journal.py @@ -117,6 +117,40 @@ def derive_turn_journal_states(events: Iterable[dict]) -> dict[str, dict]: return states +def _latest_turn_id_for_stream(events: Iterable[dict], stream_id: str) -> str | None: + stream = str(stream_id or "").strip() + if not stream: + return None + latest: str | None = None + for event in events: + if not isinstance(event, dict): + continue + if str(event.get("stream_id") or "") != stream: + continue + turn_id = str(event.get("turn_id") or "").strip() + if turn_id: + latest = turn_id + return latest + + +def append_turn_journal_event_for_stream( + session_id: str, + stream_id: str, + event: dict, + *, + session_dir: Path | None = None, +) -> dict: + """Append a lifecycle event for the turn associated with ``stream_id``.""" + payload = dict(event) + payload["stream_id"] = str(stream_id) + if not payload.get("turn_id"): + journal = read_turn_journal(session_id, session_dir=session_dir) + turn_id = _latest_turn_id_for_stream(journal.get("events") or [], stream_id) + if turn_id: + payload["turn_id"] = turn_id + return append_turn_journal_event(session_id, payload, session_dir=session_dir) + + def iter_turn_journal_session_ids(session_dir: Path) -> list[str]: journal_dir = Path(session_dir) / TURN_JOURNAL_DIR_NAME if not journal_dir.exists(): diff --git a/tests/test_pr1341_context_window_persistence.py b/tests/test_pr1341_context_window_persistence.py index 70d59950..2311250c 100644 --- a/tests/test_pr1341_context_window_persistence.py +++ b/tests/test_pr1341_context_window_persistence.py @@ -38,12 +38,11 @@ def test_streaming_persists_context_fields_on_session_before_save(): # Save call follows shortly after save_call = src.find("\n s.save()", block_start) assert save_call != -1, "s.save() not found after the post-merge marker" - # Limit bumped to 7000 in #1896 fix — the context_length fallback grew to - # accept config_context_length / provider / custom_providers kwargs and a - # legacy 2-arg fallback for older hermes-agent builds. The block is still - # focused: it's a single fallback resolver call with arg-prep scaffold and - # commentary explaining the failure mode it prevents. - assert save_call - block_start < 7000, ( + # Limit bumped to 8200 by turn-journal lifecycle events: the block now also + # records `assistant_started` immediately before the durable final save. + # The context_length fallback is still a single focused resolver call with + # arg-prep scaffold and commentary explaining the failure mode it prevents. + assert save_call - block_start < 8200, ( "s.save() should be close to the post-merge marker — block expanded unexpectedly. " "If you've added a new pre-save mutation block here, bump this limit." ) diff --git a/tests/test_turn_journal_lifecycle.py b/tests/test_turn_journal_lifecycle.py new file mode 100644 index 00000000..ea7ae704 --- /dev/null +++ b/tests/test_turn_journal_lifecycle.py @@ -0,0 +1,38 @@ +from api.turn_journal import ( + append_turn_journal_event, + append_turn_journal_event_for_stream, + derive_turn_journal_states, +) + + +def test_append_turn_journal_event_for_stream_reuses_submitted_turn_id(tmp_path): + submitted = append_turn_journal_event( + "sid-1", + {"event": "submitted", "turn_id": "turn-1", "stream_id": "stream-1", "content": "hello"}, + session_dir=tmp_path, + ) + + worker = append_turn_journal_event_for_stream( + "sid-1", + "stream-1", + {"event": "worker_started"}, + session_dir=tmp_path, + ) + + assert submitted["turn_id"] == "turn-1" + assert worker["turn_id"] == "turn-1" + states = derive_turn_journal_states([submitted, worker]) + assert states["turn-1"]["event"] == "worker_started" + + +def test_append_turn_journal_event_for_stream_falls_back_to_new_turn_for_missing_stream(tmp_path): + event = append_turn_journal_event_for_stream( + "sid-1", + "stream-missing", + {"event": "interrupted", "reason": "no submitted event found"}, + session_dir=tmp_path, + ) + + assert event["stream_id"] == "stream-missing" + assert event["turn_id"] + assert event["event"] == "interrupted" diff --git a/tests/test_turn_journal_lifecycle_callsite.py b/tests/test_turn_journal_lifecycle_callsite.py new file mode 100644 index 00000000..29e51eb3 --- /dev/null +++ b/tests/test_turn_journal_lifecycle_callsite.py @@ -0,0 +1,47 @@ +from pathlib import Path + + +def test_streaming_appends_worker_started_before_running_phase(): + src = Path("api/streaming.py").read_text(encoding="utf-8") + run_idx = src.index("def _run_agent_streaming(") + worker_idx = src.index('"event": "worker_started"', run_idx) + running_idx = src.index('update_active_run(stream_id, phase="running"', run_idx) + + assert worker_idx < running_idx + + +def test_streaming_appends_assistant_started_before_final_save(): + src = Path("api/streaming.py").read_text(encoding="utf-8") + block_idx = src.index("if not ephemeral and s.messages:") + assistant_idx = src.index('"event": "assistant_started"', block_idx) + save_idx = src.index("s.save()", assistant_idx) + + assert block_idx < assistant_idx < save_idx + + +def test_streaming_assistant_started_uses_latest_assistant_message(): + src = Path("api/streaming.py").read_text(encoding="utf-8") + block_idx = src.index("if not ephemeral and s.messages:") + assistant_idx = src.index('"event": "assistant_started"', block_idx) + block = src[block_idx:assistant_idx] + + assert "range(len(s.messages) - 1, -1, -1)" in block + assert '"assistant_message_index": _latest_assistant_idx' in src[assistant_idx:src.index("s.save()", assistant_idx)] + + +def test_streaming_appends_completed_after_final_save(): + src = Path("api/streaming.py").read_text(encoding="utf-8") + assistant_idx = src.index('"event": "assistant_started"') + save_idx = src.index("s.save()", assistant_idx) + completed_idx = src.index('"event": "completed"', save_idx) + + assert save_idx < completed_idx + + +def test_streaming_appends_interrupted_on_provider_error_path(): + src = Path("api/streaming.py").read_text(encoding="utf-8") + err_idx = src.index("err_str = str(e)") + interrupted_idx = src.index('"event": "interrupted"', err_idx) + apperror_idx = src.index("put('apperror'", interrupted_idx) + + assert err_idx < interrupted_idx < apperror_idx