Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions prek.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ types_or = ["python", "pyi"]
exclude = '^tests/esphome/external_components/'
require_serial = true

[[repos.hooks]]
id = "pylint"
name = "pylint"
entry = "uv run --frozen pylint serialx/"
language = "system"
types = ["python"]
pass_filenames = false
require_serial = true

[[repos.hooks]]
id = "cargo-fmt"
name = "Cargo fmt"
Expand Down
51 changes: 51 additions & 0 deletions pylint/plugins/pylint_serialx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Pylint plugin for serialx."""

from __future__ import annotations

from astroid.nodes import Arguments, AssignName, FunctionDef
from pylint.checkers import BaseChecker
from pylint.lint import PyLinter


class ParamReassignChecker(BaseChecker):
"""Flag rebinding a function parameter to a new value."""

name = "serialx-param-reassign"
msgs = {
# serialx reserves pylint message base 90: codes are {C,W,E,R}90xx.
"W9001": (
"Reassigning function parameter %r; bind a new local instead",
"serialx-reassigned-parameter",
"Rebinding a parameter overloads one name with two meanings and lets "
"a rewritten value leak into later uses of the original.",
),
}

def visit_assignname(self, node: AssignName) -> None:
"""Flag an assignment target that rebinds an enclosing function parameter."""
# Skip the parameter *definition* itself (its target lives under Arguments).
if isinstance(node.parent, Arguments):
return

scope = node.scope()
if not isinstance(scope, FunctionDef):
return

args = scope.args
names: set[str] = set()
for group in (args.posonlyargs, args.args, args.kwonlyargs):
names.update(arg.name for arg in group or [])
if args.vararg:
names.add(args.vararg)
if args.kwarg:
names.add(args.kwarg)

if node.name in names:
self.add_message(
"serialx-reassigned-parameter", node=node, args=(node.name,)
)


def register(linter: PyLinter) -> None:
"""Register the serialx checkers with pylint."""
linter.register_checker(ParamReassignChecker(linter))
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dev = [
"uv>=0.11.14",
"ruff>=0.14.6",
"mypy>=2.1.0",
"pylint>=3.3.0",
"codespell>=2.4.2",
"tomli>=2.3.0 ; python_version < '3.11'",
"pytest>=9.0.1",
Expand Down Expand Up @@ -173,6 +174,10 @@ disable_error_code = [
"untyped-decorator",
]

[[tool.mypy.overrides]]
module = ["astroid.*"]
ignore_missing_imports = true


[tool.coverage.run]
source = ["serialx"]
Expand All @@ -181,6 +186,14 @@ relative_files = true
[tool.coverage.paths]
source = ["serialx/", "serialx\\"]

[tool.pylint.MAIN]
init-hook = "import sys; sys.path.append('pylint/plugins')"
load-plugins = ["pylint_serialx"]

[tool.pylint."MESSAGES CONTROL"]
disable = ["all"]
enable = ["serialx-reassigned-parameter"]

[tool.ruff]
target-version = "py310"

Expand Down Expand Up @@ -319,6 +332,7 @@ max-complexity = 25

[tool.codespell]
skip = "tests/data/*"
ignore-words-list = "astroid"

[tool.coverage.report]
show_missing = true
Expand Down
14 changes: 7 additions & 7 deletions serialx/async_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,18 +273,18 @@ async def create_serial_connection(
**kwargs: Any,
) -> tuple[BaseSerialTransport, asyncio.Protocol]:
"""Create a serial port connection with asyncio."""
if transport_cls is None:
if url is None:
raise ValueError("One of `url` or `transport_cls` must be provided.")

if transport_cls is not None:
resolved_cls = transport_cls
elif url is None:
raise ValueError("One of `url` or `transport_cls` must be provided.")
else:
handler = await asyncio.get_running_loop().run_in_executor(
None, get_uri_handler, url
)
transport_cls = handler.async_transport_cls
resolved_cls = handler.async_transport_cls

assert transport_cls is not None
protocol = protocol_factory()
transport = transport_cls(loop=loop, protocol=protocol)
transport = resolved_cls(loop=loop, protocol=protocol)

await transport.connect(
path=url,
Expand Down
71 changes: 40 additions & 31 deletions serialx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,21 +339,24 @@ def __init__(
"""Initialize serial port configuration."""
super().__init__()

if not isinstance(stopbits, StopBits):
stopbits = StopBits(stopbits)
normalized_stopbits = (
stopbits if isinstance(stopbits, StopBits) else StopBits(stopbits)
)

if parity is None:
parity = Parity.NONE
elif not isinstance(parity, Parity):
parity = Parity(parity)
normalized_parity = Parity.NONE
elif isinstance(parity, Parity):
normalized_parity = parity
else:
normalized_parity = Parity(parity)

self._path = path
self._baudrate = baudrate
self._stopbits = stopbits
self._stopbits = normalized_stopbits
self._xonxoff = xonxoff
self._rtscts = rtscts
self._dsrdtr = dsrdtr
self._parity = parity
self._parity = normalized_parity
self._byte_size = byte_size
self._exclusive = exclusive
self._read_timeout = read_timeout
Expand Down Expand Up @@ -397,9 +400,12 @@ def _check_broken(self) -> None:
def from_url(cls, url: str, *args: Any, **kwargs: Any) -> BaseSerial:
"""Create the appropriate serial port subclass for the given URL."""
handler = get_uri_handler(url)
target = url
if handler.strip_uri_scheme:
url = url.removeprefix(handler.scheme).removeprefix(handler.unique_scheme)
return handler.sync_cls(url, *args, **kwargs)
target = url.removeprefix(handler.scheme).removeprefix(
handler.unique_scheme
)
return handler.sync_cls(target, *args, **kwargs)

@maybe_wrap_exceptions
def open(self) -> None:
Expand Down Expand Up @@ -471,8 +477,9 @@ def set_modem_pins(
"""Set modem control bits."""
self._check_broken()

if modem_pins is None:
modem_pins = ModemPins(
pins = modem_pins
if pins is None:
pins = ModemPins(
le=PinState.convert(le),
dtr=PinState.convert(dtr),
rts=PinState.convert(rts),
Expand All @@ -484,7 +491,7 @@ def set_modem_pins(
dsr=PinState.convert(dsr),
)

return self._set_modem_pins(modem_pins)
return self._set_modem_pins(pins)

@abstractmethod
def _get_modem_pins(self) -> ModemPins:
Expand All @@ -500,8 +507,8 @@ def _set_modem_pins(self, modem_pins: ModemPins) -> None:
def readinto(self, b: Buffer, *, timeout: float | None = None) -> int:
"""Read bytes from serial port into buffer."""
self._check_broken()
timeout = self._read_timeout if timeout is None else timeout
return self._readinto(b, timeout=timeout)
effective_timeout = self._read_timeout if timeout is None else timeout
return self._readinto(b, timeout=effective_timeout)

@abstractmethod
def _readinto(self, b: Buffer, *, timeout: float | None) -> int:
Expand All @@ -512,8 +519,8 @@ def _readinto(self, b: Buffer, *, timeout: float | None) -> int:
def write(self, data: Buffer, *, timeout: float | None = None) -> int:
"""Write bytes to serial port."""
self._check_broken()
timeout = self._write_timeout if timeout is None else timeout
return self._write(data, timeout=timeout)
effective_timeout = self._write_timeout if timeout is None else timeout
return self._write(data, timeout=effective_timeout)

@abstractmethod
def _write(self, data: Buffer, *, timeout: float | None) -> int:
Expand Down Expand Up @@ -581,14 +588,14 @@ def readexactly(self, n: int, *, timeout: float | None = None) -> bytes:
buffer = bytearray(n)
view = memoryview(buffer)
remaining = n
timeout = self.read_timeout if timeout is None else timeout
remaining_timeout = self.read_timeout if timeout is None else timeout

while remaining > 0:
with measure_time() as get_elapsed:
read = self.readinto(view, timeout=timeout)
read = self.readinto(view, timeout=remaining_timeout)

if timeout is not None:
timeout -= get_elapsed()
if remaining_timeout is not None:
remaining_timeout -= get_elapsed()

view = view[read:]
remaining -= read
Expand All @@ -611,14 +618,14 @@ def read_until(
"""Read until the expected sequence is found."""
buffer = bytearray()
expected_len = len(expected)
timeout = self.read_timeout if timeout is None else timeout
remaining_timeout = self.read_timeout if timeout is None else timeout

while True:
with measure_time() as get_elapsed:
byte = self.readexactly(1, timeout=timeout)
byte = self.readexactly(1, timeout=remaining_timeout)

if timeout is not None:
timeout -= get_elapsed()
if remaining_timeout is not None:
remaining_timeout -= get_elapsed()

if not byte:
break
Expand Down Expand Up @@ -1009,15 +1016,16 @@ async def connect(
**kwargs: Unpack[ConnectKwargs],
) -> None:
"""Connect to serial port."""
if path is not None:
handler = get_uri_handler(path)
target = path
if target is not None:
handler = get_uri_handler(target)
if handler.strip_uri_scheme:
path = path.removeprefix(handler.scheme).removeprefix(
target = target.removeprefix(handler.scheme).removeprefix(
handler.unique_scheme
)

try:
await self._connect(path=path, **kwargs)
await self._connect(path=target, **kwargs)
except BaseException:
# Intentionally catch cancellation too: callers should only observe
# connect failure/cancel after transport resources are released.
Expand Down Expand Up @@ -1058,8 +1066,9 @@ async def set_modem_pins(
"""Set modem control bits."""
self._check_broken()

if modem_pins is None:
modem_pins = ModemPins(
pins = modem_pins
if pins is None:
pins = ModemPins(
le=PinState.convert(le),
dtr=PinState.convert(dtr),
rts=PinState.convert(rts),
Expand All @@ -1071,7 +1080,7 @@ async def set_modem_pins(
dsr=PinState.convert(dsr),
)

return await self._set_modem_pins(modem_pins)
return await self._set_modem_pins(pins)

async def flush(self) -> None:
"""Flush write buffers, waiting until all data is written."""
Expand Down
18 changes: 10 additions & 8 deletions serialx/descriptor_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def get_write_buffer_limits(self) -> tuple[int, int]:
def _set_write_buffer_limits(
self, high: int | None = None, low: int | None = None
) -> None:
# pylint: disable=serialx-reassigned-parameter
if high is None:
if low is None: # noqa: SIM108
high = 64 * 1024
Expand Down Expand Up @@ -257,9 +258,10 @@ def write(self, data: bytes | bytearray | memoryview) -> None:

self._check_broken()

if isinstance(data, bytearray):
data = memoryview(data)
if not data:
buf: bytes | memoryview = (
memoryview(data) if isinstance(data, bytearray) else data
)
if not buf:
return

if self._closing or self._conn_lost_count > 0:
Expand All @@ -273,7 +275,7 @@ def write(self, data: bytes | bytearray | memoryview) -> None:
if not self._buffer:
# Attempt to send it right away first.
try:
n = os.write(self._fileno, data)
n = os.write(self._fileno, buf)
except (BlockingIOError, InterruptedError):
n = 0
except (SystemExit, KeyboardInterrupt):
Expand All @@ -288,17 +290,17 @@ def write(self, data: bytes | bytearray | memoryview) -> None:
)
return

len_data = len(data)
len_data = len(buf)
LOGGER.debug("Sent %d of %d bytes", n, len_data)

if n == len_data:
return
elif n > 0:
data = memoryview(data)[n:]
buf = memoryview(buf)[n:]
self._loop.add_writer(self._fileno, self._write_ready)

LOGGER.debug("Buffering %r", data)
self._buffer += data
LOGGER.debug("Buffering %r", buf)
self._buffer += buf
self._maybe_pause_protocol()

def _write_ready(self) -> None:
Expand Down
9 changes: 4 additions & 5 deletions serialx/platforms/serial_pyodide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,24 +214,23 @@ async def _connect( # type: ignore[override]
else:
raise UnsupportedSetting(f"Unsupported byte_size: {byte_size!r}")

if js_port is None:
js_port = _REGISTERED_JS_PORTS.get(path)
port = js_port if js_port is not None else _REGISTERED_JS_PORTS.get(path)

if js_port is None:
if port is None:
raise SerialException(
f"No JS serial port registered for {path!r}; call "
f"`register_js_port(path, js_port)` or pass `js_port=` to `connect`"
)

await js_port.open(
await port.open(
baudRate=self._serial.baudrate,
dataBits=data_bits,
flowControl=flow_control,
parity=_PARITY_MAP[self._serial.parity],
stopBits=_STOPBITS_MAP[self._serial.stopbits],
)

self._js_port = js_port
self._js_port = port
assert self._js_port is not None

if self._serial.rtsdtr_on_open is not PinState.UNDEFINED:
Expand Down
Loading
Loading