"""Crash-safe WebUI turn journal helpers. The journal is deliberately tiny: one JSONL file per session, append-only events, and read helpers that tolerate malformed lines. Recovery and repair can then reason about submitted turns without depending on in-memory stream state. """ from __future__ import annotations import json import os import re import time import uuid from contextlib import contextmanager from pathlib import Path from typing import Iterable try: # pragma: no cover - fcntl is unavailable on Windows. import fcntl as _fcntl except ImportError: # pragma: no cover _fcntl = None TURN_JOURNAL_DIR_NAME = "_turn_journal" _TERMINAL_EVENTS = {"completed", "interrupted"} _SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_.-]+$") def _default_session_dir() -> Path: from api.models import SESSION_DIR return Path(SESSION_DIR) def _journal_path(session_id: str, session_dir: Path | None = None) -> Path: sid = str(session_id or "").strip() if not sid or "/" in sid or "\\" in sid or not _SESSION_ID_RE.fullmatch(sid): raise ValueError("invalid session_id") root = Path(session_dir) if session_dir is not None else _default_session_dir() return root / TURN_JOURNAL_DIR_NAME / f"{sid}.jsonl" def _make_turn_id() -> str: return f"{time.strftime('%Y%m%dT%H%M%SZ', time.gmtime())}-{uuid.uuid4().hex[:12]}" @contextmanager def _journal_file_lock(file_obj): """Serialize multi-process journal writes when advisory locks exist. ``O_APPEND`` keeps normal same-process appends simple, but a long JSONL event can exceed POSIX's small atomic-write boundary. On Unix, take an advisory lock around the single event write+fsync so two WebUI worker processes cannot interleave large submitted-message payloads into corrupted JSONL. Platforms without ``fcntl`` keep the previous best-effort append behavior. """ if _fcntl is None: yield return _fcntl.flock(file_obj.fileno(), _fcntl.LOCK_EX) try: yield finally: _fcntl.flock(file_obj.fileno(), _fcntl.LOCK_UN) def append_turn_journal_event( session_id: str, event: dict, *, session_dir: Path | None = None, ) -> dict: """Append one turn journal event and fsync it before returning. The returned event is the exact payload written, with default ``version``, ``session_id``, ``turn_id``, and ``created_at`` fields filled in. """ if not isinstance(event, dict): raise TypeError("event must be a dict") event_name = str(event.get("event") or "").strip() if not event_name: raise ValueError("event is required") payload = dict(event) payload.setdefault("version", 1) payload["session_id"] = str(session_id) payload.setdefault("turn_id", _make_turn_id()) payload.setdefault("created_at", time.time()) path = _journal_path(session_id, session_dir=session_dir) path.parent.mkdir(parents=True, exist_ok=True) line = json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + "\n" fd = os.open(path, os.O_CREAT | os.O_APPEND | os.O_WRONLY, 0o600) with os.fdopen(fd, "a", encoding="utf-8") as fh: with _journal_file_lock(fh): fh.write(line) fh.flush() os.fsync(fh.fileno()) try: dir_fd = os.open(path.parent, os.O_DIRECTORY) try: os.fsync(dir_fd) finally: os.close(dir_fd) except OSError: pass return payload def read_turn_journal(session_id: str, *, session_dir: Path | None = None) -> dict: """Read a session journal, returning valid events plus malformed lines.""" path = _journal_path(session_id, session_dir=session_dir) events: list[dict] = [] malformed: list[dict] = [] try: lines = path.read_text(encoding="utf-8").splitlines() except FileNotFoundError: return {"session_id": str(session_id), "events": [], "malformed": []} for line_no, raw in enumerate(lines, start=1): if not raw.strip(): continue try: event = json.loads(raw) except json.JSONDecodeError: malformed.append({"line": line_no, "raw": raw}) continue if isinstance(event, dict): events.append(event) else: malformed.append({"line": line_no, "raw": raw}) return {"session_id": str(session_id), "events": events, "malformed": malformed} def derive_turn_journal_states(events: Iterable[dict]) -> tuple[dict[str, dict], list[dict]]: '''Return the latest event per ``turn_id`` and any terminal-collision entries. The first element is the latest event per turn_id (same overwrite-by-timestamp behaviour as before). The second element is a list of collision records, one per turn_id that had more than one terminal event. Each collision record contains ``turn_id`` and the ``events`` list (in ascending created_at order). A collision means the same logical turn recorded both ``completed`` and ``interrupted`` terminal events -- the derived state still picks the latest by timestamp, but callers can now detect and audit the double-terminal situation explicitly rather than having it silently collapse. ''' states: dict[str, dict] = {} # Collect all terminal events per turn_id to detect collisions terminal_events: dict[str, list[dict]] = {} for event in events: if not isinstance(event, dict): continue turn_id = str(event.get('turn_id') or '').strip() if not turn_id: continue # Track terminal events for collision detection if is_terminal_turn_event(event): terminal_events.setdefault(turn_id, []).append(event) # Existing latest-by-timestamp derivation previous = states.get(turn_id) if previous is None or float(event.get('created_at') or 0) >= float(previous.get('created_at') or 0): states[turn_id] = event # Build collision list: turn_ids with more than one terminal event collisions = [ {'turn_id': tid, 'events': sorted(evts, key=lambda e: float(e.get('created_at') or 0))} for tid, evts in terminal_events.items() if len(evts) > 1 ] return states, collisions def _latest_turn_id_for_stream(events: Iterable[dict], stream_id: str) -> str | None: stream = str(stream_id or "").strip() if not stream: return None latest: str | None = None for event in events: if not isinstance(event, dict): continue if str(event.get("stream_id") or "") != stream: continue turn_id = str(event.get("turn_id") or "").strip() if turn_id: latest = turn_id return latest def append_turn_journal_event_for_stream( session_id: str, stream_id: str, event: dict, *, session_dir: Path | None = None, ) -> dict: """Append a lifecycle event for the turn associated with ``stream_id``.""" payload = dict(event) payload["stream_id"] = str(stream_id) if not payload.get("turn_id"): journal = read_turn_journal(session_id, session_dir=session_dir) turn_id = _latest_turn_id_for_stream(journal.get("events") or [], stream_id) if turn_id: payload["turn_id"] = turn_id return append_turn_journal_event(session_id, payload, session_dir=session_dir) def iter_turn_journal_session_ids(session_dir: Path) -> list[str]: journal_dir = Path(session_dir) / TURN_JOURNAL_DIR_NAME if not journal_dir.exists(): return [] return sorted(path.stem for path in journal_dir.glob("*.jsonl") if path.is_file()) def is_terminal_turn_event(event: dict) -> bool: return str((event or {}).get("event") or "") in _TERMINAL_EVENTS