From 7c9ed07ff17a2d95acf70e12443ba2e4cf61c7b5 Mon Sep 17 00:00:00 2001 From: dskrypa Date: Sun, 22 Mar 2026 09:42:33 -0400 Subject: [PATCH 1/3] minor code cleanup for context usage and ChoiceMap._handle_duplicate_choice; refined/added some type annotations and comments --- lib/cli_command_parser/config.py | 4 +- lib/cli_command_parser/core.py | 14 ++++--- lib/cli_command_parser/error_handling/base.py | 22 ++++++----- .../error_handling/windows.py | 2 +- lib/cli_command_parser/formatting/commands.py | 3 +- lib/cli_command_parser/formatting/params.py | 19 +++++----- lib/cli_command_parser/metadata.py | 12 ++++-- lib/cli_command_parser/parameters/base.py | 21 ++++++----- .../parameters/choice_map.py | 37 +++++++++++-------- lib/cli_command_parser/parameters/groups.py | 5 ++- 10 files changed, 78 insertions(+), 61 deletions(-) diff --git a/lib/cli_command_parser/config.py b/lib/cli_command_parser/config.py index 2b43485d..9b561ad2 100644 --- a/lib/cli_command_parser/config.py +++ b/lib/cli_command_parser/config.py @@ -121,7 +121,7 @@ class OptionNameMode(FixedFlag): # fmt: on @classmethod - def _missing_(cls, value: str | int | None) -> OptionNameMode: # type: ignore[override] + def _missing_(cls, value: str | int | None) -> Self: # type: ignore[override] try: return OPT_NAME_MODE_ALIASES[value] # type: ignore[index] except KeyError: @@ -212,7 +212,7 @@ class AllowLeadingDash(Enum): # fmt: on @classmethod - def _missing_(cls, value) -> Self: + def _missing_(cls, value: str | bool) -> Self: # type: ignore[override] if isinstance(value, str): try: return cls._member_map_[value.upper()] # type: ignore[return-value] diff --git a/lib/cli_command_parser/core.py b/lib/cli_command_parser/core.py index 9c636eee..49c9f071 100644 --- a/lib/cli_command_parser/core.py +++ b/lib/cli_command_parser/core.py @@ -167,13 +167,15 @@ def _prepare_config(mcs, bases: Bases, config: AnyConfig, kwargs: dict[str, Any] return None - @overload - @classmethod - def config(mcs, cls: CommandAny, default: None = None) -> CommandConfig | None: ... + if TYPE_CHECKING: - @overload - @classmethod - def config(mcs, cls: CommandAny, default: T) -> CommandConfig | T: ... + @overload + @classmethod + def config(mcs, cls: CommandAny, default: None = None) -> CommandConfig | None: ... + + @overload + @classmethod + def config(mcs, cls: CommandAny, default: T) -> CommandConfig | T: ... @classmethod def config(mcs, cls: CommandAny, default: T | None = None) -> CommandConfig | T | None: diff --git a/lib/cli_command_parser/error_handling/base.py b/lib/cli_command_parser/error_handling/base.py index 2999be9e..5026ac29 100644 --- a/lib/cli_command_parser/error_handling/base.py +++ b/lib/cli_command_parser/error_handling/base.py @@ -6,14 +6,18 @@ import sys from collections import ChainMap -from typing import Callable, Iterator, Type, TypeVar +from typing import TYPE_CHECKING, Callable, Iterator, Type, TypeVar from ..exceptions import CommandParserException +if TYPE_CHECKING: + from ..typing import Self + __all__ = ['ErrorHandler', 'error_handler', 'extended_error_handler', 'no_exit_handler', 'NullErrorHandler'] E = TypeVar('E', bound=BaseException) HandlerFunc = Callable[[E], bool | int | None] +HandlerDecorator = Callable[[HandlerFunc], HandlerFunc] class ErrorHandler: @@ -38,16 +42,16 @@ def unregister(self, *exceptions: Type[BaseException]): except KeyError: pass - def __call__(self, *exceptions: Type[BaseException]): - def _handler(handler: HandlerFunc | staticmethod): + def __call__(self, *exceptions: Type[BaseException]) -> HandlerDecorator: + def _handler(handler: HandlerFunc) -> HandlerFunc: self.register(handler, *exceptions) return handler return _handler @classmethod - def cls_handler(cls, *exceptions: Type[E]): - def _cls_handler(handler: HandlerFunc | staticmethod): + def cls_handler(cls, *exceptions: Type[E]) -> HandlerDecorator: + def _cls_handler(handler: HandlerFunc) -> HandlerFunc: for exc in exceptions: cls._exc_handler_map[exc] = Handler(exc, handler) return handler @@ -66,7 +70,7 @@ def iter_handlers(self, exc_type: Type[BaseException], exc: BaseException) -> It for candidate in candidates: yield candidate.handler - def __enter__(self): + def __enter__(self) -> Self: return self def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb) -> bool: @@ -90,7 +94,7 @@ def copy(self) -> ErrorHandler: class NullErrorHandler: - def __enter__(self): + def __enter__(self) -> Self: return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -110,9 +114,7 @@ def __init__(self, exc_cls: Type[BaseException], handler: HandlerFunc): self.handler = handler def __eq__(self, other) -> bool: - if not isinstance(other, Handler): - return False - return other.exc_cls == self.exc_cls and other.handler == self.handler + return isinstance(other, Handler) and other.exc_cls == self.exc_cls and other.handler == self.handler def __lt__(self, other: Handler) -> bool: return issubclass(self.exc_cls, other.exc_cls) diff --git a/lib/cli_command_parser/error_handling/windows.py b/lib/cli_command_parser/error_handling/windows.py index f504f35f..17183cf9 100644 --- a/lib/cli_command_parser/error_handling/windows.py +++ b/lib/cli_command_parser/error_handling/windows.py @@ -33,7 +33,7 @@ def handle_kb_interrupt(exc: KeyboardInterrupt) -> int: @extended_error_handler(OSError) -def handle_win_os_pipe_error(exc: OSError): +def handle_win_os_pipe_error(exc: OSError) -> bool: """ This is a workaround for `[Windows] I/O on a broken pipe may raise an EINVAL OSError instead of BrokenPipeError `_, which is a bug in the way that the diff --git a/lib/cli_command_parser/formatting/commands.py b/lib/cli_command_parser/formatting/commands.py index 84650f52..d02e0f36 100644 --- a/lib/cli_command_parser/formatting/commands.py +++ b/lib/cli_command_parser/formatting/commands.py @@ -98,7 +98,8 @@ def format_usage( allow_sys_argv: Bool = True, cont_indent: int = 4, ) -> str: - if (wrap_usage_str := ctx.config.wrap_usage_str) is True: + if (wrap_usage_str := ctx.config.wrap_usage_str) is True: # noqa + # `is True` is used because it supports True -> term width or an explicit width wrap_usage_str = ctx.terminal_width if usage := self._meta.usage: diff --git a/lib/cli_command_parser/formatting/params.py b/lib/cli_command_parser/formatting/params.py index f74cf08c..a731166e 100644 --- a/lib/cli_command_parser/formatting/params.py +++ b/lib/cli_command_parser/formatting/params.py @@ -315,6 +315,9 @@ class ChoiceGroup: __slots__ = ('choice_strs', 'choices') + choices: list[Choice[Any]] + choice_strs: list[str] + def __init__(self, choice: Choice): self.choices = [choice] self.choice_strs = [choice.choice] if choice.choice else [] @@ -351,10 +354,9 @@ def format(self, default_mode: CmdAliasMode, prefix: str = '') -> Iterator[str]: :return: Generator that yields formatted help text entries (strings) for the Choices in this group. """ for choice, usage, description in self.prepare(default_mode): - if usage is not None: - yield format_help_entry((usage,), description, lpad=4, prefix=prefix) + yield format_help_entry((usage,), description, lpad=4, prefix=prefix) - def prepare(self, default_mode: CmdAliasMode) -> Iterator[tuple[Choice, OptStr, OptStr]]: + def prepare(self, default_mode: CmdAliasMode) -> Iterator[tuple[Choice, str, OptStr]]: """ Prepares the choice values and descriptions to use for each Choice in this group based on the configured alias mode. @@ -364,9 +366,8 @@ def prepare(self, default_mode: CmdAliasMode) -> Iterator[tuple[Choice, OptStr, :return: Generator that yields 3-tuples containing the :class:`.Choice` object, the choice string value, and the help text / description for that choice / alias. """ - # If it's not a Command, get_config will return None. If it is a Command, then it will use its config. If the - # alias mode is not set on that target Command, but it is set on its parent, then this will use that parent's - # setting. + # If the target is a Command, its config will be used, otherwise, get_config will return None. If the alias + # mode is not set on that target Command, but it is set on its parent, then this will use that parent's setting. if config := get_config(self.choices[0].target): mode = config.cmd_alias_mode or default_mode else: @@ -383,7 +384,7 @@ def prepare(self, default_mode: CmdAliasMode) -> Iterator[tuple[Choice, OptStr, # Treat it as a format string yield from self.prepare_aliases(mode) - def prepare_combined(self) -> tuple[Choice, OptStr, OptStr]: + def prepare_combined(self) -> tuple[Choice, str, OptStr]: """ Prepare this group's Choices for inclusion in help text / documentation by combining all aliases into a single entry. @@ -399,7 +400,7 @@ def prepare_combined(self) -> tuple[Choice, OptStr, OptStr]: return first, usage, first.help - def prepare_aliases(self, format_str: str = 'Alias of: {choice}') -> Iterator[tuple[Choice, OptStr, OptStr]]: + def prepare_aliases(self, format_str: str = 'Alias of: {choice}') -> Iterator[tuple[Choice, str, OptStr]]: """ Prepare this group's Choices for inclusion in help text / documentation using an alternate description for aliases. @@ -431,7 +432,7 @@ def prepare_aliases(self, format_str: str = 'Alias of: {choice}') -> Iterator[tu for choice_str in choice_strs: yield first, choice_str, format_str.format(choice=first_str, alias=choice_str, help=help_str) - def prepare_repeated(self) -> Iterator[tuple[Choice, OptStr, OptStr]]: + def prepare_repeated(self) -> Iterator[tuple[Choice, str, OptStr]]: """ Prepare this group's Choices for inclusion in help text / documentation with no modifications. Choices that are considered aliases are simply repeated as if they were not aliases. diff --git a/lib/cli_command_parser/metadata.py b/lib/cli_command_parser/metadata.py index e5829447..fec94984 100644 --- a/lib/cli_command_parser/metadata.py +++ b/lib/cli_command_parser/metadata.py @@ -18,7 +18,8 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Literal, Type, TypeVar, overload from urllib.parse import urlparse -from .context import NoActiveContext, ctx +from .context import get_current_context +from .exceptions import NoActiveContext if TYPE_CHECKING: from .core import CommandMeta @@ -314,6 +315,8 @@ def _repr(obj, indent=0) -> str: class ProgFinder: + """Helper to find the name of the current program when ``prog`` was not set explicitly.""" + @cached_property def mod_obj_prog_map(self) -> dict[str, dict[str, str]]: mod_obj_prog_map: dict[str, dict[str, str]] = defaultdict(dict) @@ -342,7 +345,7 @@ def normalize( if allow_sys_argv is None: try: - allow_sys_argv = ctx.allow_argv_prog + allow_sys_argv = get_current_context().allow_argv_prog except NoActiveContext: allow_sys_argv = True @@ -384,9 +387,10 @@ def _iter_entry_point_candidates(self, cmd_module: str): else: yield prog, obj, getattr(obj, '__module__', ''), getattr(obj, '__qualname__', '') - def _from_sys_argv(self) -> OptStr: + @classmethod + def _from_sys_argv(cls) -> OptStr: try: - ctx_prog = ctx.prog + ctx_prog = get_current_context().prog except NoActiveContext: return None diff --git a/lib/cli_command_parser/parameters/base.py b/lib/cli_command_parser/parameters/base.py index 29443217..700a2251 100644 --- a/lib/cli_command_parser/parameters/base.py +++ b/lib/cli_command_parser/parameters/base.py @@ -15,7 +15,7 @@ from ..annotations import get_descriptor_value_type from ..config import DEFAULT_CONFIG, AllowLeadingDash, CommandConfig, OptionNameMode -from ..context import Context, ctx, get_current_context +from ..context import get_current_context from ..exceptions import BadArgument, InvalidChoice, MissingArgument, ParameterDefinitionError from ..inputs import InputType, normalize_input_type from ..inputs.choices import _ChoicesBase @@ -31,6 +31,7 @@ from typing import Literal, NoReturn, TypeAlias from ..commands import Command + from ..context import Context from ..formatting.params import ParamHelpFormatter from ..typing import Bool, NormalizedType, OptStr, OptStrs, Self, Strings from ._typing import CommandMethod, DefaultFunc, LeadingDash @@ -507,21 +508,21 @@ def __get__(self, command: Command | object | None, owner: Any = None) -> Self | def result(self, command: Command | Any = None, missing_default: TD | _NotSetType = _NotSet) -> T | D | TD: """The final result / parsed value for this Parameter that is returned upon access as a descriptor.""" - if (value := ctx.get_parsed_value(self)) is not _NotSet: + if (value := get_current_context().get_parsed_value(self)) is not _NotSet: return self.action.finalize_value(value) if self.required: if missing_default is _NotSet: raise MissingArgument(self) return missing_default - else: - try: - return self.action.get_default(command, missing_default) - except InputValidationError as e: - # At this point, a default value was provided when this param was defined, but it wasn't acceptable - # TODO: Do any of the other cases handled by the `prepare_value` method need to be checked here? - # Need to test choices - a non-acceptable choice may make sense as the default in some cases - raise BadArgument(self, f'bad default value - {e}') from e + + try: + return self.action.get_default(command, missing_default) + except InputValidationError as e: + # At this point, a default value was provided when this param was defined, but it wasn't acceptable + # TODO: Do any of the other cases handled by the `prepare_value` method need to be checked here? + # Need to test choices - a non-acceptable choice may make sense as the default in some cases + raise BadArgument(self, f'bad default value - {e}') from e # endregion diff --git a/lib/cli_command_parser/parameters/choice_map.py b/lib/cli_command_parser/parameters/choice_map.py index 5ab39222..e0160ae8 100644 --- a/lib/cli_command_parser/parameters/choice_map.py +++ b/lib/cli_command_parser/parameters/choice_map.py @@ -43,6 +43,7 @@ class Choice(Generic[T]): __slots__ = ('choice', 'target', 'help', 'local') + choice: OptStr target: T def __init__(self, choice: OptStr, target: T, help: OptStr = None, local: bool = False): # noqa @@ -142,14 +143,16 @@ def register_choice(self, choice: str, target: T, help: OptStr = None): # noqa self._register_choice(choice, target, help) def _register_choice(self, choice: OptStr, target: T, help: OptStr = None, local: bool = False): # noqa - try: - existing = self.choices[choice] - except KeyError: - self.choices[choice] = Choice(choice, target, help, local) - self._update_nargs() - else: - prefix = 'Invalid default' if choice is None else f'Invalid {choice=} for' - raise CommandDefinitionError(f'{prefix} {target=} - already assigned to {existing}') + if existing := self.choices.get(choice): + self._handle_duplicate_choice(choice, target, existing) + + self.choices[choice] = Choice(choice, target, help, local) + self._update_nargs() + + @classmethod + def _handle_duplicate_choice(cls, choice: OptStr, target: T, existing: Choice): + prefix = 'Invalid default' if choice is None else f'Invalid {choice=} for' + raise CommandDefinitionError(f'{prefix} {target=} - already assigned to {existing}') def _no_choices_error(self) -> NoReturn: raise CommandDefinitionError(f'No choices were registered for {self}') @@ -260,18 +263,20 @@ def register_command(self, choice: OptStr, command: CommandCls, help: OptStr) -> if meta.description and (not meta.parent or meta.parent.description != meta.description): help = meta.description # noqa - try: - self.register_choice(choice, command, help) - except CommandDefinitionError: - from ..core import get_parent - - parent = get_parent(command) - msg = f'Invalid {choice=} for {command} with {parent=} - already assigned to {self.choices[choice].target}' - raise CommandDefinitionError(msg) from None + self.register_choice(choice, command, help) command._is_subcommand_ = True # This is used indirectly by ``main()`` to filter out non-top-level Commands return command + @classmethod + def _handle_duplicate_choice(cls, choice: OptStr, command: CommandCls, existing: Choice): # type: ignore[override] + from ..core import get_parent + + parent = get_parent(command) + raise CommandDefinitionError( + f'Invalid {choice=} for {command} with {parent=} - already assigned to {existing.target}' + ) + def register( self, command_or_choice: str | CommandCls | None = None, diff --git a/lib/cli_command_parser/parameters/groups.py b/lib/cli_command_parser/parameters/groups.py index 78bc05f0..292c5e8e 100644 --- a/lib/cli_command_parser/parameters/groups.py +++ b/lib/cli_command_parser/parameters/groups.py @@ -9,7 +9,7 @@ from itertools import count from typing import TYPE_CHECKING, Iterable, Iterator -from ..context import ctx +from ..context import get_current_context from ..exceptions import CommandDefinitionError, ParamConflict, ParameterDefinitionError, ParamsMissing from .base import BaseOption, BasePositional, ParamBase, _group_stack from .pass_thru import PassThru @@ -191,6 +191,7 @@ def _categorize_params(self) -> tuple[ParamList, ParamList]: """Called after parsing to group this group's members by whether they were provided or not.""" provided: ParamList = [] missing: ParamList = [] + ctx = get_current_context() for obj in self.members: if ctx.num_provided(obj): provided.append(obj) @@ -242,7 +243,7 @@ def validate(self): that nested groups are validated before any group that they are a member of. """ provided, missing = self._categorize_params() - ctx.record_action(self, len(provided)) + get_current_context().record_action(self, len(provided)) self._check_conflicts(provided, missing) if not missing: return From fac2888fc52dc6436b78fe0e80d8b67f21c8e570 Mon Sep 17 00:00:00 2001 From: dskrypa Date: Sun, 29 Mar 2026 08:39:13 -0400 Subject: [PATCH 2/3] removed type: ignore comments that were rendered unnecessary by other changes --- lib/cli_command_parser/conversion/visitor.py | 4 +--- lib/cli_command_parser/core.py | 10 +++++----- lib/cli_command_parser/formatting/params.py | 4 ++-- lib/cli_command_parser/inputs/base.py | 2 +- lib/cli_command_parser/inputs/choices.py | 16 ++++++++-------- lib/cli_command_parser/inputs/numeric.py | 2 +- lib/cli_command_parser/testing.py | 2 +- 7 files changed, 19 insertions(+), 21 deletions(-) diff --git a/lib/cli_command_parser/conversion/visitor.py b/lib/cli_command_parser/conversion/visitor.py index 6329c1ef..b4e9f83f 100644 --- a/lib/cli_command_parser/conversion/visitor.py +++ b/lib/cli_command_parser/conversion/visitor.py @@ -182,9 +182,7 @@ def _visit_for_smart(self, node: For, loop_var: str, ele_names: list[str]): """ log.debug(f'Attempting smart for loop visit for {loop_var=} in {ele_names=}') refs: list[AstArgumentParser] = [ - ref # type: ignore[misc] # mypy doesn't seem to recognize the isinstance part of the condition - for name in ele_names - if (ref := self.scopes.get(name)) and isinstance(ref, AstArgumentParser) + ref for name in ele_names if (ref := self.scopes.get(name)) and isinstance(ref, AstArgumentParser) ] # log.debug(f' > Found {len(refs)=}, {len(ele_names)=}') diff --git a/lib/cli_command_parser/core.py b/lib/cli_command_parser/core.py index 49c9f071..29460e81 100644 --- a/lib/cli_command_parser/core.py +++ b/lib/cli_command_parser/core.py @@ -136,7 +136,7 @@ def _maybe_register_sub_cmd(mcs, cls, choice: Choice, choices: OptChoices, help: if parent := mcs.parent(cls, False): if sub_cmd := mcs.params(parent).sub_command: for choice, choice_help in _choice_items(choice, choices): - sub_cmd.register_command(choice, cls, choice_help or help) # type: ignore[attr-defined] + sub_cmd.register_command(choice, cls, choice_help or help) elif choices or (choice is not None and choice is not _NotSet): _no_choices_registered_warning(choice, choices, cls, f'its {parent=} has no SubCommand parameter') elif choices or (choice is not None and choice is not _NotSet): @@ -180,7 +180,7 @@ def config(mcs, cls: CommandAny, default: T) -> CommandConfig | T: ... @classmethod def config(mcs, cls: CommandAny, default: T | None = None) -> CommandConfig | T | None: try: - return cls.__config # type: ignore[union-attr] # This attr is not overwritten for every subclass + return cls.__config # This attr is not overwritten for every subclass except AttributeError: # This means that the Command and all of its parents have no custom config return default @@ -199,7 +199,7 @@ def parent(mcs, cls: CommandAny, include_abc: bool = True) -> CommandMeta | None ``include_abc``). """ try: - first, parent = cls.__parents # type: ignore[union-attr] # Works for both Command objects and classes + first, parent = cls.__parents # Works for both Command objects and classes except TypeError: pass else: @@ -222,7 +222,7 @@ def parent(mcs, cls: CommandAny, include_abc: bool = True) -> CommandMeta | None def params(mcs, cls: CommandAny) -> CommandParameters: # Late initialization is necessary to allow late assignment of Parameters for now try: - params = cls.__params # type: ignore[union-attr] + params = cls.__params except AttributeError: raise TypeError('CommandParameters are only available for Command subclasses') from None @@ -247,7 +247,7 @@ def meta(mcs, cls: CommandMeta) -> ProgramMetadata: def _mro(cmd_or_cls: CommandAny) -> tuple[CommandMeta, list[type]]: # In the return value of type.mro(...), 0 is always the class itself, -1 is always object try: - return cmd_or_cls, type.mro(cmd_or_cls)[1:-1] # type: ignore[arg-type,return-value] + return cmd_or_cls, type.mro(cmd_or_cls)[1:-1] # type: ignore[return-value] except TypeError: # a Command object was provided instead of a Command class cmd_cls: CommandMeta = cmd_or_cls.__class__ # type: ignore[assignment] return cmd_cls, type.mro(cmd_cls)[1:-1] diff --git a/lib/cli_command_parser/formatting/params.py b/lib/cli_command_parser/formatting/params.py index a731166e..32e54cda 100644 --- a/lib/cli_command_parser/formatting/params.py +++ b/lib/cli_command_parser/formatting/params.py @@ -123,7 +123,7 @@ def format_metavar(self) -> str: config = ctx.config if (t := param.type) is not None: try: - metavar = t.format_metavar( # type: ignore[union-attr,attr-defined] + metavar = t.format_metavar( # type: ignore[union-attr] config.choice_delim, config.sort_choices ) except Exception: # noqa # pylint: disable=W0703 @@ -134,7 +134,7 @@ def format_metavar(self) -> str: if config.use_type_metavar and t is not None: try: - name = t.__name__ # type: ignore[union-attr,attr-defined] + name = t.__name__ # type: ignore[union-attr] except AttributeError: pass else: diff --git a/lib/cli_command_parser/inputs/base.py b/lib/cli_command_parser/inputs/base.py index 1348d6f6..b81eb1f1 100644 --- a/lib/cli_command_parser/inputs/base.py +++ b/lib/cli_command_parser/inputs/base.py @@ -48,7 +48,7 @@ def format_metavar(self, choice_delim: str = ',', sort_choices: bool = False) -> class _FixedInputType(InputType[T], ABC): __slots__ = () - def fix_default(self, value: str | T | None) -> str | T | None: # type: ignore[override] + def fix_default(self, value: str | T | None) -> str | T | None: if value is None or not isinstance(value, str) or not self._fix_default: return value return self(value) diff --git a/lib/cli_command_parser/inputs/choices.py b/lib/cli_command_parser/inputs/choices.py index 3ab8bdd7..907ae649 100644 --- a/lib/cli_command_parser/inputs/choices.py +++ b/lib/cli_command_parser/inputs/choices.py @@ -64,8 +64,8 @@ def _iter_normalized(self, value: str | T, choices: Collection | None = None) -> yield value if not self.case_sensitive and (choices is None or isinstance(choices, (set, Mapping))): # Choices validates that all members of `choices` are strings when not case_sensitive - yield value.lower() # type: ignore[misc,union-attr] - yield value.upper() # type: ignore[misc,union-attr] + yield value.lower() + yield value.upper() def _case_insensitive_map_choice(self, value: Any) -> T: if not self.case_sensitive: @@ -114,22 +114,22 @@ def __init__( def _choices_repr(self, delim: str = ',') -> str: try: - return delim.join(map(repr, sorted(self.choices))) # type: ignore[type-var] + return delim.join(map(repr, sorted(self.choices))) except TypeError: # The choice values are not sortable return delim.join(sorted(map(repr, self.choices))) def __call__(self, value: str) -> T: choices = self.choices - value = self._normalize(value) # type: ignore[assignment] + value = self._normalize(value) for val in self._iter_normalized(value, choices): if val in choices: - return value # type: ignore[return-value] + return value if not self.case_sensitive: # choices/value are confirmed to be str in init when case_sensitive=False - norm_value = value.casefold() # type: ignore[attr-defined] + norm_value = value.casefold() for choice in choices: - if norm_value == choice.casefold(): # type: ignore[attr-defined] + if norm_value == choice.casefold(): return choice raise InvalidChoiceError(value, choices) @@ -155,7 +155,7 @@ def __init__(self, choices: Mapping[Any, T], *args, **kwargs): # TODO: Alternate ChoiceMap where values are used as help text, similar to SubCommand with local_choices def __call__(self, value: str) -> T: - value = self._normalize(value) # type: ignore[assignment] + value = self._normalize(value) for val in self._iter_normalized(value): try: return self.choices[val] diff --git a/lib/cli_command_parser/inputs/numeric.py b/lib/cli_command_parser/inputs/numeric.py index f7a90e67..07443cdd 100644 --- a/lib/cli_command_parser/inputs/numeric.py +++ b/lib/cli_command_parser/inputs/numeric.py @@ -205,7 +205,7 @@ def __call__(self, value: str) -> N: return num_val -class Bytes(NumericInput[int | float]): # type: ignore[type-var] +class Bytes(NumericInput[int | float]): """ A byte count/size. diff --git a/lib/cli_command_parser/testing.py b/lib/cli_command_parser/testing.py index d5f967be..a5c92a7e 100644 --- a/lib/cli_command_parser/testing.py +++ b/lib/cli_command_parser/testing.py @@ -301,7 +301,7 @@ def stderr(self) -> str: def __enter__(self) -> RedirectStreams: streams: dict[str, IO] = {'stdout': self._stdout, 'stderr': self._stderr} if self._stdin is not None: - streams['stdin'] = self._stdin # type: ignore[assignment] + streams['stdin'] = self._stdin for name, io in streams.items(): self._old[name] = getattr(sys, name) From 9d11cf464371508601ed0476d832d7669c7acc5a Mon Sep 17 00:00:00 2001 From: dskrypa Date: Sun, 29 Mar 2026 09:15:33 -0400 Subject: [PATCH 3/3] minor breaking change to inputs.files.Serialized and related classes to support proper typing for #151 --- docs/_src/inputs.rst | 45 ++- lib/cli_command_parser/inputs/_typing.py | 68 +++- lib/cli_command_parser/inputs/files.py | 404 +++++++++++++++++++---- lib/cli_command_parser/inputs/utils.py | 170 +++++++--- tests/test_inputs/test_file_inputs.py | 20 +- 5 files changed, 575 insertions(+), 132 deletions(-) diff --git a/docs/_src/inputs.rst b/docs/_src/inputs.rst index 36747f93..4df4ae65 100644 --- a/docs/_src/inputs.rst +++ b/docs/_src/inputs.rst @@ -118,20 +118,28 @@ In addition to plain text or binary files, custom input handlers also exist for and a generic handler (:class:`.Serialized`) exists for any other serialization format. They all extend :ref:`inputs:File`, so the same options are accepted. -.. _serialized_init_params: +.. version-changed:: 2026-04-TBD + + A breaking change was made to the generic :class:`.Serialized` class to remove support for a single ``converter`` + callable that handled either serialization xor deserialization. Instead, it now requires a ``serializer`` that + provides an interface with ``load`` / ``dump`` and/or ``loads`` / ``dumps``, similar to the *json* and *pickle* + modules. + + Additionally, the ``pass_file`` parameter was removed from :class:`.Json`, :class:`.Pickle`, and + :class:`.Serialized`. When the provided ``serializer`` has a ``load`` or ``dump`` attribute, it will always be + preferred over the ``loads`` / ``dumps`` variants. -**Additional Serialized initialization parameters:** -:converter: The function to call to serialize or deserialize the content in the specified file -:pass_file: True to call the given function with the file, False to handle (de)serialization and read/write as - separate steps. If True, when reading, the converter will be called with the file as the only argument; when writing, - the converter will be called as ``converter(data, f)``. If False, when reading, the converter will be called with - the content from the file; when writing, the converter will be called before writing the data to the file. +.. _serialized_init_params: +**Additional Serialized initialization parameters:** -The JSON and Pickle handlers do not accept the above 2 parameters. The converter is automatically picked to be -``dump`` or ``load`` based on whether the provided ``mode`` is for reading or writing, and the ``pass_file`` -option will be overridden if provided. +:serializer: Class or module that provides ``load``/``dump`` and/or ``loads``/``dumps`` methods/functions for + deserialization and serialization, respectively. Expects them to follow the same interface as the *json* or + *pickle* modules, with :func:`python:json.loads`, :func:`python:json.dumps`, :func:`python:pickle.load`, etc. +:lazy: If True, a :class:`.SerializedFileWrapper` will be stored in the Parameter using this file, otherwise the file + will be eagerly read immediately upon parsing of the path argument. When planning to write serialized data to a file, + only the default ``lazy=True`` is supported - eager writes are not supported. Adding another snippet to the above :gh_examples:`example `:: @@ -156,16 +164,19 @@ We can see that the JSON content from stdin was automatically deserialized when [1] ('b', 2) -When using the generic :class:`.Serialized` directly, the specific (de)serialization function needs to be provided:: +When using the generic :class:`.Serialized` directly, the module/object needs to be provided:: + + Serialized(pickle, mode='rb', lazy=False) # Read pickled data eagerly upon accessing the attribute + + Serialized(json, lazy=False) # Read JSON data eagerly upon accessing the attribute - Serialized(pickle.loads, mode='rb', lazy=False) - Serialized(pickle.load, pass_file=True, mode='rb', lazy=False) + Serialized(json, mode='w') # Provides a file wrapper with a `write` method - Serialized(json.loads, lazy=False) - Serialized(json.load, pass_file=True, lazy=False) - Serialized(json.dumps, mode='w') - Serialized(json.dump, pass_file=True, mode='w') +Any module or object that provides an interface similar to the :mod:`python:json` and :mod:`python:pickle` stdlib +modules is accepted as a *serializer*. It must implement a subset of ``load`` / ``dump`` and/or ``loads`` / ``dumps`` +methods or functions. When reading, ``load`` is always used if it is present, and ``loads`` is used as a fallback +option. Similarly for writing, ``dump`` is preferred over ``dumps``. diff --git a/lib/cli_command_parser/inputs/_typing.py b/lib/cli_command_parser/inputs/_typing.py index fe73095c..24e4b87f 100644 --- a/lib/cli_command_parser/inputs/_typing.py +++ b/lib/cli_command_parser/inputs/_typing.py @@ -1,16 +1,13 @@ from __future__ import annotations -from typing import IO, TYPE_CHECKING, Any, Callable, Sequence, TypeAlias, TypeVar, Union +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, AnyStr, BinaryIO, Callable, Literal, Protocol, TextIO, TypeAlias, TypeVar, Union if TYPE_CHECKING: from datetime import date, datetime, time, timedelta from numbers import Number as _Number -Deserializer: TypeAlias = Callable[[str | bytes | IO], Any] -Serializer: TypeAlias = Callable[[Any, IO], None] | Callable[[Any], str | bytes] -Converter: TypeAlias = Deserializer | Serializer - Locale = str | tuple[str | None, str | None] TimeBound: TypeAlias = Union['datetime', 'date', 'time', 'timedelta', None] @@ -18,3 +15,64 @@ Number: TypeAlias = N | None NumType: TypeAlias = Callable[[str | Any], N] RngType: TypeAlias = range | int | Sequence[int] + +# There are many more compatible variations of file open modes, but enumerating every one results in poor readability +# of popup type hints. +OpenBinaryMode: TypeAlias = Literal['rb', 'wb'] +OpenTextMode: TypeAlias = Literal['r', 'w'] +OpenAnyMode: TypeAlias = str + + +# region Minimal IO Protocols + +_T_co = TypeVar('_T_co', covariant=True) +_T_contra = TypeVar('_T_contra', contravariant=True) + + +class SupportsRead(Protocol[_T_co]): + def read(self, length: int = ..., /) -> _T_co: ... + + +class SupportsReadLine(Protocol[_T_co]): + def read(self, length: int = ..., /) -> _T_co: ... + def readline(self) -> _T_co: ... + + +class SupportsWrite(Protocol[_T_contra]): + def write(self, s: _T_contra, /) -> object: ... + + +class SupportsRW(Protocol[AnyStr]): + def read(self, length: int = ..., /) -> AnyStr: ... + def readline(self) -> AnyStr: ... + def write(self, s: AnyStr, /) -> object: ... + + +# endregion + + +# region Serializer Protocols + + +class FileSerializer(Protocol[AnyStr]): + __slots__ = () + + def load(self, fp: SupportsRead[AnyStr] | SupportsReadLine[AnyStr]) -> Any: ... + def dump(self, obj: Any, fp: SupportsWrite[AnyStr]) -> None: ... + + +class AnyStrSerializer(Protocol[AnyStr]): + __slots__ = () + + def loads(self, data: AnyStr) -> Any: ... + def dumps(self, obj: Any) -> AnyStr: ... + + +class FullSerializer(FileSerializer[AnyStr], AnyStrSerializer[AnyStr], Protocol): + __slots__ = () + + +AnySerializer: TypeAlias = FileSerializer[AnyStr] | AnyStrSerializer[AnyStr] | FullSerializer[AnyStr] + + +# endregion diff --git a/lib/cli_command_parser/inputs/files.py b/lib/cli_command_parser/inputs/files.py index bbd39436..c166a7cc 100644 --- a/lib/cli_command_parser/inputs/files.py +++ b/lib/cli_command_parser/inputs/files.py @@ -9,19 +9,29 @@ import os from abc import ABC from pathlib import Path as _Path -from typing import IO, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, AnyStr, Literal, TypeVar, overload from ..typing import T from .base import InputType from .exceptions import InputValidationError -from .utils import FileWrapper, InputParam, StatMode, allows_write, fix_windows_path +from .utils import ( + FileWrapper, + InputParam, + JsonSerializer, + SerializedFileWrapper, + StatMode, + allows_write, + fix_windows_path, +) if TYPE_CHECKING: from ..typing import Bool, OptStr, PathLike - from ._typing import Converter + from ._typing import AnySerializer, OpenAnyMode, OpenBinaryMode, OpenTextMode __all__ = ['Path', 'File', 'Serialized', 'Json', 'Pickle'] +T_co = TypeVar('T_co', covariant=True) + class FileInput(InputType[T], ABC): exists: InputParam[bool | None] = InputParam(None) @@ -67,7 +77,7 @@ def fix_default(self, value: T | str | None) -> T | str | None: """ if value is None or not self._fix_default: return value - return self(value) # type: ignore[arg-type] + return self(value) def validated_path(self, path: PathLike) -> _Path: if not isinstance(path, _Path): @@ -133,7 +143,7 @@ def __call__(self, value: PathLike) -> _Path: return self.validated_path(value) -class File(FileInput[FileWrapper | str | bytes]): +class File(FileInput[T_co]): """ :param mode: The mode in which the file should be opened. For more info, see :func:`python:open`. :param encoding: The encoding to use when reading the file in text mode. Ignored if the parsed path is ``-``. @@ -144,16 +154,118 @@ class File(FileInput[FileWrapper | str | bytes]): :param kwargs: Additional keyword arguments to pass to :class:`.Path`. """ - mode: InputParam[str] = InputParam('r') + mode: InputParam[OpenAnyMode] = InputParam('r') type: InputParam[StatMode] = InputParam(StatMode.FILE) encoding: InputParam[str | None] = InputParam(None) errors: InputParam[str | None] = InputParam(None) lazy: InputParam[bool] = InputParam(True) parents: InputParam[bool] = InputParam(False) + if TYPE_CHECKING: + + @overload + def __init__( + self: File[FileWrapper[str]], + mode: OpenTextMode = 'r', + *, + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Literal[True] = True, + parents: Bool = False, + ): ... + + @overload + def __init__( + self: File[FileWrapper[bytes]], + mode: OpenBinaryMode, + *, + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Literal[True] = True, + parents: Bool = False, + ): ... + + @overload + def __init__( + self: File[str], + mode: OpenTextMode = 'r', + *, + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Literal[False], + parents: Bool = False, + ): ... + + @overload + def __init__( + self: File[bytes], + mode: OpenBinaryMode, + *, + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Literal[False], + parents: Bool = False, + ): ... + + @overload + def __init__( + self, + mode: OpenAnyMode = 'r', + *, + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Bool = True, + parents: Bool = False, + ): ... + def __init__( self, - mode: str = 'r', + mode: OpenAnyMode = 'r', *, encoding: OptStr = None, errors: OptStr = None, @@ -174,85 +286,255 @@ def __init__( self.parents = parents def _prep_file_wrapper(self, path: _Path) -> FileWrapper: - return FileWrapper(path, self.mode, self.encoding, self.errors, parents=self.parents) + return FileWrapper(path, self.mode, encoding=self.encoding, errors=self.errors, parents=self.parents) - def __call__(self, value: PathLike) -> FileWrapper | str | bytes: + def __call__(self, value: PathLike) -> T_co: wrapper = self._prep_file_wrapper(self.validated_path(value)) if self.lazy: - return wrapper + return wrapper # type: ignore[return-value] return wrapper.read() -class Serialized(File): +class Serialized(File[T_co]): """ - :param converter: Function to use to (de)serialize the given file, such as :func:`python:json.loads`, - :func:`python:json.dumps`, :func:`python:pickle.load`, etc. - :param pass_file: For reading, if True, call the converter with the file object, otherwise read the - file first and call the converter with the result. For writing, if True, call the converter with both the - data to be written and the file object, otherwise call the converter with only the data and then write the - result to the file. + :param serializer: Class or module that provides ``load``/``dump`` and/or ``loads``/``dumps`` methods/functions for + deserialization and serialization, respectively. Expects them to follow the same interface as the *json* or + *pickle* modules, with :func:`python:json.loads`, :func:`python:json.dumps`, :func:`python:pickle.load`, etc. :param kwargs: Additional keyword arguments to pass to :class:`.File` """ - converter: InputParam[Converter | None] = InputParam(None) - pass_file: InputParam[bool] = InputParam(False) - - def __init__(self, converter: Converter, *, pass_file: Bool = False, **kwargs): + serializer: AnySerializer + + if TYPE_CHECKING: + + @overload + def __init__( + self: Serialized[SerializedFileWrapper[AnyStr]], + serializer: AnySerializer[AnyStr], + *, + mode: OpenAnyMode = 'r', + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Literal[True] = True, + parents: Bool = False, + ): ... + + @overload + def __init__( + self: Serialized[Any], + serializer: AnySerializer, + *, + mode: OpenAnyMode = 'r', + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Literal[False], + parents: Bool = False, + ): ... + + @overload + def __init__( + self, + serializer: AnySerializer, + *, + mode: OpenAnyMode = 'r', + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Bool = True, + parents: Bool = False, + ): ... + + def __init__(self, serializer: AnySerializer, **kwargs): super().__init__(**kwargs) - self.converter = converter - self.pass_file = pass_file + self.serializer = serializer def __repr__(self) -> str: - non_defaults = ', '.join(f'{k}={v!r}' for k, v in self.__dict__.items() if k != 'converter') - # `converter` must be excluded to prevent infinite recursion when an instance method is stored in that attr + non_defaults = ', '.join(f'{k}={v!r}' for k, v in self.__dict__.items() if k != 'serializer') + # `serializer` must be excluded to prevent infinite recursion when an instance method is stored in that attr return f'<{self.__class__.__name__}({non_defaults})>' - def _prep_file_wrapper(self, path: _Path) -> FileWrapper: - return FileWrapper(path, self.mode, self.encoding, self.errors, self.converter, self.pass_file, self.parents) + def _prep_file_wrapper(self, path: _Path) -> SerializedFileWrapper[AnyStr]: + return SerializedFileWrapper( + path, + self.mode, + serializer=self.serializer, + encoding=self.encoding, + errors=self.errors, + parents=self.parents, + ) -class Json(Serialized): +class Json(Serialized[T_co]): """ :param kwargs: Additional keyword arguments to pass to :class:`.File` """ - def __init__(self, *, mode: str = 'rb', wrap_errors: bool = True, **kwargs): - import json - - write = allows_write(mode, True) - kwargs['pass_file'] = True - super().__init__(json.dump if write else self._load_json, mode=mode, **kwargs) - self.wrap_errors = wrap_errors - - def _load_json(self, f: IO): - from json import JSONDecodeError, load - - try: - return load(f) - except JSONDecodeError as e: - if self.wrap_errors: - if name := getattr(f, 'name', None): - msg = f'json from file={name!r} - are you sure it contains properly formatted json?' - else: - msg = "the provided json content - are you sure it's properly formatted json?" - raise InputValidationError(f'Unable to load {msg} - error: {e}') from e - else: - raise - - -class Pickle(Serialized): + if TYPE_CHECKING: + + @overload + def __init__( + self: Json[SerializedFileWrapper[str]], + *, + mode: OpenTextMode = 'r', + wrap_errors: Bool = True, + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Literal[True] = True, + parents: Bool = False, + ): ... + + @overload + def __init__( + self: Json[Any], + *, + mode: OpenTextMode = 'r', + wrap_errors: Bool = True, + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Literal[False], + parents: Bool = False, + ): ... + + @overload + def __init__( + self, + *, + mode: OpenTextMode = 'r', + wrap_errors: Bool = True, + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Bool = True, + parents: Bool = False, + ): ... + + def __init__(self, *, mode: OpenTextMode = 'r', wrap_errors: Bool = True, **kwargs): + super().__init__(JsonSerializer(wrap_errors), mode=mode, **kwargs) + + +class Pickle(Serialized[T_co]): """ :param kwargs: Additional keyword arguments to pass to :class:`.File` """ - def __init__(self, *, mode: str = 'rb', **kwargs): + if TYPE_CHECKING: + + @overload + def __init__( + self: Pickle[SerializedFileWrapper[bytes]], + *, + mode: OpenBinaryMode = 'rb', + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Literal[True] = True, + parents: Bool = False, + ): ... + + @overload + def __init__( + self: Pickle[Any], + *, + mode: OpenBinaryMode = 'rb', + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Literal[False], + parents: Bool = False, + ): ... + + @overload + def __init__( + self, + *, + mode: OpenBinaryMode = 'rb', + exists: Bool = None, + expand: Bool = True, + resolve: Bool = False, + type: StatMode | str = StatMode.FILE, # noqa + readable: Bool = False, + writable: Bool = False, + allow_dash: Bool = False, + use_windows_fix: Bool = True, + fix_default: Bool = True, + encoding: OptStr = None, + errors: OptStr = None, + lazy: Bool = True, + parents: Bool = False, + ): ... + + def __init__(self, *, mode: OpenBinaryMode = 'rb', **kwargs): import pickle - if 't' in mode: - raise ValueError(f'Invalid {mode=} - pickle does not read/write text') - if 'b' not in mode: - mode += 'b' + if 't' in mode or 'b' not in mode: + raise ValueError(f'Invalid {mode=} - pickle does not read/write text - it requires a binary open mode') - write = allows_write(mode, True) - kwargs['pass_file'] = True - super().__init__(pickle.dump if write else pickle.load, mode=mode, **kwargs) + super().__init__(pickle, mode=mode, **kwargs) diff --git a/lib/cli_command_parser/inputs/utils.py b/lib/cli_command_parser/inputs/utils.py index 2d3cc4fa..450a8dbe 100644 --- a/lib/cli_command_parser/inputs/utils.py +++ b/lib/cli_command_parser/inputs/utils.py @@ -6,30 +6,41 @@ from __future__ import annotations +import json import sys import warnings from contextlib import contextmanager from pathlib import Path from stat import S_IFBLK, S_IFCHR, S_IFDIR, S_IFIFO, S_IFLNK, S_IFMT, S_IFREG, S_IFSOCK -from typing import IO, TYPE_CHECKING, Any, Generic, Iterator, Literal, TypeVar, overload +from typing import IO, TYPE_CHECKING, Any, AnyStr, Generic, Iterator, Literal, TypeVar, overload from weakref import finalize from ..utils import FixedFlag +from ._typing import FileSerializer from .exceptions import InputValidationError if TYPE_CHECKING: from ..typing import Bool, OptStr, Self - from ._typing import Converter, Number + from ._typing import AnySerializer, Number, OpenAnyMode, OpenBinaryMode, OpenTextMode, SupportsRead, SupportsRW -__all__ = ['InputParam', 'StatMode', 'FileWrapper', 'fix_windows_path', 'range_str', 'RangeMixin'] +__all__ = [ + 'InputParam', + 'StatMode', + 'FileWrapper', + 'SerializedFileWrapper', + 'JsonSerializer', + 'fix_windows_path', + 'range_str', + 'RangeMixin', +] -_T = TypeVar('_T') +T = TypeVar('T') -class InputParam(Generic[_T]): +class InputParam(Generic[T]): __slots__ = ('default', 'name') - def __init__(self, default: _T): + def __init__(self, default: T): self.default = default def __set_name__(self, owner, name: str): @@ -39,9 +50,9 @@ def __set_name__(self, owner, name: str): def __get__(self, instance: Literal[None], owner: Any) -> Self: ... @overload - def __get__(self, instance: object, owner: Any) -> _T: ... + def __get__(self, instance: object, owner: Any) -> T: ... - def __get__(self, instance: object, owner: Any) -> Self | _T: + def __get__(self, instance: object, owner: Any) -> Self | T: try: return instance.__dict__[self.name] except AttributeError: # instance is None @@ -49,7 +60,7 @@ def __get__(self, instance: object, owner: Any) -> Self | _T: except KeyError: return self.default - def __set__(self, instance: object, value: _T): + def __set__(self, instance: object, value: T): if value != self.default: instance.__dict__[self.name] = value @@ -103,15 +114,49 @@ def __str__(self) -> str: return ', '.join(names) -class FileWrapper: +class FileWrapper(Generic[AnyStr]): + if TYPE_CHECKING: + + @overload + def __init__( + self: FileWrapper[str], + path: Path, + mode: OpenTextMode = 'r', + *, + encoding: OptStr = None, + errors: OptStr = None, + parents: Bool = False, + ): ... + + @overload + def __init__( + self: FileWrapper[bytes], + path: Path, + mode: OpenBinaryMode, + *, + encoding: OptStr = None, + errors: OptStr = None, + parents: Bool = False, + ): ... + + @overload + def __init__( + self, + path: Path, + mode: OpenAnyMode = 'r', + *, + encoding: OptStr = None, + errors: OptStr = None, + parents: Bool = False, + ): ... + def __init__( self, path: Path, - mode: str = 'r', + mode: OpenAnyMode = 'r', + *, encoding: OptStr = None, errors: OptStr = None, - converter: Converter | None = None, - pass_file: Bool = False, parents: Bool = False, ): self.path = path @@ -119,40 +164,29 @@ def __init__( self.binary = 'b' in mode self.encoding = encoding self.errors = errors - self.converter = converter - self.pass_file = pass_file self.parents = parents self._fp: IO | None = None self._finalizer: finalize | None = None def __eq__(self, other) -> bool: - attrs = ('path', 'mode', 'binary', 'encoding', 'errors', 'converter', 'pass_file', 'parents') + attrs = ('path', 'mode', 'binary', 'encoding', 'errors', 'parents') try: return all(getattr(self, a) == getattr(other, a) for a in attrs) except AttributeError: # not a FileWrapper return NotImplemented - def read(self) -> Any: - with self._file() as f: - if self.converter is not None: - return self.converter(f if self.pass_file else f.read()) # type: ignore[call-arg] - else: - return f.read() + def read(self) -> AnyStr: + with self as f: + return f.read() - def write(self, data: Any): - with self._file() as f: - if self.converter is not None: - if self.pass_file: - self.converter(data, f) # type: ignore[call-arg] - else: - f.write(self.converter(data)) # type: ignore[call-arg] - else: - f.write(data) + def write(self, data: AnyStr) -> None: + with self as f: + f.write(data) - def _open(self) -> IO: + def _open(self) -> SupportsRW[AnyStr]: if self.path == Path('-'): stream = sys.stdin if 'r' in self.mode else sys.stdout - return stream.buffer if self.binary else stream + return stream.buffer if self.binary else stream # type: ignore[return-value] if self.parents and allows_write(self.mode): self.path.parent.mkdir(parents=True, exist_ok=True) @@ -187,20 +221,76 @@ def close(self): if do_close: self._close() + def __enter__(self) -> SupportsRW[AnyStr]: + return self._open() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +class SerializedFileWrapper(FileWrapper[AnyStr]): + serializer: AnySerializer[AnyStr] + + def __init__( + self, + path: Path, + mode: OpenAnyMode, + serializer: AnySerializer[AnyStr], + *, + encoding: OptStr = None, + errors: OptStr = None, + parents: Bool = False, + ): + super().__init__(path, mode, encoding=encoding, errors=errors, parents=parents) + self.serializer = serializer + + def __eq__(self, other) -> bool: + attrs = ('__class__', 'path', 'mode', 'binary', 'encoding', 'errors', 'serializer', 'parents') + return all(getattr(self, a) == getattr(other, a) for a in attrs) + + def __enter__(self) -> Self: # type: ignore[override] + return self + @contextmanager - def _file(self) -> Iterator[IO]: + def _file(self) -> Iterator[SupportsRW[AnyStr]]: try: yield self._open() finally: self.close() - def __enter__(self) -> IO | FileWrapper: - if self.converter is not None: - return self - return self._open() + def read(self) -> Any: + with self._file() as f: + if hasattr(self.serializer, 'load'): + return self.serializer.load(f) + return self.serializer.loads(f.read()) - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() + def write(self, data: Any) -> None: + with self._file() as f: + if hasattr(self.serializer, 'dump'): + self.serializer.dump(data, f) + else: + f.write(self.serializer.dumps(data)) + + +class JsonSerializer(FileSerializer[str]): + __slots__ = ('wrap_errors',) + dump = staticmethod(json.dump) # noqa + + def __init__(self, wrap_errors: Bool = True): + self.wrap_errors = wrap_errors + + def load(self, fp: SupportsRead[str]) -> Any: + if not self.wrap_errors: + return json.load(fp) + + try: + return json.load(fp) + except json.JSONDecodeError as e: + if name := getattr(fp, 'name', None): + msg = f'json from file={name!r} - are you sure it contains properly formatted json?' + else: + msg = "the provided json content - are you sure it's properly formatted json?" + raise InputValidationError(f'Unable to load {msg} - error: {e}') from e def allows_write(mode: str, strict: bool = False) -> bool: diff --git a/tests/test_inputs/test_file_inputs.py b/tests/test_inputs/test_file_inputs.py index 9fb45355..5758caf1 100755 --- a/tests/test_inputs/test_file_inputs.py +++ b/tests/test_inputs/test_file_inputs.py @@ -15,7 +15,7 @@ from cli_command_parser.exceptions import BadArgument from cli_command_parser.inputs import File, Json, Path as PathInput, Pickle, Serialized, StatMode from cli_command_parser.inputs.exceptions import InputValidationError -from cli_command_parser.inputs.utils import FileWrapper, InputParam, fix_windows_path +from cli_command_parser.inputs.utils import FileWrapper, InputParam, SerializedFileWrapper, fix_windows_path from cli_command_parser.testing import ParserTest, RedirectStreams PKG = 'cli_command_parser.inputs' @@ -55,7 +55,7 @@ def fn_mock(return_value) -> Mock: # endregion -class FileInputTest(TestCase): +class FileInputTest(ParserTest): # region Stat Mode def test_invalid_stat_modes(self): @@ -177,10 +177,11 @@ def test_pickle_text_rejected(self): with self.subTest(case=case), self.assertRaises(ValueError): Pickle(mode=case) - def test_pickle_b_added(self): - self.assertEqual('rb', Pickle(mode='r').mode) - self.assertEqual('r+b', Pickle(mode='r+').mode) - self.assertEqual('wb', Pickle(mode='w').mode) + def test_pickle_no_b_rejected(self): + for mode in ('r', 'r+', 'w'): + with self.subTest(mode=mode): + with self.assert_raises_contains_str(ValueError, 'it requires a binary open mode'): + Pickle(mode=mode) # endregion @@ -336,7 +337,7 @@ def test_json_write_with(self): def test_serialized_json_write_with(self): with temp_path('a') as a: - with Serialized(json.dumps, mode='w')(a.as_posix()) as j: + with Serialized(json, mode='w')(a.as_posix()) as j: j.write({'a': 1}) self.assertEqual('{"a": 1}', a.read_text()) @@ -351,12 +352,13 @@ def test_plain_read_with(self): def test_serialized_pickle_read(self): with temp_path('a') as a: a.write_bytes(pickle.dumps({'a': 1})) - self.assertEqual({'a': 1}, Serialized(pickle.loads, mode='rb', lazy=False)(a.as_posix())) + self.assertEqual({'a': 1}, Serialized(pickle, mode='rb', lazy=False)(a.as_posix())) def test_serialized_pickle_read_with(self): with temp_path('a') as a: a.write_bytes(pickle.dumps({'a': 1})) - with Serialized(pickle.loads, mode='rb')(a.as_posix()) as f: + with Serialized(pickle, mode='rb')(a.as_posix()) as f: + self.assertIsInstance(f, SerializedFileWrapper) self.assertEqual({'a': 1}, f.read()) def test_pickle_read(self):