diff --git a/README.md b/README.md index 0d24a7c..07903c5 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,14 @@ Or add it directly to your Claude Desktop config (`claude_desktop_config.json`): This exposes 3 tools to your AI agent: `check_message_safety`, `get_session_risk`, and `list_recent_escalations`. +For HTTP MCP, the server binds to `127.0.0.1` by default. If you expose it +beyond localhost, set a bearer token first: + +```bash +export HUMANE_PROXY_ADMIN_KEY=your-secret-token +humane-proxy mcp-serve --transport http --host 0.0.0.0 --port 3000 +``` + --- ## Available On @@ -441,16 +449,27 @@ curl -X DELETE http://localhost:8000/admin/sessions/user-42 \ ```bash pip install humane-proxy[mcp] humane-proxy mcp-serve # stdio (default) -humane-proxy mcp-serve --transport http --port 3000 # HTTP +humane-proxy mcp-serve --transport http --port 3000 # HTTP on 127.0.0.1 ``` +HTTP MCP is local-only by default. To bind publicly, pass `--host 0.0.0.0` +explicitly and protect tool access with a bearer token: + +```bash +export HUMANE_PROXY_ADMIN_KEY=your-secret-token +humane-proxy mcp-serve --transport http --host 0.0.0.0 --port 3000 +``` + +Clients must send `Authorization: Bearer your-secret-token` when the token is +configured. Leave `HUMANE_PROXY_ADMIN_KEY` unset for stdio/local-only MCP. + Exposes three tools via Model Context Protocol: | Tool | Description | |---|---| | `check_message_safety` | Full pipeline classification | -| `get_session_risk` | Session trajectory (trend, spike, category counts) | -| `list_recent_escalations` | Audit log query | +| `get_session_risk` | Read-only session trajectory snapshot (trend, spike, category counts) | +| `list_recent_escalations` | Bounded audit log query | Available on the [Official MCP Registry](https://registry.modelcontextprotocol.io). diff --git a/humane_proxy/api/admin.py b/humane_proxy/api/admin.py index f83dea1..8ad5248 100644 --- a/humane_proxy/api/admin.py +++ b/humane_proxy/api/admin.py @@ -28,7 +28,7 @@ from datetime import datetime, timezone from typing import Any -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, HTTPException, Query, Response from fastapi.responses import StreamingResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -298,32 +298,21 @@ def get_session_risk( finally: conn.close() - from humane_proxy.risk.trajectory import analyze + from humane_proxy.risk.trajectory import snapshot - # Build trajectory by replaying each escalation. - trajectory = None - for row in rows: - rec = _row_to_dict(row) - trajectory = analyze( - session_id + "_admin_replay", # isolated session key - rec["risk_score"], - rec.get("category", "safe"), - ) + trajectory = snapshot(session_id) return { "session_id": session_id, "escalation_count": len(rows), "history": [_row_to_dict(r) for r in rows], - "trajectory": ( - { - "spike_detected": trajectory.spike_detected, - "trend": trajectory.trend, - "window_scores": trajectory.window_scores, - "category_counts": trajectory.category_counts, - } - if trajectory - else None - ), + "trajectory": { + "spike_detected": trajectory.spike_detected, + "trend": trajectory.trend, + "window_scores": trajectory.window_scores, + "category_counts": trajectory.category_counts, + "message_count": trajectory.message_count, + }, } @@ -381,11 +370,11 @@ def get_stats(_: str = Depends(_require_admin)) -> dict: } -@router.delete("/sessions/{session_id}", status_code=204) +@router.delete("/sessions/{session_id}", status_code=204, response_class=Response) def delete_session_data( session_id: str, _: str = Depends(_require_admin), -) -> None: +) -> Response: """Delete all escalation records for a session (privacy right to erasure).""" conn = _get_conn() try: @@ -397,3 +386,4 @@ def delete_session_data( conn.close() logger.info("Deleted %d records for session %s (admin request)", deleted, session_id) + return Response(status_code=204) diff --git a/humane_proxy/cli.py b/humane_proxy/cli.py index b1de617..c8282fa 100644 --- a/humane_proxy/cli.py +++ b/humane_proxy/cli.py @@ -555,13 +555,14 @@ async def _run_all(): @click.option("--transport", "-t", default="stdio", type=click.Choice(["stdio", "http"]), help="Transport mode: stdio (default) or http") -@click.option("--host", default="0.0.0.0", help="HTTP bind host (default: 0.0.0.0)") +@click.option("--host", default="127.0.0.1", help="HTTP bind host (default: 127.0.0.1)") @click.option("--port", "-p", default=3000, type=int, help="HTTP bind port (default: 3000)") def mcp_serve(transport: str, host: str, port: int) -> None: """Start the MCP server (requires [mcp] extra). Use --transport stdio (default) for local integration with agents. - Use --transport http for remote access and registry listing. + Use --transport http for HTTP access. Set HUMANE_PROXY_ADMIN_KEY + before exposing HTTP MCP beyond localhost. """ try: if transport == "http": diff --git a/humane_proxy/escalation/query.py b/humane_proxy/escalation/query.py new file mode 100644 index 0000000..19ac868 --- /dev/null +++ b/humane_proxy/escalation/query.py @@ -0,0 +1,29 @@ +"""Shared escalation query validation helpers.""" + +from __future__ import annotations + +ALLOWED_ESCALATION_CATEGORIES = frozenset({"self_harm", "criminal_intent"}) +DEFAULT_ESCALATION_LIMIT = 20 +MAX_ESCALATION_LIMIT = 100 + + +def normalize_escalation_query( + limit: int = DEFAULT_ESCALATION_LIMIT, + category: str | None = None, +) -> tuple[int, str | None]: + """Clamp escalation query size and validate optional category filters.""" + try: + normalized_limit = int(limit) + except (TypeError, ValueError): + normalized_limit = DEFAULT_ESCALATION_LIMIT + + normalized_limit = max(1, min(normalized_limit, MAX_ESCALATION_LIMIT)) + normalized_category = category.strip() if isinstance(category, str) else None + if normalized_category == "": + normalized_category = None + + if normalized_category and normalized_category not in ALLOWED_ESCALATION_CATEGORIES: + allowed = ", ".join(sorted(ALLOWED_ESCALATION_CATEGORIES)) + raise ValueError(f"category must be one of: {allowed}") + + return normalized_limit, normalized_category diff --git a/humane_proxy/integrations/autogen.py b/humane_proxy/integrations/autogen.py index f245869..6000d4c 100644 --- a/humane_proxy/integrations/autogen.py +++ b/humane_proxy/integrations/autogen.py @@ -59,16 +59,9 @@ def get_session_risk(session_id: str) -> str: Returns: JSON string with spike detection, trend, and category distribution. """ - from humane_proxy.risk.trajectory import analyze + from humane_proxy.risk.trajectory import snapshot, to_dict - result = analyze(session_id, 0.0, "safe") - return json.dumps({ - "spike_detected": result.spike_detected, - "trend": result.trend, - "window_scores": result.window_scores, - "category_counts": result.category_counts, - "message_count": result.message_count, - }, indent=2) + return json.dumps(to_dict(snapshot(session_id)), indent=2) def list_recent_escalations(limit: int = 20, category: str = "") -> str: @@ -81,11 +74,13 @@ def list_recent_escalations(limit: int = 20, category: str = "") -> str: Returns: JSON string with list of escalation records. """ + from humane_proxy.escalation.query import normalize_escalation_query from humane_proxy.storage.factory import get_store + limit, category = normalize_escalation_query(limit, category) store = get_store() results = store.query( - category=category if category else None, + category=category, limit=limit, ) return json.dumps(results, indent=2, default=str) diff --git a/humane_proxy/integrations/crewai.py b/humane_proxy/integrations/crewai.py index ba76632..cf92793 100644 --- a/humane_proxy/integrations/crewai.py +++ b/humane_proxy/integrations/crewai.py @@ -94,17 +94,10 @@ class GetSessionRiskTool(BaseTool): args_schema: Type[BaseModel] = SessionRiskInput def _run(self, session_id: str) -> str: - from humane_proxy.risk.trajectory import analyze + from humane_proxy.risk.trajectory import snapshot, to_dict import json - result = analyze(session_id, 0.0, "safe") - return json.dumps({ - "spike_detected": result.spike_detected, - "trend": result.trend, - "window_scores": result.window_scores, - "category_counts": result.category_counts, - "message_count": result.message_count, - }, indent=2) + return json.dumps(to_dict(snapshot(session_id)), indent=2) class ListEscalationsTool(BaseTool): name: str = "list_recent_escalations" @@ -112,12 +105,14 @@ class ListEscalationsTool(BaseTool): args_schema: Type[BaseModel] = ListEscalationsInput def _run(self, limit: int = 20, category: str = "") -> str: + from humane_proxy.escalation.query import normalize_escalation_query from humane_proxy.storage.factory import get_store import json + limit, category = normalize_escalation_query(limit, category) store = get_store() results = store.query( - category=category if category else None, + category=category, limit=limit, ) return json.dumps(results, indent=2, default=str) diff --git a/humane_proxy/integrations/llamaindex.py b/humane_proxy/integrations/llamaindex.py index 3658db0..c519dd5 100644 --- a/humane_proxy/integrations/llamaindex.py +++ b/humane_proxy/integrations/llamaindex.py @@ -56,16 +56,9 @@ def _get_session_risk(session_id: str) -> dict: dict ``{"spike_detected": bool, "trend": str, "window_scores": list, ...}`` """ - from humane_proxy.risk.trajectory import analyze + from humane_proxy.risk.trajectory import snapshot, to_dict - result = analyze(session_id, 0.0, "safe") - return { - "spike_detected": result.spike_detected, - "trend": result.trend, - "window_scores": result.window_scores, - "category_counts": result.category_counts, - "message_count": result.message_count, - } + return to_dict(snapshot(session_id)) def _list_recent_escalations(limit: int = 20, category: str | None = None) -> list[dict]: @@ -83,8 +76,10 @@ def _list_recent_escalations(limit: int = 20, category: str | None = None) -> li list[dict] List of escalation records. """ + from humane_proxy.escalation.query import normalize_escalation_query from humane_proxy.storage.factory import get_store + limit, category = normalize_escalation_query(limit, category) store = get_store() return store.query(category=category, limit=limit) diff --git a/humane_proxy/mcp_server.py b/humane_proxy/mcp_server.py index 04600f0..c6f8101 100644 --- a/humane_proxy/mcp_server.py +++ b/humane_proxy/mcp_server.py @@ -8,10 +8,52 @@ from __future__ import annotations +import ipaddress import logging +import os logger = logging.getLogger("humane_proxy.mcp") +MCP_TOKEN_ENV = "HUMANE_PROXY_ADMIN_KEY" +MCP_DEFAULT_HOST = "127.0.0.1" + + +def _is_public_bind_host(host: str) -> bool: + """Return whether an HTTP bind host may be reachable beyond localhost.""" + normalized = (host or "").strip() + if normalized.startswith("[") and normalized.endswith("]"): + normalized = normalized[1:-1] + + if not normalized: + return True + if normalized.lower() == "localhost": + return False + + try: + address = ipaddress.ip_address(normalized) + except ValueError: + return True + + return not address.is_loopback + + +def _get_mcp_auth_provider(): + """Return a FastMCP Bearer auth provider when HTTP MCP auth is configured.""" + token = os.environ.get(MCP_TOKEN_ENV, "").strip() + if not token: + return None + + try: + from fastmcp.server.auth import BearerTokenAuth # type: ignore[import] + except ImportError as exc: + raise RuntimeError( + f"{MCP_TOKEN_ENV} is set, but this FastMCP version does not expose " + "server Bearer token auth. Upgrade fastmcp to use HTTP MCP auth." + ) from exc + + return BearerTokenAuth(token=token) + + try: from fastmcp import FastMCP # type: ignore[import] _MCP_AVAILABLE = True @@ -24,8 +66,11 @@ # --------------------------------------------------------------------------- if _MCP_AVAILABLE: + auth_provider = _get_mcp_auth_provider() + mcp_kwargs = {"auth": auth_provider} if auth_provider is not None else {} mcp = FastMCP( - "humane-proxy" + "humane-proxy", + **mcp_kwargs, ) @mcp.tool() @@ -71,17 +116,9 @@ async def get_session_risk(session_id: str) -> dict: ``{"spike_detected": bool, "trend": str, "window_scores": list, "category_counts": dict, "message_count": int}`` """ - from humane_proxy.risk.trajectory import analyze - - # Analyze with a neutral message to get current state. - result = analyze(session_id, 0.0, "safe") - return { - "spike_detected": result.spike_detected, - "trend": result.trend, - "window_scores": result.window_scores, - "category_counts": result.category_counts, - "message_count": result.message_count, - } + from humane_proxy.risk.trajectory import snapshot, to_dict + + return to_dict(snapshot(session_id)) @mcp.tool() async def list_recent_escalations( @@ -105,8 +142,11 @@ async def list_recent_escalations( """ import json import sqlite3 + from humane_proxy.escalation.query import normalize_escalation_query from humane_proxy.escalation.local_db import _get_db_path + limit, category = normalize_escalation_query(limit, category) + conn = sqlite3.connect(_get_db_path(), check_same_thread=False) try: if category: @@ -148,7 +188,7 @@ def serve() -> None: mcp.run() -def serve_http(host: str = "0.0.0.0", port: int = 3000) -> None: +def serve_http(host: str = MCP_DEFAULT_HOST, port: int = 3000) -> None: """Start the MCP server in Streamable HTTP mode. This exposes the MCP tools over HTTP, making the server compatible @@ -158,7 +198,7 @@ def serve_http(host: str = "0.0.0.0", port: int = 3000) -> None: Parameters ---------- host: - Bind address (default ``"0.0.0.0"``). + Bind address (default ``"127.0.0.1"``). port: Bind port (default ``3000``). """ @@ -166,6 +206,13 @@ def serve_http(host: str = "0.0.0.0", port: int = 3000) -> None: raise RuntimeError( "MCP server requires fastmcp. Install with: pip install humane-proxy[mcp]" ) + if _is_public_bind_host(host) and not os.environ.get(MCP_TOKEN_ENV, "").strip(): + logger.warning( + "Starting HTTP MCP on public host %s without %s. " + "Set a bearer token before exposing this server beyond localhost.", + host, + MCP_TOKEN_ENV, + ) assert mcp is not None mcp.run(transport="http", host=host, port=port) diff --git a/humane_proxy/risk/trajectory.py b/humane_proxy/risk/trajectory.py index 1c63399..5fbf0bf 100644 --- a/humane_proxy/risk/trajectory.py +++ b/humane_proxy/risk/trajectory.py @@ -42,6 +42,7 @@ # Each entry is (score, timestamp_seconds). session_history: dict[str, deque[tuple[float, float]]] = {} _category_history: dict[str, deque[str]] = {} +_last_spike_by_session: dict[str, bool] = {} def _evict_oldest_sessions() -> None: @@ -55,6 +56,7 @@ def _evict_oldest_sessions() -> None: oldest_key = next(iter(session_history)) del session_history[oldest_key] _category_history.pop(oldest_key, None) + _last_spike_by_session.pop(oldest_key, None) # --------------------------------------------------------------------------- @@ -85,6 +87,30 @@ def _weighted_mean(history: deque[tuple[float, float]], now: float) -> float: return weighted_sum / total_weight if total_weight > 0 else 0.0 +def _trend_for_scores(scores: list[float]) -> str: + """Return the trend label for a list of recent raw scores.""" + if len(scores) < 4: + return "stable" + + mid = len(scores) // 2 + first_half_avg = sum(scores[:mid]) / mid + second_half_avg = sum(scores[mid:]) / (len(scores) - mid) + trend_delta = second_half_avg - first_half_avg + if trend_delta > 0.15: + return "escalating" + if trend_delta < -0.15: + return "declining" + return "stable" + + +def _category_counts(session_id: str) -> dict[str, int]: + """Return the category distribution for a tracked session.""" + cat_counts: dict[str, int] = {} + for c in _category_history.get(session_id, []): + cat_counts[c] = cat_counts.get(c, 0) + 1 + return cat_counts + + def detect_spike(session_id: str, current_score: float) -> bool: """Return ``True`` if the current score spikes above the recent average. @@ -169,6 +195,7 @@ def analyze( """ # Run spike detection (this also appends the score to session_history). spike = detect_spike(session_id, score) + _last_spike_by_session[session_id] = spike # Track category history. if session_id not in _category_history: @@ -179,27 +206,35 @@ def analyze( history = session_history.get(session_id, deque()) scores = [s for s, _ in history] - # Trend detection: compare first half vs second half of the window. - trend = "stable" - if len(scores) >= 4: - mid = len(scores) // 2 - first_half_avg = sum(scores[:mid]) / mid - second_half_avg = sum(scores[mid:]) / (len(scores) - mid) - trend_delta = second_half_avg - first_half_avg - if trend_delta > 0.15: - trend = "escalating" - elif trend_delta < -0.15: - trend = "declining" - - # Category distribution. - cat_counts: dict[str, int] = {} - for c in _category_history.get(session_id, []): - cat_counts[c] = cat_counts.get(c, 0) + 1 - return TrajectoryResult( spike_detected=spike, - trend=trend, + trend=_trend_for_scores(scores), window_scores=scores, - category_counts=cat_counts, + category_counts=_category_counts(session_id), message_count=len(scores), ) + + +def snapshot(session_id: str) -> TrajectoryResult: + """Return the current trajectory state without recording a new event.""" + history = session_history.get(session_id, deque()) + scores = [s for s, _ in history] + + return TrajectoryResult( + spike_detected=_last_spike_by_session.get(session_id, False), + trend=_trend_for_scores(scores), + window_scores=scores, + category_counts=_category_counts(session_id), + message_count=len(scores), + ) + + +def to_dict(result: TrajectoryResult) -> dict: + """Serialize a trajectory result for MCP and agent integrations.""" + return { + "spike_detected": result.spike_detected, + "trend": result.trend, + "window_scores": result.window_scores, + "category_counts": result.category_counts, + "message_count": result.message_count, + } diff --git a/tests/test_admin_api.py b/tests/test_admin_api.py index d31e70e..f4b8b5d 100644 --- a/tests/test_admin_api.py +++ b/tests/test_admin_api.py @@ -117,6 +117,30 @@ def test_stats_by_category(self, _seeded_db): assert by_cat["criminal_intent"] == 1 +class TestSessionRisk: + HEADERS = {"Authorization": "Bearer test-admin-secret"} + + def test_session_risk_is_read_only(self, _seeded_db): + from humane_proxy.risk.trajectory import analyze, session_history, _category_history + + sid = "sess-1" + analyze(sid, 0.1, "safe") + analyze(sid, 0.1, "safe") + analyze(sid, 0.1, "safe") + analyze(sid, 0.9, "self_harm") + before_score_count = len(session_history[sid]) + before_category_count = len(_category_history[sid]) + + resp = client.get(f"/admin/sessions/{sid}/risk", headers=self.HEADERS) + + assert resp.status_code == 200 + data = resp.json() + assert data["trajectory"]["message_count"] == before_score_count + assert data["trajectory"]["spike_detected"] is True + assert len(session_history[sid]) == before_score_count + assert len(_category_history[sid]) == before_category_count + + class TestDeleteSession: HEADERS = {"Authorization": "Bearer test-admin-secret"} diff --git a/tests/test_integrations_smoke.py b/tests/test_integrations_smoke.py index 3640b87..6b52172 100644 --- a/tests/test_integrations_smoke.py +++ b/tests/test_integrations_smoke.py @@ -3,6 +3,8 @@ import pytest from unittest.mock import patch, MagicMock +from humane_proxy.risk.trajectory import analyze, session_history, _category_history + def test_llamaindex_tools_creation(): """Verify LlamaIndex tools can be created when dependency exists.""" import sys @@ -74,3 +76,61 @@ def test_autogen_tools_creation(): assert mock_proxy.register_for_execution.call_count == 3 finally: del sys.modules["autogen"] + + +def test_autogen_session_risk_is_read_only(): + from humane_proxy.integrations.autogen import get_session_risk + + sid = "autogen-risk-read-only" + analyze(sid, 0.4, "safe") + before_count = len(session_history[sid]) + + get_session_risk(sid) + + assert len(session_history[sid]) == before_count + assert len(_category_history[sid]) == before_count + + +def test_llamaindex_session_risk_is_read_only(): + from humane_proxy.integrations.llamaindex import _get_session_risk + + sid = "llamaindex-risk-read-only" + analyze(sid, 0.6, "criminal_intent") + before_count = len(session_history[sid]) + + result = _get_session_risk(sid) + + assert result["message_count"] == before_count + assert len(session_history[sid]) == before_count + assert len(_category_history[sid]) == before_count + + +def test_crewai_session_risk_is_read_only(): + import json + import sys + + class MockBaseTool: + pass + + mock_crewai = MagicMock() + mock_crewai.tools.BaseTool = MockBaseTool + sys.modules["crewai"] = mock_crewai + sys.modules["crewai.tools"] = mock_crewai.tools + + try: + from humane_proxy.integrations.crewai import get_safety_tools + + sid = "crewai-risk-read-only" + analyze(sid, 0.7, "self_harm") + before_count = len(session_history[sid]) + + tools = get_safety_tools() + risk_tool = next(tool for tool in tools if tool.name == "get_session_risk") + result = json.loads(risk_tool._run(sid)) + + assert result["message_count"] == before_count + assert len(session_history[sid]) == before_count + assert len(_category_history[sid]) == before_count + finally: + del sys.modules["crewai"] + del sys.modules["crewai.tools"] diff --git a/tests/test_mcp_security.py b/tests/test_mcp_security.py new file mode 100644 index 0000000..d2b59d4 --- /dev/null +++ b/tests/test_mcp_security.py @@ -0,0 +1,94 @@ +"""Security-focused MCP helper tests.""" + +import sys +import types + +import pytest + +from humane_proxy.escalation.query import normalize_escalation_query +from humane_proxy.mcp_server import ( + MCP_DEFAULT_HOST, + MCP_TOKEN_ENV, + _get_mcp_auth_provider, + _is_public_bind_host, + serve_http, +) + + +def test_http_mcp_defaults_to_localhost(): + assert MCP_DEFAULT_HOST == "127.0.0.1" + assert serve_http.__defaults__[0] == "127.0.0.1" + + +def test_mcp_auth_provider_uses_configured_bearer_token(monkeypatch): + class FakeBearerTokenAuth: + def __init__(self, token: str): + self.token = token + + fastmcp_module = types.ModuleType("fastmcp") + server_module = types.ModuleType("fastmcp.server") + auth_module = types.ModuleType("fastmcp.server.auth") + auth_module.BearerTokenAuth = FakeBearerTokenAuth + + monkeypatch.setitem(sys.modules, "fastmcp", fastmcp_module) + monkeypatch.setitem(sys.modules, "fastmcp.server", server_module) + monkeypatch.setitem(sys.modules, "fastmcp.server.auth", auth_module) + monkeypatch.setenv(MCP_TOKEN_ENV, "test-mcp-secret") + + auth = _get_mcp_auth_provider() + + assert isinstance(auth, FakeBearerTokenAuth) + assert auth.token == "test-mcp-secret" + + +def test_mcp_auth_provider_is_optional(monkeypatch): + monkeypatch.delenv(MCP_TOKEN_ENV, raising=False) + assert _get_mcp_auth_provider() is None + + +@pytest.mark.parametrize( + ("raw_limit", "expected"), + [ + (0, 1), + (-50, 1), + (25, 25), + (500, 100), + ("not-a-number", 20), + ], +) +def test_escalation_query_limit_is_clamped(raw_limit, expected): + limit, category = normalize_escalation_query(raw_limit, "self_harm") + + assert limit == expected + assert category == "self_harm" + + +def test_escalation_query_rejects_unknown_categories(): + with pytest.raises(ValueError, match="category must be one of"): + normalize_escalation_query(20, "all_data") + + +def test_escalation_query_treats_whitespace_category_as_unfiltered(): + limit, category = normalize_escalation_query(20, " ") + + assert limit == 20 + assert category is None + + +@pytest.mark.parametrize( + ("host", "expected"), + [ + ("127.0.0.1", False), + ("::1", False), + ("localhost", False), + ("0.0.0.0", True), + ("::", True), + ("[::]", True), + ("192.168.1.10", True), + ("10.0.0.5", True), + ("mcp.example.com", True), + ("", True), + ], +) +def test_public_bind_detection_flags_non_loopback_hosts(host, expected): + assert _is_public_bind_host(host) is expected diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 9edb6e7..d69421d 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -9,7 +9,9 @@ analyze, detect_spike, session_history, + snapshot, _category_history, + _last_spike_by_session, _weighted_mean, ) from humane_proxy.classifiers.models import TrajectoryResult @@ -136,6 +138,41 @@ def test_window_scores_are_raw(self): for s in result.window_scores: assert isinstance(s, float) + def test_snapshot_is_read_only(self): + sid = "snapshot-read-only-v3" + analyze(sid, 0.2, "safe") + analyze(sid, 0.8, "self_harm") + + first = snapshot(sid) + second = snapshot(sid) + + assert first.message_count == 2 + assert second.message_count == 2 + assert first.window_scores == [0.2, 0.8] + assert second.category_counts == {"safe": 1, "self_harm": 1} + assert len(session_history[sid]) == 2 + assert len(_category_history[sid]) == 2 + + def test_snapshot_preserves_last_spike_state(self): + sid = "snapshot-spike-state-v3" + for _ in range(3): + analyze(sid, 0.1, "safe") + analyze(sid, 0.9, "self_harm") + + result = snapshot(sid) + + assert result.spike_detected is True + assert _last_spike_by_session[sid] is True + + def test_snapshot_empty_session(self): + result = snapshot("snapshot-empty-v3") + + assert result.spike_detected is False + assert result.trend == "stable" + assert result.window_scores == [] + assert result.category_counts == {} + assert result.message_count == 0 + class TestTrendDetection: def test_escalating_trend(self):