diff --git a/src/providers/_stream_worker.py b/src/providers/_stream_worker.py new file mode 100644 index 00000000..bcb5b80f --- /dev/null +++ b/src/providers/_stream_worker.py @@ -0,0 +1,148 @@ +"""Worker-thread + bounded-queue stream consumption shared by providers. + +Extracted from ``openai_compatible.py`` (#279) so AnthropicProvider and +MinimaxProvider get the same ESC-unwind guarantees: the SDKs' sync +``httpx`` reads don't reliably honor a cross-thread ``response.close()`` +behind buffering proxies (LiteLLM, corporate proxies, mitmproxy), so the +blocking iteration runs on a daemon worker thread and the calling thread +polls a bounded queue, re-checking the abort signal between ticks. + +Guarantees (pinned by tests/test_openai_compat_abort_signal.py and the +provider-specific abort tests): + +- ESC unblocks the caller within ~100 ms regardless of SDK behavior. +- Items received BEFORE the abort are still delivered to ``on_item``; + nothing is delivered after (the worker stops enqueueing the moment + the abort trips). +- The queue is bounded (#278): a proxy that keeps sending after ESC + cannot grow memory; the worker stops READING within one put-poll. +- A consumer that dies for a non-abort reason (``on_item`` raising) + releases the worker via ``consumer_gone`` instead of leaving it + retrying a full queue forever. +""" + +from __future__ import annotations + +import contextlib +import logging +import queue +import threading +from typing import Any, Callable, TypeVar + +from ._stream_abort import StreamAbortGuard + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + +_DONE = object() +_QUEUE_MAXSIZE = 64 +_PUT_POLL_S = 0.25 +_GET_POLL_S = 0.1 + +# emit(item) -> bool: False means "stop producing" (abort/consumer gone). +Emit = Callable[[Any], bool] + + +def run_stream_on_worker( + produce: Callable[[Emit], T], + on_item: Callable[[Any], None], + guard: StreamAbortGuard, + *, + thread_name: str = "provider-stream", +) -> T | None: + """Run ``produce(emit)`` on a daemon worker thread, delivering each + emitted item to ``on_item`` on the calling thread. + + Returns ``produce``'s return value. Raises ``AbortError`` promptly + when the guard's signal trips; re-raises ``produce``'s exception + otherwise (translated through ``guard.reraise_if_aborted`` first, so + an SDK error caused by the close-on-abort listener surfaces as + ``AbortError`` with the original as its cause). + + ``produce`` must treat ``emit(...) is False`` as "stop now": the + consumer is gone or the user aborted, and nothing further will be + drained. + """ + chunk_queue: queue.Queue = queue.Queue(maxsize=_QUEUE_MAXSIZE) + consumer_gone = threading.Event() + + def _put_or_drop(item: Any) -> bool: + while True: + if guard.aborted or consumer_gone.is_set(): + return False + try: + chunk_queue.put(item, timeout=_PUT_POLL_S) + return True + except queue.Full: + continue + + def _emit(item: Any) -> bool: + return _put_or_drop(("item", item)) + + def _worker() -> None: + try: + value = produce(_emit) + _put_or_drop(("result", value)) + except BaseException as exc: # noqa: BLE001 — relayed to the consumer + if not _put_or_drop(("error", exc)): + # Abort won the race against a genuine error; the + # consumer raises AbortError, so keep the loser + # visible somewhere. + logger.debug("stream error dropped after abort", exc_info=exc) + finally: + _put_or_drop(_DONE) + + worker = threading.Thread(target=_worker, daemon=True, name=thread_name) + + outcome: tuple[str, Any] | None = None + with contextlib.ExitStack() as consumer_scope: + # Releases the worker (sets consumer_gone) no matter how the + # consumer loop exits — abort, on_item error, or natural break — + # so a blocked put never outlives its consumer. + consumer_scope.callback(consumer_gone.set) + worker.start() + while True: + try: + msg = chunk_queue.get(timeout=_GET_POLL_S) + except queue.Empty: + # The 100 ms tick bounds how long the user waits between + # ESC and the prompt returning, regardless of how slow / + # blocked the underlying SDK iteration is. + if guard.aborted: + guard.raise_if_post_aborted() + continue + + if msg is _DONE: + break + kind, payload = msg + if kind == "item": + on_item(payload) + # Check abort AFTER processing so any already-delivered + # item is preserved; we just don't take the next one. + if guard.aborted: + guard.raise_if_post_aborted() + continue + # "result" / "error" — terminal; the _DONE sentinel follows. + outcome = (kind, payload) + + # Error outcomes FIRST (before the post-abort check) so a relayed + # exception keeps its chain — an abort racing in after the relay + # surfaces as ``AbortError from payload`` via reraise_if_aborted, + # and a relayed KeyboardInterrupt/SystemExit is re-raised as-is so + # the outer signal-handling story stays intact (pre-refactor + # semantics: the error was raised at dequeue time). + if outcome is not None and outcome[0] == "error": + payload = outcome[1] + if isinstance(payload, Exception): + guard.reraise_if_aborted(payload) + raise payload + + # The signal may have fired between the worker finishing and here. + guard.raise_if_post_aborted() + + if outcome is not None: + return outcome[1] + # Defensive: an outcome-less _DONE should not occur (a dropped + # result implies a tripped abort, which raises above). + return None diff --git a/src/providers/anthropic_provider.py b/src/providers/anthropic_provider.py index e6fbd6e5..d7760989 100644 --- a/src/providers/anthropic_provider.py +++ b/src/providers/anthropic_provider.py @@ -346,9 +346,15 @@ def _fallback_to_chat() -> ChatResponse: max_tokens=max_tokens, ) + from ._stream_worker import run_stream_on_worker + streamed_text = "" watchdog_fired = False final_message = None + # Written by the worker (inside ``_produce``'s finally, BEFORE + # the exception/result is relayed to this thread) and read by + # the except handler / post-stream branch below — never raced. + _watchdog_state = {"fired": False} try: with client.messages.stream( model=model, @@ -362,32 +368,55 @@ def _fallback_to_chat() -> ChatResponse: # (see ``_stream_abort.py`` for the race-safe ordering # and the close-via-stream.response.close mechanism). # The provider keeps the watchdog and fallback logic - # local: they aren't abort-related. + # local: they aren't abort-related. The ITERATION runs on + # a worker thread (#279, see ``_stream_worker.py``) so + # ESC unwinds promptly even when a buffering proxy keeps + # the SDK's blocking read alive after the listener's + # close. watchdog = StreamWatchdog(stream) watchdog.arm() - try: - for text in stream.text_stream: - # Each chunk pushes the deadline forward. - watchdog.reset() - if not text: - continue - streamed_text += text - if on_text_chunk is not None: - on_text_chunk(text) + + def _produce(emit): try: - final_message = stream.get_final_message() - except Exception: - final_message = None + for text in stream.text_stream: + # Each chunk pushes the deadline forward. + watchdog.reset() + if not text: + continue + if not emit(text): + return None # abort/consumer gone + try: + return stream.get_final_message() + except Exception: + return None + finally: + # Snapshot watchdog state BEFORE the result or + # exception is relayed to the consumer (critic + # B1 lineage: the except handler below reads it). + _watchdog_state["fired"] = watchdog.fired + watchdog.disarm() + + def _on_text(text: str) -> None: + nonlocal streamed_text + streamed_text += text + if on_text_chunk is not None: + on_text_chunk(text) + + try: + final_message = run_stream_on_worker( + _produce, _on_text, guard, thread_name="anthropic-stream" + ) finally: - # Snapshot watchdog state INSIDE the finally so it - # survives an exception propagating through the - # iterator (close() raises mid-stream). Critic B1 - # caught this — otherwise the assignment was on a - # line never reached during the exception path and - # the fallback branch below ran with watchdog_fired - # still False. - watchdog_fired = watchdog.fired + # Consumer-side disarm guarantee: on the abort path + # against a stuck stream, the worker (and _produce's + # finally) may never unblock — without this, the + # armed 90s timer would leak per ESC. disarm() is + # idempotent; get_final_message racing the with- + # block __exit__ on the worker is benign (httpx + # raises cleanly on cross-thread close; the result + # is dropped post-abort). watchdog.disarm() + watchdog_fired = _watchdog_state["fired"] except Exception as streaming_exc: # Abort path FIRST: a user cancel must win over the # watchdog fallback (the abort listener may also have @@ -399,8 +428,10 @@ def _fallback_to_chat() -> ChatResponse: # WI-5.2 fallback path: stream interrupted by the idle # watchdog. Fall back to non-streaming so the user still # gets an answer. If the failure is something else - # (network/auth/etc.), re-raise the original. - if watchdog_fired: + # (network/auth/etc.), re-raise the original. Read the + # shared state, not the local: the local assignment after + # run_stream_on_worker never ran on this path. + if _watchdog_state["fired"]: try: return _fallback_to_chat() except Exception as fallback_exc: diff --git a/src/providers/minimax_provider.py b/src/providers/minimax_provider.py index 03e7d477..37dbbbcb 100644 --- a/src/providers/minimax_provider.py +++ b/src/providers/minimax_provider.py @@ -192,6 +192,8 @@ def chat_stream_response( if tools: extra_kwargs["tools"] = tools + from ._stream_worker import run_stream_on_worker + streamed_text = "" final_message: Any = None try: @@ -203,16 +205,30 @@ def chat_stream_response( **extra_kwargs, **{k: v for k, v in kwargs.items() if k not in ["model", "max_tokens", "tools"]}, ) as stream, guard.attach(stream): - for text in stream.text_stream: - if not text: - continue + # Iteration runs on a worker thread (#279, see + # ``_stream_worker.py``) so ESC unwinds promptly even + # when a buffering proxy keeps the SDK's blocking read + # alive after the listener's close. + def _produce(emit): + for text in stream.text_stream: + if not text: + continue + if not emit(text): + return None # abort/consumer gone + try: + return stream.get_final_message() + except Exception: + return None + + def _on_text(text: str) -> None: + nonlocal streamed_text streamed_text += text if on_text_chunk is not None: on_text_chunk(text) - try: - final_message = stream.get_final_message() - except Exception: - final_message = None + + final_message = run_stream_on_worker( + _produce, _on_text, guard, thread_name="minimax-stream" + ) except Exception as streaming_exc: guard.reraise_if_aborted(streaming_exc) raise diff --git a/src/providers/openai_compatible.py b/src/providers/openai_compatible.py index b91076cc..189008ff 100644 --- a/src/providers/openai_compatible.py +++ b/src/providers/openai_compatible.py @@ -632,170 +632,76 @@ def chat_stream_response( usage_obj: Any = None tool_calls_by_index: dict[int, dict[str, str]] = {} - # Worker-thread iteration. The OpenAI Python SDK uses sync - # ``httpx`` for streaming, and ``response.close()`` from another - # thread is best-effort — for LiteLLM-proxied connections (and - # some other httpx configurations) the SDK's blocking socket - # read doesn't actually return when the response is closed. - # Unlike JavaScript's native ``fetch + AbortSignal`` integration - # (which the TypeScript reference uses), Python has no portable - # way to make a sync blocking read honor an abort from another - # thread. - # - # Workaround: hoist the iteration onto a daemon worker thread - # that pushes chunks into a bounded queue. The main thread polls - # the queue with a short timeout and re-checks ``guard.aborted`` - # each tick. On abort we raise ``AbortError`` immediately; the - # worker notices the abort (or the consumer's exit) at its next - # put attempt and stops reading the stream within one 0.25s - # poll. The benefit is that the user's prompt comes back in - # ~100 ms regardless of LiteLLM/httpx behavior. - import queue as _queue - import threading as _threading - - _DONE = object() - # Bounded (#278): after ESC the consumer stops draining, and a - # proxy that keeps sending bytes without closing the iterator - # would otherwise grow the queue without limit. 64 bounds the - # post-abort staleness to a trivial drain while giving the - # producer slack against transient consumer pauses. - chunk_queue: _queue.Queue = _queue.Queue(maxsize=64) - # Set when the consumer loop exits for ANY reason. Without it, a - # consumer that unwinds for a non-abort reason (on_text_chunk - # raising, KeyboardInterrupt) would leave the worker retrying a - # full queue forever — an immortal thread pinning the httpx - # connection open. - consumer_gone = _threading.Event() - - def _put_or_drop_on_abort(item: Any) -> bool: - """Block until ``item`` is enqueued, or drop it once the - abort trips or the consumer exits (either way nobody will - drain it; keeping nothing alive is the point). Returns - False when dropped.""" - while True: - if guard.aborted or consumer_gone.is_set(): - return False - try: - chunk_queue.put(item, timeout=0.25) - return True - except _queue.Full: - continue - - def _drain_stream() -> None: - try: - for c in stream: - if not _put_or_drop_on_abort(c): - return # stop reading; orphaned socket dies upstream - except BaseException as exc: # noqa: BLE001 — surface to consumer - if not _put_or_drop_on_abort(exc): - # Abort won the race against a genuine error; the - # consumer raises AbortError, so keep the loser - # visible somewhere. - logger.debug( - "stream error dropped after abort", exc_info=exc - ) - finally: - _put_or_drop_on_abort(_DONE) - - worker = _threading.Thread( - target=_drain_stream, - daemon=True, - name=f"openai-stream-{id(stream)}", - ) - - import contextlib as _contextlib - - with _contextlib.ExitStack() as _consumer_scope: - # Releases the worker (sets consumer_gone) no matter how the - # consumer loop exits — abort, callback error, or natural - # break — so a blocked put never outlives its consumer. - _consumer_scope.callback(consumer_gone.set) - _consumer_scope.enter_context(guard.attach(stream)) - worker.start() - while True: - try: - item = chunk_queue.get(timeout=0.1) - except _queue.Empty: - # No chunk available right now — check abort and - # loop. The 100 ms tick bounds how long the user - # waits between pressing ESC and the prompt - # returning, regardless of how slow / blocked the - # underlying SDK iteration is. - if guard.aborted: - # Use ``raise_if_post_aborted`` so the abort - # reason from the controller is preserved - # (rather than hardcoding ``"user_interrupt"``, - # which would silently downgrade a non-default - # reason like a future ``"rate_limit_backoff"``). - guard.raise_if_post_aborted() - continue - - if item is _DONE: - break - if isinstance(item, BaseException): - if isinstance(item, Exception): - guard.reraise_if_aborted(item) - raise item - # KeyboardInterrupt/SystemExit from the worker - # path — re-raise as-is so the outer signal- - # handling story stays intact. - raise item - - chunk = item - response_model = getattr(chunk, "model", response_model) - usage_candidate = getattr(chunk, "usage", None) - if usage_candidate is not None: - usage_obj = usage_candidate - - choices = getattr(chunk, "choices", None) or [] - if choices: - choice = choices[0] - if getattr(choice, "finish_reason", None): - finish_reason = choice.finish_reason - - delta = getattr(choice, "delta", None) - if delta is not None: - content_piece = getattr(delta, "content", None) - if content_piece: - piece = str(content_piece) - content_parts.append(piece) - if on_text_chunk is not None: - on_text_chunk(piece) - - reasoning_piece = getattr(delta, "reasoning_content", None) - if reasoning_piece: - reasoning_parts.append(str(reasoning_piece)) - - tool_call_deltas = getattr(delta, "tool_calls", None) or [] - for tc in tool_call_deltas: - idx = getattr(tc, "index", 0) - entry = tool_calls_by_index.setdefault(idx, {"id": "", "name": "", "arguments": ""}) - - tc_id = getattr(tc, "id", None) - if tc_id: - entry["id"] = str(tc_id) - - function = getattr(tc, "function", None) - if function is not None: - fn_name = getattr(function, "name", None) - if fn_name: - entry["name"] += str(fn_name) - fn_args = getattr(function, "arguments", None) - if fn_args: - entry["arguments"] += str(fn_args) - - # Check abort AFTER processing this chunk so any - # already-delivered content is preserved (matches the - # in-loop-check semantics from the old implementation: - # the chunk-list test pins that the chunk we received - # before the abort gets processed; we just don't take - # the next one). - if guard.aborted: - guard.raise_if_post_aborted() - - # Stream completed naturally OR the in-loop check broke out. - # In the latter case the signal is already tripped; raise so - # the caller bails at the same place every other path does. + # Worker-thread iteration (see ``_stream_worker.py``): the OpenAI + # Python SDK uses sync ``httpx`` for streaming, and + # ``response.close()`` from another thread is best-effort — for + # LiteLLM-proxied connections (and some other httpx + # configurations) the SDK's blocking socket read doesn't return + # when the response is closed. The worker+bounded-queue pattern + # unblocks ESC in ~100 ms regardless (#278/#279). + from ._stream_worker import run_stream_on_worker + + def _produce(emit): + for c in stream: + if not emit(c): + return None # abort/consumer gone — stop reading + return None + + def _on_chunk(chunk: Any) -> None: + nonlocal response_model, finish_reason, usage_obj + response_model = getattr(chunk, "model", response_model) + usage_candidate = getattr(chunk, "usage", None) + if usage_candidate is not None: + usage_obj = usage_candidate + + choices = getattr(chunk, "choices", None) or [] + if choices: + choice = choices[0] + if getattr(choice, "finish_reason", None): + finish_reason = choice.finish_reason + + delta = getattr(choice, "delta", None) + if delta is not None: + content_piece = getattr(delta, "content", None) + if content_piece: + piece = str(content_piece) + content_parts.append(piece) + if on_text_chunk is not None: + on_text_chunk(piece) + + reasoning_piece = getattr(delta, "reasoning_content", None) + if reasoning_piece: + reasoning_parts.append(str(reasoning_piece)) + + tool_call_deltas = getattr(delta, "tool_calls", None) or [] + for tc in tool_call_deltas: + idx = getattr(tc, "index", 0) + entry = tool_calls_by_index.setdefault(idx, {"id": "", "name": "", "arguments": ""}) + + tc_id = getattr(tc, "id", None) + if tc_id: + entry["id"] = str(tc_id) + + function = getattr(tc, "function", None) + if function is not None: + fn_name = getattr(function, "name", None) + if fn_name: + entry["name"] += str(fn_name) + fn_args = getattr(function, "arguments", None) + if fn_args: + entry["arguments"] += str(fn_args) + + with guard.attach(stream): + run_stream_on_worker( + _produce, + _on_chunk, + guard, + thread_name=f"openai-stream-{id(stream)}", + ) + + # Stream completed naturally OR the abort check broke out. In + # the latter case the signal is already tripped; raise so the + # caller bails at the same place every other path does. guard.raise_if_post_aborted() tool_uses: list[dict[str, Any]] = [] diff --git a/tests/test_minimax_abort_signal.py b/tests/test_minimax_abort_signal.py index 94980670..eb4f836c 100644 --- a/tests/test_minimax_abort_signal.py +++ b/tests/test_minimax_abort_signal.py @@ -146,3 +146,51 @@ def test_listener_detached_after_normal_completion() -> None: ) assert controller.signal._listeners == [] + + +class _StuckMinimaxStream: + """Minimax (anthropic-SDK-shaped) stream whose iterator never honors + ``response.close()`` — the buffering-proxy scenario (#279).""" + + def __init__(self) -> None: + self.response = MagicMock() + self._never_set = threading.Event() + self._iter_entered = threading.Event() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + @property + def text_stream(self): + self._iter_entered.set() + self._never_set.wait() # blocks forever, even after close() + return + yield # pragma: no cover + + def get_final_message(self): # pragma: no cover — never reached + raise AssertionError("unreachable") + + +def test_abort_unwinds_promptly_even_when_iterator_never_returns() -> None: + """#279: the worker+queue decoupling ported from the OpenAI path.""" + controller = AbortController() + stream = _StuckMinimaxStream() + provider = _provider_with_stream(stream) + + def _trip_after_worker_starts() -> None: + assert stream._iter_entered.wait(timeout=2.0), "worker never entered iterator" + controller.abort("user_interrupt") + + threading.Thread(target=_trip_after_worker_starts, daemon=True).start() + + start = time.monotonic() + with pytest.raises(AbortError): + provider.chat_stream_response( + messages=[{"role": "user", "content": "hi"}], + abort_signal=controller.signal, + ) + elapsed = time.monotonic() - start + assert elapsed < 3.0, f"abort took {elapsed:.2f}s against a stuck iterator" diff --git a/tests/test_provider_abort_signal.py b/tests/test_provider_abort_signal.py index 1cca5983..b9d213af 100644 --- a/tests/test_provider_abort_signal.py +++ b/tests/test_provider_abort_signal.py @@ -234,3 +234,56 @@ def test_listener_detached_after_normal_completion() -> None: # No listeners should remain attached after the call completes. assert controller.signal._listeners == [] + + +class _StuckAnthropicStream: + """Anthropic-SDK-shaped stream whose iterator never honors + ``response.close()`` — the buffering-proxy scenario (#279). The + worker-thread iteration must not rely on the iterator unblocking.""" + + def __init__(self) -> None: + self.response = MagicMock() + self._never_set = threading.Event() + self._iter_entered = threading.Event() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + @property + def text_stream(self): + self._iter_entered.set() + self._never_set.wait() # blocks forever, even after close() + return + yield # pragma: no cover + + def get_final_message(self): # pragma: no cover — never reached + raise AssertionError("unreachable") + + +def test_abort_unwinds_promptly_even_when_iterator_never_returns() -> None: + """#279: the worker+queue decoupling ported from the OpenAI path. + + Without it, ESC against a proxy that ignores ``response.close()`` + left the caller blocked inside ``stream.text_stream`` until the + connection died on its own.""" + controller = AbortController() + stream = _StuckAnthropicStream() + provider = _provider_with_stream(stream) + + def _trip_after_worker_starts() -> None: + assert stream._iter_entered.wait(timeout=2.0), "worker never entered iterator" + controller.abort("user_interrupt") + + threading.Thread(target=_trip_after_worker_starts, daemon=True).start() + + start = time.monotonic() + with pytest.raises(AbortError): + provider.chat_stream_response( + messages=[{"role": "user", "content": "hi"}], + abort_signal=controller.signal, + ) + elapsed = time.monotonic() - start + assert elapsed < 3.0, f"abort took {elapsed:.2f}s against a stuck iterator"