diff --git a/src/services/mcp/client.py b/src/services/mcp/client.py index 7a66e1bf..45136634 100644 --- a/src/services/mcp/client.py +++ b/src/services/mcp/client.py @@ -7,6 +7,8 @@ import time from typing import Any, Callable +from src.utils.abort_controller import AbortError, AbortSignal + from .errors import ( McpAuthError, McpSessionExpiredError, @@ -335,6 +337,11 @@ async def _receive_loop(self) -> None: break if msg.id is not None and msg.id in self._pending_requests: future = self._pending_requests.pop(msg.id) + if future.done(): + # Aborted (cancelled) request whose response raced + # the cleanup — resolving it would raise + # InvalidStateError and kill the receive loop. + continue if msg.error: future.set_exception( McpToolCallError( @@ -355,24 +362,81 @@ async def _send_request( self, method: str, params: dict[str, Any] | None = None, + abort_signal: AbortSignal | None = None, ) -> Any: if self._transport is None: raise RuntimeError("Transport not connected") + if abort_signal is not None: + abort_signal.throw_if_aborted() request_id = self._next_id() msg = JsonRpcMessage( method=method, params=params, id=request_id, ) - future: asyncio.Future[Any] = asyncio.get_event_loop().create_future() + loop = asyncio.get_running_loop() + future: asyncio.Future[Any] = loop.create_future() self._pending_requests[request_id] = future - await self._transport.send(msg) + + # ESC-cancel (#277): the abort listener fires on the aborting + # thread (TUI/REPL ESC handler), so hop onto this loop to cancel + # the pending future — wait_for then raises CancelledError, which + # we convert to AbortError after notifying the server. When user + # abort and a genuine task-cancel coincide, AbortError wins: this + # coroutine runs on a per-call asyncio.run loop in production, so + # no external task cancellation can reach it anyway. + registered_abort: Callable[[], None] | None = None + if abort_signal is not None: + def _on_abort() -> None: + loop.call_soon_threadsafe(future.cancel) + + registered_abort = abort_signal.add_listener(_on_abort, once=True) + if abort_signal.aborted: + # Abort fired between throw_if_aborted and add_listener — + # the listener will never fire. Skip the send entirely. + self._pending_requests.pop(request_id, None) + abort_signal.remove_listener(registered_abort) + raise AbortError(abort_signal.reason or "user_interrupt") + timeout_s = _get_tool_timeout_ms() / 1000.0 try: + # Inside the try so the finally's listener/pending cleanup + # also covers a send that raises (closed transport, broken + # pipe) or is cancelled mid-await. + await self._transport.send(msg) return await asyncio.wait_for(future, timeout=timeout_s) - except asyncio.TimeoutError: - self._pending_requests.pop(request_id, None) + except asyncio.CancelledError: + if abort_signal is not None and abort_signal.aborted: + # JSON-RPC cancellation per MCP spec: best-effort + # notification so a compliant server stops the work; a + # server that ignores it merely leaks one request — the + # client is already unblocked. Bounded so a wedged + # transport (the likely cause of the hang being escaped) + # cannot block the unblock path. + try: + await asyncio.wait_for( + self._send_notification( + "notifications/cancelled", + { + "requestId": request_id, + "reason": abort_signal.reason or "user_interrupt", + }, + ), + timeout=2, + ) + except Exception: + logger.debug( + "failed to send cancellation for request %s", request_id + ) + raise AbortError(abort_signal.reason or "user_interrupt") from None raise + finally: + # No-op on the success path (the receive loop pops when it + # resolves); guarantees no stranded never-resolving future on + # timeout, abort, task-cancel, or send failure. + self._pending_requests.pop(request_id, None) + if abort_signal is not None and registered_abort is not None: + abort_signal.remove_listener(registered_abort) async def _send_notification( self, @@ -411,6 +475,7 @@ async def call_tool( tool_name: str, arguments: dict[str, Any] | None = None, meta: dict[str, Any] | None = None, + abort_signal: AbortSignal | None = None, ) -> McpToolResult: params: dict[str, Any] = { "name": tool_name, @@ -425,7 +490,9 @@ async def call_tool( # cache is cleared on detection so the next request reconnects # against a fresh session rather than reusing the expired one. try: - result = await self._send_request("tools/call", params) + result = await self._send_request( + "tools/call", params, abort_signal=abort_signal + ) except McpToolCallError as err: if not is_mcp_session_expired_error(err): # Regular tool error (invalid params, server-rejected, etc.) — @@ -437,7 +504,9 @@ async def call_tool( # failed and re-raised). A second session-expired here means # the server is unstable / the retry hit a fresh session that # already expired — propagate so we don't loop indefinitely. - result = await self._send_request("tools/call", params) + result = await self._send_request( + "tools/call", params, abort_signal=abort_signal + ) if not result or not isinstance(result, dict): return McpToolResult() diff --git a/src/services/mcp/tool_wrapper.py b/src/services/mcp/tool_wrapper.py index 8888ef28..1ef4c539 100644 --- a/src/services/mcp/tool_wrapper.py +++ b/src/services/mcp/tool_wrapper.py @@ -32,6 +32,7 @@ from src.tool_system.build_tool import McpInfo, Tool, build_tool from src.tool_system.context import ToolContext from src.tool_system.protocol import ToolResult +from src.utils.abort_controller import AbortError from .client import McpClient from .mcp_string_utils import build_mcp_tool_name @@ -253,8 +254,18 @@ async def _async_call(args: dict[str, Any], ctx: ToolContext) -> ToolResult: is_error=True, ) + # getattr: duck-typed/mocked contexts (spec'd mocks don't expose + # default_factory dataclass fields) may lack abort_controller. + abort_controller = getattr(ctx, "abort_controller", None) + abort_signal = ( + abort_controller.signal if abort_controller is not None else None + ) try: - result = await client.call_tool(mcp_tool.name, args) + result = await client.call_tool( + mcp_tool.name, + args, + abort_signal=abort_signal, + ) content_blocks: list[dict[str, Any]] = list(result.content) if result.content else [] # WI-8.2: budget-truncate before rendering so the model never @@ -324,6 +335,10 @@ async def _async_call(args: dict[str, Any], ctx: ToolContext) -> ToolResult: is_error=False, mcp_meta=mcp_meta, ) + except AbortError: + # ESC-cancel (#277): propagate so the dispatch layer renders + # the user-cancel message instead of a generic tool error. + raise except Exception as e: return ToolResult( name=fully_qualified_name, diff --git a/tests/test_mcp_abort.py b/tests/test_mcp_abort.py new file mode 100644 index 00000000..80b5e8cb --- /dev/null +++ b/tests/test_mcp_abort.py @@ -0,0 +1,225 @@ +"""#277 — MCP tool calls honor abort_controller (ESC-cancel). + +A pending ``tools/call`` must unblock the moment the abort signal trips +(not at the multi-minute MCP request timeout), send the MCP +``notifications/cancelled`` notification so a compliant server stops the +work, and surface as ``AbortError`` so the dispatch layer renders the +user-cancel message instead of a generic tool error. +""" +from __future__ import annotations + +import asyncio +import threading +import time +from typing import Any + +import pytest + +from src.services.mcp.client import McpClient +from src.services.mcp.transport import JsonRpcMessage +from src.utils.abort_controller import AbortController, AbortError + + +class _HangingTransport: + """Records sends; never delivers a response (a hung MCP server).""" + + def __init__(self): + self.sent: list[JsonRpcMessage] = [] + self._closed = asyncio.Event() + + @property + def is_connected(self) -> bool: + return not self._closed.is_set() + + async def send(self, message: JsonRpcMessage) -> None: + self.sent.append(message) + + async def receive(self) -> JsonRpcMessage | None: + await self._closed.wait() + return None + + async def close(self) -> None: + self._closed.set() + + +def _make_client() -> tuple[McpClient, _HangingTransport]: + client = McpClient() + transport = _HangingTransport() + client._transport = transport + return client, transport + + +def _abort_after(controller: AbortController, delay_s: float) -> None: + t = threading.Timer(delay_s, lambda: controller.abort("user_interrupt")) + t.daemon = True + t.start() + + +class TestMcpAbort: + @pytest.mark.asyncio + async def test_abort_unblocks_pending_call_fast(self): + client, transport = _make_client() + controller = AbortController() + _abort_after(controller, 0.1) + + start = time.monotonic() + with pytest.raises(AbortError): + await client.call_tool( + "slow_tool", {}, abort_signal=controller.signal + ) + elapsed = time.monotonic() - start + assert elapsed < 2.0, f"abort took {elapsed:.2f}s" + + @pytest.mark.asyncio + async def test_abort_sends_cancellation_notification(self): + client, transport = _make_client() + controller = AbortController() + _abort_after(controller, 0.05) + + with pytest.raises(AbortError): + await client.call_tool("slow_tool", {}, abort_signal=controller.signal) + + call_msg = transport.sent[0] + assert call_msg.method == "tools/call" + cancels = [m for m in transport.sent if m.method == "notifications/cancelled"] + assert len(cancels) == 1 + assert cancels[0].params["requestId"] == call_msg.id + assert cancels[0].params["reason"] == "user_interrupt" + assert cancels[0].id is None # notification, not a request + + @pytest.mark.asyncio + async def test_abort_cleans_pending_request(self): + client, transport = _make_client() + controller = AbortController() + _abort_after(controller, 0.05) + + with pytest.raises(AbortError): + await client.call_tool("slow_tool", {}, abort_signal=controller.signal) + assert client._pending_requests == {} + + @pytest.mark.asyncio + async def test_pre_aborted_signal_raises_before_sending(self): + client, transport = _make_client() + controller = AbortController() + controller.abort("user_interrupt") + + with pytest.raises(AbortError): + await client.call_tool("slow_tool", {}, abort_signal=controller.signal) + assert transport.sent == [] + + @pytest.mark.asyncio + async def test_listener_removed_after_normal_completion(self): + client, transport = _make_client() + controller = AbortController() + + async def _respond(): + while not transport.sent: + await asyncio.sleep(0.01) + req = transport.sent[0] + future = client._pending_requests[req.id] + future.set_result({"content": [{"type": "text", "text": "ok"}]}) + + responder = asyncio.create_task(_respond()) + result = await client.call_tool("tool", {}, abort_signal=controller.signal) + await responder + + assert result.content[0]["text"] == "ok" + assert controller.signal._listeners == [] + + @pytest.mark.asyncio + async def test_send_failure_cleans_listener_and_pending(self): + """transport.send raising must not leak the abort listener or the + pending future (the finally covers the send, not just the wait).""" + client, transport = _make_client() + + async def _broken_send(message): + raise ConnectionError("broken pipe") + + transport.send = _broken_send # type: ignore[method-assign] + controller = AbortController() + + with pytest.raises(ConnectionError): + await client.call_tool("tool", {}, abort_signal=controller.signal) + assert controller.signal._listeners == [] + assert client._pending_requests == {} + + @pytest.mark.asyncio + async def test_without_signal_behavior_unchanged(self): + client, transport = _make_client() + + async def _respond(): + while not transport.sent: + await asyncio.sleep(0.01) + req = transport.sent[0] + client._pending_requests[req.id].set_result( + {"content": [{"type": "text", "text": "plain"}]} + ) + + responder = asyncio.create_task(_respond()) + result = await client.call_tool("tool", {}) + await responder + assert result.content[0]["text"] == "plain" + + +class TestReceiveLoopCancelledRace: + @pytest.mark.asyncio + async def test_late_response_for_cancelled_future_does_not_kill_loop(self): + """A response racing the abort cleanup must not raise + InvalidStateError inside the receive loop.""" + client = McpClient() + + class _ScriptedTransport(_HangingTransport): + def __init__(self): + super().__init__() + self.inbox: asyncio.Queue[JsonRpcMessage | None] = asyncio.Queue() + + async def receive(self) -> JsonRpcMessage | None: + return await self.inbox.get() + + transport = _ScriptedTransport() + client._transport = transport + receive_task = asyncio.create_task(client._receive_loop()) + + # A cancelled future still registered in pending (the race window). + loop = asyncio.get_event_loop() + cancelled_future: asyncio.Future[Any] = loop.create_future() + cancelled_future.cancel() + client._pending_requests[1] = cancelled_future + + # A live request that must still resolve afterwards. + live_future: asyncio.Future[Any] = loop.create_future() + client._pending_requests[2] = live_future + + await transport.inbox.put(JsonRpcMessage(id=1, result={"late": True})) + await transport.inbox.put(JsonRpcMessage(id=2, result={"ok": True})) + + assert await asyncio.wait_for(live_future, timeout=2) == {"ok": True} + await transport.inbox.put(None) + await asyncio.wait_for(receive_task, timeout=2) + + +class TestToolWrapperAbortPropagation: + @pytest.mark.asyncio + async def test_wrapper_reraises_abort_error(self, tmp_path): + """AbortError must escape the wrapper's except-Exception so the + dispatch layer renders the user-cancel message.""" + from src.permissions.types import ToolPermissionContext + from src.services.mcp.tool_wrapper import wrap_mcp_tool + from src.services.mcp.types import McpToolSchema + from src.tool_system.context import ToolContext + + class _AbortingClient: + async def call_tool(self, *args, **kwargs): + raise AbortError("user_interrupt") + + tool = wrap_mcp_tool( + "srv", + McpToolSchema(name="t", description="", input_schema={"type": "object"}), + _AbortingClient(), # type: ignore[arg-type] + ) + ctx = ToolContext( + workspace_root=tmp_path, + permission_context=ToolPermissionContext(mode="bypassPermissions"), + ) + with pytest.raises(AbortError): + tool.call({}, ctx) diff --git a/tests/test_mcp_client_full.py b/tests/test_mcp_client_full.py index d313b465..5c551fbe 100644 --- a/tests/test_mcp_client_full.py +++ b/tests/test_mcp_client_full.py @@ -229,7 +229,7 @@ async def test_session_expired_triggers_reconnect_and_retry_succeeds(self): # returns success. Tracker captures the call sequence. calls = [] - async def fake_send_request(method, params=None): + async def fake_send_request(method, params=None, abort_signal=None): calls.append(method) if len(calls) == 1: raise McpToolCallError( @@ -270,7 +270,7 @@ async def test_session_expired_retry_propagates_second_failure(self): calls = [] - async def fake_send_request(method, params=None): + async def fake_send_request(method, params=None, abort_signal=None): calls.append(method) raise McpToolCallError( '{"code":32600,"message":"Session terminated"}', @@ -305,7 +305,7 @@ async def test_non_session_expired_error_propagates_without_retry(self): calls = [] reconnect_called = [] - async def fake_send_request(method, params=None): + async def fake_send_request(method, params=None, abort_signal=None): calls.append(method) raise McpToolCallError( '{"code":-32602,"message":"Invalid params"}', @@ -350,7 +350,7 @@ async def test_session_expired_with_failed_reconnect_propagates_original(self): "Session expired", ) - async def fake_send_request(method, params=None): + async def fake_send_request(method, params=None, abort_signal=None): raise original async def fake_reconnect(): @@ -387,7 +387,7 @@ async def test_concurrent_session_expired_calls_share_one_reconnect(self): N_CALLERS = 10 call_history: list[str] = [] - async def fake_send_request(method, params=None): + async def fake_send_request(method, params=None, abort_signal=None): idx = len(call_history) call_history.append(method) if idx < N_CALLERS: