diff --git a/examples/unstable/sandbox_06_streaming_files.py b/examples/unstable/sandbox_06_streaming_files.py new file mode 100644 index 0000000..824de8a --- /dev/null +++ b/examples/unstable/sandbox_06_streaming_files.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +"""Demonstrate streaming file upload and download on a Sandbox session.""" + +from datetime import timedelta +from tempfile import TemporaryDirectory +from uuid import uuid4 + +import anyio +from dotenv import load_dotenv + +from vercel.unstable import sandbox + +load_dotenv() + +DATA_SIZE = 1024 * 1024 # 1 MiB +CHUNK_SIZE = 64 * 1024 + + +async def main() -> None: + name = f"vercel-py-streaming-{uuid4().hex[:12]}" + with TemporaryDirectory() as directory: + source_path = anyio.Path(directory) / "source.bin" + target_path = anyio.Path(directory) / "target.bin" + await source_path.write_bytes(b"\x01" * DATA_SIZE) + + async with sandbox.create_sandbox( + name=name, + runtime="python3.13", + execution_time_limit=timedelta(minutes=2), + ) as box: + async with ( + await anyio.open_file(source_path, "rb") as source, + box.fs.open("workspace/reference.bin", "wb", permissions=0o600) as target, + ): + while chunk := await source.read(CHUNK_SIZE): + await target.write(chunk) + + copied = 0 + async with ( + box.fs.open("workspace/reference.bin", "rb") as source, + await anyio.open_file(target_path, "wb") as target, + ): + while chunk := await source.read(CHUNK_SIZE): + await target.write(chunk) + copied += len(chunk) + print(f"Downloaded {copied} bytes") + + assert await target_path.read_bytes() == b"\x01" * DATA_SIZE + + print("Streaming transfer complete") + + +if __name__ == "__main__": + anyio.run(main) diff --git a/pyproject.toml b/pyproject.toml index fca8d2b..14aae79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "httpx>=0.27.0", "pydantic>=2.7.0", "anyio>=4.0.0", - "typing-extensions>=4.0.0 ; python_version < '3.11'", + "typing-extensions>=4.6.0", "python-dotenv", "websockets>=12.0", "cbor2>=5.8.0,<6", diff --git a/src/vercel/_internal/byte_stream.py b/src/vercel/_internal/byte_stream.py new file mode 100644 index 0000000..5f17b6d --- /dev/null +++ b/src/vercel/_internal/byte_stream.py @@ -0,0 +1,253 @@ +"""Adapt byte sources for business logic shared by sync and async APIs. + +Shared code consumes the async-shaped ``ReadableByteStream`` and +``StagingByteFile`` protocols. Callers select the runtime matching their public +API, then use its factories instead of constructing the private adapters directly. +The sync runtime never suspends, while the async runtime awaits or offloads I/O as +appropriate. +""" + +import inspect +import io +import tempfile +from collections.abc import AsyncIterator +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Protocol, TypeAlias, cast + +import anyio +from typing_extensions import Buffer + + +class SyncByteReader(Protocol): + """Caller-provided byte source with a blocking ``read`` method.""" + + def read(self, size: int = -1, /) -> bytes: ... + + +class AsyncByteReader(Protocol): + """Caller-provided byte source with an asynchronous ``read`` method. + + This is structurally identical to ``ReadableByteStream`` but describes an + input that has not yet been normalized by a byte-stream runtime. + """ + + async def read(self, size: int = -1, /) -> bytes: ... + + +BytesLike: TypeAlias = bytes | bytearray | memoryview +SyncByteSource: TypeAlias = BytesLike | SyncByteReader +AsyncByteSource: TypeAlias = AsyncByteReader +RawByteSource: TypeAlias = SyncByteSource | AsyncByteSource + + +class ReadableByteStream(Protocol): + """Normalized readable stream consumed by shared internal workflows. + + Its async shape hides whether a runtime performs the read inline, awaits an + async source, or moves a blocking read to a worker thread. + """ + + async def read(self, size: int = -1, /) -> bytes: ... + + +class StagingByteFile(ReadableByteStream, Protocol): + """SDK-owned temporary byte file used by shared staging workflows. + + The runtime's temporary-file context manager owns the lifetime of streams + implementing this protocol. + """ + + async def write(self, data: bytes, /) -> int: ... + + async def readinto(self, buffer: Buffer, /) -> int: + """Read bytes into a writable buffer. + + ``Buffer`` cannot express mutability, so read-only buffers are rejected + at runtime. + """ + ... + + async def flush(self) -> None: ... + + async def tell(self) -> int: ... + + async def seek(self, offset: int, whence: int = 0, /) -> int: ... + + async def truncate(self, size: int | None = None, /) -> int: ... + + +class StagingFileRuntime(Protocol): + """Runtime-specific temporary-file capability for shared business logic.""" + + def temporary_file(self) -> AbstractAsyncContextManager[StagingByteFile]: ... + + +def _bytes_result(value: object) -> bytes: + if isinstance(value, bytes): + return value + raise TypeError(f"byte stream returned {type(value).__name__}, expected bytes") + + +class _SyncReader: + """Expose a blocking reader through an async-shaped, non-suspending method.""" + + def __init__(self, source: SyncByteReader) -> None: + self._source = source + + async def read(self, size: int = -1, /) -> bytes: + return _bytes_result(self._source.read(size)) + + +class _MemoryReader: + """Give an immutable bytes snapshot a stateful, non-suspending read cursor.""" + + def __init__(self, data: BytesLike) -> None: + self._data = memoryview(bytes(data)) + self._offset = 0 + + async def read(self, size: int = -1, /) -> bytes: + remaining = self._data[self._offset :] + if size < 0: + self._offset = len(self._data) + return bytes(remaining) + chunk = bytes(remaining[:size]) + self._offset += len(chunk) + return chunk + + +class _AsyncReader: + """Normalize a genuinely asynchronous reader and validate its results.""" + + def __init__(self, source: AsyncByteReader) -> None: + self._source = source + + async def read(self, size: int = -1, /) -> bytes: + return _bytes_result(await self._source.read(size)) + + +class _ThreadedSyncReader: + """Run a blocking reader on a worker thread for use by async workflows.""" + + def __init__(self, source: SyncByteReader) -> None: + self._source = source + + async def read(self, size: int = -1, /) -> bytes: + return _bytes_result(await anyio.to_thread.run_sync(self._source.read, size)) + + +class _SyncTemporaryFile: + """Expose a blocking temporary file through non-suspending async methods.""" + + def __init__(self) -> None: + self._file = cast(io.BufferedRandom, tempfile.TemporaryFile("w+b")) + + def _ensure_open(self) -> None: + if self._file.closed: + raise anyio.ClosedResourceError + + async def read(self, size: int = -1, /) -> bytes: + self._ensure_open() + return self._file.read(size) + + async def write(self, data: bytes, /) -> int: + self._ensure_open() + return self._file.write(data) + + async def readinto(self, buffer: Buffer, /) -> int: + """Read bytes into a writable buffer. + + ``Buffer`` cannot express mutability, so read-only buffers are rejected + at runtime. + """ + self._ensure_open() + return self._file.readinto(buffer) + + async def flush(self) -> None: + self._ensure_open() + self._file.flush() + + async def tell(self) -> int: + self._ensure_open() + return self._file.tell() + + async def seek(self, offset: int, whence: int = 0, /) -> int: + self._ensure_open() + return self._file.seek(offset, whence) + + async def truncate(self, size: int | None = None, /) -> int: + self._ensure_open() + return self._file.truncate(size) + + def close(self) -> None: + self._file.close() + + +@asynccontextmanager +async def _sync_temporary_file() -> AsyncIterator[StagingByteFile]: + file = _SyncTemporaryFile() + try: + yield file + finally: + file.close() + + +class SyncByteStreamRuntime: + """Adapt blocking byte primitives for shared async-shaped workflows. + + Every operation completes without suspending so sync entry points can drive + the shared coroutine with ``iter_coroutine`` and no event loop. + """ + + @staticmethod + def reader(source: SyncByteSource) -> ReadableByteStream: + if isinstance(source, (bytes, bytearray, memoryview)): + return _MemoryReader(source) + read = getattr(source, "read", None) + if not callable(read): + raise TypeError("byte source must provide a callable read method") + if inspect.iscoroutinefunction(read): + raise TypeError("sync byte stream runtime does not support async readers") + return _SyncReader(cast(SyncByteReader, source)) + + def temporary_file(self) -> AbstractAsyncContextManager[StagingByteFile]: + return _sync_temporary_file() + + +class AsyncByteStreamRuntime: + """Adapt byte primitives for execution under AnyIO. + + Async readers are awaited directly, while blocking readers run on a worker + thread so they do not block the event loop. + """ + + @staticmethod + def reader(source: RawByteSource) -> ReadableByteStream: + if isinstance(source, (bytes, bytearray, memoryview)): + return _MemoryReader(source) + read = getattr(source, "read", None) + if not callable(read): + raise TypeError("byte source must provide a callable read method") + if inspect.iscoroutinefunction(read): + return _AsyncReader(cast(AsyncByteReader, source)) + return _ThreadedSyncReader(cast(SyncByteReader, source)) + + def temporary_file(self) -> AbstractAsyncContextManager[StagingByteFile]: + return cast( + AbstractAsyncContextManager[StagingByteFile], + anyio.TemporaryFile("w+b"), + ) + + +__all__ = [ + "AsyncByteReader", + "AsyncByteSource", + "AsyncByteStreamRuntime", + "BytesLike", + "RawByteSource", + "ReadableByteStream", + "StagingByteFile", + "StagingFileRuntime", + "SyncByteReader", + "SyncByteSource", + "SyncByteStreamRuntime", +] diff --git a/src/vercel/_internal/http/__init__.py b/src/vercel/_internal/http/__init__.py index cabea02..ec7366b 100644 --- a/src/vercel/_internal/http/__init__.py +++ b/src/vercel/_internal/http/__init__.py @@ -17,6 +17,8 @@ RawBody, ReadResponsePolicy, RequestBody, + StreamingRequest, + StreamingResponse, SyncTransport, TransportOptions, extract_structured_error, @@ -34,6 +36,8 @@ "RawBody", "ReadResponsePolicy", "RequestBody", + "StreamingRequest", + "StreamingResponse", "RetryPolicy", "SleepFn", "create_base_client", diff --git a/src/vercel/_internal/http/transport.py b/src/vercel/_internal/http/transport.py index 33de7f6..d303b2c 100644 --- a/src/vercel/_internal/http/transport.py +++ b/src/vercel/_internal/http/transport.py @@ -4,11 +4,17 @@ import abc import json +import queue +import threading +from collections.abc import AsyncIterator, Iterator +from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass from datetime import timedelta from types import TracebackType from typing import Any +import anyio +import anyio.abc import httpx from httpx import USE_CLIENT_DEFAULT from httpx._types import HeaderTypes, QueryParamTypes @@ -98,6 +104,39 @@ async def send( ) -> httpx.Response: raise NotImplementedError() + def request_stream( + self, + method: str, + path: str, + *, + token: str | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + timeout: timedelta | None = None, + follow_redirects: bool | None = None, + read_response: ReadResponsePolicy = ReadResponsePolicy.NON_SUCCESS_ONLY, + response_chunk_size: int | None = None, + ) -> AbstractAsyncContextManager[StreamingRequest]: + """Open a lexical scope for an incrementally supplied request body.""" + raise NotImplementedError() + + async def open_response_stream( + self, + method: str, + path: str, + *, + token: str | None = None, + params: QueryParamTypes | None = None, + body: RequestBody = None, + headers: HeaderTypes | None = None, + timeout: timedelta | None = None, + follow_redirects: bool | None = None, + read_response: ReadResponsePolicy = ReadResponsePolicy.NON_SUCCESS_ONLY, + chunk_size: int | None = None, + ) -> StreamingResponse: + """Open a response whose body is consumed incrementally.""" + raise NotImplementedError() + def _build_request( self, method: str, @@ -145,8 +184,418 @@ def _build_request( ) +class StreamingRequest(abc.ABC): + """An in-flight request with an incrementally supplied request body.""" + + @abc.abstractmethod + async def write(self, data: bytes) -> None: + raise NotImplementedError() + + @abc.abstractmethod + async def finish(self) -> StreamingResponse: + raise NotImplementedError() + + @abc.abstractmethod + async def abort(self) -> None: + raise NotImplementedError() + + +class StreamingResponse(abc.ABC): + """An owned streaming response with async-shaped iteration.""" + + response: httpx.Response + + def __aiter__(self) -> StreamingResponse: + return self + + async def read(self) -> bytes: + """Consume and close the remaining response body.""" + body = bytearray() + try: + async for chunk in self: + body.extend(chunk) + finally: + await self.aclose() + return bytes(body) + + @abc.abstractmethod + async def __anext__(self) -> bytes: + raise NotImplementedError() + + @abc.abstractmethod + def aiter_lines(self) -> AsyncIterator[str]: + raise NotImplementedError() + + @abc.abstractmethod + async def aclose(self) -> None: + raise NotImplementedError() + + +def _read_sync_response(response: httpx.Response, policy: ReadResponsePolicy) -> None: + if policy is ReadResponsePolicy.ALWAYS or ( + policy is ReadResponsePolicy.NON_SUCCESS_ONLY and not response.is_success + ): + response.read() + + +async def _read_async_response(response: httpx.Response, policy: ReadResponsePolicy) -> None: + if policy is ReadResponsePolicy.ALWAYS or ( + policy is ReadResponsePolicy.NON_SUCCESS_ONLY and not response.is_success + ): + await response.aread() + + +_STREAM_EOF = object() +_STREAM_ABORT = object() + + +class _RequestStreamAborted(Exception): + pass + + +class _SyncRequestBody: + def __init__(self, chunks: queue.Queue[bytes | object]) -> None: + self._chunks = chunks + + def __iter__(self) -> Iterator[bytes]: + while True: + item = self._chunks.get() + if item is _STREAM_EOF: + return + if item is _STREAM_ABORT: + raise _RequestStreamAborted + yield item # type: ignore[misc] + + +class _SyncStreamingRequest(StreamingRequest): + def __init__( + self, + *, + client: httpx.Client, + request: httpx.Request, + chunks: queue.Queue[bytes | object], + follow_redirects: bool | None, + read_response: ReadResponsePolicy, + chunk_size: int | None, + ) -> None: + self._client = client + self._request = request + self._chunks = chunks + self._follow_redirects = follow_redirects + self._read_response = read_response + self._chunk_size = chunk_size + self._response: httpx.Response | None = None + self._error: BaseException | None = None + self._closed = False + self._aborted = False + self._completed = False + self._finished = threading.Event() + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def _run(self) -> None: + response: httpx.Response | None = None + try: + response = self._client.send( + self._request, + stream=True, + follow_redirects=self._follow_redirects + if self._follow_redirects is not None + else USE_CLIENT_DEFAULT, + ) + _read_sync_response(response, self._read_response) + self._response = response + except _RequestStreamAborted: + if not self._aborted: + self._error = anyio.BrokenResourceError() + except BaseException as exc: + self._error = exc + if response is not None: + try: + response.close() + except BaseException: + pass + finally: + self._finished.set() + + def _raise_worker_error(self) -> None: + if self._error is not None: + raise self._error + if self._finished.is_set() and self._response is None and not self._aborted: + raise anyio.BrokenResourceError + + def _put(self, item: bytes | object) -> None: + while True: + self._raise_worker_error() + if self._finished.is_set(): + self._raise_worker_error() + raise anyio.BrokenResourceError + try: + self._chunks.put(item, timeout=0.05) + return + except queue.Full: + continue + + async def write(self, data: bytes) -> None: + if self._closed: + raise anyio.ClosedResourceError + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError(f"a bytes-like object is required, not {type(data).__name__}") + chunk = bytes(data) + if chunk: + self._put(chunk) + else: + self._raise_worker_error() + + async def finish(self) -> StreamingResponse: + if self._closed: + raise anyio.ClosedResourceError + self._closed = True + try: + self._put(_STREAM_EOF) + except BaseException: + self._thread.join() + raise + self._thread.join() + self._raise_worker_error() + if self._response is None: + raise anyio.BrokenResourceError + self._completed = True + return _SyncStreamingResponse(self._response, self._chunk_size) + + async def abort(self) -> None: + if self._aborted: + return + self._closed = True + self._aborted = True + while not self._finished.is_set(): + try: + self._chunks.put(_STREAM_ABORT, timeout=0.05) + break + except queue.Full: + continue + self._thread.join() + if not self._completed and self._response is not None: + try: + self._response.close() + except BaseException: + pass + + +class _SyncStreamingResponse(StreamingResponse): + def __init__(self, response: httpx.Response, chunk_size: int | None) -> None: + self.response = response + self._iterator = response.iter_bytes(chunk_size) + self._closed = False + + async def __anext__(self) -> bytes: + if self._closed: + raise StopAsyncIteration + try: + return next(self._iterator) + except StopIteration: + await self.aclose() + raise StopAsyncIteration from None + except BaseException: + await self.aclose() + raise + + async def aiter_lines(self) -> AsyncIterator[str]: + try: + lines = self.response.iter_lines() + while not self._closed: + try: + yield next(lines) + except StopIteration: + return + finally: + await self.aclose() + + async def aclose(self) -> None: + if not self._closed: + self._closed = True + try: + self.response.close() + except BaseException: + pass + + +class _AsyncRequestBody: + def __init__(self, receive: anyio.abc.ObjectReceiveStream[bytes]) -> None: + self._receive = receive + + async def __aiter__(self) -> AsyncIterator[bytes]: + async with self._receive: + async for chunk in self._receive: + yield chunk + + +class _AsyncStreamingRequest(StreamingRequest): + def __init__( + self, + *, + client: httpx.AsyncClient, + request: httpx.Request, + send: anyio.abc.ObjectSendStream[bytes], + receive: anyio.abc.ObjectReceiveStream[bytes], + follow_redirects: bool | None, + read_response: ReadResponsePolicy, + chunk_size: int | None, + ) -> None: + self._client = client + self._request = request + self._send = send + self._receive = receive + self._follow_redirects = follow_redirects + self._read_response = read_response + self._chunk_size = chunk_size + self._response: httpx.Response | None = None + self._error: BaseException | None = None + self._closed = False + self._aborted = False + self._completed = False + self._cancel_scope: anyio.CancelScope | None = None + self._done = anyio.Event() + + async def _run(self) -> None: + response: httpx.Response | None = None + try: + with anyio.CancelScope() as cancel_scope: + self._cancel_scope = cancel_scope + if self._aborted: + cancel_scope.cancel() + else: + response = await self._client.send( + self._request, + stream=True, + follow_redirects=self._follow_redirects + if self._follow_redirects is not None + else USE_CLIENT_DEFAULT, + ) + await _read_async_response(response, self._read_response) + self._response = response + if self._aborted and response is not None and self._response is None: + with anyio.CancelScope(shield=True): + await response.aclose() + except BaseException as exc: + if not self._aborted: + self._error = exc + if response is not None: + try: + with anyio.CancelScope(shield=True): + await response.aclose() + except BaseException: + pass + finally: + try: + with anyio.CancelScope(shield=True): + await self._receive.aclose() + except BaseException as exc: + if not self._aborted and self._error is None: + self._error = exc + self._done.set() + + def _raise_worker_error(self) -> None: + if self._error is not None: + raise self._error + + async def write(self, data: bytes) -> None: + if self._closed: + raise anyio.ClosedResourceError + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError(f"a bytes-like object is required, not {type(data).__name__}") + chunk = bytes(data) + self._raise_worker_error() + if not chunk: + return + try: + await self._send.send(chunk) + except anyio.get_cancelled_exc_class(): + with anyio.CancelScope(shield=True): + await self.abort() + raise + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + self._raise_worker_error() + raise + self._raise_worker_error() + + async def finish(self) -> StreamingResponse: + if self._closed: + raise anyio.ClosedResourceError + self._closed = True + try: + await self._send.aclose() + await self._done.wait() + except BaseException: + with anyio.CancelScope(shield=True): + await self.abort() + raise + self._raise_worker_error() + if self._response is None: + raise anyio.BrokenResourceError + self._completed = True + return _AsyncStreamingResponse(self._response, self._chunk_size) + + async def abort(self) -> None: + if self._aborted: + return + self._closed = True + self._aborted = True + if self._cancel_scope is not None: + self._cancel_scope.cancel() + with anyio.CancelScope(shield=True): + await self._send.aclose() + await self._done.wait() + if not self._completed and self._response is not None: + try: + await self._response.aclose() + except BaseException: + pass + + +class _AsyncStreamingResponse(StreamingResponse): + def __init__(self, response: httpx.Response, chunk_size: int | None) -> None: + self.response = response + self._iterator = response.aiter_bytes(chunk_size) + self._closed = False + + async def __anext__(self) -> bytes: + if self._closed: + raise StopAsyncIteration + try: + return await anext(self._iterator) + except StopAsyncIteration: + await self.aclose() + raise + except BaseException: + with anyio.CancelScope(shield=True): + await self.aclose() + raise + + async def aiter_lines(self) -> AsyncIterator[str]: + try: + lines = self.response.aiter_lines() + while not self._closed: + try: + yield await anext(lines) + except StopAsyncIteration: + return + finally: + with anyio.CancelScope(shield=True): + await self.aclose() + + async def aclose(self) -> None: + if not self._closed: + self._closed = True + try: + with anyio.CancelScope(shield=True): + await self.response.aclose() + except BaseException: + pass + + class SyncTransport(BaseTransport): - """Sync transport with async interface for use with iter_coroutine().""" + """Sync transport with a non-suspending async-shaped interface.""" _client: httpx.Client @@ -177,12 +626,81 @@ async def send( if follow_redirects is not None else USE_CLIENT_DEFAULT, ) - if read_response is ReadResponsePolicy.ALWAYS or ( - read_response is ReadResponsePolicy.NON_SUCCESS_ONLY and not response.is_success - ): - response.read() + _read_sync_response(response, read_response) return response + @asynccontextmanager + async def request_stream( + self, + method: str, + path: str, + *, + token: str | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + timeout: timedelta | None = None, + follow_redirects: bool | None = None, + read_response: ReadResponsePolicy = ReadResponsePolicy.NON_SUCCESS_ONLY, + response_chunk_size: int | None = None, + ) -> AsyncIterator[StreamingRequest]: + chunks: queue.Queue[bytes | object] = queue.Queue(maxsize=1) + request = self._build_request( + method, + path, + token=token, + params=params, + body=RawBody(_SyncRequestBody(chunks)), + headers=headers, + timeout=timeout, + ) + streaming_request = _SyncStreamingRequest( + client=self._client, + request=request, + chunks=chunks, + follow_redirects=follow_redirects, + read_response=read_response, + chunk_size=response_chunk_size, + ) + try: + yield streaming_request + except BaseException: + try: + await streaming_request.abort() + except BaseException: + pass + raise + else: + if not streaming_request._completed: + await streaming_request.abort() + + async def open_response_stream( + self, + method: str, + path: str, + *, + token: str | None = None, + params: QueryParamTypes | None = None, + body: RequestBody = None, + headers: HeaderTypes | None = None, + timeout: timedelta | None = None, + follow_redirects: bool | None = None, + read_response: ReadResponsePolicy = ReadResponsePolicy.NON_SUCCESS_ONLY, + chunk_size: int | None = None, + ) -> StreamingResponse: + response = await self.send( + method, + path, + token=token, + params=params, + body=body, + headers=headers, + timeout=timeout, + follow_redirects=follow_redirects, + stream=True, + read_response=read_response, + ) + return _SyncStreamingResponse(response, chunk_size) + def close(self) -> None: self._client.close() @@ -228,12 +746,97 @@ async def send( if follow_redirects is not None else USE_CLIENT_DEFAULT, ) - if read_response is ReadResponsePolicy.ALWAYS or ( - read_response is ReadResponsePolicy.NON_SUCCESS_ONLY and not response.is_success - ): - await response.aread() + await _read_async_response(response, read_response) return response + @asynccontextmanager + async def request_stream( + self, + method: str, + path: str, + *, + token: str | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + timeout: timedelta | None = None, + follow_redirects: bool | None = None, + read_response: ReadResponsePolicy = ReadResponsePolicy.NON_SUCCESS_ONLY, + response_chunk_size: int | None = None, + ) -> AsyncIterator[StreamingRequest]: + send, receive = anyio.create_memory_object_stream[bytes](1) + request = self._build_request( + method, + path, + token=token, + params=params, + body=RawBody(_AsyncRequestBody(receive)), + headers=headers, + timeout=timeout, + ) + streaming_request = _AsyncStreamingRequest( + client=self._client, + request=request, + send=send, + receive=receive, + follow_redirects=follow_redirects, + read_response=read_response, + chunk_size=response_chunk_size, + ) + scope_error: BaseException | None = None + scope_traceback: TracebackType | None = None + try: + async with anyio.create_task_group() as tasks: + tasks.start_soon(streaming_request._run) + try: + await anyio.lowlevel.checkpoint() + yield streaming_request + except BaseException as exc: + scope_error = exc + scope_traceback = exc.__traceback__ + with anyio.CancelScope(shield=True): + try: + await streaming_request.abort() + except BaseException: + pass + else: + if not streaming_request._completed: + with anyio.CancelScope(shield=True): + await streaming_request.abort() + finally: + with anyio.CancelScope(shield=True): + await send.aclose() + await receive.aclose() + if scope_error is not None: + raise scope_error.with_traceback(scope_traceback) + + async def open_response_stream( + self, + method: str, + path: str, + *, + token: str | None = None, + params: QueryParamTypes | None = None, + body: RequestBody = None, + headers: HeaderTypes | None = None, + timeout: timedelta | None = None, + follow_redirects: bool | None = None, + read_response: ReadResponsePolicy = ReadResponsePolicy.NON_SUCCESS_ONLY, + chunk_size: int | None = None, + ) -> StreamingResponse: + response = await self.send( + method, + path, + token=token, + params=params, + body=body, + headers=headers, + timeout=timeout, + follow_redirects=follow_redirects, + stream=True, + read_response=read_response, + ) + return _AsyncStreamingResponse(response, chunk_size) + async def aclose(self) -> None: await self._client.aclose() @@ -293,5 +896,7 @@ def extract_structured_error(response: httpx.Response) -> tuple[str, object]: "RawBody", "ReadResponsePolicy", "RequestBody", + "StreamingRequest", + "StreamingResponse", "extract_structured_error", ] diff --git a/src/vercel/_internal/unstable/sandbox/api_client.py b/src/vercel/_internal/unstable/sandbox/api_client.py index c206f69..c4f18f7 100644 --- a/src/vercel/_internal/unstable/sandbox/api_client.py +++ b/src/vercel/_internal/unstable/sandbox/api_client.py @@ -1,17 +1,15 @@ """Internal Sandbox v2 API client.""" -import io import json import platform -import posixpath import sys -import tarfile from collections.abc import AsyncIterator, Mapping, Sequence +from contextlib import asynccontextmanager from datetime import timedelta from importlib.metadata import version as _pkg_version from typing import Literal, TypeVar, cast -from httpx import AsyncByteStream, Response +from httpx import Response from httpx._types import QueryParamTypes from pydantic import ( AliasChoices, @@ -25,10 +23,11 @@ from vercel._internal.http import ( BaseTransport, - BytesBody, JSONBody, ReadResponsePolicy, RequestBody, + StreamingRequest, + StreamingResponse, extract_structured_error, ) from vercel._internal.time import MILLISECOND, parse_duration, to_ms_int @@ -54,7 +53,6 @@ _Omitted, _parse_network_policy, _serialize_network_policy, - _WriteFile, ) from vercel._internal.unstable.sandbox.options import ( SandboxCredentials, @@ -89,6 +87,27 @@ ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel) +class _WriteFilesUpload: + """Own one sandbox filesystem upload through server acceptance.""" + + def __init__(self, request: StreamingRequest) -> None: + self._request = request + + async def write(self, data: bytes) -> None: + await self._request.write(data) + + async def finish(self) -> None: + stream = await self._request.finish() + await stream.read() + response = stream.response + if not response.is_success: + message, data = extract_structured_error(response) + raise SandboxApiError(response, message, data=data) + + async def abort(self) -> None: + await self._request.abort() + + class _ApiModel(BaseModel): model_config = ConfigDict( extra="ignore", frozen=True, populate_by_name=True, serialize_by_alias=True @@ -697,49 +716,6 @@ def _drop_none(data: Mapping[str, JSONValue | None]) -> JSONObject: return {key: value for key, value in data.items() if value is not None} -def _normalize_mode(mode: object) -> int | None: - match mode: - case None: - return None - case bool(): - raise TypeError("mode must be an integer between 0 and 0o777") - case int() if 0 <= mode <= 0o777: - return mode - case int(): - raise ValueError("mode must be an integer between 0 and 0o777") - case _: - raise TypeError("mode must be an integer between 0 and 0o777") - - -def _normalize_tar_path(path: str, *, cwd: str) -> str: - if not posixpath.isabs(cwd): - raise ValueError("cwd must be an absolute path") - if posixpath.isabs(path): - absolute_path = posixpath.normpath(path) - else: - absolute_path = posixpath.normpath(posixpath.join(cwd, path)) - return posixpath.relpath(absolute_path, "/") - - -def _build_write_files_tarball( - files: Sequence[_WriteFile], - *, - cwd: str, -) -> bytes: - buffer = io.BytesIO() - with tarfile.open(fileobj=buffer, mode="w:gz") as tar: - for file in files: - info = tarfile.TarInfo(name=_normalize_tar_path(file.path, cwd=cwd)) - mode = _normalize_mode(file.mode) - if mode is not None: - info.mode = mode - info.size = len(file.content) - tar.addfile(info, io.BytesIO(file.content)) - # BytesBody currently requires bytes, so finalizing the in-memory archive - # makes one additional copy. Streaming uploads are intentionally deferred. - return buffer.getvalue() - - def _validate_response(model: type[ResponseModelT], data: JSONObject) -> ResponseModelT: try: return model.model_validate(data) @@ -766,22 +742,6 @@ def _parse_run_process_record(line: str) -> JSONObject: return cast(JSONObject, record) -async def _response_lines(response: Response) -> AsyncIterator[str]: - if isinstance(response.stream, AsyncByteStream): - async for line in response.aiter_lines(): - yield line - else: - for line in response.iter_lines(): - yield line - - -async def _close_stream_response(response: Response) -> None: - if isinstance(response.stream, AsyncByteStream): - await response.aclose() - else: - response.close() - - class SandboxApiClient: def __init__( self, @@ -789,10 +749,12 @@ def __init__( base_url: str, credentials_factory: SandboxCredentialsFactory, transport: BaseTransport, + file_transfer_timeout: timedelta, ) -> None: self._credentials_factory = credentials_factory self._base_url = base_url self._transport = transport + self._file_transfer_timeout = file_transfer_timeout def _url(self, path: str) -> str: return self._base_url.rstrip("/") + "/" + path.lstrip("/") @@ -806,6 +768,7 @@ async def _request( body: RequestBody = None, params: Mapping[str, JSONValue | None] | None = None, headers: Mapping[str, str] | None = None, + timeout: timedelta | None = None, ) -> Response: query = cast( QueryParamTypes, @@ -827,6 +790,7 @@ async def _request( params=query, body=body, headers=request_headers, + timeout=timeout, read_response=ReadResponsePolicy.ALWAYS, ) @@ -845,7 +809,8 @@ async def _request_stream( body: RequestBody = None, params: Mapping[str, JSONValue | None] | None = None, headers: Mapping[str, str] | None = None, - ) -> Response: + timeout: timedelta | None = None, + ) -> StreamingResponse: query = cast( QueryParamTypes, _drop_none( @@ -859,22 +824,25 @@ async def _request_stream( "user-agent": USER_AGENT, **dict(headers or {}), } - response = await self._transport.send( + response = await self._transport.open_response_stream( method, self._url(path), token=credentials.token, params=query, body=body, headers=request_headers, - stream=True, + timeout=timeout, read_response=ReadResponsePolicy.NON_SUCCESS_ONLY, ) - if response.is_success: + if response.response.is_success: return response - message, data = extract_structured_error(response) - raise SandboxApiError(response, message, data=data) + try: + message, data = extract_structured_error(response.response) + raise SandboxApiError(response.response, message, data=data) + finally: + await response.aclose() async def _request_json( self, @@ -1310,7 +1278,7 @@ async def run_process( initial: ProcessState | None = None final: ProcessState | None = None try: - async for line in _response_lines(response): + async for line in response.aiter_lines(): if not line: continue record = _parse_run_process_record(line) @@ -1347,7 +1315,7 @@ async def run_process( data=record, ) finally: - await _close_stream_response(response) + await response.aclose() if initial is None: raise SandboxResponseError("Sandbox process response is missing initial metadata") @@ -1413,39 +1381,52 @@ async def mkdir( body=JSONBody(request.to_api_dict()), ) - async def read_bytes( + async def open_read_response( self, *, session_id: str, path: str, cwd: str | None = None, - ) -> bytes: + ) -> StreamingResponse: credentials = await self._credentials_factory() request = _FilesystemPathRequest(path=path, cwd=cwd) - response = await self._request( + return await self._request_stream( "POST", format_url_path("v2/sandboxes/sessions/{session_id}/fs/read", session_id=session_id), credentials=credentials, body=JSONBody(request.to_api_dict()), + timeout=self._file_transfer_timeout, ) - return response.content - async def write_files( + @asynccontextmanager + async def write_files_request( self, *, session_id: str, - files: Sequence[_WriteFile], - cwd: str, - ) -> None: + ) -> AsyncIterator[_WriteFilesUpload]: credentials = await self._credentials_factory() - payload = _build_write_files_tarball(files, cwd=cwd) - await self._request( - "POST", - format_url_path("v2/sandboxes/sessions/{session_id}/fs/write", session_id=session_id), - credentials=credentials, - body=BytesBody(payload, "application/gzip"), - headers={"x-cwd": "/"}, + query = cast( + QueryParamTypes, + _drop_none({"teamId": credentials.team_id}), ) + async with self._transport.request_stream( + "POST", + self._url( + format_url_path( + "v2/sandboxes/sessions/{session_id}/fs/write", session_id=session_id + ) + ), + token=credentials.token, + params=query, + headers={ + "user-agent": USER_AGENT, + "x-cwd": "/", + "content-type": "application/gzip", + }, + timeout=self._file_transfer_timeout, + read_response=ReadResponsePolicy.NON_SUCCESS_ONLY, + ) as request: + yield _WriteFilesUpload(request) async def kill_command( self, @@ -1472,7 +1453,7 @@ async def command_logs_response( *, session_id: str, command_id: str, - ) -> Response: + ) -> StreamingResponse: credentials = await self._credentials_factory() return await self._request_stream( "GET", diff --git a/src/vercel/_internal/unstable/sandbox/async_filesystem_handle.py b/src/vercel/_internal/unstable/sandbox/async_filesystem_handle.py new file mode 100644 index 0000000..6709a34 --- /dev/null +++ b/src/vercel/_internal/unstable/sandbox/async_filesystem_handle.py @@ -0,0 +1,243 @@ +"""Asynchronous facades for shared Sandbox file-handle state machines.""" + +from collections.abc import AsyncIterator, Iterable +from contextlib import AbstractAsyncContextManager +from types import TracebackType + +import anyio + +from vercel._internal.unstable.sandbox.filesystem_handle_core import ( + BinaryReaderCore, + BinaryWriterCore, + TextReaderCore, + TextWriterCore, +) + + +class _AsyncHandle: + _core: BinaryReaderCore | TextReaderCore | BinaryWriterCore | TextWriterCore + + @property + def name(self) -> str: + return self._core.name + + @property + def mode(self) -> str: + return self._core.mode + + @property + def closed(self) -> bool: + return self._core.closed + + def readable(self) -> bool: + return self._core.readable() + + def writable(self) -> bool: + return self._core.writable() + + def seekable(self) -> bool: + return False + + +class SandboxBinaryReader(_AsyncHandle, AsyncIterator[bytes]): + _core: BinaryReaderCore + + def __init__(self, core: BinaryReaderCore) -> None: + self._core = core + self._guard = anyio.ResourceGuard("reading from") + + async def __aenter__(self) -> "SandboxBinaryReader": + await self._core.enter() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> None: + try: + await self.aclose() + except BaseException: + if exc_type is None: + raise + + async def read(self, size: int = -1) -> bytes: + with self._guard: + return await self._core.read(size) + + async def readline(self, size: int = -1) -> bytes: + with self._guard: + return await self._core.readline(size) + + async def readinto(self, buffer: object) -> int: + with self._guard: + return await self._core.readinto(buffer) + + def __aiter__(self) -> "SandboxBinaryReader": + return self + + async def __anext__(self) -> bytes: + line = await self.readline() + if not line: + raise StopAsyncIteration + return line + + async def aclose(self) -> None: + await self._core.close() + + +class SandboxTextReader(_AsyncHandle, AsyncIterator[str]): + _core: TextReaderCore + + def __init__(self, core: TextReaderCore) -> None: + self._core = core + self._guard = anyio.ResourceGuard("reading from") + + async def __aenter__(self) -> "SandboxTextReader": + await self._core.enter() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> None: + try: + await self.aclose() + except BaseException: + if exc_type is None: + raise + + async def read(self, size: int = -1) -> str: + with self._guard: + return await self._core.read(size) + + async def readline(self, size: int = -1) -> str: + with self._guard: + return await self._core.readline(size) + + def __aiter__(self) -> "SandboxTextReader": + return self + + async def __anext__(self) -> str: + line = await self.readline() + if not line: + raise StopAsyncIteration + return line + + async def aclose(self) -> None: + await self._core.close() + + +class SandboxBinaryWriter(_AsyncHandle): + _core: BinaryWriterCore + + def __init__(self, core: BinaryWriterCore) -> None: + self._core = core + self._guard = anyio.ResourceGuard("writing to") + self._lifecycle: AbstractAsyncContextManager[None] | None = None + + async def __aenter__(self) -> "SandboxBinaryWriter": + lifecycle = self._core.lifecycle() + self._lifecycle = lifecycle + try: + await lifecycle.__aenter__() + except BaseException: + self._lifecycle = None + raise + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> None: + lifecycle, self._lifecycle = self._lifecycle, None + if lifecycle is None: + return + if exc_type is not None: + try: + with anyio.CancelScope(shield=True): + await lifecycle.__aexit__(exc_type, exc, traceback) + except BaseException: + pass + else: + await lifecycle.__aexit__(None, None, None) + + async def write(self, data: bytes, /) -> int: + with self._guard: + return await self._core.write(data) + + async def writelines(self, lines: Iterable[bytes], /) -> None: + for line in lines: + await self.write(line) + + async def flush(self) -> None: + with self._guard: + await self._core.flush() + + async def aclose(self) -> None: + lifecycle, self._lifecycle = self._lifecycle, None + if lifecycle is None: + await self._core.close() + else: + await lifecycle.__aexit__(None, None, None) + + +class SandboxTextWriter(_AsyncHandle): + _core: TextWriterCore + + def __init__(self, core: TextWriterCore) -> None: + self._core = core + self._guard = anyio.ResourceGuard("writing to") + self._lifecycle: AbstractAsyncContextManager[None] | None = None + + async def __aenter__(self) -> "SandboxTextWriter": + lifecycle = self._core.lifecycle() + self._lifecycle = lifecycle + try: + await lifecycle.__aenter__() + except BaseException: + self._lifecycle = None + raise + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> None: + lifecycle, self._lifecycle = self._lifecycle, None + if lifecycle is None: + return + if exc_type is not None: + try: + with anyio.CancelScope(shield=True): + await lifecycle.__aexit__(exc_type, exc, traceback) + except BaseException: + pass + else: + await lifecycle.__aexit__(None, None, None) + + async def write(self, text: str, /) -> int: + with self._guard: + return await self._core.write(text) + + async def writelines(self, lines: Iterable[str], /) -> None: + for line in lines: + await self.write(line) + + async def flush(self) -> None: + with self._guard: + await self._core.flush() + + async def aclose(self) -> None: + lifecycle, self._lifecycle = self._lifecycle, None + if lifecycle is None: + await self._core.close() + else: + await lifecycle.__aexit__(None, None, None) diff --git a/src/vercel/_internal/unstable/sandbox/async_runtime.py b/src/vercel/_internal/unstable/sandbox/async_runtime.py index 6235df4..fb202bd 100644 --- a/src/vercel/_internal/unstable/sandbox/async_runtime.py +++ b/src/vercel/_internal/unstable/sandbox/async_runtime.py @@ -7,16 +7,30 @@ from dataclasses import dataclass from datetime import timedelta from types import TracebackType -from typing import Any, TextIO +from typing import Any, Literal, TextIO, overload +from vercel._internal.byte_stream import AsyncByteStreamRuntime from vercel._internal.polyfills import Self from vercel._internal.time import parse_duration_seconds, parse_required_duration_seconds +from vercel._internal.unstable.sandbox.async_filesystem_handle import ( + SandboxBinaryReader, + SandboxBinaryWriter, + SandboxTextReader, + SandboxTextWriter, +) from vercel._internal.unstable.sandbox.errors import ( SandboxCleanupError, SandboxResponseError, SandboxTerminalStateError, ) -from vercel._internal.unstable.sandbox.log_stream import _parse_command_log_record +from vercel._internal.unstable.sandbox.filesystem_handle_common import _validate_open_options +from vercel._internal.unstable.sandbox.filesystem_handle_core import ( + BinaryReaderCore, + BinaryWriterCore, + FilesystemHandleBinding, + TextReaderCore, + TextWriterCore, +) from vercel._internal.unstable.sandbox.models import ( _OMITTED, CompletedProcess, @@ -52,9 +66,12 @@ SandboxHandleBase, SnapshotHandleBase, _coerce_remote_path, + _normalize_tar_path, _ProcessHandleState, _SandboxFilesystemBatchBase, _signal_number, + _UploadFileEntry, + _validate_file_mode, ) from vercel._internal.unstable.sandbox.service import SandboxService, _SandboxTerminalState from vercel._internal.unstable.sandbox.state import ( @@ -202,11 +219,108 @@ def __init__( self._session_id = session_id self._write_files_cwd = write_files_cwd + @overload + def open( + self, + path: RemotePath, + mode: Literal["r"] = "r", + *, + cwd: RemotePath | None = None, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + size: None = None, + permissions: None = None, + ) -> SandboxTextReader: ... + + @overload + def open( + self, + path: RemotePath, + mode: Literal["rb"], + *, + cwd: RemotePath | None = None, + encoding: None = None, + errors: None = None, + newline: None = None, + size: None = None, + permissions: None = None, + ) -> SandboxBinaryReader: ... + + @overload + def open( + self, + path: RemotePath, + mode: Literal["w"], + *, + cwd: RemotePath | None = None, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + size: None = None, + permissions: int | None = None, + ) -> SandboxTextWriter: ... + + @overload + def open( + self, + path: RemotePath, + mode: Literal["wb"], + *, + cwd: RemotePath | None = None, + encoding: None = None, + errors: None = None, + newline: None = None, + size: int | None = None, + permissions: int | None = None, + ) -> SandboxBinaryWriter: ... + + def open( + self, + path: RemotePath, + mode: str = "r", + *, + cwd: RemotePath | None = None, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + size: int | None = None, + permissions: int | None = None, + ) -> SandboxBinaryReader | SandboxTextReader | SandboxBinaryWriter | SandboxTextWriter: + """Create a lazy, single-use sequential file handle.""" + path, mode, encoding, errors, newline, size, permissions = _validate_open_options( + path, + mode, + encoding=encoding, + errors=errors, + newline=newline, + size=size, + permissions=permissions, + ) + normalized_cwd = None if cwd is None else _coerce_remote_path(cwd) + binding = FilesystemHandleBinding( + service=self._service, + runtime=self._service.staging_file_runtime, + session_id=self._session_id, + write_files_cwd=self._write_files_cwd, + path=path, + cwd=normalized_cwd, + ) + if mode == "rb": + return SandboxBinaryReader(BinaryReaderCore(binding)) + if mode == "r": + return SandboxTextReader(TextReaderCore(binding, encoding, errors, newline)) + if mode == "wb": + return SandboxBinaryWriter( + BinaryWriterCore(binding.write_target_source(size=size, permissions=permissions)) + ) + return SandboxTextWriter(TextWriterCore(binding, encoding, errors, newline, permissions)) + async def _collect_output(self, command: ProcessState) -> tuple[str, str]: stdout: list[str] = [] stderr: list[str] = [] - async for event in _process_logs( - self._service, session_id=command.session_id, process_id=command.id + async for event in self._service.process_logs( + session_id=command.session_id, process_id=command.id ): if event.stream == "stdout": stdout.append(event.data) @@ -249,6 +363,7 @@ async def read_bytes(self, path: RemotePath, *, cwd: RemotePath | None = None) - SandboxPathNotFoundError: If the file does not exist. """ return await self._service.read_bytes( + operation="read_bytes", session_id=self._session_id(), path=_coerce_remote_path(path), cwd=None if cwd is None else _coerce_remote_path(cwd), @@ -338,10 +453,24 @@ async def write_text( async def _write_files( self, files: Sequence[_WriteFile], *, cwd: RemotePath | None = None ) -> None: - await self._service.write_files( + for file in files: + _validate_file_mode(file.mode) + resolved_cwd = self._write_files_cwd(cwd) + entries = [ + _UploadFileEntry( + path=f.path, + size=len(f.content), + source=AsyncByteStreamRuntime.reader(f.content), + mode=f.mode, + archive_path=_normalize_tar_path(f.path, cwd=resolved_cwd), + ) + for f in files + ] + await self._service.write_stream_archive( session_id=self._session_id(), - files=files, - cwd=self._write_files_cwd(cwd), + entries=entries, + paths=tuple(entry.path for entry in entries), + cwd=resolved_cwd, ) def batch(self, *, cwd: RemotePath | None = None) -> "SandboxFilesystemBatch": @@ -1313,15 +1442,4 @@ async def get_snapshot(service: SandboxService, *, snapshot_id: str) -> Snapshot def _process_logs( service: SandboxService, *, session_id: str, process_id: str ) -> AsyncIterator[ProcessLog]: - async def iterate() -> AsyncIterator[ProcessLog]: - response = await service.process_logs_response(session_id=session_id, process_id=process_id) - try: - async for line in response.aiter_lines(): - if line: - event = _parse_command_log_record(line) - if event is not None: - yield event - finally: - await response.aclose() - - return iterate() + return service.process_logs(session_id=session_id, process_id=process_id) diff --git a/src/vercel/_internal/unstable/sandbox/errors.py b/src/vercel/_internal/unstable/sandbox/errors.py index bdfa139..39bfee3 100644 --- a/src/vercel/_internal/unstable/sandbox/errors.py +++ b/src/vercel/_internal/unstable/sandbox/errors.py @@ -141,6 +141,32 @@ def __init__( self.cause = cause +class SandboxFilesystemTransferError(SandboxFilesystemError): + """Common error for Sandbox filesystem streaming transfers.""" + + +class SandboxUploadSizeMismatchError(SandboxFilesystemTransferError): + """Declared upload size does not match the bytes the source produced.""" + + def __init__( + self, + path: str, + *, + declared: int, + consumed: int, + early_end: bool, + ) -> None: + super().__init__( + f"Sandbox upload size mismatch for {path!r}: " + f"declared {declared} bytes, consumed {consumed} bytes" + f"{' (source ended early)' if early_end else ' (source produced trailing data)'}" + ) + self.path = path + self.declared = declared + self.consumed = consumed + self.early_end = early_end + + def _extract_api_error_code(data: object | None) -> str | None: if not isinstance(data, dict): return None diff --git a/src/vercel/_internal/unstable/sandbox/filesystem_handle_common.py b/src/vercel/_internal/unstable/sandbox/filesystem_handle_common.py new file mode 100644 index 0000000..c5689cc --- /dev/null +++ b/src/vercel/_internal/unstable/sandbox/filesystem_handle_common.py @@ -0,0 +1,196 @@ +"""Pure helpers shared by synchronous and asynchronous sandbox file handles.""" + +import codecs +import io +import os +from enum import Enum, auto +from typing import Any, Literal, TypeAlias + +from vercel._internal.unstable.sandbox.errors import SandboxUploadSizeMismatchError +from vercel._internal.unstable.sandbox.runtime_common import ( + RemotePath, + _coerce_remote_path, + _validate_file_mode, + _validate_transfer_size, +) + +OpenMode: TypeAlias = Literal["r", "rb", "w", "wb"] + + +class _HandleState(Enum): + CREATED = auto() + ACTIVE = auto() + CLOSED = auto() + + +def _validate_open_options( + path: RemotePath, + mode: str, + *, + encoding: str | None, + errors: str | None, + newline: str | None, + size: int | None, + permissions: int | None, +) -> tuple[str, OpenMode, str, str, str | None, int | None, int | None]: + normalized_path = _coerce_remote_path(path) + if mode not in ("r", "rb", "w", "wb"): + raise ValueError("mode must be 'r', 'rb', 'w', or 'wb'") + binary = "b" in mode + if binary and (encoding is not None or errors is not None or newline is not None): + raise ValueError("encoding, errors, and newline are not supported in binary mode") + if newline not in (None, "", "\n", "\r", "\r\n"): + raise ValueError("illegal newline value") + if size is not None: + if mode != "wb": + raise ValueError("size is only supported in 'wb' mode") + size = _validate_transfer_size(size) + if permissions is not None: + if mode.startswith("r"): + raise ValueError("permissions are not supported in read mode") + permissions = _validate_file_mode(permissions) + resolved_encoding = "utf-8" if encoding is None else encoding + resolved_errors = "strict" if errors is None else errors + if not binary: + codecs.lookup(resolved_encoding) + codecs.lookup_error(resolved_errors) + return ( + normalized_path, + mode, # type: ignore[return-value] + resolved_encoding, + resolved_errors, + newline, + size, + permissions, + ) + + +def _validate_read_size(size: int) -> None: + if not isinstance(size, int): + raise TypeError("size must be an integer") + if size < -1: + raise ValueError("size must be -1 or non-negative") + + +class _ExactSizeValidator: + __slots__ = ("_declared", "_name", "_written") + + def __init__(self, name: str, declared: int) -> None: + self._name = name + self._declared = declared + self._written = 0 + + def validate_write(self, size: int) -> None: + consumed = self._written + size + if consumed > self._declared: + raise SandboxUploadSizeMismatchError( + self._name, + declared=self._declared, + consumed=consumed, + early_end=False, + ) + + def record_write(self, size: int) -> None: + self._written += size + + def validate_close(self) -> None: + if self._written != self._declared: + raise SandboxUploadSizeMismatchError( + self._name, + declared=self._declared, + consumed=self._written, + early_end=True, + ) + + +class _HandleInfo: + __slots__ = ("_state", "mode", "name") + + def __init__(self, name: str, mode: OpenMode) -> None: + self.name = name + self.mode = mode + self._state = _HandleState.CREATED + + @property + def closed(self) -> bool: + return self._state is _HandleState.CLOSED + + def readable(self) -> bool: + return self.mode.startswith("r") + + def writable(self) -> bool: + return self.mode.startswith("w") + + def seekable(self) -> bool: + return False + + def _enter(self) -> None: + if self._state is not _HandleState.CREATED: + raise ValueError("I/O operation on closed or already-entered file") + self._state = _HandleState.ACTIVE + + def _ensure_active(self) -> None: + if self._state is not _HandleState.ACTIVE: + raise ValueError("I/O operation on closed or unopened file") + + def _mark_closed(self) -> None: + self._state = _HandleState.CLOSED + + +class _TextReadBuffer: + __slots__ = ("_buffer", "_decoder", "_eof", "_newline") + + def __init__(self, encoding: str, errors: str, newline: str | None) -> None: + decoder: Any = codecs.getincrementaldecoder(encoding)(errors) + if newline in (None, ""): + decoder = io.IncrementalNewlineDecoder(decoder, newline is None) + self._decoder = decoder + self._newline = newline + self._buffer = "" + self._eof = False + + def feed(self, data: bytes, *, final: bool = False) -> None: + self._buffer += self._decoder.decode(data, final) + self._eof = final + + def take(self, size: int) -> str: + if size < 0: + result, self._buffer = self._buffer, "" + else: + result, self._buffer = self._buffer[:size], self._buffer[size:] + return result + + def line_end(self, size: int = -1) -> int | None: + limit = len(self._buffer) if size < 0 else min(size, len(self._buffer)) + text = self._buffer[:limit] + if self._newline is None: + index = text.find("\n") + return None if index < 0 else index + 1 + if self._newline == "": + for index, char in enumerate(text): + if char == "\n": + return index + 1 + if char == "\r": + if index + 1 < len(text): + return index + (2 if text[index + 1] == "\n" else 1) + if self._eof or limit < len(self._buffer): + return index + 1 + return None + return None + index = text.find(self._newline) + return None if index < 0 else index + len(self._newline) + + +class _TextEncoder: + __slots__ = ("_encoder", "_newline") + + def __init__(self, encoding: str, errors: str, newline: str | None) -> None: + self._encoder = codecs.getincrementalencoder(encoding)(errors) + self._newline = os.linesep if newline is None else newline + + def encode(self, text: str, *, final: bool = False) -> bytes: + if not isinstance(text, str): + raise TypeError(f"write() argument must be str, not {type(text).__name__}") + if self._newline not in ("", "\n"): + text = text.replace("\n", self._newline) + return self._encoder.encode(text, final) diff --git a/src/vercel/_internal/unstable/sandbox/filesystem_handle_core.py b/src/vercel/_internal/unstable/sandbox/filesystem_handle_core.py new file mode 100644 index 0000000..6a87464 --- /dev/null +++ b/src/vercel/_internal/unstable/sandbox/filesystem_handle_core.py @@ -0,0 +1,341 @@ +"""Shared state machines for sync and async Sandbox file handles.""" + +from collections.abc import AsyncIterator, Callable +from contextlib import asynccontextmanager + +from vercel._internal.byte_stream import ( + StagingFileRuntime, +) +from vercel._internal.http import StreamingResponse +from vercel._internal.unstable.sandbox.filesystem_handle_common import ( + _HandleInfo, + _TextEncoder, + _TextReadBuffer, + _validate_read_size, +) +from vercel._internal.unstable.sandbox.filesystem_write import ( + _BoundWrite, + _FilesystemWriteTargetSource, + _WriteTarget, + _WriteTargetSource, +) +from vercel._internal.unstable.sandbox.runtime_common import ( + RemotePath, + _normalize_tar_path, +) +from vercel._internal.unstable.sandbox.service import SandboxService + +_CHUNK_SIZE = 64 * 1024 + + +class FilesystemHandleBinding: + def __init__( + self, + *, + service: SandboxService, + runtime: StagingFileRuntime, + session_id: Callable[[], str], + write_files_cwd: Callable[[RemotePath | None], str], + path: str, + cwd: RemotePath | None, + ) -> None: + self.service = service + self.runtime = runtime + self._session_id = session_id + self._write_files_cwd = write_files_cwd + self.path = path + self.cwd = None if cwd is None else str(cwd) + + async def open_response(self) -> StreamingResponse: + return await self.service.open_read_response( + operation="open", + session_id=self._session_id(), + path=self.path, + cwd=self.cwd, + ) + + def bind_write(self) -> _BoundWrite: + cwd = self._write_files_cwd(self.cwd) + return _BoundWrite( + service=self.service, + session_id=self._session_id(), + path=self.path, + cwd=cwd, + archive_path=_normalize_tar_path(self.path, cwd=cwd), + ) + + def write_target_source( + self, *, size: int | None, permissions: int | None + ) -> _WriteTargetSource: + return _FilesystemWriteTargetSource( + name=self.path, + runtime=self.runtime, + bind=self.bind_write, + size=size, + permissions=permissions, + ) + + +class BinaryReaderCore(_HandleInfo): + def __init__(self, binding: FilesystemHandleBinding) -> None: + super().__init__(binding.path, "rb") + self._binding = binding + self._response: StreamingResponse | None = None + self._buffer = bytearray() + self._eof = False + + async def enter(self) -> None: + self._enter() + try: + self._response = await self._binding.open_response() + except BaseException: + self._mark_closed() + raise + + async def _pump(self) -> None: + if self._eof: + return + assert self._response is not None + try: + chunk = await anext(self._response) + except StopAsyncIteration: + self._eof = True + await self._close_response() + except BaseException: + self._eof = True + await self._close_response() + raise + else: + self._buffer.extend(chunk) + + async def read(self, size: int = -1) -> bytes: + self._ensure_active() + _validate_read_size(size) + while not self._eof and (size < 0 or len(self._buffer) < size): + await self._pump() + if size < 0: + result = bytes(self._buffer) + self._buffer.clear() + return result + result = bytes(self._buffer[:size]) + del self._buffer[:size] + return result + + async def readline(self, size: int = -1) -> bytes: + self._ensure_active() + _validate_read_size(size) + while True: + limit = len(self._buffer) if size < 0 else min(size, len(self._buffer)) + newline = self._buffer.find(b"\n", 0, limit) + if newline >= 0: + return await self.read(newline + 1) + if (size >= 0 and len(self._buffer) >= size) or self._eof: + return await self.read(limit) + await self._pump() + + async def readinto(self, buffer: object) -> int: + self._ensure_active() + view = memoryview(buffer) # type: ignore[arg-type] + if view.readonly: + raise TypeError("readinto() argument must be read-write bytes-like object") + data = await self.read(view.nbytes) + view.cast("B")[: len(data)] = data + return len(data) + + async def _close_response(self) -> None: + response, self._response = self._response, None + if response is not None: + await response.aclose() + + async def close(self) -> None: + if not self.closed: + self._buffer.clear() + self._eof = True + await self._close_response() + self._mark_closed() + + +class TextReaderCore(_HandleInfo): + def __init__( + self, + binding: FilesystemHandleBinding, + encoding: str, + errors: str, + newline: str | None, + ) -> None: + super().__init__(binding.path, "r") + self._binary = BinaryReaderCore(binding) + self._text = _TextReadBuffer(encoding, errors, newline) + + async def enter(self) -> None: + self._enter() + try: + await self._binary.enter() + except BaseException: + self._mark_closed() + raise + + async def _pump(self) -> None: + data = await self._binary.read(_CHUNK_SIZE) + self._text.feed(data, final=not data) + + async def read(self, size: int = -1) -> str: + self._ensure_active() + _validate_read_size(size) + while not self._text._eof and (size < 0 or len(self._text._buffer) < size): + await self._pump() + return self._text.take(size) + + async def readline(self, size: int = -1) -> str: + self._ensure_active() + _validate_read_size(size) + while True: + end = self._text.line_end(size) + if end is not None: + return self._text.take(end) + if (size >= 0 and len(self._text._buffer) >= size) or self._text._eof: + limit = len(self._text._buffer) if size < 0 else min(size, len(self._text._buffer)) + return self._text.take(limit) + await self._pump() + + async def close(self) -> None: + if not self.closed: + await self._binary.close() + self._mark_closed() + + +class BinaryWriterCore(_HandleInfo): + def __init__(self, source: _WriteTargetSource) -> None: + super().__init__(source.name, "wb") + self._source = source + self._target: _WriteTarget | None = None + + @asynccontextmanager + async def lifecycle(self) -> AsyncIterator[None]: + self._enter() + try: + async with self._source.acquire() as target: + self._target = target + try: + yield + except BaseException: + try: + await self.abort() + except BaseException: + pass + raise + else: + await self.close() + finally: + self._target = None + except BaseException: + if not self.closed: + try: + await self.abort() + except BaseException: + pass + raise + + async def write(self, data: bytes) -> int: + self._ensure_active() + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError(f"a bytes-like object is required, not {type(data).__name__}") + chunk = bytes(data) + assert self._target is not None + await self._target.write(chunk) + return len(chunk) + + async def flush(self) -> None: + self._ensure_active() + assert self._target is not None + await self._target.flush() + + async def close(self) -> None: + if self.closed: + return + self._ensure_active() + try: + assert self._target is not None + await self._target.finish() + except BaseException: + try: + await self.abort() + except BaseException: + pass + raise + finally: + self._mark_closed() + + async def abort(self) -> None: + if self.closed: + return + try: + if self._target is not None: + await self._target.abort() + finally: + self._mark_closed() + + +class TextWriterCore(_HandleInfo): + def __init__( + self, + binding: FilesystemHandleBinding, + encoding: str, + errors: str, + newline: str | None, + permissions: int | None, + ) -> None: + super().__init__(binding.path, "w") + self._binary = BinaryWriterCore( + binding.write_target_source(size=None, permissions=permissions) + ) + self._encoder = _TextEncoder(encoding, errors, newline) + + @asynccontextmanager + async def lifecycle(self) -> AsyncIterator[None]: + self._enter() + try: + async with self._binary.lifecycle(): + try: + yield + except BaseException: + self._mark_closed() + raise + else: + await self.close() + except BaseException: + if not self.closed: + self._mark_closed() + raise + + async def write(self, text: str) -> int: + self._ensure_active() + await self._binary.write(self._encoder.encode(text)) + return len(text) + + async def flush(self) -> None: + self._ensure_active() + await self._binary.flush() + + async def close(self) -> None: + if self.closed: + return + self._ensure_active() + try: + suffix = self._encoder.encode("", final=True) + if suffix: + await self._binary.write(suffix) + await self._binary.close() + except BaseException: + try: + await self._binary.abort() + except BaseException: + pass + raise + finally: + self._mark_closed() + + async def abort(self) -> None: + if not self.closed: + await self._binary.abort() + self._mark_closed() diff --git a/src/vercel/_internal/unstable/sandbox/filesystem_write.py b/src/vercel/_internal/unstable/sandbox/filesystem_write.py new file mode 100644 index 0000000..cf326ea --- /dev/null +++ b/src/vercel/_internal/unstable/sandbox/filesystem_write.py @@ -0,0 +1,182 @@ +"""Write targets for streaming Sandbox filesystem handles.""" + +from collections.abc import AsyncIterator, Callable +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Protocol + +from vercel._internal.byte_stream import ( + ReadableByteStream, + StagingByteFile, +) +from vercel._internal.unstable.sandbox.filesystem_handle_common import ( + _ExactSizeValidator, +) +from vercel._internal.unstable.sandbox.runtime_common import _UploadFileEntry +from vercel._internal.unstable.sandbox.service import SandboxArchiveUpload, SandboxService + + +class _WriteTarget(Protocol): + async def write(self, data: bytes) -> None: ... + + async def flush(self) -> None: ... + + async def finish(self) -> None: ... + + async def abort(self) -> None: ... + + +class _WriteTargetSource(Protocol): + name: str + + def acquire(self) -> AbstractAsyncContextManager[_WriteTarget]: ... + + +class _WriteBinding(Protocol): + async def publish(self, source: ReadableByteStream, size: int, mode: int | None) -> None: ... + + def open_upload( + self, size: int, mode: int | None + ) -> AbstractAsyncContextManager[_WriteTarget]: ... + + +class _TemporaryFileRuntime(Protocol): + def temporary_file(self) -> AbstractAsyncContextManager[StagingByteFile]: ... + + +class _BoundWrite: + def __init__( + self, + *, + service: SandboxService, + session_id: str, + path: str, + cwd: str, + archive_path: str, + ) -> None: + self._service = service + self._session_id = session_id + self._path = path + self._cwd = cwd + self._archive_path = archive_path + + async def publish(self, source: ReadableByteStream, size: int, mode: int | None) -> None: + await self._service.write_stream_archive( + session_id=self._session_id, + entries=[ + _UploadFileEntry( + path=self._path, + size=size, + source=source, + mode=mode, + archive_path=self._archive_path, + ) + ], + paths=(self._path,), + cwd=self._cwd, + ) + + def open_upload( + self, size: int, mode: int | None + ) -> AbstractAsyncContextManager[SandboxArchiveUpload]: + @asynccontextmanager + async def open_upload() -> AsyncIterator[SandboxArchiveUpload]: + async with self._service.open_archive_upload( + session_id=self._session_id, + paths=(self._path,), + cwd=self._cwd, + ) as upload: + await upload.start_entry(self._archive_path, size, mode) + yield upload + + return open_upload() + + +class _FilesystemWriteTargetSource: + def __init__( + self, + *, + name: str, + runtime: _TemporaryFileRuntime, + bind: Callable[[], _WriteBinding], + size: int | None, + permissions: int | None, + ) -> None: + self.name = name + self._runtime = runtime + self._bind = bind + self._size = size + self._permissions = permissions + + def acquire(self) -> AbstractAsyncContextManager[_WriteTarget]: + return _acquire_write_target( + runtime=self._runtime, + bound=self._bind(), + name=self.name, + size=self._size, + permissions=self._permissions, + ) + + +class _SpooledWriteTarget: + def __init__( + self, + spool: StagingByteFile, + bound: _WriteBinding, + permissions: int | None, + ) -> None: + self._spool = spool + self._bound = bound + self._permissions = permissions + + async def write(self, data: bytes) -> None: + await self._spool.write(data) + + async def flush(self) -> None: + await self._spool.flush() + + async def finish(self) -> None: + await self._spool.flush() + size = await self._spool.tell() + await self._spool.seek(0) + await self._bound.publish(self._spool, size, self._permissions) + + async def abort(self) -> None: + pass + + +class _ExactSizeWriteTarget: + def __init__(self, target: _WriteTarget, *, name: str, size: int) -> None: + self._target = target + self._validator = _ExactSizeValidator(name, size) + + async def write(self, data: bytes) -> None: + self._validator.validate_write(len(data)) + await self._target.write(data) + self._validator.record_write(len(data)) + + async def flush(self) -> None: + await self._target.flush() + + async def finish(self) -> None: + self._validator.validate_close() + await self._target.finish() + + async def abort(self) -> None: + await self._target.abort() + + +@asynccontextmanager +async def _acquire_write_target( + *, + runtime: _TemporaryFileRuntime, + bound: _WriteBinding, + name: str, + size: int | None, + permissions: int | None, +) -> AsyncIterator[_WriteTarget]: + if size is None: + async with runtime.temporary_file() as spool: + yield _SpooledWriteTarget(spool, bound, permissions) + else: + async with bound.open_upload(size, permissions) as upload: + yield _ExactSizeWriteTarget(upload, name=name, size=size) diff --git a/src/vercel/_internal/unstable/sandbox/options.py b/src/vercel/_internal/unstable/sandbox/options.py index 3b14ca9..d66fec2 100644 --- a/src/vercel/_internal/unstable/sandbox/options.py +++ b/src/vercel/_internal/unstable/sandbox/options.py @@ -1,12 +1,14 @@ """Sandbox service options.""" from dataclasses import dataclass +from datetime import timedelta from typing import Protocol from vercel._internal.unstable.options import ServiceOptions from vercel._internal.unstable.sandbox.errors import SandboxCredentialsError DEFAULT_SANDBOX_API_BASE_URL = "https://vercel.com/api" +_DEFAULT_FILE_TRANSFER_TIMEOUT = timedelta(minutes=5) @dataclass(frozen=True, slots=True) @@ -49,12 +51,14 @@ class SandboxServiceOptions(ServiceOptions): base_url: str credentials_factory: SandboxCredentialsFactory + file_transfer_timeout: timedelta def __init__( self, *, base_url: str | None = None, credentials_factory: SandboxCredentialsFactory | None = None, + file_transfer_timeout: timedelta | None = None, ) -> None: object.__setattr__( self, @@ -66,3 +70,10 @@ def __init__( "credentials_factory", credentials_factory or _default_sandbox_credentials_factory(), ) + object.__setattr__( + self, + "file_transfer_timeout", + file_transfer_timeout + if file_transfer_timeout is not None + else _DEFAULT_FILE_TRANSFER_TIMEOUT, + ) diff --git a/src/vercel/_internal/unstable/sandbox/runtime_common.py b/src/vercel/_internal/unstable/sandbox/runtime_common.py index 285e81c..f55ada1 100644 --- a/src/vercel/_internal/unstable/sandbox/runtime_common.py +++ b/src/vercel/_internal/unstable/sandbox/runtime_common.py @@ -4,12 +4,13 @@ import posixpath import signal as signal_module from collections.abc import Callable, Sequence -from dataclasses import replace +from dataclasses import dataclass, replace from datetime import timedelta from enum import Enum, auto from pathlib import PurePosixPath from typing import Generic, Literal, TypeAlias, TypeVar +from vercel._internal.byte_stream import ReadableByteStream from vercel._internal.unstable.sandbox.errors import SandboxResponseError from vercel._internal.unstable.sandbox.models import ( JSONObject, @@ -29,6 +30,19 @@ RuntimeSessionHandleT = TypeVar("RuntimeSessionHandleT", bound="RuntimeSessionHandleBase") RemotePath: TypeAlias = str | PurePosixPath +_SourceT = TypeVar("_SourceT") + + +@dataclass(frozen=True, slots=True) +class _UploadFileEntry(Generic[_SourceT]): + path: str + size: int + source: _SourceT + mode: int | None = None + archive_path: str | None = None + + +_StreamUploadFileEntry: TypeAlias = _UploadFileEntry[ReadableByteStream] class _FilesystemBatchState(Enum): @@ -47,6 +61,7 @@ def __init__(self) -> None: def _stage(self, file: _WriteFile) -> None: if self._state is not _FilesystemBatchState.ACTIVE: raise RuntimeError("filesystem batch staging is only allowed inside its context") + _validate_file_mode(file.mode) self._files.append(file) def write_bytes(self, path: RemotePath, data: bytes, *, mode: int | None = None) -> None: @@ -118,6 +133,43 @@ def _resolve_write_files_cwd(cwd: RemotePath | None, *, default: str) -> str: return posixpath.normpath(posixpath.join(default, normalized_cwd)) +def _normalize_tar_path(path: str, *, cwd: str) -> str: + if not posixpath.isabs(cwd): + raise ValueError("cwd must be an absolute path") + absolute_path = ( + posixpath.normpath(path) + if posixpath.isabs(path) + else posixpath.normpath(posixpath.join(cwd, path)) + ) + return posixpath.relpath(absolute_path, "/") + + +def _validate_transfer_size(size: object) -> int: + if isinstance(size, bool) or not isinstance(size, int): + raise TypeError("size must be an integer >= 0") + if size < 0: + raise ValueError("size must be >= 0") + return size + + +def _validate_chunk_size(chunk_size: object) -> int: + if isinstance(chunk_size, bool) or not isinstance(chunk_size, int): + raise TypeError("chunk_size must be a positive integer") + if chunk_size < 1: + raise ValueError("chunk_size must be positive") + return chunk_size + + +def _validate_file_mode(mode: object) -> int | None: + if mode is None: + return None + if isinstance(mode, bool) or not isinstance(mode, int): + raise TypeError("file mode must be an integer or None") + if not 0 <= mode <= 0o777: + raise ValueError("file mode must be between 0 and 0o777") + return mode + + def _signal_number(value: int | str | signal_module.Signals | None) -> int: if value is None: return int(signal_module.Signals.SIGTERM) diff --git a/src/vercel/_internal/unstable/sandbox/service.py b/src/vercel/_internal/unstable/sandbox/service.py index 8b37e32..e35a6a6 100644 --- a/src/vercel/_internal/unstable/sandbox/service.py +++ b/src/vercel/_internal/unstable/sandbox/service.py @@ -1,12 +1,15 @@ """Neutral orchestration for unstable Sandbox operations.""" -from collections.abc import Awaitable, Callable, Mapping, Sequence +from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping, Sequence +from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import timedelta from typing import TYPE_CHECKING, Literal, cast -import httpx - +from vercel._internal.byte_stream import ( + StagingFileRuntime, +) +from vercel._internal.http import StreamingResponse from vercel._internal.unstable.sandbox.api_client import SandboxApiClient from vercel._internal.unstable.sandbox.errors import ( SandboxApiError, @@ -14,11 +17,14 @@ SandboxFilesystemWriteError, SandboxPathNotFoundError, SandboxResponseError, + SandboxUploadSizeMismatchError, ) +from vercel._internal.unstable.sandbox.log_stream import _parse_command_log_record from vercel._internal.unstable.sandbox.models import ( _OMITTED, DirectoryEntry, NetworkPolicy, + ProcessLog, SandboxQuery, SandboxQueryByCreatedAt, SandboxQueryByCurrentSnapshotId, @@ -31,10 +37,12 @@ SnapshotRetention, SnapshotRetentionUpdate, TagFilter, - _WriteFile, ) from vercel._internal.unstable.sandbox.options import SandboxServiceOptions from vercel._internal.unstable.sandbox.process_output import ProcessOutputRouter +from vercel._internal.unstable.sandbox.runtime_common import ( + _StreamUploadFileEntry, +) from vercel._internal.unstable.sandbox.state import ( CompletedProcessState, ProcessState, @@ -46,6 +54,7 @@ SnapshotsPageState, SnapshotState, ) +from vercel._internal.unstable.sandbox.streaming_archive import ArchiveRequestWriter if TYPE_CHECKING: from vercel._internal.unstable.session import _BaseSdkSession @@ -190,11 +199,13 @@ def __init__( options: SandboxServiceOptions, ensure_open: Callable[[], None], sleep: AsyncSleep, + staging_file_runtime: StagingFileRuntime, ) -> None: self._api_client = api_client self._options = options self._ensure_open = ensure_open self._sleep = sleep + self._staging_file_runtime = staging_file_runtime @property def api_client(self) -> SandboxApiClient: @@ -204,6 +215,10 @@ def api_client(self) -> SandboxApiClient: def options(self) -> SandboxServiceOptions: return self._options + @property + def staging_file_runtime(self) -> StagingFileRuntime: + return self._staging_file_runtime + async def _wait_for_ready_sandbox( self, sandbox: SandboxState, *, project_id: str | None = None ) -> SandboxState: @@ -562,31 +577,110 @@ async def mkdir( ) from error raise - async def read_bytes(self, *, session_id: str, path: str, cwd: str | None = None) -> bytes: + async def write_stream_archive( + self, + *, + session_id: str, + entries: Sequence[_StreamUploadFileEntry], + paths: tuple[str, ...], + cwd: str, + ) -> None: + self._ensure_open() + await self._write_stream_archive( + session_id=session_id, + entries=entries, + paths=paths, + cwd=cwd, + ) + + async def _write_stream_archive( + self, + *, + session_id: str, + entries: Sequence[_StreamUploadFileEntry], + paths: tuple[str, ...], + cwd: str, + ) -> None: + if not entries: + return + async with self.open_archive_upload( + session_id=session_id, + paths=paths, + cwd=cwd, + ) as upload: + for entry in entries: + await upload.add_source(entry) + + @asynccontextmanager + async def open_archive_upload( + self, + *, + session_id: str, + paths: tuple[str, ...], + cwd: str, + ) -> AsyncGenerator["SandboxArchiveUpload", None]: self._ensure_open() try: - return await self._api_client.read_bytes(session_id=session_id, path=path, cwd=cwd) + async with self._api_client.write_files_request(session_id=session_id) as request: + writer = ArchiveRequestWriter(request, 64 * 1024) + upload = SandboxArchiveUpload( + writer=writer, + paths=paths, + cwd=cwd, + ) + try: + yield upload + except BaseException: + if not upload.finished: + await upload.abort() + raise + else: + if not upload.finished: + await upload.finish() + except SandboxApiError as error: + raise SandboxFilesystemWriteError(paths=paths, cwd=cwd, cause=error) from error + + async def open_read_response( + self, + *, + operation: str, + session_id: str, + path: str, + cwd: str | None = None, + ) -> StreamingResponse: + self._ensure_open() + try: + return await self._api_client.open_read_response( + session_id=session_id, path=path, cwd=cwd + ) except SandboxApiError as error: if error.code in _MISSING_PATH_ERROR_CODES: raise SandboxPathNotFoundError( - path, operation="read_bytes", cwd=cwd, cause=error + path, operation=operation, cwd=cwd, cause=error ) from error raise - async def write_files( + async def read_bytes( self, *, + operation: str, session_id: str, - files: Sequence[_WriteFile], - cwd: str, - ) -> None: - self._ensure_open() + path: str, + cwd: str | None, + ) -> bytes: + response = await self.open_read_response( + operation=operation, + session_id=session_id, + path=path, + cwd=cwd, + ) + data = bytearray() try: - await self._api_client.write_files(session_id=session_id, files=files, cwd=cwd) - except SandboxApiError as error: - raise SandboxFilesystemWriteError( - paths=tuple(file.path for file in files), cwd=cwd, cause=error - ) from error + async for chunk in response: + data.extend(chunk) + finally: + await response.aclose() + return bytes(data) async def _filesystem_command( self, @@ -775,12 +869,111 @@ async def send_process_signal( session_id=session_id, command_id=process_id, signal=signal ) - async def process_logs_response(self, *, session_id: str, process_id: str) -> httpx.Response: + async def process_logs_response(self, *, session_id: str, process_id: str) -> StreamingResponse: self._ensure_open() return await self._api_client.command_logs_response( session_id=session_id, command_id=process_id ) + async def process_logs( + self, *, session_id: str, process_id: str + ) -> AsyncGenerator[ProcessLog, None]: + response = await self.process_logs_response(session_id=session_id, process_id=process_id) + try: + async for line in response.aiter_lines(): + if line: + event = _parse_command_log_record(line) + if event is not None: + yield event + finally: + await response.aclose() + + +class SandboxArchiveUpload: + """Service-owned lifecycle for one multi-entry archive upload.""" + + _CHUNK_SIZE = 64 * 1024 + + def __init__( + self, + *, + writer: ArchiveRequestWriter, + paths: tuple[str, ...], + cwd: str, + ) -> None: + self._writer = writer + self._paths = paths + self._cwd = cwd + self._finished = False + + @property + def finished(self) -> bool: + return self._finished + + async def start_entry(self, archive_path: str, size: int, mode: int | None) -> None: + await self._writer.start_entry(archive_path, size, mode) + + async def finish_entry(self) -> None: + await self._writer.finish_entry() + + async def add_source(self, entry: _StreamUploadFileEntry) -> None: + source = entry.source + await self.start_entry(entry.archive_path or entry.path, entry.size, entry.mode) + remaining = entry.size + while remaining > 0: + chunk = await source.read(min(self._CHUNK_SIZE, remaining)) + if not isinstance(chunk, bytes): + raise TypeError(f"Source produced non-bytes chunk of type {type(chunk).__name__}") + if not chunk: + raise SandboxUploadSizeMismatchError( + entry.path, + declared=entry.size, + consumed=entry.size - remaining, + early_end=True, + ) + consumed = entry.size - remaining + len(chunk) + if len(chunk) > remaining: + raise SandboxUploadSizeMismatchError( + entry.path, + declared=entry.size, + consumed=consumed, + early_end=False, + ) + await self.write(chunk) + remaining -= len(chunk) + + trailing = await source.read(1) + if not isinstance(trailing, bytes): + raise TypeError(f"Source produced non-bytes chunk of type {type(trailing).__name__}") + if trailing: + raise SandboxUploadSizeMismatchError( + entry.path, + declared=entry.size, + consumed=entry.size + len(trailing), + early_end=False, + ) + await self.finish_entry() + + async def write(self, data: bytes) -> None: + await self._writer.write(data) + + async def flush(self) -> None: + await self._writer.write(b"") + + async def finish(self) -> None: + try: + await self._writer.finish() + except SandboxApiError as error: + raise SandboxFilesystemWriteError( + paths=self._paths, cwd=self._cwd, cause=error + ) from error + finally: + self._finished = True + + async def abort(self) -> None: + self._finished = True + await self._writer.abort() + def get_sandbox_service(session: "_BaseSdkSession") -> SandboxService: def factory() -> SandboxService: @@ -790,10 +983,12 @@ def factory() -> SandboxService: base_url=options.base_url, credentials_factory=options.credentials_factory, transport=session.get_transport(), + file_transfer_timeout=options.file_transfer_timeout, ), options=options, ensure_open=session.check_open, sleep=session.sleep, + staging_file_runtime=session.get_staging_file_runtime(), ) return session.get_or_create_service(SandboxService, factory) diff --git a/src/vercel/_internal/unstable/sandbox/streaming_archive.py b/src/vercel/_internal/unstable/sandbox/streaming_archive.py new file mode 100644 index 0000000..d3154b8 --- /dev/null +++ b/src/vercel/_internal/unstable/sandbox/streaming_archive.py @@ -0,0 +1,177 @@ +"""Bounded-memory streaming tar+gzip encoder for sandbox filesystem transfers.""" + +import tarfile +import zlib +from collections.abc import Generator, Iterator +from typing import Protocol + +import anyio + +from vercel._internal.unstable.sandbox.runtime_common import _validate_file_mode + + +class _TarGzipEncoder: + """Synchronous state machine that builds a gzipped tar archive chunk by chunk. + + Each entry is added sequentially: call ``add_entry``, feed data with + ``write_entry_data``, then ``finish_entry``. Compressed chunks can be drained + at any time via ``next_chunk``. After all entries, call ``finalize`` to get + the remaining chunks (including the trailer and gzip flush). + """ + + __slots__ = ("_chunk_size", "_compressor", "_buffer", "_current_entry", "_finalized") + + def __init__(self, chunk_size: int) -> None: + if chunk_size < 1: + raise ValueError("chunk_size must be a positive integer") + self._chunk_size = chunk_size + self._compressor = zlib.compressobj(wbits=31) + self._buffer: bytearray = bytearray() + self._current_entry: _CurrentEntry | None = None + self._finalized = False + + def add_entry(self, path: str, size: int, mode: int | None = None) -> int: + if self._finalized: + raise RuntimeError("Encoder already finalized") + if self._current_entry is not None: + raise RuntimeError("Previous entry not finished") + + info = tarfile.TarInfo(name=path) + info.size = size + normalized_mode = _validate_file_mode(mode) + info.mode = 0o644 if normalized_mode is None else normalized_mode + info.uid = 0 + info.gid = 0 + info.mtime = 0 + info.uname = "" + info.gname = "" + + header = info.tobuf(format=tarfile.PAX_FORMAT) + self._compress(header) + self._current_entry = _CurrentEntry(size=size) + return size + + def write_entry_data(self, data: bytes) -> None: + if self._current_entry is None: + raise RuntimeError("No active entry") + entry = self._current_entry + if entry.written + len(data) > entry.size: + raise ValueError("Trailing data: would exceed declared entry size") + self._compress(data) + entry.written += len(data) + + def finish_entry(self) -> None: + if self._current_entry is None: + raise RuntimeError("No active entry") + entry = self._current_entry + if entry.written < entry.size: + raise ValueError("Early end: entry not fully written") + + remainder = entry.size % 512 + if remainder > 0: + self._compress(b"\0" * (512 - remainder)) + + self._current_entry = None + + def finalize(self) -> Iterator[bytes]: + if self._finalized: + raise RuntimeError("Encoder already finalized") + if self._current_entry is not None: + raise RuntimeError("Cannot finalize with active entry") + + self._finalized = True + self._compress(b"\0" * 1024) + + flushed = self._compressor.flush() + if flushed: + self._buffer.extend(flushed) + + chunks = list(self.drain()) + if self._buffer: + chunks.append(bytes(self._buffer)) + self._buffer.clear() + return iter(chunks) + + def drain(self) -> Generator[bytes, None, None]: + """Drain all output currently available in bounded chunks.""" + while len(self._buffer) >= self._chunk_size: + chunk = bytes(self._buffer[: self._chunk_size]) + del self._buffer[: self._chunk_size] + yield chunk + if self._buffer: + yield bytes(self._buffer) + self._buffer.clear() + + def next_chunk(self) -> bytes | None: + return next(self.drain(), None) + + def _compress(self, data: bytes) -> None: + compressed = self._compressor.compress(data) + if compressed: + self._buffer.extend(compressed) + + +class _ArchiveUpload(Protocol): + async def write(self, data: bytes) -> None: ... + + async def finish(self) -> None: ... + + async def abort(self) -> None: ... + + +class ArchiveRequestWriter: + """Push raw entry data through the archive encoder into one request.""" + + def __init__(self, request: _ArchiveUpload, chunk_size: int) -> None: + self._request = request + self._encoder = _TarGzipEncoder(chunk_size) + self._entry_open = False + self._closed = False + + async def _drain(self) -> None: + for chunk in self._encoder.drain(): + await self._request.write(chunk) + + async def start_entry(self, path: str, size: int, mode: int | None) -> None: + if self._closed: + raise anyio.ClosedResourceError + self._encoder.add_entry(path, size, mode) + self._entry_open = True + await self._drain() + + async def write(self, data: bytes) -> None: + if self._closed: + raise anyio.ClosedResourceError + self._encoder.write_entry_data(data) + await self._drain() + + async def finish_entry(self) -> None: + if self._closed: + raise anyio.ClosedResourceError + self._encoder.finish_entry() + self._entry_open = False + await self._drain() + + async def finish(self) -> None: + if self._closed: + raise anyio.ClosedResourceError + if self._entry_open: + await self.finish_entry() + self._closed = True + await self._drain() + for chunk in self._encoder.finalize(): + await self._request.write(chunk) + await self._request.finish() + + async def abort(self) -> None: + if not self._closed: + self._closed = True + await self._request.abort() + + +class _CurrentEntry: + __slots__ = ("size", "written") + + def __init__(self, size: int) -> None: + self.size = size + self.written = 0 diff --git a/src/vercel/_internal/unstable/sandbox/sync_filesystem_handle.py b/src/vercel/_internal/unstable/sandbox/sync_filesystem_handle.py new file mode 100644 index 0000000..dbc3c52 --- /dev/null +++ b/src/vercel/_internal/unstable/sandbox/sync_filesystem_handle.py @@ -0,0 +1,227 @@ +"""Synchronous facades for shared Sandbox file-handle state machines.""" + +from collections.abc import Iterable, Iterator +from contextlib import AbstractAsyncContextManager +from types import TracebackType + +from vercel._internal.iter_coroutine import iter_coroutine +from vercel._internal.unstable.sandbox.filesystem_handle_core import ( + BinaryReaderCore, + BinaryWriterCore, + TextReaderCore, + TextWriterCore, +) + + +class _SyncHandle: + _core: BinaryReaderCore | TextReaderCore | BinaryWriterCore | TextWriterCore + + @property + def name(self) -> str: + return self._core.name + + @property + def mode(self) -> str: + return self._core.mode + + @property + def closed(self) -> bool: + return self._core.closed + + def readable(self) -> bool: + return self._core.readable() + + def writable(self) -> bool: + return self._core.writable() + + def seekable(self) -> bool: + return False + + +class SyncSandboxBinaryReader(_SyncHandle, Iterator[bytes]): + _core: BinaryReaderCore + + def __init__(self, core: BinaryReaderCore) -> None: + self._core = core + + def __enter__(self) -> "SyncSandboxBinaryReader": + iter_coroutine(self._core.enter()) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> None: + try: + self.close() + except BaseException: + if exc_type is None: + raise + + def read(self, size: int = -1) -> bytes: + return iter_coroutine(self._core.read(size)) + + def readline(self, size: int = -1) -> bytes: + return iter_coroutine(self._core.readline(size)) + + def readinto(self, buffer: object) -> int: + return iter_coroutine(self._core.readinto(buffer)) + + def __iter__(self) -> "SyncSandboxBinaryReader": + return self + + def __next__(self) -> bytes: + line = self.readline() + if not line: + raise StopIteration + return line + + def close(self) -> None: + iter_coroutine(self._core.close()) + + +class SyncSandboxTextReader(_SyncHandle, Iterator[str]): + _core: TextReaderCore + + def __init__(self, core: TextReaderCore) -> None: + self._core = core + + def __enter__(self) -> "SyncSandboxTextReader": + iter_coroutine(self._core.enter()) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> None: + try: + self.close() + except BaseException: + if exc_type is None: + raise + + def read(self, size: int = -1) -> str: + return iter_coroutine(self._core.read(size)) + + def readline(self, size: int = -1) -> str: + return iter_coroutine(self._core.readline(size)) + + def __iter__(self) -> "SyncSandboxTextReader": + return self + + def __next__(self) -> str: + line = self.readline() + if not line: + raise StopIteration + return line + + def close(self) -> None: + iter_coroutine(self._core.close()) + + +class SyncSandboxBinaryWriter(_SyncHandle): + _core: BinaryWriterCore + + def __init__(self, core: BinaryWriterCore) -> None: + self._core = core + self._lifecycle: AbstractAsyncContextManager[None] | None = None + + def __enter__(self) -> "SyncSandboxBinaryWriter": + lifecycle = self._core.lifecycle() + self._lifecycle = lifecycle + try: + iter_coroutine(lifecycle.__aenter__()) + except BaseException: + self._lifecycle = None + raise + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> None: + lifecycle, self._lifecycle = self._lifecycle, None + if lifecycle is None: + return + if exc_type is not None: + try: + iter_coroutine(lifecycle.__aexit__(exc_type, exc, traceback)) + except BaseException: + pass + else: + iter_coroutine(lifecycle.__aexit__(None, None, None)) + + def write(self, data: bytes, /) -> int: + return iter_coroutine(self._core.write(data)) + + def writelines(self, lines: Iterable[bytes], /) -> None: + for line in lines: + self.write(line) + + def flush(self) -> None: + iter_coroutine(self._core.flush()) + + def close(self) -> None: + lifecycle, self._lifecycle = self._lifecycle, None + if lifecycle is None: + iter_coroutine(self._core.close()) + else: + iter_coroutine(lifecycle.__aexit__(None, None, None)) + + +class SyncSandboxTextWriter(_SyncHandle): + _core: TextWriterCore + + def __init__(self, core: TextWriterCore) -> None: + self._core = core + self._lifecycle: AbstractAsyncContextManager[None] | None = None + + def __enter__(self) -> "SyncSandboxTextWriter": + lifecycle = self._core.lifecycle() + self._lifecycle = lifecycle + try: + iter_coroutine(lifecycle.__aenter__()) + except BaseException: + self._lifecycle = None + raise + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> None: + lifecycle, self._lifecycle = self._lifecycle, None + if lifecycle is None: + return + if exc_type is not None: + try: + iter_coroutine(lifecycle.__aexit__(exc_type, exc, traceback)) + except BaseException: + pass + else: + iter_coroutine(lifecycle.__aexit__(None, None, None)) + + def write(self, text: str, /) -> int: + return iter_coroutine(self._core.write(text)) + + def writelines(self, lines: Iterable[str], /) -> None: + for line in lines: + self.write(line) + + def flush(self) -> None: + iter_coroutine(self._core.flush()) + + def close(self) -> None: + lifecycle, self._lifecycle = self._lifecycle, None + if lifecycle is None: + iter_coroutine(self._core.close()) + else: + iter_coroutine(lifecycle.__aexit__(None, None, None)) diff --git a/src/vercel/_internal/unstable/sandbox/sync_runtime.py b/src/vercel/_internal/unstable/sandbox/sync_runtime.py index 0271e62..f0ca876 100644 --- a/src/vercel/_internal/unstable/sandbox/sync_runtime.py +++ b/src/vercel/_internal/unstable/sandbox/sync_runtime.py @@ -5,8 +5,9 @@ from collections.abc import Callable, Iterator, Mapping, Sequence from datetime import timedelta from types import TracebackType -from typing import Any, TextIO +from typing import Any, Literal, TextIO, overload +from vercel._internal.byte_stream import SyncByteStreamRuntime from vercel._internal.iter_coroutine import iter_coroutine from vercel._internal.polyfills import Self from vercel._internal.time import parse_duration_seconds, parse_required_duration_seconds @@ -15,7 +16,14 @@ SandboxResponseError, SandboxTerminalStateError, ) -from vercel._internal.unstable.sandbox.log_stream import _parse_command_log_record +from vercel._internal.unstable.sandbox.filesystem_handle_common import _validate_open_options +from vercel._internal.unstable.sandbox.filesystem_handle_core import ( + BinaryReaderCore, + BinaryWriterCore, + FilesystemHandleBinding, + TextReaderCore, + TextWriterCore, +) from vercel._internal.unstable.sandbox.models import ( _OMITTED, CompletedProcess, @@ -50,9 +58,12 @@ SandboxHandleBase, SnapshotHandleBase, _coerce_remote_path, + _normalize_tar_path, _ProcessHandleState, _SandboxFilesystemBatchBase, _signal_number, + _UploadFileEntry, + _validate_file_mode, ) from vercel._internal.unstable.sandbox.service import SandboxService, _SandboxTerminalState from vercel._internal.unstable.sandbox.state import ( @@ -61,6 +72,12 @@ SandboxState, SnapshotState, ) +from vercel._internal.unstable.sandbox.sync_filesystem_handle import ( + SyncSandboxBinaryReader, + SyncSandboxBinaryWriter, + SyncSandboxTextReader, + SyncSandboxTextWriter, +) from vercel._internal.unstable.sandbox.text_reader import SyncTextReader, _sync_text_readers @@ -97,9 +114,7 @@ def __init__( super().__init__(payload) self._service = service self.stdout, self.stderr = _sync_text_readers( - lambda: iter_coroutine( - service.process_logs_response(session_id=self._session_id, process_id=self.id) - ), + lambda: service.process_logs_response(session_id=self._session_id, process_id=self.id), stdout=stdout, stderr=stderr, ) @@ -206,11 +221,115 @@ def __init__( self._session_id = session_id self._write_files_cwd = write_files_cwd + @overload + def open( + self, + path: RemotePath, + mode: Literal["r"] = "r", + *, + cwd: RemotePath | None = None, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + size: None = None, + permissions: None = None, + ) -> SyncSandboxTextReader: ... + + @overload + def open( + self, + path: RemotePath, + mode: Literal["rb"], + *, + cwd: RemotePath | None = None, + encoding: None = None, + errors: None = None, + newline: None = None, + size: None = None, + permissions: None = None, + ) -> SyncSandboxBinaryReader: ... + + @overload + def open( + self, + path: RemotePath, + mode: Literal["w"], + *, + cwd: RemotePath | None = None, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + size: None = None, + permissions: int | None = None, + ) -> SyncSandboxTextWriter: ... + + @overload + def open( + self, + path: RemotePath, + mode: Literal["wb"], + *, + cwd: RemotePath | None = None, + encoding: None = None, + errors: None = None, + newline: None = None, + size: int | None = None, + permissions: int | None = None, + ) -> SyncSandboxBinaryWriter: ... + + def open( + self, + path: RemotePath, + mode: str = "r", + *, + cwd: RemotePath | None = None, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + size: int | None = None, + permissions: int | None = None, + ) -> ( + SyncSandboxBinaryReader + | SyncSandboxTextReader + | SyncSandboxBinaryWriter + | SyncSandboxTextWriter + ): + """Create a lazy, single-use sequential file handle.""" + path, mode, encoding, errors, newline, size, permissions = _validate_open_options( + path, + mode, + encoding=encoding, + errors=errors, + newline=newline, + size=size, + permissions=permissions, + ) + normalized_cwd = None if cwd is None else _coerce_remote_path(cwd) + binding = FilesystemHandleBinding( + service=self._service, + runtime=self._service.staging_file_runtime, + session_id=self._session_id, + write_files_cwd=self._write_files_cwd, + path=path, + cwd=normalized_cwd, + ) + if mode == "rb": + return SyncSandboxBinaryReader(BinaryReaderCore(binding)) + if mode == "r": + return SyncSandboxTextReader(TextReaderCore(binding, encoding, errors, newline)) + if mode == "wb": + return SyncSandboxBinaryWriter( + BinaryWriterCore(binding.write_target_source(size=size, permissions=permissions)) + ) + return SyncSandboxTextWriter( + TextWriterCore(binding, encoding, errors, newline, permissions) + ) + async def _collect_output(self, command: ProcessState) -> tuple[str, str]: stdout: list[str] = [] stderr: list[str] = [] - for event in _process_logs( - self._service, session_id=command.session_id, process_id=command.id + async for event in self._service.process_logs( + session_id=command.session_id, process_id=command.id ): if event.stream == "stdout": stdout.append(event.data) @@ -256,6 +375,7 @@ def read_bytes(self, path: RemotePath, *, cwd: RemotePath | None = None) -> byte """ return iter_coroutine( self._service.read_bytes( + operation="read_bytes", session_id=self._session_id(), path=_coerce_remote_path(path), cwd=None if cwd is None else _coerce_remote_path(cwd), @@ -344,11 +464,25 @@ def write_text( ) def _write_files(self, files: Sequence[_WriteFile], *, cwd: RemotePath | None = None) -> None: + for file in files: + _validate_file_mode(file.mode) + resolved_cwd = self._write_files_cwd(cwd) + entries = [ + _UploadFileEntry( + path=f.path, + size=len(f.content), + source=SyncByteStreamRuntime.reader(f.content), + mode=f.mode, + archive_path=_normalize_tar_path(f.path, cwd=resolved_cwd), + ) + for f in files + ] iter_coroutine( - self._service.write_files( + self._service.write_stream_archive( session_id=self._session_id(), - files=files, - cwd=self._write_files_cwd(cwd), + entries=entries, + paths=tuple(entry.path for entry in entries), + cwd=resolved_cwd, ) ) @@ -1213,14 +1347,16 @@ def get_snapshot(service: SandboxService, *, snapshot_id: str) -> SyncSnapshot: def _process_logs( service: SandboxService, *, session_id: str, process_id: str ) -> Iterator[ProcessLog]: - response = iter_coroutine( - service.process_logs_response(session_id=session_id, process_id=process_id) - ) + stream = service.process_logs(session_id=session_id, process_id=process_id) + + async def next_log() -> ProcessLog: + return await anext(stream) + try: - for line in response.iter_lines(): - if line: - event = _parse_command_log_record(line) - if event is not None: - yield event + while True: + try: + yield iter_coroutine(next_log()) + except StopAsyncIteration: + return finally: - response.close() + iter_coroutine(stream.aclose()) diff --git a/src/vercel/_internal/unstable/sandbox/text_reader.py b/src/vercel/_internal/unstable/sandbox/text_reader.py index f1c99dd..c93c189 100644 --- a/src/vercel/_internal/unstable/sandbox/text_reader.py +++ b/src/vercel/_internal/unstable/sandbox/text_reader.py @@ -1,21 +1,25 @@ -"""Text reader contracts and private process log stream implementations.""" +"""Process log readers backed by one async-shaped streaming core.""" +import inspect import subprocess +import threading from abc import ABC, abstractmethod from collections import deque from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Mapping from types import TracebackType +from typing import Protocol, TypeAlias import anyio -import httpx +from vercel._internal.http import StreamingResponse +from vercel._internal.iter_coroutine import iter_coroutine from vercel._internal.unstable.sandbox.log_stream import _parse_command_log_record from vercel._internal.unstable.sandbox.models import ProcessLogStream +_OpenResponse: TypeAlias = Callable[[], StreamingResponse | Awaitable[StreamingResponse]] -class _TextBuffer: - __slots__ = ("_chunks", "_head", "_size", "eof") +class _TextBuffer: def __init__(self) -> None: self._chunks: deque[str] = deque() self._head = 0 @@ -35,34 +39,22 @@ def clear(self) -> None: self._head = 0 self._size = 0 - def _prefix(self, size: int) -> str: - remaining = size - parts: list[str] = [] - for index, chunk in enumerate(self._chunks): - start = self._head if index == 0 else 0 - part = chunk[start : start + remaining] - parts.append(part) - remaining -= len(part) - if remaining == 0: - break - return "".join(parts) - def take(self, size: int) -> str: size = self._size if size < 0 else min(size, self._size) - value = self._prefix(size) remaining = size + parts: list[str] = [] while remaining: chunk = self._chunks[0] available = len(chunk) - self._head - if remaining < available: - self._head += remaining - remaining = 0 - else: - remaining -= available + count = min(remaining, available) + parts.append(chunk[self._head : self._head + count]) + self._head += count + remaining -= count + if self._head == len(chunk): self._chunks.popleft() self._head = 0 self._size -= size - return value + return "".join(parts) def take_line(self) -> str | None: seen = 0 @@ -78,58 +70,30 @@ def take_line(self) -> str | None: class TextReader(anyio.abc.ObjectReceiveStream[str], ABC): - """Read one remote process output stream asynchronously. - - A reader consumes its stream once. Use ``read`` for buffered reads, - ``readline`` or async iteration for line-oriented reads, and ``aclose`` to - release the underlying response early. - """ - @property @abstractmethod - def closed(self) -> bool: - """Return whether this reader has been closed.""" - ... + def closed(self) -> bool: ... @abstractmethod - async def read(self, size: int = -1) -> str: - """Read up to ``size`` characters, or the remaining stream.""" - ... + async def read(self, size: int = -1) -> str: ... @abstractmethod - async def readline(self) -> str: - """Read through the next newline or return the remaining text at EOF.""" - ... + async def readline(self) -> str: ... class SyncTextReader(ABC): - """Read one remote process output stream synchronously. - - A reader consumes its stream once. Use ``read`` for buffered reads, - ``readline`` or iteration for line-oriented reads, and ``close`` to release - the underlying response early. - """ - @property @abstractmethod - def closed(self) -> bool: - """Return whether this reader has been closed.""" - ... + def closed(self) -> bool: ... @abstractmethod - def read(self, size: int = -1) -> str: - """Read up to ``size`` characters, or the remaining stream.""" - ... + def read(self, size: int = -1) -> str: ... @abstractmethod - def readline(self) -> str: - """Read through the next newline or return the remaining text at EOF.""" - ... + def readline(self) -> str: ... @abstractmethod - def close(self) -> None: - """Close the reader and discard unread output from this stream.""" - ... + def close(self) -> None: ... def __iter__(self) -> Iterator[str]: while line := self.readline(): @@ -149,40 +113,78 @@ def __exit__( self.close() -def _distinct_buffers( - routes: Mapping[ProcessLogStream, "_TextBuffer | None"], -) -> "list[_TextBuffer]": +def _distinct_buffers(routes: Mapping[ProcessLogStream, _TextBuffer | None]) -> list[_TextBuffer]: return list({id(buffer): buffer for buffer in routes.values() if buffer is not None}.values()) -class _AsyncTextTransport: - __slots__ = ("_broken", "_lines", "_live", "_lock", "_open_response", "_response", "_routes") +class _PumpLock(Protocol): + async def acquire(self) -> None: ... + + def release(self) -> None: ... + +class _SyncPumpLock: + def __init__(self) -> None: + self._lock = threading.Lock() + + async def acquire(self) -> None: + self._lock.acquire() + + def release(self) -> None: + self._lock.release() + + +class _AsyncPumpLock: + def __init__(self) -> None: + self._lock = anyio.Lock() + + async def acquire(self) -> None: + await self._lock.acquire() + + def release(self) -> None: + self._lock.release() + + +def _normalize_open_response( + open_response: _OpenResponse, +) -> Callable[[], Awaitable[StreamingResponse]]: + async def open_stream() -> StreamingResponse: + result = open_response() + if inspect.isawaitable(result): + result = await result + if isinstance(result, StreamingResponse): + return result + raise TypeError("open_response must return an HTTP streaming response") + + return open_stream + + +class _TextTransportCore: def __init__( self, - open_response: Callable[[], Awaitable[httpx.Response]], + open_response: Callable[[], Awaitable[StreamingResponse]], routes: Mapping[ProcessLogStream, _TextBuffer | None], + lock: _PumpLock, ) -> None: self._open_response = open_response - self._response: httpx.Response | None = None + self._response: StreamingResponse | None = None self._lines: AsyncIterator[str] | None = None self._routes = dict(routes) self._live = len(_distinct_buffers(routes)) self._broken = False - self._lock = anyio.Lock() + self._lock = lock async def _cleanup(self) -> None: response, self._response = self._response, None self._lines = None if response is not None: - with anyio.CancelScope(shield=True): - await response.aclose() + await response.aclose() async def pump(self) -> None: - # Fail fast without acquiring a lock if self._broken: raise anyio.BrokenResourceError - async with self._lock: + await self._lock.acquire() + try: if self._broken: raise anyio.BrokenResourceError try: @@ -212,56 +214,76 @@ async def pump(self) -> None: buffer.eof = True await self._cleanup() raise + finally: + self._lock.release() async def close(self, buffer: _TextBuffer) -> None: - buffer.clear() - buffer.eof = True - for stream, target in self._routes.items(): - if target is buffer: - self._routes[stream] = None - self._live -= 1 - if self._live == 0: - await self._cleanup() - - -class _TextReader(TextReader): - __slots__ = ("_buffer", "_closed", "_guard", "_transport") + await self._lock.acquire() + try: + buffer.clear() + buffer.eof = True + for stream, target in self._routes.items(): + if target is buffer: + self._routes[stream] = None + self._live -= 1 + if self._live == 0: + await self._cleanup() + finally: + self._lock.release() - def __init__(self, transport: _AsyncTextTransport, buffer: _TextBuffer) -> None: - self._transport = transport - self._buffer = buffer - self._closed = False - self._guard = anyio.ResourceGuard("reading from") - @property - def closed(self) -> bool: - return self._closed +class _ReaderCore: + def __init__(self, transport: _TextTransportCore, buffer: _TextBuffer) -> None: + self.transport = transport + self.buffer = buffer + self.closed = False - def _ensure_open(self) -> None: - if self._closed: + def ensure_open(self) -> None: + if self.closed: raise anyio.ClosedResourceError - if self._transport._broken: + if self.transport._broken: raise anyio.BrokenResourceError async def read(self, size: int = -1) -> str: - self._ensure_open() + self.ensure_open() if size < -1: raise ValueError("size must be -1 or non-negative") if size == 0: return "" + while not self.buffer.eof and (size < 0 or len(self.buffer) < size): + await self.transport.pump() + return self.buffer.take(size) + + async def readline(self) -> str: + self.ensure_open() + while True: + line = self.buffer.take_line() + if line is not None: + return line + await self.transport.pump() + + async def close(self) -> None: + if not self.closed: + self.closed = True + await self.transport.close(self.buffer) + + +class _TextReader(TextReader): + def __init__(self, core: _ReaderCore) -> None: + self._core = core + self._guard = anyio.ResourceGuard("reading from") + + @property + def closed(self) -> bool: + return self._core.closed + + async def read(self, size: int = -1) -> str: with self._guard: - while not self._buffer.eof and (size < 0 or len(self._buffer) < size): - await self._transport.pump() - return self._buffer.take(size) + return await self._core.read(size) async def readline(self) -> str: - self._ensure_open() with self._guard: - while True: - line = self._buffer.take_line() - if line is not None: - return line - await self._transport.pump() + return await self._core.readline() async def receive(self) -> str: line = await self.readline() @@ -270,125 +292,29 @@ async def receive(self) -> str: return line async def aclose(self) -> None: - if not self._closed: - self._closed = True - await self._transport.close(self._buffer) - - -class _SyncTextTransport: - __slots__ = ("_broken", "_lines", "_live", "_open_response", "_response", "_routes") - - def __init__( - self, - open_response: Callable[[], httpx.Response], - routes: Mapping[ProcessLogStream, _TextBuffer | None], - ) -> None: - self._open_response = open_response - self._response: httpx.Response | None = None - self._lines: Iterator[str] | None = None - self._routes = dict(routes) - self._live = len(_distinct_buffers(routes)) - self._broken = False - - def _cleanup(self) -> None: - response, self._response = self._response, None - self._lines = None - if response is not None: - response.close() - - def pump(self) -> None: - if self._broken: - raise anyio.BrokenResourceError - try: - if self._response is None: - self._response = self._open_response() - self._lines = self._response.iter_lines() - assert self._lines is not None - while True: - try: - line = next(self._lines) - except StopIteration: - for buffer in _distinct_buffers(self._routes): - buffer.eof = True - self._cleanup() - return - if not line: - continue - event = _parse_command_log_record(line) - if event is not None: - target = self._routes[event.stream] - if target is not None: - target.append(event.data) - return - except BaseException: - self._broken = True - for buffer in _distinct_buffers(self._routes): - buffer.eof = True - self._cleanup() - raise - - def close(self, buffer: _TextBuffer) -> None: - buffer.clear() - buffer.eof = True - for stream, target in self._routes.items(): - if target is buffer: - self._routes[stream] = None - self._live -= 1 - if self._live == 0: - self._cleanup() + with anyio.CancelScope(shield=True): + await self._core.close() class _SyncTextReader(SyncTextReader): - __slots__ = ("_buffer", "_closed", "_guard", "_transport") - - def __init__(self, transport: _SyncTextTransport, buffer: _TextBuffer) -> None: - self._transport = transport - self._buffer = buffer - self._closed = False - self._guard = anyio.ResourceGuard("reading from") + def __init__(self, core: _ReaderCore) -> None: + self._core = core @property def closed(self) -> bool: - return self._closed - - def _ensure_open(self) -> None: - if self._closed: - raise anyio.ClosedResourceError - if self._transport._broken: - raise anyio.BrokenResourceError + return self._core.closed def read(self, size: int = -1) -> str: - self._ensure_open() - if size < -1: - raise ValueError("size must be -1 or non-negative") - if size == 0: - return "" - with self._guard: - while not self._buffer.eof and (size < 0 or len(self._buffer) < size): - self._transport.pump() - return self._buffer.take(size) + return iter_coroutine(self._core.read(size)) def readline(self) -> str: - self._ensure_open() - with self._guard: - while True: - line = self._buffer.take_line() - if line is not None: - return line - self._transport.pump() + return iter_coroutine(self._core.readline()) def close(self) -> None: - if not self._closed: - self._closed = True - self._transport.close(self._buffer) + iter_coroutine(self._core.close()) def _reader_buffers(stdout: int, stderr: int) -> tuple[_TextBuffer | None, _TextBuffer | None]: - """Resolve validated Popen-style destinations to per-stream buffers. - - ``subprocess.STDOUT`` makes stderr share stdout's buffer (or its absence), - and ``subprocess.DEVNULL`` drops a stream entirely. - """ stdout_buffer = _TextBuffer() if stdout == subprocess.PIPE else None if stderr == subprocess.STDOUT: stderr_buffer = stdout_buffer @@ -399,43 +325,53 @@ def _reader_buffers(stdout: int, stderr: int) -> tuple[_TextBuffer | None, _Text return stdout_buffer, stderr_buffer -def _text_readers( - open_response: Callable[[], Awaitable[httpx.Response]], - *, - stdout: int = subprocess.PIPE, - stderr: int = subprocess.PIPE, -) -> tuple[TextReader | None, TextReader | None]: +def _cores( + open_response: Callable[[], Awaitable[StreamingResponse]], + stdout: int, + stderr: int, + lock: _PumpLock, +) -> tuple[_ReaderCore | None, _ReaderCore | None]: stdout_buffer, stderr_buffer = _reader_buffers(stdout, stderr) if stdout_buffer is None and stderr_buffer is None: return None, None - transport = _AsyncTextTransport( + transport = _TextTransportCore( open_response, {ProcessLogStream.STDOUT: stdout_buffer, ProcessLogStream.STDERR: stderr_buffer}, + lock, ) return ( - None if stdout_buffer is None else _TextReader(transport, stdout_buffer), + None if stdout_buffer is None else _ReaderCore(transport, stdout_buffer), None if stderr_buffer is None or stderr_buffer is stdout_buffer - else _TextReader(transport, stderr_buffer), + else _ReaderCore(transport, stderr_buffer), + ) + + +def _text_readers( + open_response: _OpenResponse, + *, + stdout: int = subprocess.PIPE, + stderr: int = subprocess.PIPE, +) -> tuple[TextReader | None, TextReader | None]: + stdout_core, stderr_core = _cores( + _normalize_open_response(open_response), stdout, stderr, _AsyncPumpLock() + ) + return ( + None if stdout_core is None else _TextReader(stdout_core), + None if stderr_core is None else _TextReader(stderr_core), ) def _sync_text_readers( - open_response: Callable[[], httpx.Response], + open_response: _OpenResponse, *, stdout: int = subprocess.PIPE, stderr: int = subprocess.PIPE, ) -> tuple[SyncTextReader | None, SyncTextReader | None]: - stdout_buffer, stderr_buffer = _reader_buffers(stdout, stderr) - if stdout_buffer is None and stderr_buffer is None: - return None, None - transport = _SyncTextTransport( - open_response, - {ProcessLogStream.STDOUT: stdout_buffer, ProcessLogStream.STDERR: stderr_buffer}, + stdout_core, stderr_core = _cores( + _normalize_open_response(open_response), stdout, stderr, _SyncPumpLock() ) return ( - None if stdout_buffer is None else _SyncTextReader(transport, stdout_buffer), - None - if stderr_buffer is None or stderr_buffer is stdout_buffer - else _SyncTextReader(transport, stderr_buffer), + None if stdout_core is None else _SyncTextReader(stdout_core), + None if stderr_core is None else _SyncTextReader(stderr_core), ) diff --git a/src/vercel/_internal/unstable/session.py b/src/vercel/_internal/unstable/session.py index 849afa2..9d3823f 100644 --- a/src/vercel/_internal/unstable/session.py +++ b/src/vercel/_internal/unstable/session.py @@ -15,6 +15,11 @@ from vercel._internal.unstable.options import ServiceOptions, merge_service_options if TYPE_CHECKING: + from vercel._internal.byte_stream import ( + AsyncByteStreamRuntime, + StagingFileRuntime, + SyncByteStreamRuntime, + ) from vercel._internal.http import AsyncTransport, BaseTransport, SyncTransport ServiceOptionsT = TypeVar("ServiceOptionsT", bound=ServiceOptions) @@ -81,6 +86,9 @@ def get_or_create_service( def get_transport(self) -> "BaseTransport": raise NotImplementedError() + def get_staging_file_runtime(self) -> "StagingFileRuntime": + raise NotImplementedError() + async def sleep(self, seconds: float) -> None: raise NotImplementedError() @@ -104,6 +112,7 @@ def __init__( httpx_client_factory=httpx_client_factory, ) self._transport: AsyncTransport | None = None + self._staging_file_runtime: AsyncByteStreamRuntime | None = None @classmethod def default(cls) -> Self: @@ -169,6 +178,14 @@ def get_transport(self) -> "AsyncTransport": self._transport = AsyncTransport(client) return self._transport + def get_staging_file_runtime(self) -> "AsyncByteStreamRuntime": + from vercel._internal.byte_stream import AsyncByteStreamRuntime + + self.check_open() + if self._staging_file_runtime is None: + self._staging_file_runtime = AsyncByteStreamRuntime() + return self._staging_file_runtime + async def sleep(self, seconds: float) -> None: await anyio.sleep(seconds) @@ -198,6 +215,7 @@ def __init__( httpx_client_factory=httpx_client_factory, ) self._transport: SyncTransport | None = None + self._staging_file_runtime: SyncByteStreamRuntime | None = None @classmethod def default(cls) -> Self: @@ -271,6 +289,14 @@ def get_transport(self) -> "SyncTransport": self._transport = SyncTransport(client) return self._transport + def get_staging_file_runtime(self) -> "SyncByteStreamRuntime": + from vercel._internal.byte_stream import SyncByteStreamRuntime + + self.check_open() + if self._staging_file_runtime is None: + self._staging_file_runtime = SyncByteStreamRuntime() + return self._staging_file_runtime + async def sleep(self, seconds: float) -> None: time.sleep(seconds) diff --git a/src/vercel/unstable/README.md b/src/vercel/unstable/README.md index f8928d1..822fcb9 100644 --- a/src/vercel/unstable/README.md +++ b/src/vercel/unstable/README.md @@ -417,13 +417,30 @@ entries = await sandbox_.fs.listdir("workspace") handle on every operation. It follows a replacement current session only after new sandbox state has been applied to that handle. `SandboxRuntimeSession.fs` remains bound to that specific historical session identity. The async -`SandboxFilesystem` and sync `SyncSandboxFilesystem` expose `mkdir`, +`SandboxFilesystem` and sync `SyncSandboxFilesystem` expose `open`, `mkdir`, `read_bytes`, `read_text`, `write_bytes`, `write_text`, `batch`, `exists`, -`is_file`, `is_dir`, `listdir`, `remove`, and `rename`. A batch stages files -synchronously inside its context and submits one tarball on clean exit. +`is_file`, `is_dir`, `listdir`, `remove`, and `rename`. +A batch stages files synchronously inside its context and submits one tarball on clean exit. `listdir()` returns sorted `DirectoryEntry(path=..., kind=...)` values, where `kind` is `file`, `directory`, `symlink`, or `other`. +`open()` returns a lazy, single-use sequential handle for `"r"`, `"rb"`, `"w"`, +or `"wb"`. Reads stream the response in bounded chunks. Unsized binary and text +writes spool locally and publish on successful close; `"wb"` accepts an exact +`size` to stream directly. `read_bytes()`, `read_text()`, `write_bytes()`, and +`write_text()` remain whole-file conveniences. + +```python +async with await anyio.open_file("input.csv", "rb") as source, box.fs.open( + "workspace/input.csv", "wb" +) as target: + while chunk := await source.read(64 * 1024): + await target.write(chunk) + +async with box.fs.open("workspace/result.json", "rb") as source: + result = await source.read() +``` + `create_process(...)` accepts the `subprocess.Popen` output sentinels. `stdout` accepts `subprocess.PIPE` (default) or `subprocess.DEVNULL`; `stderr` additionally accepts `subprocess.STDOUT`, which merges stderr output into the diff --git a/src/vercel/unstable/sandbox/__init__.py b/src/vercel/unstable/sandbox/__init__.py index 7e613f1..008dd5b 100644 --- a/src/vercel/unstable/sandbox/__init__.py +++ b/src/vercel/unstable/sandbox/__init__.py @@ -2,6 +2,12 @@ from collections.abc import AsyncIterator, Mapping +from vercel._internal.unstable.sandbox.async_filesystem_handle import ( + SandboxBinaryReader, + SandboxBinaryWriter, + SandboxTextReader, + SandboxTextWriter, +) from vercel._internal.unstable.sandbox.async_runtime import ( CreateSandboxOperation, Process, @@ -26,12 +32,14 @@ SandboxError, SandboxFilesystemCommandError, SandboxFilesystemError, + SandboxFilesystemTransferError, SandboxFilesystemWriteError, SandboxInvalidHandleError, SandboxPathNotFoundError, SandboxResponseError, SandboxStreamError, SandboxTerminalStateError, + SandboxUploadSizeMismatchError, ) from vercel._internal.unstable.sandbox.models import ( CompletedProcess, @@ -306,6 +314,10 @@ async def get_snapshot(*, snapshot_id: str) -> Snapshot: __all__ = [ + "SandboxBinaryReader", + "SandboxBinaryWriter", + "SandboxTextReader", + "SandboxTextWriter", "Sandbox", "SandboxApiError", "SandboxCleanupError", @@ -318,9 +330,11 @@ async def get_snapshot(*, snapshot_id: str) -> Snapshot: "SandboxFilesystemBatch", "SandboxFilesystemCommandError", "SandboxFilesystemError", + "SandboxFilesystemTransferError", "SandboxFilesystemWriteError", "SandboxInvalidHandleError", "SandboxPathNotFoundError", + "SandboxUploadSizeMismatchError", "NetworkPolicy", "NetworkPolicyKeyValueMatcher", "NetworkPolicyMatcher", diff --git a/src/vercel/unstable/sandbox/sync.py b/src/vercel/unstable/sandbox/sync.py index 9f60678..a25aeee 100644 --- a/src/vercel/unstable/sandbox/sync.py +++ b/src/vercel/unstable/sandbox/sync.py @@ -9,12 +9,14 @@ SandboxError, SandboxFilesystemCommandError, SandboxFilesystemError, + SandboxFilesystemTransferError, SandboxFilesystemWriteError, SandboxInvalidHandleError, SandboxPathNotFoundError, SandboxResponseError, SandboxStreamError, SandboxTerminalStateError, + SandboxUploadSizeMismatchError, ) from vercel._internal.unstable.sandbox.models import ( CompletedProcess, @@ -47,6 +49,12 @@ from vercel._internal.unstable.sandbox.options import SandboxServiceOptions from vercel._internal.unstable.sandbox.service import SandboxService, get_sandbox_service from vercel._internal.unstable.sandbox.state import SnapshotRetentionState +from vercel._internal.unstable.sandbox.sync_filesystem_handle import ( + SyncSandboxBinaryReader, + SyncSandboxBinaryWriter, + SyncSandboxTextReader, + SyncSandboxTextWriter, +) from vercel._internal.unstable.sandbox.sync_runtime import ( SyncProcess, SyncSandbox, @@ -310,9 +318,11 @@ def get_snapshot(*, snapshot_id: str) -> SyncSnapshot: "SandboxError", "SandboxFilesystemCommandError", "SandboxFilesystemError", + "SandboxFilesystemTransferError", "SandboxFilesystemWriteError", "SandboxInvalidHandleError", "SandboxPathNotFoundError", + "SandboxUploadSizeMismatchError", "NetworkPolicy", "NetworkPolicyKeyValueMatcher", "NetworkPolicyMatcher", @@ -339,10 +349,14 @@ def get_snapshot(*, snapshot_id: str) -> SyncSnapshot: "SnapshotRetentionState", "SnapshotSource", "SyncSandbox", + "SyncSandboxBinaryReader", + "SyncSandboxBinaryWriter", "SyncProcess", "SyncSandboxFilesystem", "SyncSandboxFilesystemBatch", "SyncSandboxRuntimeSession", + "SyncSandboxTextReader", + "SyncSandboxTextWriter", "SyncSnapshot", "TagFilter", "TarballSource", diff --git a/tests/integration/test_http_transport_read_response.py b/tests/integration/test_http_transport_read_response.py index c68d24f..1a69ad7 100644 --- a/tests/integration/test_http_transport_read_response.py +++ b/tests/integration/test_http_transport_read_response.py @@ -22,18 +22,19 @@ async def __aiter__(self) -> AsyncIterator[bytes]: @pytest.mark.parametrize( - ("read_response", "status_code", "expected_consumed"), + ("read_response", "status_code", "expected_consumed", "expected_closed"), [ - (None, 200, False), - (ReadResponsePolicy.ALWAYS, 200, True), - (ReadResponsePolicy.NON_SUCCESS_ONLY, 200, False), - (ReadResponsePolicy.NON_SUCCESS_ONLY, 400, True), + (None, 200, False, False), + (ReadResponsePolicy.ALWAYS, 200, True, True), + (ReadResponsePolicy.NON_SUCCESS_ONLY, 200, False, False), + (ReadResponsePolicy.NON_SUCCESS_ONLY, 400, True, True), ], ) def test_sync_transport_read_response_policy( read_response: ReadResponsePolicy | None, status_code: int, expected_consumed: bool, + expected_closed: bool, ) -> None: client = httpx.Client( transport=httpx.MockTransport( @@ -51,6 +52,7 @@ def test_sync_transport_read_response_policy( ) ) assert response.is_stream_consumed is expected_consumed + assert response.is_closed is expected_closed if expected_consumed: assert response.content == PAYLOAD finally: @@ -58,18 +60,19 @@ def test_sync_transport_read_response_policy( @pytest.mark.parametrize( - ("read_response", "status_code", "expected_consumed"), + ("read_response", "status_code", "expected_consumed", "expected_closed"), [ - (None, 200, False), - (ReadResponsePolicy.ALWAYS, 200, True), - (ReadResponsePolicy.NON_SUCCESS_ONLY, 200, False), - (ReadResponsePolicy.NON_SUCCESS_ONLY, 400, True), + (None, 200, False, False), + (ReadResponsePolicy.ALWAYS, 200, True, True), + (ReadResponsePolicy.NON_SUCCESS_ONLY, 200, False, False), + (ReadResponsePolicy.NON_SUCCESS_ONLY, 400, True, True), ], ) async def test_async_transport_read_response_policy( read_response: ReadResponsePolicy | None, status_code: int, expected_consumed: bool, + expected_closed: bool, ) -> None: client = httpx.AsyncClient( transport=httpx.MockTransport( @@ -85,6 +88,7 @@ async def test_async_transport_read_response_policy( "GET", "https://example.com", stream=True, read_response=read_response ) assert response.is_stream_consumed is expected_consumed + assert response.is_closed is expected_closed if expected_consumed: assert response.content == PAYLOAD finally: diff --git a/tests/integration/test_http_transport_streaming.py b/tests/integration/test_http_transport_streaming.py new file mode 100644 index 0000000..c17bea5 --- /dev/null +++ b/tests/integration/test_http_transport_streaming.py @@ -0,0 +1,548 @@ +import threading +from typing import cast + +import anyio +import httpx +import pytest + +from vercel._internal.http import ( + AsyncTransport, + ReadResponsePolicy, + StreamingResponse, + SyncTransport, +) +from vercel._internal.iter_coroutine import iter_coroutine + + +class _SyncChunks(httpx.SyncByteStream): + def __init__(self, chunks: list[bytes]) -> None: + self.chunks = chunks + self.closed = False + + def __iter__(self): # type: ignore[no-untyped-def] + yield from self.chunks + + def close(self) -> None: + self.closed = True + + +class _AsyncChunks(httpx.AsyncByteStream): + def __init__(self, chunks: list[bytes]) -> None: + self.chunks = chunks + self.closed = False + + async def __aiter__(self): # type: ignore[no-untyped-def] + for chunk in self.chunks: + yield chunk + + async def aclose(self) -> None: + self.closed = True + + +class _FailingSyncChunks(httpx.SyncByteStream): + def __init__(self, error: BaseException) -> None: + self.error = error + self.closed = False + + def __iter__(self): # type: ignore[no-untyped-def] + yield b"first\n" + raise self.error + + def close(self) -> None: + self.closed = True + + +class _FailingAsyncChunks(httpx.AsyncByteStream): + def __init__(self, error: BaseException) -> None: + self.error = error + self.closed = False + + async def __aiter__(self): # type: ignore[no-untyped-def] + yield b"first\n" + raise self.error + + async def aclose(self) -> None: + self.closed = True + + +def _encoded_line_chunks() -> list[bytes]: + content = "café\r\nsecond\rthird\nlast".encode("utf-16-le") + return [content[offset : offset + 3] for offset in range(0, len(content), 3)] + + +@pytest.mark.parametrize( + ("policy", "status", "is_consumed"), + [ + (ReadResponsePolicy.NON_SUCCESS_ONLY, 400, True), + (ReadResponsePolicy.NEVER, 201, False), + ], +) +def test_sync_request_stream_finishes_under_one_iter_coroutine( + policy: ReadResponsePolicy, status: int, is_consumed: bool +) -> None: + received: list[bytes] = [] + + def handler(request: httpx.Request) -> httpx.Response: + received.extend(cast(httpx.SyncByteStream, request.stream)) + return httpx.Response(status, stream=_SyncChunks([b"response"])) + + transport = SyncTransport(httpx.Client(transport=httpx.MockTransport(handler))) + + async def operation() -> StreamingResponse: + async with transport.request_stream( + "POST", + "https://example.com/upload", + read_response=policy, + response_chunk_size=3, + ) as request: + await request.write(b"first") + await request.write(memoryview(b"second")) # type: ignore[arg-type] + response = await request.finish() + assert response.response.status_code == status + with pytest.raises(anyio.ClosedResourceError): + await request.write(b"after finish") + return response + + response = iter_coroutine(operation()) + assert isinstance(response, StreamingResponse) + assert b"".join(received) == b"firstsecond" + assert response.response.is_stream_consumed is is_consumed + + async def consume() -> list[bytes]: + return [chunk async for chunk in response] + + assert iter_coroutine(consume()) == [b"res", b"pon", b"se"] + assert response.response.is_closed + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("policy", "status", "is_consumed"), + [ + (ReadResponsePolicy.NON_SUCCESS_ONLY, 400, True), + (ReadResponsePolicy.NEVER, 201, False), + ], +) +async def test_async_request_stream_finishes( + policy: ReadResponsePolicy, status: int, is_consumed: bool +) -> None: + received: list[bytes] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + async for chunk in cast(httpx.AsyncByteStream, request.stream): + received.append(chunk) + return httpx.Response(status, stream=_AsyncChunks([b"response"])) + + transport = AsyncTransport(httpx.AsyncClient(transport=httpx.MockTransport(handler))) + async with transport.request_stream( + "POST", + "https://example.com/upload", + read_response=policy, + response_chunk_size=3, + ) as request: + await request.write(b"first") + await request.write(bytearray(b"second")) # type: ignore[arg-type] + response = await request.finish() + assert response.response.status_code == status + with pytest.raises(anyio.ClosedResourceError): + await request.write(b"after finish") + + assert isinstance(response, StreamingResponse) + assert b"".join(received) == b"firstsecond" + assert response.response.is_stream_consumed is is_consumed + assert [chunk async for chunk in response] == [b"res", b"pon", b"se"] + assert response.response.is_closed + + +@pytest.mark.parametrize( + ("chunks", "preconsume", "expected"), + [ + ([], 0, b""), + ([b"single"], 0, b"single"), + ([b"one", b"two", b"three"], 0, b"onetwothree"), + ([b"one", b"two", b"three"], 1, b"twothree"), + ([b"one", b"two"], 2, b""), + ], +) +def test_sync_response_read_consumes_remaining_body_and_closes( + chunks: list[bytes], preconsume: int, expected: bytes +) -> None: + body = _SyncChunks(chunks) + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(202, headers={"x-result": "ok"}, stream=body) + + transport = SyncTransport(httpx.Client(transport=httpx.MockTransport(handler))) + + async def operation() -> bytes: + stream = await transport.open_response_stream("GET", "https://example.com/result") + for _ in range(preconsume): + await anext(stream) + result = await stream.read() + assert stream.response.status_code == 202 + assert stream.response.headers["x-result"] == "ok" + return result + + assert iter_coroutine(operation()) == expected + assert body.closed + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("chunks", "preconsume", "expected"), + [ + ([], 0, b""), + ([b"single"], 0, b"single"), + ([b"one", b"two", b"three"], 0, b"onetwothree"), + ([b"one", b"two", b"three"], 1, b"twothree"), + ([b"one", b"two"], 2, b""), + ], +) +async def test_async_response_read_consumes_remaining_body_and_closes( + chunks: list[bytes], preconsume: int, expected: bytes +) -> None: + body = _AsyncChunks(chunks) + + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(202, headers={"x-result": "ok"}, stream=body) + + transport = AsyncTransport(httpx.AsyncClient(transport=httpx.MockTransport(handler))) + stream = await transport.open_response_stream("GET", "https://example.com/result") + for _ in range(preconsume): + await anext(stream) + assert await stream.read() == expected + assert stream.response.status_code == 202 + assert stream.response.headers["x-result"] == "ok" + assert body.closed + + +def test_sync_response_read_closes_on_failure() -> None: + error = RuntimeError("stream failed") + body = _FailingSyncChunks(error) + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, stream=body) + + transport = SyncTransport(httpx.Client(transport=httpx.MockTransport(handler))) + + async def operation() -> None: + stream = await transport.open_response_stream("GET", "https://example.com/result") + with pytest.raises(RuntimeError) as exc_info: + await stream.read() + assert exc_info.value is error + + iter_coroutine(operation()) + assert body.closed + + +@pytest.mark.anyio +async def test_async_response_read_closes_on_failure() -> None: + error = RuntimeError("stream failed") + body = _FailingAsyncChunks(error) + + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, stream=body) + + transport = AsyncTransport(httpx.AsyncClient(transport=httpx.MockTransport(handler))) + stream = await transport.open_response_stream("GET", "https://example.com/result") + with pytest.raises(RuntimeError) as exc_info: + await stream.read() + assert exc_info.value is error + assert body.closed + + +def test_sync_request_scope_implicitly_aborts() -> None: + body_closed = threading.Event() + + class BlockingClient(httpx.Client): + def send(self, request: httpx.Request, **kwargs): # type: ignore[no-untyped-def, override] + try: + list(cast(httpx.SyncByteStream, request.stream)) + finally: + body_closed.set() + return httpx.Response(204) + + transport = SyncTransport(BlockingClient()) + retained = None + + async def operation() -> None: + nonlocal retained + async with transport.request_stream("POST", "https://example.com/upload") as request: + retained = request + + iter_coroutine(operation()) + assert body_closed.is_set() + assert retained is not None + with pytest.raises(anyio.ClosedResourceError): + iter_coroutine(retained.write(b"after context")) + + +@pytest.mark.anyio +@pytest.mark.parametrize("raise_in_body", [False, True]) +async def test_async_request_scope_implicitly_aborts(raise_in_body: bool) -> None: + body_closed = anyio.Event() + body_error = RuntimeError("body failed") + retained = None + + class BlockingClient(httpx.AsyncClient): + async def send( + self, + request: httpx.Request, + **kwargs, # type: ignore[no-untyped-def, override] + ) -> httpx.Response: + try: + async for _ in cast(httpx.AsyncByteStream, request.stream): + pass + finally: + body_closed.set() + return httpx.Response(204) + + transport = AsyncTransport(BlockingClient()) + + async def operation() -> None: + nonlocal retained + async with transport.request_stream("POST", "https://example.com/upload") as request: + retained = request + await request.write(b"first") + if raise_in_body: + raise body_error + + if raise_in_body: + with pytest.raises(RuntimeError) as exc_info: + await operation() + assert exc_info.value is body_error + else: + await operation() + assert body_closed.is_set() + assert retained is not None + with pytest.raises(anyio.ClosedResourceError): + await retained.write(b"after context") + + +@pytest.mark.anyio +async def test_async_request_entry_cancellation_cleans_up_worker() -> None: + stopped = anyio.Event() + + class BlockingTransport(httpx.AsyncBaseTransport): + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + try: + await anyio.sleep_forever() + finally: + stopped.set() + raise AssertionError("unreachable") + + transport = AsyncTransport(httpx.AsyncClient(transport=BlockingTransport())) + entered = False + with anyio.move_on_after(0) as scope: + async with transport.request_stream("POST", "https://example.com/upload"): + entered = True + assert scope.cancel_called + assert not entered + assert stopped.is_set() + + +@pytest.mark.anyio +async def test_async_request_cancellation_during_blocked_write_aborts() -> None: + class BlockingTransport(httpx.AsyncBaseTransport): + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + await anyio.sleep_forever() + raise AssertionError("unreachable") + + transport = AsyncTransport(httpx.AsyncClient(transport=BlockingTransport())) + async with transport.request_stream("POST", "https://example.com/upload") as request: + await request.write(b"first") + with anyio.move_on_after(0.01) as scope: + await request.write(b"second") + assert scope.cancel_called + with pytest.raises(anyio.ClosedResourceError): + await request.write(b"after cancellation") + + +@pytest.mark.anyio +async def test_async_request_stream_has_one_chunk_backpressure() -> None: + consume = anyio.Event() + second_started = anyio.Event() + second_finished = anyio.Event() + + class BlockingClient(httpx.AsyncClient): + async def send( + self, + request: httpx.Request, + **kwargs, # type: ignore[no-untyped-def, override] + ) -> httpx.Response: + await consume.wait() + async for _ in cast(httpx.AsyncByteStream, request.stream): + pass + return httpx.Response(204) + + transport = AsyncTransport(BlockingClient()) + async with transport.request_stream("POST", "https://example.com/upload") as request: + await request.write(b"first") + + async def write_second() -> None: + second_started.set() + await request.write(b"second") + second_finished.set() + + async with anyio.create_task_group() as tasks: + tasks.start_soon(write_second) + await second_started.wait() + await anyio.lowlevel.checkpoint() + assert not second_finished.is_set() + consume.set() + response = await request.finish() + await response.aclose() + + +def test_sync_request_stream_has_one_chunk_backpressure() -> None: + consume = threading.Event() + second_started = threading.Event() + second_finished = threading.Event() + errors: list[BaseException] = [] + + class BlockingClient(httpx.Client): + def send(self, request: httpx.Request, **kwargs): # type: ignore[no-untyped-def, override] + consume.wait() + list(cast(httpx.SyncByteStream, request.stream)) + return httpx.Response(204) + + transport = SyncTransport(BlockingClient()) + + async def operation() -> None: + async with transport.request_stream("POST", "https://example.com/upload") as request: + await request.write(b"first") + second_started.set() + await request.write(b"second") + second_finished.set() + response = await request.finish() + await response.aclose() + + def run() -> None: + try: + iter_coroutine(operation()) + except BaseException as error: + errors.append(error) + + worker = threading.Thread(target=run) + worker.start() + assert second_started.wait(timeout=1) + assert not second_finished.is_set() + consume.set() + worker.join(timeout=1) + assert not worker.is_alive() + assert second_finished.is_set() + assert errors == [] + + +def test_sync_request_stream_preserves_worker_error_identity() -> None: + error = RuntimeError("worker failed") + + def handler(request: httpx.Request) -> httpx.Response: + raise error + + transport = SyncTransport(httpx.Client(transport=httpx.MockTransport(handler))) + + async def operation() -> None: + async with transport.request_stream("POST", "https://example.com/upload") as request: + with pytest.raises(RuntimeError) as exc_info: + await request.finish() + assert exc_info.value is error + await request.abort() + await request.abort() + + iter_coroutine(operation()) + + +@pytest.mark.anyio +async def test_async_request_stream_preserves_worker_error_identity() -> None: + error = RuntimeError("worker failed") + + async def handler(request: httpx.Request) -> httpx.Response: + raise error + + transport = AsyncTransport(httpx.AsyncClient(transport=httpx.MockTransport(handler))) + async with transport.request_stream("POST", "https://example.com/upload") as request: + with pytest.raises(RuntimeError) as exc_info: + await request.finish() + assert exc_info.value is error + await request.abort() + await request.abort() + with pytest.raises(anyio.ClosedResourceError): + await request.write(b"after abort") + + +def test_sync_response_line_stream_uses_httpx_decoding_and_closes() -> None: + body = _SyncChunks(_encoded_line_chunks()) + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + headers={"content-type": "text/plain; charset=utf-16-le"}, + stream=body, + ) + + transport = SyncTransport(httpx.Client(transport=httpx.MockTransport(handler))) + + async def operation() -> list[str]: + stream = await transport.open_response_stream("GET", "https://example.com/logs") + return [line async for line in stream.aiter_lines()] + + assert iter_coroutine(operation()) == ["café", "second", "third", "last"] + assert body.closed + + +@pytest.mark.anyio +async def test_async_response_line_stream_uses_httpx_decoding_and_closes() -> None: + body = _AsyncChunks(_encoded_line_chunks()) + + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + headers={"content-type": "text/plain; charset=utf-16-le"}, + stream=body, + ) + + transport = AsyncTransport(httpx.AsyncClient(transport=httpx.MockTransport(handler))) + stream = await transport.open_response_stream("GET", "https://example.com/logs") + assert [line async for line in stream.aiter_lines()] == [ + "café", + "second", + "third", + "last", + ] + assert body.closed + + +def test_sync_response_line_stream_closes_on_failure() -> None: + error = RuntimeError("stream failed") + body = _FailingSyncChunks(error) + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, stream=body) + + transport = SyncTransport(httpx.Client(transport=httpx.MockTransport(handler))) + + async def operation() -> None: + stream = await transport.open_response_stream("GET", "https://example.com/logs") + with pytest.raises(RuntimeError) as exc_info: + [line async for line in stream.aiter_lines()] + assert exc_info.value is error + + iter_coroutine(operation()) + assert body.closed + + +@pytest.mark.anyio +async def test_async_response_line_stream_closes_on_failure() -> None: + error = RuntimeError("stream failed") + body = _FailingAsyncChunks(error) + + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, stream=body) + + transport = AsyncTransport(httpx.AsyncClient(transport=httpx.MockTransport(handler))) + stream = await transport.open_response_stream("GET", "https://example.com/logs") + with pytest.raises(RuntimeError) as exc_info: + [line async for line in stream.aiter_lines()] + assert exc_info.value is error + assert body.closed diff --git a/tests/live/_unstable_scenarios.py b/tests/live/_unstable_scenarios.py index b58819a..9d74dc8 100644 --- a/tests/live/_unstable_scenarios.py +++ b/tests/live/_unstable_scenarios.py @@ -1,12 +1,15 @@ """Shared live scenarios for the experimental Sandbox public API.""" import asyncio +import hashlib import subprocess +import tempfile import time from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import timedelta +from pathlib import Path from typing import Any from vercel import unstable as vercel @@ -70,6 +73,14 @@ class ProcessFilesystemObservation: invalid_write_failed: bool +@dataclass(frozen=True, slots=True) +class StreamingTransferObservation: + digest_matches: bool + empty_matches: bool + explicit_mode: str + missing_download_failed: bool + + @dataclass(frozen=True, slots=True) class NetworkPolicyObservation: allow_all_created: bool @@ -143,6 +154,14 @@ async def read_bytes(self, box: Any, path: str) -> bytes: async def write_bytes(self, box: Any, path: str, data: bytes) -> None: raise NotImplementedError + async def write_path( + self, box: Any, remote: str, local: Path, *, mode: int | None = None + ) -> None: + raise NotImplementedError + + async def read_path(self, box: Any, remote: str, local: Path) -> int: + raise NotImplementedError + async def exists(self, box: Any, path: str) -> bool: raise NotImplementedError @@ -280,7 +299,8 @@ async def read_text(self, box: Any, path: str) -> str: return await box.fs.read_text(path) async def write_text(self, box: Any, path: str, text: str) -> None: - await box.fs.write_text(path, text) + async with box.fs.open(path, "w") as target: + await target.write(text) async def read_bytes(self, box: Any, path: str) -> bytes: return await box.fs.read_bytes(path) @@ -288,6 +308,31 @@ async def read_bytes(self, box: Any, path: str) -> bytes: async def write_bytes(self, box: Any, path: str, data: bytes) -> None: await box.fs.write_bytes(path, data) + async def write_path( + self, box: Any, remote: str, local: Path, *, mode: int | None = None + ) -> None: + import anyio + + async with ( + await anyio.open_file(local, "rb") as source, + box.fs.open(remote, "wb", permissions=mode) as target, + ): + while chunk := await source.read(64 * 1024): + await target.write(chunk) + + async def read_path(self, box: Any, remote: str, local: Path) -> int: + import anyio + + copied = 0 + async with ( + box.fs.open(remote, "rb") as source, + await anyio.open_file(local, "wb") as target, + ): + while chunk := await source.read(64 * 1024): + await target.write(chunk) + copied += len(chunk) + return copied + async def exists(self, box: Any, path: str) -> bool: return await box.fs.exists(path) @@ -454,7 +499,8 @@ async def read_text(self, box: Any, path: str) -> str: return box.fs.read_text(path) async def write_text(self, box: Any, path: str, text: str) -> None: - box.fs.write_text(path, text) + with box.fs.open(path, "w") as target: + target.write(text) async def read_bytes(self, box: Any, path: str) -> bytes: return box.fs.read_bytes(path) @@ -462,6 +508,24 @@ async def read_bytes(self, box: Any, path: str) -> bytes: async def write_bytes(self, box: Any, path: str, data: bytes) -> None: box.fs.write_bytes(path, data) + async def write_path( + self, box: Any, remote: str, local: Path, *, mode: int | None = None + ) -> None: + with ( + local.open("rb") as source, + box.fs.open(remote, "wb", permissions=mode) as target, + ): + while chunk := source.read(64 * 1024): + target.write(chunk) + + async def read_path(self, box: Any, remote: str, local: Path) -> int: + copied = 0 + with box.fs.open(remote, "rb") as source, local.open("wb") as target: + while chunk := source.read(64 * 1024): + target.write(chunk) + copied += len(chunk) + return copied + async def exists(self, box: Any, path: str) -> bool: return box.fs.exists(path) @@ -642,6 +706,65 @@ async def process_filesystem_flow( ) +async def streaming_transfer_flow( + driver: _ScenarioDriver, name: str +) -> StreamingTransferObservation: + payload = bytes(range(256)) * 1025 + expected_digest = hashlib.sha256(payload).digest() + remote_paths = ("large.bin", "empty.bin") + + with tempfile.TemporaryDirectory() as directory: + source = Path(directory) / "source.bin" + empty_source = Path(directory) / "empty.bin" + target = Path(directory) / "target.bin" + empty_target = Path(directory) / "empty-target.bin" + missing_target = Path(directory) / "missing.bin" + source.write_bytes(payload) + empty_source.write_bytes(b"") + + async with driver.session(): + async with driver.ephemeral_sandbox(name) as box: + try: + await driver.write_path(box, remote_paths[0], source, mode=0o600) + await driver.write_path(box, remote_paths[1], empty_source) + copied = await driver.read_path(box, remote_paths[0], target) + empty_copied = await driver.read_path(box, remote_paths[1], empty_target) + + command = await driver.create_process( + box, + "python", + [ + "-c", + "import os; print(oct(os.stat('large.bin').st_mode & 0o777))", + ], + ) + stdout, _ = await driver.read_process_streams(command) + await driver.wait(command) + + try: + await driver.read_path(box, "missing.bin", missing_target) + except SandboxPathNotFoundError: + missing_download_failed = True + else: + missing_download_failed = False + finally: + for remote_path in remote_paths: + try: + await driver.remove(box, remote_path) + except SandboxPathNotFoundError: + pass + + return StreamingTransferObservation( + digest_matches=( + copied == len(payload) + and hashlib.sha256(target.read_bytes()).digest() == expected_digest + ), + empty_matches=empty_copied == 0 and empty_target.read_bytes() == b"", + explicit_mode=stdout.strip(), + missing_download_failed=missing_download_failed, + ) + + async def network_policy_flow(driver: _ScenarioDriver, name: str) -> NetworkPolicyObservation: box = None allow_all_created = False diff --git a/tests/live/test_unstable_sandbox_live.py b/tests/live/test_unstable_sandbox_live.py index 19a6a28..b70ff45 100644 --- a/tests/live/test_unstable_sandbox_live.py +++ b/tests/live/test_unstable_sandbox_live.py @@ -9,11 +9,13 @@ NetworkPolicyObservation, PersistentObservation, ProcessFilesystemObservation, + StreamingTransferObservation, SyncDriver, WorkspaceObservation, network_policy_flow, persistent_snapshot_flow, process_filesystem_flow, + streaming_transfer_flow, workspace_command_flow, ) from .conftest import requires_sandbox_credentials @@ -97,6 +99,22 @@ async def test_process_filesystem_flow_has_sync_async_semantic_parity() -> None: _assert_process_filesystem(sync_result) +@requires_sandbox_credentials +@pytest.mark.live +@pytest.mark.asyncio +async def test_streaming_transfer_flow() -> None: + expected = StreamingTransferObservation( + digest_matches=True, + empty_matches=True, + explicit_mode="0o600", + missing_download_failed=True, + ) + async_result = await streaming_transfer_flow(AsyncDriver(), _name("transfer", "async")) + sync_result = await streaming_transfer_flow(SyncDriver(), _name("transfer", "sync")) + assert async_result == expected + assert sync_result == expected + + @requires_sandbox_credentials @pytest.mark.live @pytest.mark.asyncio diff --git a/tests/unit/test_byte_stream_runtime.py b/tests/unit/test_byte_stream_runtime.py new file mode 100644 index 0000000..a3f469e --- /dev/null +++ b/tests/unit/test_byte_stream_runtime.py @@ -0,0 +1,169 @@ +import io +import threading + +import anyio +import pytest + +from vercel._internal.byte_stream import ( + AsyncByteStreamRuntime, + StagingFileRuntime, + SyncByteStreamRuntime, +) +from vercel._internal.iter_coroutine import iter_coroutine + + +class _SyncReader: + def __init__(self, data: bytes) -> None: + self._source = io.BytesIO(data) + + def read(self, size: int = -1, /) -> bytes: + return self._source.read(size) + + +class _AsyncReader: + def __init__(self, data: bytes) -> None: + self._source = io.BytesIO(data) + + async def read(self, size: int = -1, /) -> bytes: + return self._source.read(size) + + +async def _assert_bytes_like_readers(runtime: SyncByteStreamRuntime) -> None: + for value in (b"bytes", bytearray(b"bytearray"), memoryview(b"memoryview")): + source = runtime.reader(value) + assert await source.read(4) == bytes(value)[:4] + assert await source.read() == bytes(value)[4:] + + +def test_sync_runtime_reader_operations_never_suspend() -> None: + runtime = SyncByteStreamRuntime() + + async def operation() -> None: + await _assert_bytes_like_readers(runtime) + sync_source = runtime.reader(_SyncReader(b"sync")) + assert await sync_source.read(2) == b"sy" + assert await sync_source.read() == b"nc" + + iter_coroutine(operation()) + + +def test_sync_runtime_rejects_async_reader() -> None: + with pytest.raises(TypeError, match="does not support async readers"): + SyncByteStreamRuntime().reader(_AsyncReader(b"async")) # type: ignore[arg-type] + + +@pytest.mark.anyio +async def test_async_runtime_adapts_async_readers() -> None: + runtime = AsyncByteStreamRuntime() + async_source = runtime.reader(_AsyncReader(b"async")) + assert await async_source.read(2) == b"as" + assert await async_source.read() == b"ync" + + +@pytest.mark.anyio +async def test_async_runtime_runs_sync_reader_on_worker_thread() -> None: + caller_thread = threading.get_ident() + reader_thread: int | None = None + + class Reader(_SyncReader): + def read(self, size: int = -1, /) -> bytes: + nonlocal reader_thread + reader_thread = threading.get_ident() + return super().read(size) + + runtime = AsyncByteStreamRuntime() + sync_source = runtime.reader(Reader(b"sync")) + assert await sync_source.read(2) == b"sy" + assert await sync_source.read() == b"nc" + assert reader_thread is not None + assert reader_thread != caller_thread + + +def test_sync_runtime_rejects_invalid_and_non_bytes_readers() -> None: + class MissingReader: + pass + + class NonCallableReader: + read = b"not callable" + + class BadSyncReader: + def read(self, size: int = -1, /) -> str: + return "not bytes" + + runtime = SyncByteStreamRuntime() + for missing in (MissingReader(), NonCallableReader()): + with pytest.raises(TypeError, match="callable read method"): + runtime.reader(missing) # type: ignore[arg-type] + + source = runtime.reader(BadSyncReader()) # type: ignore[arg-type] + + async def operation() -> None: + with pytest.raises(TypeError, match="returned str, expected bytes"): + await source.read() + + iter_coroutine(operation()) + + +@pytest.mark.anyio +async def test_async_runtime_rejects_invalid_and_non_bytes_readers() -> None: + class MissingReader: + pass + + class NonCallableReader: + read = b"not callable" + + class BadSyncReader: + def read(self, size: int = -1, /) -> str: + return "not bytes" + + class BadAsyncReader: + async def read(self, size: int = -1, /) -> str: + return "not bytes" + + runtime = AsyncByteStreamRuntime() + for missing in (MissingReader(), NonCallableReader()): + with pytest.raises(TypeError, match="callable read method"): + runtime.reader(missing) # type: ignore[arg-type] + + for bad in (BadSyncReader(), BadAsyncReader()): + source = runtime.reader(bad) # type: ignore[arg-type] + with pytest.raises(TypeError, match="returned str, expected bytes"): + await source.read() + + +def test_sync_temporary_file_context_never_suspends_and_owns_cleanup() -> None: + runtime: StagingFileRuntime = SyncByteStreamRuntime() + + async def operation() -> None: + async with runtime.temporary_file() as temporary: + await temporary.write(b"temporary") + + with pytest.raises(anyio.ClosedResourceError): + await temporary.read() + + with pytest.raises(ValueError, match="stop"): + async with runtime.temporary_file() as failed: + raise ValueError("stop") + + with pytest.raises(anyio.ClosedResourceError): + await failed.read() + + iter_coroutine(operation()) + + +@pytest.mark.anyio +async def test_async_temporary_file_context_owns_cleanup() -> None: + runtime: StagingFileRuntime = AsyncByteStreamRuntime() + + async with runtime.temporary_file() as temporary: + await temporary.write(b"temporary") + + with pytest.raises((anyio.ClosedResourceError, ValueError)): + await temporary.read() + + with pytest.raises(ValueError, match="stop"): + async with runtime.temporary_file() as failed: + raise ValueError("stop") + + with pytest.raises((anyio.ClosedResourceError, ValueError)): + await failed.read() diff --git a/tests/unstable/test_sandbox_api_client.py b/tests/unstable/test_sandbox_api_client.py index b28eb02..a41d339 100644 --- a/tests/unstable/test_sandbox_api_client.py +++ b/tests/unstable/test_sandbox_api_client.py @@ -4,9 +4,15 @@ import pytest from httpx._types import HeaderTypes, QueryParamTypes -from vercel._internal.http import BaseTransport, ReadResponsePolicy, RequestBody -from vercel._internal.unstable.sandbox.api_client import SandboxApiClient -from vercel._internal.unstable.sandbox.errors import SandboxResponseError +from vercel._internal.http import ( + BaseTransport, + ReadResponsePolicy, + RequestBody, + StreamingRequest, + StreamingResponse, +) +from vercel._internal.unstable.sandbox.api_client import SandboxApiClient, _WriteFilesUpload +from vercel._internal.unstable.sandbox.errors import SandboxApiError, SandboxResponseError from vercel._internal.unstable.sandbox.options import SandboxCredentials from vercel._internal.url import format_url_path @@ -37,7 +43,37 @@ async def send( ) -async def test_invalid_json_response_raises_response_error(mock_env_clear: None) -> None: +class _CompletedResponse(StreamingResponse): + def __init__(self, response: httpx.Response) -> None: + self.response = response + self.closed = False + + async def __anext__(self) -> bytes: + raise StopAsyncIteration + + async def aiter_lines(self): # type: ignore[no-untyped-def] + if False: + yield "" + + async def aclose(self) -> None: + self.closed = True + + +class _CompletedRequest(StreamingRequest): + def __init__(self, response: _CompletedResponse) -> None: + self.response = response + + async def write(self, data: bytes) -> None: + raise NotImplementedError + + async def finish(self) -> StreamingResponse: + return self.response + + async def abort(self) -> None: + raise NotImplementedError + + +def _sandbox_client(transport: BaseTransport) -> SandboxApiClient: async def credentials_factory() -> SandboxCredentials: return SandboxCredentials( token="token", @@ -45,18 +81,42 @@ async def credentials_factory() -> SandboxCredentials: project_id="prj_123", ) - transport = InvalidJsonTransport() - client = SandboxApiClient( + return SandboxApiClient( base_url="https://sandbox.test", credentials_factory=credentials_factory, transport=transport, + file_transfer_timeout=timedelta(minutes=5), ) + +async def test_invalid_json_response_raises_response_error(mock_env_clear: None) -> None: + transport = InvalidJsonTransport() + client = _sandbox_client(transport) + with pytest.raises(SandboxResponseError): await client.get_sandbox(name="preview") assert transport.paths == ["https://sandbox.test/v2/sandboxes/preview"] +@pytest.mark.parametrize("status", [204, 400]) +async def test_write_files_upload_finish_closes_stream(status: int) -> None: + raw_response = httpx.Response( + status, + json={"error": {"message": "upload failed"}}, + request=httpx.Request("POST", "https://sandbox.test/upload"), + ) + stream = _CompletedResponse(raw_response) + upload = _WriteFilesUpload(_CompletedRequest(stream)) + + if status < 400: + await upload.finish() + else: + with pytest.raises(SandboxApiError): + await upload.finish() + + assert stream.closed + + def test_format_url_path_quotes_placeholder_values() -> None: assert format_url_path( "v2/sandboxes/{name}/{command_id}", diff --git a/tests/unstable/test_sandbox_filesystem.py b/tests/unstable/test_sandbox_filesystem.py index d8183f4..76eb12a 100644 --- a/tests/unstable/test_sandbox_filesystem.py +++ b/tests/unstable/test_sandbox_filesystem.py @@ -1,12 +1,14 @@ import io import json import tarfile +from collections.abc import AsyncIterator, Iterator from pathlib import PurePosixPath from typing import cast import httpx import pytest import respx +from hypothesis import HealthCheck, given, settings, strategies as st from vercel import unstable as vercel from vercel._internal.unstable.sandbox.options import SandboxCredentials @@ -18,10 +20,46 @@ SandboxFilesystemWriteError, SandboxPathNotFoundError, SandboxServiceOptions, + SandboxUploadSizeMismatchError, sync as sandbox_sync, ) +async def _read_as_chunks(source: bytes, chunk_size: int) -> AsyncIterator[bytes]: + offset = 0 + while offset < len(source): + chunk = source[offset : offset + chunk_size] + yield chunk + offset += chunk_size + + +class _SyncByteReader: + def __init__(self, data: bytes) -> None: + self._data = data + self._offset = 0 + + def read(self, n: int = -1) -> bytes: + if n < 0: + chunk = self._data[self._offset :] + self._offset = len(self._data) + else: + chunk = self._data[self._offset : self._offset + n] + self._offset += len(chunk) + return chunk + + +class _SyncByteWriter: + def __init__(self) -> None: + self.written: list[bytes] = [] + self.closed = False + + def write(self, data: bytes) -> None: + self.written.append(data) + + def close(self) -> None: + self.closed = True + + def _sandbox_response(session_id: str = "sbx_1") -> dict[str, object]: return { "sandbox": { @@ -83,7 +121,37 @@ def _tar_entries(content: bytes) -> dict[str, tuple[bytes, int]]: extracted = archive.extractfile(member) assert extracted is not None entries[member.name] = (extracted.read(), member.mode) - return entries + return entries + + +@pytest.mark.parametrize( + ("mode", "options", "error_type"), + [ + ("invalid", {}, ValueError), + ("rb", {"encoding": "utf-8"}, ValueError), + ("r", {"size": 1}, ValueError), + ("rb", {"permissions": 0o600}, ValueError), + ("wb", {"size": -1}, ValueError), + ("wb", {"size": True}, TypeError), + ("wb", {"permissions": 0o1000}, ValueError), + ("wb", {"permissions": True}, TypeError), + ], +) +@respx.mock +async def test_filesystem_open_rejects_invalid_options( + mock_env_clear: None, + mode: str, + options: dict[str, object], + error_type: type[Exception], +) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + + async with vercel.session(service_options=_session_options()): + box = await sandbox.create_sandbox(name="preview", runtime="python3.13") + with pytest.raises(error_type): + box.fs.open("data.bin", mode, **options) # type: ignore[call-overload] @respx.mock @@ -138,6 +206,56 @@ async def test_async_filesystem_native_operations_and_write_composition( } +@respx.mock +async def test_async_unknown_size_writer_publishes_temporary_spool( + mock_env_clear: None, +) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + write = respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/write").mock( + return_value=httpx.Response(204) + ) + + async with vercel.session(service_options=_session_options()): + box = await sandbox.create_sandbox(name="preview", runtime="python3.13") + async with box.fs.open("spooled.bin", "wb") as writer: + await writer.write(b"spooled") + await writer.write(b" data") + assert write.call_count == 0 + + assert _tar_entries(write.calls[0].request.content) == { + "vercel/sandbox/spooled.bin": (b"spooled data", 0o644) + } + + +@pytest.mark.parametrize("content", [b"", b"abc"]) +@respx.mock +async def test_async_binary_writer_rejects_incomplete_declared_size( + mock_env_clear: None, content: bytes +) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/write").mock( + return_value=httpx.Response(204) + ) + + async with vercel.session(service_options=_session_options()): + box = await sandbox.create_sandbox(name="preview", runtime="python3.13") + with pytest.raises(SandboxUploadSizeMismatchError) as exc_info: + async with box.fs.open("data.bin", "wb", size=4) as writer: + await writer.write(content) + + error = exc_info.value + assert (error.path, error.declared, error.consumed, error.early_end) == ( + "data.bin", + 4, + len(content), + True, + ) + + @respx.mock async def test_filesystem_write_wraps_api_error(mock_env_clear: None) -> None: respx.post("https://sandbox.test/v2/sandboxes").mock( @@ -178,6 +296,18 @@ async def test_filesystem_target_binding_tracks_sandbox_but_not_runtime_session( second = respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_2/fs/mkdir").mock( return_value=httpx.Response(204) ) + first_write = respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/write").mock( + return_value=httpx.Response(204) + ) + second_write = respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_2/fs/write").mock( + return_value=httpx.Response(204) + ) + first_read = respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/read").mock( + return_value=httpx.Response(200, content=b"original") + ) + second_read = respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_2/fs/read").mock( + return_value=httpx.Response(200, content=b"current") + ) async with vercel.session(service_options=_session_options()): box = await sandbox.create_sandbox(name="preview", runtime="python3.13") @@ -187,9 +317,21 @@ async def test_filesystem_target_binding_tracks_sandbox_but_not_runtime_session( await box.update(current_snapshot_id="snap_1") await retained_box_fs.mkdir("current") await retained_session_fs.mkdir("original") + async with retained_box_fs.open("current.bin", "wb", size=7) as current_writer: + await current_writer.write(b"current") + async with retained_session_fs.open("original.bin", "wb", size=8) as original_writer: + await original_writer.write(b"original") + async with retained_box_fs.open("current.bin", "rb") as current_reader: + assert await current_reader.read() == b"current" + async with retained_session_fs.open("original.bin", "rb") as original_reader: + assert await original_reader.read() == b"original" assert second.call_count == 1 assert first.call_count == 1 + assert second_write.call_count == 1 + assert first_write.call_count == 1 + assert second_read.call_count == 1 + assert first_read.call_count == 1 @respx.mock @@ -347,6 +489,54 @@ def test_sync_filesystem_capability_uses_sync_boundary(mock_env_clear: None) -> assert not hasattr(box.current_session, method) +@respx.mock +def test_sync_unknown_size_writer_publishes_temporary_spool(mock_env_clear: None) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + write = respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/write").mock( + return_value=httpx.Response(204) + ) + + with vercel.session(service_options=_session_options()): + box = sandbox_sync.create_sandbox(name="preview", runtime="python3.13") + with box.fs.open("spooled.bin", "wb") as writer: + writer.write(b"spooled") + writer.write(b" data") + assert write.call_count == 0 + + assert _tar_entries(write.calls[0].request.content) == { + "vercel/sandbox/spooled.bin": (b"spooled data", 0o644) + } + + +@pytest.mark.parametrize("content", [b"", b"abc"]) +@respx.mock +def test_sync_binary_writer_rejects_incomplete_declared_size( + mock_env_clear: None, content: bytes +) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/write").mock( + return_value=httpx.Response(204) + ) + + with vercel.session(service_options=_session_options()): + box = sandbox_sync.create_sandbox(name="preview", runtime="python3.13") + with pytest.raises(SandboxUploadSizeMismatchError) as exc_info: + with box.fs.open("data.bin", "wb", size=4) as writer: + writer.write(content) + + error = exc_info.value + assert (error.path, error.declared, error.consumed, error.early_end) == ( + "data.bin", + 4, + len(content), + True, + ) + + @respx.mock def test_sync_filesystem_batch_stages_one_request(mock_env_clear: None) -> None: respx.post("https://sandbox.test/v2/sandboxes").mock( @@ -371,3 +561,213 @@ def test_sync_filesystem_batch_stages_one_request(mock_env_clear: None) -> None: "tmp/data.bin": (b"\x01", 0o600), "tmp/message.txt": (b"hello", 0o644), } + + +class _TrackedAsyncStream(httpx.AsyncByteStream): + def __init__(self, chunks: list[bytes], *, failure: BaseException | None = None) -> None: + self._chunks = chunks + self._failure = failure + self.aclose_called = False + + async def __aiter__(self) -> "AsyncIterator[bytes]": + for chunk in self._chunks: + yield chunk + if self._failure is not None: + raise self._failure + + async def aclose(self) -> None: + self.aclose_called = True + + +class _TrackedSyncStream(httpx.SyncByteStream): + def __init__(self, chunks: list[bytes], *, failure: BaseException | None = None) -> None: + self._chunks = chunks + self._failure = failure + self.close_called = False + + def __iter__(self) -> Iterator[bytes]: + yield from self._chunks + if self._failure is not None: + raise self._failure + + def close(self) -> None: + self.close_called = True + + +@given( + prefix=st.text( + alphabet=st.characters(min_codepoint=0, max_codepoint=0xD7FF, exclude_characters="\r\n") + ), + suffix=st.text( + alphabet=st.characters(min_codepoint=0, max_codepoint=0xD7FF, exclude_characters="\r\n") + ), + chunk_size=st.integers(min_value=1, max_value=8), +) +@settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) +async def test_text_reader_preserves_crlf_split_across_chunks( + mock_env_clear: None, + prefix: str, + suffix: str, + chunk_size: int, +) -> None: + prefix_bytes = prefix.encode() + suffix_bytes = suffix.encode() + chunks = [ + *( + prefix_bytes[offset : offset + chunk_size] + for offset in range(0, len(prefix_bytes), chunk_size) + ), + b"\r", + b"\n", + *( + suffix_bytes[offset : offset + chunk_size] + for offset in range(0, len(suffix_bytes), chunk_size) + ), + ] + + with respx.mock: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/read").mock( + return_value=httpx.Response(200, stream=_TrackedAsyncStream(chunks)) + ) + + async with vercel.session(service_options=_session_options()): + box = await sandbox.create_sandbox(name="preview", runtime="python3.13") + async with box.fs.open("data.txt", "r", newline="") as reader: + assert await reader.readline() == f"{prefix}\r\n" + assert await reader.read() == suffix + + +@respx.mock +async def test_read_bytes_response_closed_after_streaming_read(mock_env_clear: None) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + stream = _TrackedAsyncStream([b"bytes"]) + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/read").mock( + return_value=httpx.Response(200, stream=stream) + ) + + async with vercel.session(service_options=_session_options()): + box = await sandbox.create_sandbox(name="preview", runtime="python3.13") + result = await box.fs.read_bytes(PurePosixPath("data.bin")) + assert result == b"bytes" + + assert stream.aclose_called + + +@respx.mock +def test_sync_read_bytes_uses_and_closes_unread_stream(mock_env_clear: None) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + stream = _TrackedSyncStream([b"abc", b"", b"def"]) + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/read").mock( + return_value=httpx.Response(200, stream=stream) + ) + + with vercel.session(service_options=_session_options()): + box = sandbox_sync.create_sandbox(name="preview", runtime="python3.13") + assert box.fs.read_bytes("data.bin") == b"abcdef" + + assert stream.close_called + + +@respx.mock +def test_sync_read_bytes_closes_response_after_stream_failure(mock_env_clear: None) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + failure = RuntimeError("stream failed") + stream = _TrackedSyncStream([b"partial"], failure=failure) + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/read").mock( + return_value=httpx.Response(200, stream=stream) + ) + + with vercel.session(service_options=_session_options()): + box = sandbox_sync.create_sandbox(name="preview", runtime="python3.13") + with pytest.raises(RuntimeError) as exc_info: + box.fs.read_bytes("data.bin") + assert exc_info.value is failure + + assert stream.close_called + + +@respx.mock +async def test_read_bytes_multiple_chunks(mock_env_clear: None) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/read").mock( + return_value=httpx.Response(200, stream=_TrackedAsyncStream([b"abc", b"def", b"ghi"])) + ) + + async with vercel.session(service_options=_session_options()): + box = await sandbox.create_sandbox(name="preview", runtime="python3.13") + result = await box.fs.read_bytes(PurePosixPath("data.bin")) + assert result == b"abcdefghi" + + +@respx.mock +async def test_read_bytes_empty_file(mock_env_clear: None) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/read").mock( + return_value=httpx.Response(200, stream=_TrackedAsyncStream([])) + ) + + async with vercel.session(service_options=_session_options()): + box = await sandbox.create_sandbox(name="preview", runtime="python3.13") + result = await box.fs.read_bytes(PurePosixPath("empty.txt")) + assert result == b"" + + +@respx.mock +async def test_read_bytes_missing_path(mock_env_clear: None) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/read").mock( + return_value=httpx.Response( + 404, json={"error": {"code": "not_found", "message": "missing"}} + ) + ) + + async with vercel.session(service_options=_session_options()): + box = await sandbox.create_sandbox(name="preview", runtime="python3.13") + with pytest.raises(SandboxPathNotFoundError) as exc_info: + await box.fs.read_bytes(PurePosixPath("missing.txt")) + assert exc_info.value.path == "missing.txt" + assert exc_info.value.operation == "read_bytes" + assert exc_info.value.cause.code == "not_found" + + +@respx.mock +async def test_read_bytes_and_read_text_still_work(mock_env_clear: None) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + response_count = 0 + + def stream_response(request: httpx.Request) -> httpx.Response: + nonlocal response_count + response_count += 1 + return httpx.Response(200, stream=_TrackedAsyncStream([b"content"])) + + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/fs/read").mock( + side_effect=stream_response + ) + + async with vercel.session(service_options=_session_options()): + box = await sandbox.create_sandbox(name="preview", runtime="python3.13") + + raw = await box.fs.read_bytes(PurePosixPath("data.bin")) + assert raw == b"content" + + text = await box.fs.read_text("message.txt") + assert text == "content" + + assert response_count == 2 diff --git a/tests/unstable/test_sandbox_filesystem_write.py b/tests/unstable/test_sandbox_filesystem_write.py new file mode 100644 index 0000000..33e02ee --- /dev/null +++ b/tests/unstable/test_sandbox_filesystem_write.py @@ -0,0 +1,89 @@ +import pytest +from hypothesis import given, settings, strategies as st + +from vercel._internal.unstable.sandbox.filesystem_write import _ExactSizeWriteTarget +from vercel.unstable.sandbox import SandboxUploadSizeMismatchError + + +class _RecordingUpload: + def __init__(self) -> None: + self.writes: list[bytes] = [] + self.finishes = 0 + self.write_error: BaseException | None = None + + async def write(self, data: bytes) -> None: + if self.write_error is not None: + raise self.write_error + self.writes.append(data) + + async def flush(self) -> None: + pass + + async def finish(self) -> None: + self.finishes += 1 + + async def abort(self) -> None: + pass + + +@pytest.mark.asyncio +@settings(deadline=None) +@given( + declared=st.integers(min_value=0, max_value=64), + writes=st.lists( + st.tuples(st.binary(max_size=16), st.booleans()), + max_size=8, + ), +) +async def test_exact_size_target_accounting( + declared: int, + writes: list[tuple[bytes, bool]], +) -> None: + upload = _RecordingUpload() + target = _ExactSizeWriteTarget(upload, name="data.bin", size=declared) + forwarded: list[bytes] = [] + consumed = 0 + + for chunk, inject_failure in writes: + attempted = consumed + len(chunk) + if attempted > declared: + with pytest.raises(SandboxUploadSizeMismatchError) as overflow_info: + await target.write(chunk) + error = overflow_info.value + assert (error.path, error.declared, error.consumed, error.early_end) == ( + "data.bin", + declared, + attempted, + False, + ) + assert upload.writes == forwarded + break + + if inject_failure: + write_error = RuntimeError("write failed") + upload.write_error = write_error + with pytest.raises(RuntimeError) as write_info: + await target.write(chunk) + assert write_info.value is write_error + upload.write_error = None + assert upload.writes == forwarded + + await target.write(chunk) + forwarded.append(chunk) + consumed = attempted + + assert upload.writes == forwarded + if consumed == declared: + await target.finish() + assert upload.finishes == 1 + else: + with pytest.raises(SandboxUploadSizeMismatchError) as underflow_info: + await target.finish() + error = underflow_info.value + assert (error.path, error.declared, error.consumed, error.early_end) == ( + "data.bin", + declared, + consumed, + True, + ) + assert upload.finishes == 0 diff --git a/tests/unstable/test_sandbox_process.py b/tests/unstable/test_sandbox_process.py index b60baa7..dabf511 100644 --- a/tests/unstable/test_sandbox_process.py +++ b/tests/unstable/test_sandbox_process.py @@ -102,24 +102,25 @@ def flush(self) -> None: class _TrackingAsyncStream(httpx.AsyncByteStream): - def __init__(self, content: bytes) -> None: - self.content = content + def __init__(self, content: bytes | list[bytes]) -> None: + self.chunks = content if isinstance(content, list) else [content] self.closed = False async def __aiter__(self) -> AsyncIterator[bytes]: - yield self.content + for chunk in self.chunks: + yield chunk async def aclose(self) -> None: self.closed = True class _TrackingSyncStream(httpx.SyncByteStream): - def __init__(self, content: bytes) -> None: - self.content = content + def __init__(self, content: bytes | list[bytes]) -> None: + self.chunks = content if isinstance(content, list) else [content] self.closed = False def __iter__(self) -> Iterator[bytes]: - yield self.content + yield from self.chunks def close(self) -> None: self.closed = True @@ -134,6 +135,11 @@ def _completed_body() -> bytes: return "".join(json.dumps(record) + "\n" for record in records).encode() +def _chunked_ndjson(*records: object) -> list[bytes]: + content = "\r\n\r\n".join(json.dumps(record, ensure_ascii=False) for record in records).encode() + return [content[offset : offset + 1] for offset in range(len(content))] + + def _session_options() -> list[SandboxServiceOptions]: async def credentials_factory() -> SandboxCredentials: return SandboxCredentials(token="token", team_id="team_1", project_id="prj_1") @@ -553,6 +559,118 @@ def test_sync_run_process_routes_and_captures( assert run.call_count == 2 +@respx.mock +async def test_async_run_process_reads_chunked_ndjson(mock_env_clear: None) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + stream = _TrackingAsyncStream( + _chunked_ndjson( + _process_response(), + {"stream": "stdout", "data": "café\n"}, + {"stream": "stderr", "data": "雪\n"}, + _process_response(0), + ) + ) + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/cmd").mock( + return_value=httpx.Response(200, stream=stream) + ) + + async with vercel.session(service_options=_session_options()): + box = await sandbox.create_sandbox(name="preview", runtime="python3.13") + result = await box.run_process("python", capture_output=True) + + assert result.stdout == "café\n" + assert result.stderr == "雪\n" + assert stream.closed + + +@respx.mock +def test_sync_run_process_reads_chunked_ndjson(mock_env_clear: None) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + stream = _TrackingSyncStream( + _chunked_ndjson( + _process_response(), + {"stream": "stdout", "data": "café\n"}, + {"stream": "stderr", "data": "雪\n"}, + _process_response(0), + ) + ) + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/cmd").mock( + return_value=httpx.Response(200, stream=stream) + ) + + with vercel.session(service_options=_session_options()): + box = sandbox_sync.create_sandbox(name="preview", runtime="python3.13") + result = box.run_process("python", capture_output=True) + + assert result.stdout == "café\n" + assert result.stderr == "雪\n" + assert stream.closed + + +@respx.mock +async def test_async_process_readers_read_chunked_ndjson(mock_env_clear: None) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/cmd").mock( + return_value=httpx.Response(200, json=_process_response()) + ) + respx.get("https://sandbox.test/v2/sandboxes/sessions/sbx_1/cmd/cmd_1").mock( + return_value=httpx.Response(200, json=_process_response(0)) + ) + stream = _TrackingAsyncStream( + _chunked_ndjson( + {"stream": "stdout", "data": "café\n"}, + {"stream": "stderr", "data": "雪\n"}, + ) + ) + respx.get("https://sandbox.test/v2/sandboxes/sessions/sbx_1/cmd/cmd_1/logs").mock( + return_value=httpx.Response(200, stream=stream) + ) + + async with vercel.session(service_options=_session_options()): + box = await sandbox.create_sandbox(name="preview", runtime="python3.13") + process = await box.create_process("python") + output = await process.communicate() + + assert output == ("café\n", "雪\n") + assert stream.closed + + +@respx.mock +def test_sync_process_readers_read_chunked_ndjson(mock_env_clear: None) -> None: + respx.post("https://sandbox.test/v2/sandboxes").mock( + return_value=httpx.Response(200, json=_sandbox_response()) + ) + respx.post("https://sandbox.test/v2/sandboxes/sessions/sbx_1/cmd").mock( + return_value=httpx.Response(200, json=_process_response()) + ) + respx.get("https://sandbox.test/v2/sandboxes/sessions/sbx_1/cmd/cmd_1").mock( + return_value=httpx.Response(200, json=_process_response(0)) + ) + stream = _TrackingSyncStream( + _chunked_ndjson( + {"stream": "stdout", "data": "café\n"}, + {"stream": "stderr", "data": "雪\n"}, + ) + ) + respx.get("https://sandbox.test/v2/sandboxes/sessions/sbx_1/cmd/cmd_1/logs").mock( + return_value=httpx.Response(200, stream=stream) + ) + + with vercel.session(service_options=_session_options()): + box = sandbox_sync.create_sandbox(name="preview", runtime="python3.13") + process = box.create_process("python") + output = process.communicate() + + assert output == ("café\n", "雪\n") + assert stream.closed + + @pytest.mark.parametrize( "kwargs", [ diff --git a/tests/unstable/test_sandbox_streaming_archive.py b/tests/unstable/test_sandbox_streaming_archive.py new file mode 100644 index 0000000..7ce622a --- /dev/null +++ b/tests/unstable/test_sandbox_streaming_archive.py @@ -0,0 +1,241 @@ +import gzip +import io +import random +import tarfile + +import pytest +from hypothesis import given, settings, strategies as st + +from vercel._internal.byte_stream import AsyncByteStreamRuntime, SyncByteStreamRuntime +from vercel._internal.iter_coroutine import iter_coroutine +from vercel._internal.unstable.sandbox.errors import SandboxUploadSizeMismatchError +from vercel._internal.unstable.sandbox.runtime_common import _UploadFileEntry +from vercel._internal.unstable.sandbox.service import SandboxArchiveUpload +from vercel._internal.unstable.sandbox.streaming_archive import ArchiveRequestWriter + + +class _CollectRequest: + def __init__(self) -> None: + self.chunks: list[bytes] = [] + self.finishes = 0 + + async def write(self, data: bytes) -> None: + self.chunks.append(data) + + async def finish(self) -> None: + self.finishes += 1 + + async def abort(self) -> None: + pass + + +def _sync_entries(entries: list[_UploadFileEntry]) -> list[_UploadFileEntry]: + runtime = SyncByteStreamRuntime() + return [ + _UploadFileEntry( + path=entry.path, + size=entry.size, + source=runtime.reader(entry.source), + mode=entry.mode, + archive_path=entry.archive_path, + ) + for entry in entries + ] + + +def sync_archive_body(entries: list[_UploadFileEntry], chunk_size: int): # type: ignore[no-untyped-def] + request = _CollectRequest() + + async def upload_entries() -> None: + upload = SandboxArchiveUpload( + writer=ArchiveRequestWriter(request, chunk_size), paths=(), cwd="/" + ) + for entry in _sync_entries(entries): + await upload.add_source(entry) + await upload.finish() + + iter_coroutine(upload_entries()) + return iter(request.chunks) + + +async def async_archive_body(entries: list[_UploadFileEntry], chunk_size: int): # type: ignore[no-untyped-def] + runtime = AsyncByteStreamRuntime() + normalized = [ + _UploadFileEntry( + path=entry.path, + size=entry.size, + source=runtime.reader(entry.source), + mode=entry.mode, + archive_path=entry.archive_path, + ) + for entry in entries + ] + request = _CollectRequest() + upload = SandboxArchiveUpload( + writer=ArchiveRequestWriter(request, chunk_size), paths=(), cwd="/" + ) + for entry in normalized: + await upload.add_source(entry) + await upload.finish() + for chunk in request.chunks: + yield chunk + + +def _read_tar(data: bytes) -> list[tuple[str, bytes, int]]: + decompressed = gzip.decompress(data) + result: list[tuple[str, bytes, int]] = [] + with tarfile.open(fileobj=io.BytesIO(decompressed), mode="r:") as tar: + for member in tar.getmembers(): + f = tar.extractfile(member) + content = f.read() if f else b"" + result.append((member.name, content, member.mode)) + return result + + +_PATH_SEGMENT = st.text( + alphabet="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-é雪", + min_size=1, + max_size=40, +) +_RELATIVE_PATH = st.lists(_PATH_SEGMENT, min_size=1, max_size=3).map("/".join) +_ARCHIVE_PATH = st.one_of( + _RELATIVE_PATH, + _RELATIVE_PATH.map(lambda path: f"/{path}"), + st.text(alphabet="abcdefghijklmnopqrstuvwxyz", min_size=101, max_size=180).map( + lambda path: f"{path}.bin" + ), +) +_ARCHIVE_DATA = st.one_of( + st.sampled_from([b"", b"x" * 511, b"x" * 512, b"x" * 513]), + st.binary(max_size=8192), +) + + +class TestBodyIterators: + @staticmethod + def _make_entries(*specs: tuple[str, bytes, int | None]) -> list[_UploadFileEntry]: + return [ + _UploadFileEntry(path=name, size=len(data), source=data, mode=mode) + for name, data, mode in specs + ] + + def _verify_entries(self, data: bytes, *specs: tuple[str, bytes, int | None]) -> None: + entries = _read_tar(data) + assert len(entries) == len(specs) + for (name, content, mode), (exp_name, exp_data, exp_mode) in zip( + entries, specs, strict=True + ): + assert name == exp_name + assert content == exp_data + assert mode == (exp_mode if exp_mode is not None else 0o644) + + @pytest.mark.anyio + @settings(max_examples=50, deadline=None) + @given( + specs=st.lists( + st.tuples( + _ARCHIVE_PATH, + _ARCHIVE_DATA, + st.sampled_from([None, 0, 0o600, 0o644, 0o755, 0o777]), + ), + max_size=5, + unique_by=lambda spec: spec[0], + ), + chunk_size=st.sampled_from([1, 64, 511, 512, 4096, 65536]), + ) + async def test_sync_and_async_archive_parity( + self, + specs: list[tuple[str, bytes, int | None]], + chunk_size: int, + ) -> None: + entries = self._make_entries(*specs) + sync_data = b"".join(sync_archive_body(entries, chunk_size)) + async_chunks: list[bytes] = [] + async for chunk in async_archive_body(entries, chunk_size): + async_chunks.append(chunk) + assert sync_data == b"".join(async_chunks) + self._verify_entries(sync_data, *specs) + + @pytest.mark.anyio + async def test_async_trailing_data_raises(self) -> None: + entry = _UploadFileEntry(path="f", size=3, source=b"extra") + with pytest.raises(SandboxUploadSizeMismatchError) as exc_info: + async for _ in async_archive_body([entry], 4096): + pass + assert exc_info.value.consumed == 4 + assert not exc_info.value.early_end + + def test_sync_emits_before_exhausting_source_with_bounded_reads(self) -> None: + size = 2 * 1024 * 1024 + data = random.Random(1).randbytes(size) + + class Reader: + def __init__(self) -> None: + self.offset = 0 + self.max_requested = 0 + + def read(self, size: int = -1, /) -> bytes: + assert 0 < size <= 65536 + self.max_requested = max(self.max_requested, size) + chunk = data[self.offset : self.offset + size] + self.offset += len(chunk) + return chunk + + reader = Reader() + + class ObservingRequest(_CollectRequest): + def __init__(self) -> None: + super().__init__() + self.first_write_offset: int | None = None + + async def write(self, chunk: bytes) -> None: + if self.first_write_offset is None: + self.first_write_offset = reader.offset + await super().write(chunk) + + request = ObservingRequest() + entries = _sync_entries([_UploadFileEntry("remote.bin", len(data), reader)]) + + async def upload_entry() -> None: + upload = SandboxArchiveUpload( + writer=ArchiveRequestWriter(request, 65536), paths=(), cwd="/" + ) + await upload.add_source(entries[0]) + await upload.finish() + + iter_coroutine(upload_entry()) + + assert request.first_write_offset is not None + assert request.first_write_offset < size + assert 0 < reader.max_requested <= 65536 + assert _read_tar(b"".join(request.chunks))[0][1] == data + + def test_sync_early_end_error_fields(self) -> None: + entry = _UploadFileEntry(path="visible/path", size=5, source=io.BytesIO(b"abc")) + with pytest.raises(SandboxUploadSizeMismatchError) as exc_info: + list(sync_archive_body([entry], 2)) + error = exc_info.value + assert (error.path, error.declared, error.consumed, error.early_end) == ( + "visible/path", + 5, + 3, + True, + ) + + def test_sync_non_bytes_and_source_errors_propagate(self) -> None: + class BadReader: + def read(self, size: int = -1, /) -> bytes: + return bytearray(b"x") # type: ignore[return-value] + + with pytest.raises(TypeError, match="expected bytes"): + list(sync_archive_body([_UploadFileEntry("f", 1, BadReader())], 4)) + + failure = RuntimeError("source failed") + + class FailingReader: + def read(self, size: int = -1, /) -> bytes: + raise failure + + with pytest.raises(RuntimeError) as exc_info: + list(sync_archive_body([_UploadFileEntry("f", 1, FailingReader())], 4)) + assert exc_info.value is failure diff --git a/tests/unstable/test_sandbox_text_reader.py b/tests/unstable/test_sandbox_text_reader.py index b9d787f..82063b3 100644 --- a/tests/unstable/test_sandbox_text_reader.py +++ b/tests/unstable/test_sandbox_text_reader.py @@ -6,12 +6,53 @@ import httpx import pytest +from vercel._internal.http import StreamingResponse from vercel._internal.unstable.sandbox.errors import SandboxStreamError from vercel._internal.unstable.sandbox.text_reader import _sync_text_readers, _text_readers -def _logs_response(*records: object) -> httpx.Response: - return httpx.Response(200, text="\n".join(json.dumps(record) for record in records) + "\n") +class _TestStreamingResponse(StreamingResponse): + def __init__(self, response: httpx.Response) -> None: + self.response = response + self._sync_iterator: Iterator[bytes] | None = None + self._async_iterator: AsyncIterator[bytes] | None = None + if isinstance(response.stream, httpx.AsyncByteStream): + self._async_iterator = response.aiter_bytes() + else: + self._sync_iterator = response.iter_bytes() + + async def __anext__(self) -> bytes: + if self._async_iterator is not None: + return await anext(self._async_iterator) + assert self._sync_iterator is not None + try: + return next(self._sync_iterator) + except StopIteration: + raise StopAsyncIteration from None + + async def aiter_lines(self) -> AsyncIterator[str]: + if isinstance(self.response.stream, httpx.AsyncByteStream): + async for line in self.response.aiter_lines(): + yield line + else: + for line in self.response.iter_lines(): + yield line + + async def aclose(self) -> None: + if isinstance(self.response.stream, httpx.AsyncByteStream): + await self.response.aclose() + else: + self.response.close() + + +def _streaming(response: httpx.Response) -> StreamingResponse: + return _TestStreamingResponse(response) + + +def _logs_response(*records: object) -> StreamingResponse: + return _streaming( + httpx.Response(200, text="\n".join(json.dumps(record) for record in records) + "\n") + ) def _logs_body(*records: object) -> bytes: @@ -45,7 +86,7 @@ def close(self) -> None: @pytest.mark.anyio @pytest.mark.parametrize("anyio_backend", ["asyncio", "trio"]) async def test_async_text_reader_lines_shared_cursor_eof_and_close(anyio_backend: str) -> None: - async def open_response() -> httpx.Response: + async def open_response() -> StreamingResponse: return _logs_response( {"stream": "stdout", "data": "first\nsecond"}, {"stream": "stderr", "data": "ignored\n"}, @@ -82,8 +123,8 @@ async def __aiter__(self) -> AsyncIterator[bytes]: await release.wait() yield b'{"stream":"stdout","data":"done\\n"}\n' - async def open_response() -> httpx.Response: - return httpx.Response(200, stream=PendingStream()) + async def open_response() -> StreamingResponse: + return _streaming(httpx.Response(200, stream=PendingStream())) reader, _ = _text_readers(open_response) assert reader is not None @@ -98,7 +139,7 @@ async def open_response() -> httpx.Response: @pytest.mark.anyio @pytest.mark.parametrize("anyio_backend", ["asyncio", "trio"]) async def test_async_text_reader_propagates_in_band_errors(anyio_backend: str) -> None: - async def open_response() -> httpx.Response: + async def open_response() -> StreamingResponse: return _logs_response( {"stream": "stdout", "data": "before\n"}, {"stream": "error", "data": {"code": "stopped", "message": "process stopped"}}, @@ -124,8 +165,8 @@ async def __aiter__(self) -> AsyncIterator[bytes]: raise httpx.ReadError("connection failed") yield b"" # pragma: no cover - async def open_response() -> httpx.Response: - return httpx.Response(200, stream=FailedStream()) + async def open_response() -> StreamingResponse: + return _streaming(httpx.Response(200, stream=FailedStream())) reader, peer = _text_readers(open_response) assert reader is not None and peer is not None @@ -143,8 +184,8 @@ async def __aiter__(self) -> AsyncIterator[bytes]: await anyio.sleep_forever() yield b"" # pragma: no cover - async def open_response() -> httpx.Response: - return httpx.Response(200, stream=PendingStream()) + async def open_response() -> StreamingResponse: + return _streaming(httpx.Response(200, stream=PendingStream())) reader, peer = _text_readers(open_response) assert reader is not None and peer is not None @@ -203,8 +244,8 @@ async def test_async_text_reader_merges_stderr_in_arrival_order(anyio_backend: s ) ) - async def open_response() -> httpx.Response: - return httpx.Response(200, stream=stream) + async def open_response() -> StreamingResponse: + return _streaming(httpx.Response(200, stream=stream)) reader, peer = _text_readers(open_response, stderr=subprocess.STDOUT) assert peer is None @@ -219,7 +260,7 @@ async def open_response() -> httpx.Response: @pytest.mark.anyio @pytest.mark.parametrize("anyio_backend", ["asyncio", "trio"]) async def test_async_text_reader_drops_devnull_stream(anyio_backend: str) -> None: - async def open_response() -> httpx.Response: + async def open_response() -> StreamingResponse: return _logs_response( {"stream": "stdout", "data": "dropped\n"}, {"stream": "stderr", "data": "kept\n"}, @@ -239,7 +280,7 @@ async def test_async_text_readers_with_no_streams_never_open_response( ) -> None: opened = 0 - async def open_response() -> httpx.Response: + async def open_response() -> StreamingResponse: nonlocal opened opened += 1 return _logs_response() @@ -252,7 +293,7 @@ async def open_response() -> httpx.Response: @pytest.mark.anyio @pytest.mark.parametrize("anyio_backend", ["asyncio", "trio"]) async def test_async_merged_reader_propagates_in_band_errors(anyio_backend: str) -> None: - async def open_response() -> httpx.Response: + async def open_response() -> StreamingResponse: return _logs_response( {"stream": "stderr", "data": "before\n"}, {"stream": "error", "data": {"code": "stopped", "message": "process stopped"}}, @@ -277,7 +318,7 @@ def test_sync_text_reader_merges_stderr_in_arrival_order() -> None: ) reader, peer = _sync_text_readers( - lambda: httpx.Response(200, stream=stream), stderr=subprocess.STDOUT + lambda: _streaming(httpx.Response(200, stream=stream)), stderr=subprocess.STDOUT ) assert peer is None assert reader is not None @@ -305,7 +346,7 @@ def test_sync_text_reader_drops_devnull_stream() -> None: def test_sync_text_readers_with_no_streams_never_open_response(stderr: int) -> None: opened = 0 - def open_response() -> httpx.Response: + def open_response() -> StreamingResponse: nonlocal opened opened += 1 return _logs_response()