From db91d48b74e1953569e84f78b79353ff298f02f7 Mon Sep 17 00:00:00 2001 From: Eric Lee Date: Thu, 11 Jun 2026 17:13:48 -0700 Subject: [PATCH] web tools: honor abort_controller in WebFetch/WebSearch (#276) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ESC during a slow fetch blocked the agent until the urllib socket timeout (15-20s) — the longest interactive stall left in the tool set. New src/utils/abortable_net.py primitives: - call_with_abort: blocking call on a daemon worker thread, abort polled at 50ms; the caller unblocks immediately on ESC while the worker dies at its (bounded) socket timeout - abortable_read: chunked body read with an abort listener that shuts the socket down (close() alone does not interrupt a recv blocked on another thread) so mid-body stalls cancel instantly WebFetch threads the signal through the redirect loop, the open, the body read, and the Cloudflare-UA retry; WebSearch wraps the Tavily request. Both raise AbortError, which dispatch already renders as the user-cancel message. Closes #276, closes #170 Co-Authored-By: Claude Opus 4.7 --- src/tool_system/tools/web_fetch.py | 27 +++- src/tool_system/tools/web_search.py | 20 ++- src/utils/abortable_net.py | 124 ++++++++++++++++ tests/test_web_abort.py | 221 ++++++++++++++++++++++++++++ 4 files changed, 381 insertions(+), 11 deletions(-) create mode 100644 src/utils/abortable_net.py create mode 100644 tests/test_web_abort.py diff --git a/src/tool_system/tools/web_fetch.py b/src/tool_system/tools/web_fetch.py index 5ff02feb..9b46a5f0 100644 --- a/src/tool_system/tools/web_fetch.py +++ b/src/tool_system/tools/web_fetch.py @@ -34,6 +34,7 @@ PermissionPassthroughResult, PermissionResult, ) +from src.utils.abortable_net import abortable_read, call_with_abort # -- HTML to Markdown ---------------------------------------------------------- @@ -231,9 +232,9 @@ def _charset_from_content_type(content_type: str) -> str | None: return match.group(1).strip().strip('"\'') if match else None -def _read_response_body(resp) -> str: +def _read_response_body(resp, abort_signal=None) -> str: """Read, transparently decompress (gzip/deflate), and decode a response body.""" - raw = resp.read(_MAX_FETCH_BYTES) + raw = abortable_read(resp, _MAX_FETCH_BYTES, abort_signal) encoding = (resp.headers.get("Content-Encoding") or "").lower() if "gzip" in encoding: try: @@ -264,20 +265,30 @@ def _is_cloudflare_challenge(e: urllib.error.HTTPError) -> bool: def _fetch_with_redirect_handling( - url: str, timeout: float = 15, fmt: str = "markdown", user_agent: str = _BROWSER_UA + url: str, + timeout: float = 15, + fmt: str = "markdown", + user_agent: str = _BROWSER_UA, + abort_signal=None, ) -> tuple[str, str, int]: opener = urllib.request.build_opener(_NoRedirectHandler) current_url = url for _ in range(_MAX_REDIRECTS): + if abort_signal is not None: + abort_signal.throw_if_aborted() req = urllib.request.Request(current_url, headers=_request_headers(fmt, user_agent)) try: - resp = opener.open(req, timeout=timeout) + resp = call_with_abort( + lambda: opener.open(req, timeout=timeout), abort_signal + ) content_type = resp.headers.get("Content-Type", "") - return _read_response_body(resp), content_type, resp.status + return _read_response_body(resp, abort_signal), content_type, resp.status except urllib.error.HTTPError as e: # Cloudflare challenged the browser UA -> retry once with a bot UA. if _is_cloudflare_challenge(e) and user_agent != _FALLBACK_UA: - return _fetch_with_redirect_handling(url, timeout, fmt, _FALLBACK_UA) + return _fetch_with_redirect_handling( + url, timeout, fmt, _FALLBACK_UA, abort_signal + ) if e.code in (301, 302, 303, 307, 308): redirect_url = e.headers.get("Location", "") if not redirect_url: @@ -476,7 +487,9 @@ def _web_fetch_call(tool_input: dict[str, Any], context: ToolContext) -> ToolRes if cached: content, content_type, status = cached else: - raw, content_type, status = _fetch_with_redirect_handling(url, fmt=fmt) + raw, content_type, status = _fetch_with_redirect_handling( + url, fmt=fmt, abort_signal=context.abort_controller.signal + ) content = _convert(raw, content_type, fmt) _cache_set(cache_key, content, content_type, status) diff --git a/src/tool_system/tools/web_search.py b/src/tool_system/tools/web_search.py index d4afee0f..91bf86d8 100644 --- a/src/tool_system/tools/web_search.py +++ b/src/tool_system/tools/web_search.py @@ -9,6 +9,8 @@ from typing import Any from urllib.parse import urlparse +from src.utils.abortable_net import call_with_abort + from ..build_tool import Tool, ValidationResult, build_tool from ..context import ToolContext from ..errors import ToolInputError @@ -130,7 +132,9 @@ def is_web_search_configured() -> bool: return _tavily_api_key() is not None -def _tavily_search(query: str, num: int = 15) -> list[dict[str, str]]: +def _tavily_search( + query: str, num: int = 15, abort_signal=None +) -> list[dict[str, str]]: """Search the web via Tavily. Raises ``ToolInputError`` when ``TAVILY_API_KEY`` is unset (so the model and @@ -153,9 +157,15 @@ def _tavily_search(query: str, num: int = 15) -> list[dict[str, str]]: method="POST", headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}, ) - try: + + def _request() -> str: with urllib.request.urlopen(req, timeout=20) as resp: - raw = resp.read(2_000_000).decode("utf-8", errors="replace") + return resp.read(2_000_000).decode("utf-8", errors="replace") + + try: + # ESC unblocks the caller immediately; the worker dies at the + # 20s socket timeout (#276). + raw = call_with_abort(_request, abort_signal) except urllib.error.HTTPError as exc: detail = "" try: @@ -333,7 +343,9 @@ def _web_search_call(tool_input: dict[str, Any], context: ToolContext) -> ToolRe start_time = time.monotonic() # Search via Tavily (requires TAVILY_API_KEY). - results = _tavily_search(query, num=15) + results = _tavily_search( + query, num=15, abort_signal=context.abort_controller.signal + ) # Apply domain filters results = _apply_domain_filters( diff --git a/src/utils/abortable_net.py b/src/utils/abortable_net.py new file mode 100644 index 00000000..a7d71060 --- /dev/null +++ b/src/utils/abortable_net.py @@ -0,0 +1,124 @@ +"""Abort-aware wrappers for blocking ``urllib`` network calls. + +ESC-cancel support for WebFetch/WebSearch (#276): ``urllib`` has no +cancellation primitive, so cancellation is built from two mechanisms: + +- ``call_with_abort``: run the blocking call on a daemon worker thread and + poll the abort signal from the caller; on abort the CALLER unblocks + immediately (raises ``AbortError``) while the worker dies at its socket + timeout. A late-arriving response is closed so the socket isn't leaked. +- ``abortable_read``: chunked body read with an abort listener that closes + the response — closing the underlying socket unblocks a ``read()`` that + is mid-await between bytes, which polling alone cannot do. +""" + +from __future__ import annotations + +import socket +import threading +from typing import Any, Callable, TypeVar + +from .abort_controller import AbortError, AbortSignal + +T = TypeVar("T") + +_POLL_INTERVAL_S = 0.05 +_READ_CHUNK_BYTES = 65536 + + +def _safe_close(obj: Any) -> None: + # ``close()`` alone does NOT interrupt a ``recv`` blocked on another + # thread — the fd stays referenced by the in-flight read. Shut the + # socket down first (http.client internals, best-effort) so the + # blocked read raises immediately instead of waiting out the timeout. + try: + sock = obj.fp.raw._sock + sock.shutdown(socket.SHUT_RDWR) + except Exception: + pass + try: + obj.close() + except Exception: + pass + + +def call_with_abort(fn: Callable[[], T], abort_signal: AbortSignal | None) -> T: + """Run blocking ``fn`` and return its result, raising ``AbortError`` + the moment ``abort_signal`` trips. + + On abort the worker thread is abandoned (it exits at its socket + timeout, bounded by the caller's ``timeout=`` argument to urllib); if + its result arrives after the abort it is closed and discarded. + """ + if abort_signal is None: + return fn() + abort_signal.throw_if_aborted() + + result: list[T] = [] + error: list[BaseException] = [] + done = threading.Event() + + def _worker() -> None: + try: + value = fn() + if abort_signal.aborted: + _safe_close(value) + else: + result.append(value) + except BaseException as exc: # noqa: BLE001 — relayed to the caller + error.append(exc) + finally: + done.set() + + thread = threading.Thread( + target=_worker, name="abortable-net-call", daemon=True + ) + thread.start() + while not done.wait(_POLL_INTERVAL_S): + if abort_signal.aborted: + raise AbortError(abort_signal.reason or "user_interrupt") + if abort_signal.aborted: + raise AbortError(abort_signal.reason or "user_interrupt") + if error: + raise error[0] + return result[0] + + +def abortable_read( + resp: Any, max_bytes: int, abort_signal: AbortSignal | None +) -> bytes: + """Read up to ``max_bytes`` from ``resp`` in chunks, raising + ``AbortError`` if ``abort_signal`` trips mid-read. + + An abort listener closes ``resp`` so a read blocked between bytes + unblocks immediately instead of waiting out the socket timeout. + """ + if abort_signal is None: + return resp.read(max_bytes) + abort_signal.throw_if_aborted() + + def _close_on_abort() -> None: + _safe_close(resp) + + registered = abort_signal.add_listener(_close_on_abort, once=True) + chunks: list[bytes] = [] + remaining = max_bytes + try: + while remaining > 0: + abort_signal.throw_if_aborted() + try: + chunk = resp.read(min(_READ_CHUNK_BYTES, remaining)) + except Exception: + if abort_signal.aborted: + raise AbortError( + abort_signal.reason or "user_interrupt" + ) from None + raise + if not chunk: + break + chunks.append(chunk) + remaining -= len(chunk) + abort_signal.throw_if_aborted() + finally: + abort_signal.remove_listener(registered) + return b"".join(chunks) diff --git a/tests/test_web_abort.py b/tests/test_web_abort.py new file mode 100644 index 00000000..e7fbb57e --- /dev/null +++ b/tests/test_web_abort.py @@ -0,0 +1,221 @@ +"""#276 — WebFetch / WebSearch honor abort_controller (ESC-cancel). + +Covers the ``src.utils.abortable_net`` primitives and their wiring into +the web tools: an abort mid-connect or mid-read must unblock the caller +in ~poll-interval time (not the 15-20s socket timeout) and surface as +``AbortError`` so the dispatch layer renders the user-cancel message. +""" +from __future__ import annotations + +import http.server +import threading +import time +import urllib.request + +import pytest + +from src.utils.abort_controller import AbortController, AbortError +from src.utils.abortable_net import abortable_read, call_with_abort + + +def _abort_after(controller: AbortController, delay_s: float) -> threading.Thread: + t = threading.Timer(delay_s, lambda: controller.abort("user_interrupt")) + t.daemon = True + t.start() + return t + + +class TestCallWithAbort: + def test_returns_result_without_signal(self): + assert call_with_abort(lambda: 42, None) == 42 + + def test_returns_result_with_untripped_signal(self): + assert call_with_abort(lambda: "ok", AbortController().signal) == "ok" + + def test_pre_aborted_raises_immediately_without_calling_fn(self): + controller = AbortController() + controller.abort("user_interrupt") + called = [] + with pytest.raises(AbortError): + call_with_abort(lambda: called.append(1), controller.signal) + assert called == [] + + def test_abort_mid_call_unblocks_fast(self): + controller = AbortController() + release = threading.Event() + + def _slow(): + release.wait(10) + return "late" + + _abort_after(controller, 0.1) + start = time.monotonic() + with pytest.raises(AbortError): + call_with_abort(_slow, controller.signal) + elapsed = time.monotonic() - start + release.set() + assert elapsed < 2.0, f"abort took {elapsed:.2f}s — should be ~0.1s" + + def test_worker_exception_propagates(self): + with pytest.raises(ValueError, match="boom"): + call_with_abort( + lambda: (_ for _ in ()).throw(ValueError("boom")), + AbortController().signal, + ) + + def test_late_result_after_abort_is_closed(self): + controller = AbortController() + closed = threading.Event() + + class _Resource: + def close(self): + closed.set() + + release = threading.Event() + + def _slow(): + release.wait(5) + return _Resource() + + _abort_after(controller, 0.05) + with pytest.raises(AbortError): + call_with_abort(_slow, controller.signal) + release.set() + assert closed.wait(2.0), "late-arriving resource was not closed" + + +class _FakeResponse: + """Chunked reader that blocks until closed after the first chunk.""" + + def __init__(self): + self._sent_first = False + self._closed = threading.Event() + + def read(self, n: int) -> bytes: + if not self._sent_first: + self._sent_first = True + return b"x" * min(n, 10) + # Block like a stalled socket until close() unblocks us. + self._closed.wait(10) + raise OSError("read on closed connection") + + def close(self): + self._closed.set() + + +class TestAbortableRead: + def test_reads_fully_without_signal(self): + class _R: + def read(self, n): + return b"abc" + + assert abortable_read(_R(), 3, None) == b"abc" + + def test_reads_fully_with_untripped_signal(self): + chunks = [b"aa", b"bb", b""] + + class _R: + def read(self, n): + return chunks.pop(0) + + out = abortable_read(_R(), 1000, AbortController().signal) + assert out == b"aabb" + + def test_abort_mid_read_closes_resp_and_raises_fast(self): + controller = AbortController() + resp = _FakeResponse() + _abort_after(controller, 0.1) + start = time.monotonic() + with pytest.raises(AbortError): + abortable_read(resp, 1_000_000, controller.signal) + elapsed = time.monotonic() - start + assert elapsed < 2.0, f"abort took {elapsed:.2f}s" + assert resp._closed.is_set() + + def test_listener_removed_on_success(self): + controller = AbortController() + + class _R: + def read(self, n): + return b"" + + abortable_read(_R(), 100, controller.signal) + assert controller.signal._listeners == [] + + +class _StallingHandler(http.server.BaseHTTPRequestHandler): + """Sends headers then stalls the body — a hung server. + + The stall must exceed every elapsed-time assertion below: a passing + test proves the abort unblocked the client while the server was + still stalling (teardown waits out the remainder, keep it short).""" + + stall_s = 3.0 + + def do_GET(self): # noqa: N802 — BaseHTTPRequestHandler API + self.send_response(200) + self.send_header("Content-Type", "text/plain") + self.send_header("Content-Length", "1000") + self.end_headers() + self.wfile.write(b"partial") + self.wfile.flush() + time.sleep(self.stall_s) + + def do_POST(self): # noqa: N802 — Tavily search POSTs + self.do_GET() + + def log_message(self, *args): # silence test output + pass + + +@pytest.fixture +def stalling_server(): + server = http.server.HTTPServer(("127.0.0.1", 0), _StallingHandler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + yield f"http://127.0.0.1:{server.server_port}/" + server.shutdown() + server.server_close() + + +class TestWebFetchAbortIntegration: + def test_abort_mid_body_unblocks_fast(self, stalling_server): + from src.tool_system.tools.web_fetch import _fetch_with_redirect_handling + + controller = AbortController() + _abort_after(controller, 0.2) + start = time.monotonic() + with pytest.raises(AbortError): + _fetch_with_redirect_handling( + stalling_server, timeout=8, abort_signal=controller.signal + ) + elapsed = time.monotonic() - start + assert elapsed < 2.0, f"abort took {elapsed:.2f}s — stalled to timeout" + + def test_pre_aborted_signal_raises_before_any_io(self): + from src.tool_system.tools.web_fetch import _fetch_with_redirect_handling + + controller = AbortController() + controller.abort("user_interrupt") + with pytest.raises(AbortError): + # Unroutable TEST-NET address: if this ever attempts I/O the + # test hangs instead of failing fast — the raise must come first. + _fetch_with_redirect_handling( + "http://192.0.2.1/", timeout=1, abort_signal=controller.signal + ) + + +class TestWebSearchAbortIntegration: + def test_abort_unblocks_tavily_request(self, stalling_server, monkeypatch): + from src.tool_system.tools import web_search as ws + + monkeypatch.setattr(ws, "_TAVILY_URL", stalling_server) + monkeypatch.setattr(ws, "_tavily_api_key", lambda: "tvly-test") + + controller = AbortController() + _abort_after(controller, 0.2) + start = time.monotonic() + with pytest.raises(AbortError): + ws._tavily_search("query", abort_signal=controller.signal) + elapsed = time.monotonic() - start + assert elapsed < 2.0, f"abort took {elapsed:.2f}s"