diff --git a/serialx/platforms/serial_win32.py b/serialx/platforms/serial_win32.py index 625b170..d8c046e 100644 --- a/serialx/platforms/serial_win32.py +++ b/serialx/platforms/serial_win32.py @@ -484,22 +484,37 @@ def __init__( self._closing: bool = False self._connect_in_progress: bool = False self._connection_made_waiter: asyncio.Future[None] | None = None + self._pending_connection_lost_exc: Exception | None = None def serial_close(self) -> None: - """Close the serial port.""" + """Release the handle off the event loop, then report connection lost.""" + if self._close_future is not None: + return - def _close_then_notify() -> None: - assert self._serial is not None - exc = None + serial = self._serial + self._serial = None + self._handle = None - try: - self._serial.close() - except Exception as e: - exc = e + if serial is None: + self._call_protocol_connection_lost(self._pending_connection_lost_exc) + return - self._loop.call_soon_threadsafe(self._call_protocol_connection_lost, exc) + self._close_future = self._loop.run_in_executor(None, serial.close) + self._close_future.add_done_callback(self._on_serial_closed) + + def _on_serial_closed(self, fut: asyncio.Future[None]) -> None: + # Consume the future's exception so it does not surface later as a noisy warning + if (exc := fut.exception()) is not None: + self._loop.call_exception_handler( + { + "message": "Unhandled exception while closing the serial port", + "exception": exc, + "transport": self, + "protocol": self._protocol, + } + ) - self._close_future = self._loop.run_in_executor(None, _close_then_notify) + self._call_protocol_connection_lost(self._pending_connection_lost_exc) def serial_shutdown(self, how: int) -> None: """Shutdown the serial connection.""" @@ -527,8 +542,8 @@ def protocol_connection_made(self, transport: asyncio.Transport) -> None: self._connection_made_waiter.set_result(None) def protocol_connection_lost(self, exc: Exception | None) -> None: - """Forward connection_lost to the protocol.""" - self._call_protocol_connection_lost(exc) + """Stash the connection-lost reason, `serial_close` dispatches it.""" + self._pending_connection_lost_exc = exc def protocol_pause_writing(self) -> None: """Forward pause_writing to the protocol.""" diff --git a/tests/test_async_lifecycle.py b/tests/test_async_lifecycle.py index 71dd3b8..9ca6a18 100644 --- a/tests/test_async_lifecycle.py +++ b/tests/test_async_lifecycle.py @@ -116,6 +116,18 @@ def total_received(self) -> bytes: return b"".join(self.data_received_chunks) +@pytest.fixture(autouse=True, params=["lazy_tasks", "eager_tasks"]) +async def task_factory(request: pytest.FixtureRequest) -> None: + """Run every lifecycle test under both the default and eager task factories.""" + if request.param == "eager_tasks": + if sys.version_info < (3, 12): + pytest.skip("Eager task factory requires Python 3.12+") + if sys.platform == "emscripten": + pytest.skip("Pyodide's WebLoop does not support custom task factories") + + asyncio.get_running_loop().set_task_factory(asyncio.eager_task_factory) + + # --- Successful lifecycle: callbacks fire exactly once --- @@ -138,6 +150,38 @@ async def test_lifecycle_normal_close_callbacks(serial_pair: SerialPair) -> None protocol.assert_clean() +async def test_lifecycle_port_released_before_connection_lost( + serial_pair: SerialPair, +) -> None: + """connection_lost must not fire until the port-releasing syscall returns.""" + loop = asyncio.get_running_loop() + protocol = RecordingProtocol() + + close_targets = ["os.close"] + if sys.platform == "win32": + close_targets.append("serialx.platforms.serial_win32.CloseHandle") + + transport, _ = await create_serial_connection( + loop, lambda: protocol, serial_pair.left, baudrate=115200 + ) + + with patch_slow(*close_targets) as (started, proceed, _completed): + transport.close() + + if not await loop.run_in_executor(None, started.wait, 1.0): + pytest.skip("Backend close path does not go through a patched syscall") + + # The releasing syscall is mid-flight, so the handle is still held. The + # protocol must not have been told the connection is lost. + assert protocol.state is ProtocolState.MADE + + proceed.set() + + await transport.wait_closed() + assert protocol.state is ProtocolState.LOST # type:ignore[comparison-overlap] + protocol.assert_clean() + + async def test_lifecycle_abort_callbacks(serial_pair: SerialPair) -> None: """Connect + abort: traverses INIT -> MADE -> LOST.""" loop = asyncio.get_running_loop()