Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions serialx/platforms/serial_win32.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,22 +484,37 @@ def __init__(
self._closing: bool = False
self._connect_in_progress: bool = False
self._connection_made_waiter: asyncio.Future[None] | None = None
self._pending_connection_lost_exc: Exception | None = None

def serial_close(self) -> None:
"""Close the serial port."""
"""Release the handle off the event loop, then report connection lost."""
if self._close_future is not None:
return

def _close_then_notify() -> None:
assert self._serial is not None
exc = None
serial = self._serial
self._serial = None
self._handle = None

try:
self._serial.close()
except Exception as e:
exc = e
if serial is None:
self._call_protocol_connection_lost(self._pending_connection_lost_exc)
return

self._loop.call_soon_threadsafe(self._call_protocol_connection_lost, exc)
self._close_future = self._loop.run_in_executor(None, serial.close)
self._close_future.add_done_callback(self._on_serial_closed)

def _on_serial_closed(self, fut: asyncio.Future[None]) -> None:
# Consume the future's exception so it does not surface later as a noisy warning
if (exc := fut.exception()) is not None:
self._loop.call_exception_handler(
{
"message": "Unhandled exception while closing the serial port",
"exception": exc,
"transport": self,
"protocol": self._protocol,
}
)

self._close_future = self._loop.run_in_executor(None, _close_then_notify)
self._call_protocol_connection_lost(self._pending_connection_lost_exc)

def serial_shutdown(self, how: int) -> None:
"""Shutdown the serial connection."""
Expand Down Expand Up @@ -527,8 +542,8 @@ def protocol_connection_made(self, transport: asyncio.Transport) -> None:
self._connection_made_waiter.set_result(None)

def protocol_connection_lost(self, exc: Exception | None) -> None:
"""Forward connection_lost to the protocol."""
self._call_protocol_connection_lost(exc)
"""Stash the connection-lost reason, `serial_close` dispatches it."""
self._pending_connection_lost_exc = exc

def protocol_pause_writing(self) -> None:
"""Forward pause_writing to the protocol."""
Expand Down
44 changes: 44 additions & 0 deletions tests/test_async_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ def total_received(self) -> bytes:
return b"".join(self.data_received_chunks)


@pytest.fixture(autouse=True, params=["lazy_tasks", "eager_tasks"])
async def task_factory(request: pytest.FixtureRequest) -> None:
"""Run every lifecycle test under both the default and eager task factories."""
if request.param == "eager_tasks":
if sys.version_info < (3, 12):
pytest.skip("Eager task factory requires Python 3.12+")
if sys.platform == "emscripten":
pytest.skip("Pyodide's WebLoop does not support custom task factories")

asyncio.get_running_loop().set_task_factory(asyncio.eager_task_factory)


# --- Successful lifecycle: callbacks fire exactly once ---


Expand All @@ -138,6 +150,38 @@ async def test_lifecycle_normal_close_callbacks(serial_pair: SerialPair) -> None
protocol.assert_clean()


async def test_lifecycle_port_released_before_connection_lost(
serial_pair: SerialPair,
) -> None:
"""connection_lost must not fire until the port-releasing syscall returns."""
loop = asyncio.get_running_loop()
protocol = RecordingProtocol()

close_targets = ["os.close"]
if sys.platform == "win32":
close_targets.append("serialx.platforms.serial_win32.CloseHandle")

transport, _ = await create_serial_connection(
loop, lambda: protocol, serial_pair.left, baudrate=115200
)

with patch_slow(*close_targets) as (started, proceed, _completed):
transport.close()

if not await loop.run_in_executor(None, started.wait, 1.0):
pytest.skip("Backend close path does not go through a patched syscall")

# The releasing syscall is mid-flight, so the handle is still held. The
# protocol must not have been told the connection is lost.
assert protocol.state is ProtocolState.MADE

proceed.set()

await transport.wait_closed()
assert protocol.state is ProtocolState.LOST # type:ignore[comparison-overlap]
Comment thread
puddly marked this conversation as resolved.
protocol.assert_clean()


async def test_lifecycle_abort_callbacks(serial_pair: SerialPair) -> None:
"""Connect + abort: traverses INIT -> MADE -> LOST."""
loop = asyncio.get_running_loop()
Expand Down
Loading