diff --git a/lib/cli_command_parser/context.py b/lib/cli_command_parser/context.py index f728eedc..794b5d94 100644 --- a/lib/cli_command_parser/context.py +++ b/lib/cli_command_parser/context.py @@ -24,8 +24,9 @@ if TYPE_CHECKING: from .command_parameters import CommandParameters from .commands import Command + from .core import CommandMeta from .parameters import ActionFlag, BaseOption, Parameter - from .typing import AnyConfig, Bool, CommandObj, CommandType, OptStr, ParamOrGroup, PathLike, StrSeq # noqa + from .typing import AnyConfig, Bool, CommandCls, CommandObj, OptStr, ParamOrGroup, PathLike, StrSeq Argv = StrSeq | None @@ -53,7 +54,7 @@ class Context(AbstractContextManager): # Extending AbstractContextManager to ma def __init__( self, argv: Argv = None, - command_cls: CommandType | None = None, + command_cls: CommandCls | None = None, *, parent: Context | None = None, config: AnyConfig | None = None, @@ -100,7 +101,7 @@ def _set_argv(self, prog: OptStr, argv: Argv): self.remaining = list(self.argv) def _sub_context( - self, command_cls: CommandType, argv: Argv = None, command: CommandObj | None = None, **kwargs + self, command_cls: CommandCls, argv: Argv = None, command: CommandObj | None = None, **kwargs ) -> Context: return self.__class__( self.remaining if argv is None else argv, @@ -322,7 +323,7 @@ def iter_action_flags(self, phase: ActionPhase) -> Iterator[ActionFlag]: def _normalize_config( - config: AnyConfig, kwargs: dict[str, Any], parent: Context | None, command: CommandType | None + config: AnyConfig, kwargs: dict[str, Any], parent: Context | None, command: CommandMeta | None ) -> CommandConfig: if config is not None: if kwargs: @@ -443,7 +444,7 @@ def get_current_context(silent: bool = False) -> Context | None: def get_or_create_context( - command_cls: CommandType, argv: Argv = None, *, command: CommandObj | None = None, **kwargs + command_cls: CommandCls, argv: Argv = None, *, command: CommandObj | None = None, **kwargs ) -> Context: """ Used internally by Commands to re-use an existing user-activated Context, or to create a new Context if there was diff --git a/lib/cli_command_parser/conversion/argparse_ast.py b/lib/cli_command_parser/conversion/argparse_ast.py index 8128dcbe..3c448ad6 100644 --- a/lib/cli_command_parser/conversion/argparse_ast.py +++ b/lib/cli_command_parser/conversion/argparse_ast.py @@ -3,17 +3,26 @@ import ast import logging import sys +from abc import ABC from argparse import ArgumentParser from ast import AST, Assign, Call, withitem from functools import cached_property, partial from inspect import BoundArguments, Signature from pathlib import Path -from typing import TYPE_CHECKING, Callable, Collection, Generic, Iterator, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterator, Literal, Type, TypeVar, overload +try: + from typing import Self +except ImportError: # added in 3.11 + Self = TypeVar('Self') # type: ignore[misc,assignment] + +from ..utils import _NotSet, _NotSetType from .argparse_utils import ArgumentParser as _ArgumentParser, SubParsersAction as _SubParsersAction from .utils import get_name_repr, iter_module_parents, unparse if TYPE_CHECKING: + from collections.abc import Collection + from cli_command_parser.typing import PathLike from .visitor import TrackedRef, TrackedRefMap @@ -25,19 +34,19 @@ OptCall = Call | None ParserCls = Type['AstArgumentParser'] ParserObj = TypeVar('ParserObj', bound='AstArgumentParser') -RepresentedCallable = TypeVar('RepresentedCallable', bound=Callable) +RepresentedCallable = Callable AC = TypeVar('AC', bound='AstCallable') D = TypeVar('D') -_NotSet = object() +VisitFunc = Callable[[InitNode, OptCall, 'TrackedRefMap'], AC] class Script: - _parser_classes = {} + mod_cls_to_ast_cls_map: dict[str, dict[str, ParserCls]] = {} path: Path | None def __init__(self, src_text: str, smart_loop_handling: bool = True, path: PathLike | None = None): self.smart_loop_handling = smart_loop_handling - self._parsers = [] + self._parsers: list[AstArgumentParser] = [] self.path = Path(path) if path else None self.src_text = src_text parse_args = (self.src_text, self.path.as_posix()) if self.path else (self.src_text,) @@ -48,31 +57,36 @@ def __repr__(self) -> str: location = f' @ {self.path.as_posix()}' if self.path else '' return f'<{self.__class__.__name__}[{parsers=}{location}]>' - @property - def mod_cls_to_ast_cls_map(self) -> dict[str, dict[str, ParserCls]]: - return self._parser_classes - @classmethod def _register_parser(cls, module: str, name: str, ast_cls: ParserCls): # Identify package-level exports that may have been defined for a custom ArgumentParser subclass modules = [module, *(parent for parent in iter_module_parents(module) if name in vars(sys.modules[parent]))] for module in modules: log.debug(f'Registering {module}.{name} -> {ast_cls}') - cls._parser_classes.setdefault(module, {})[name] = ast_cls + cls.mod_cls_to_ast_cls_map.setdefault(module, {})[name] = ast_cls @classmethod - def register_parser(cls, ast_cls: ParserCls): + def register_parser(cls, ast_cls: ParserCls) -> ParserCls: + """ + Register an AstArgumentParser class for tracking references to an :class:`argparse.ArgumentParser` or subclass + thereof. May be used as a decorator. + + :param ast_cls: :class:`AstArgumentParser` or a subclass thereof + :return: The decorated class, unmodified + """ real_cls = ast_cls.represents - cls._register_parser(real_cls.__module__, real_cls.__name__, ast_cls) + cls._register_parser(real_cls.__module__, real_cls.__name__, ast_cls) # type: ignore[union-attr] return ast_cls - def add_parser(self, ast_cls: ParserCls, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap) -> ParserObj: + def add_parser( + self, ast_cls: Type[ParserObj], node: InitNode, call: OptCall, tracked_refs: TrackedRefMap + ) -> ParserObj: parser = ast_cls(node, self, tracked_refs, call) self._parsers.append(parser) return parser @cached_property - def parsers(self) -> list[ParserObj]: + def parsers(self) -> list[AstArgumentParser]: from .visitor import ScriptVisitor, TrackedRef # noqa: F811 track_refs = (TrackedRef('argparse.REMAINDER'), TrackedRef('argparse.SUPPRESS')) @@ -88,8 +102,8 @@ def parsers(self) -> list[ParserObj]: # region Decorators & Descriptors -class visit_func: - """A method that can be called by an AST visitor.""" +class visit_func: # noqa + """Decorator for AstCallable methods that can be called by an AST visitor.""" __slots__ = ('func',) @@ -101,7 +115,9 @@ def __set_name__(self, owner: Type[AstCallable], name: str): setattr(owner, name, self.func) # There's no need to keep the descriptor - replace self with func def __get__(self, instance, owner): - # This will never actually be called, but it makes PyCharm happy + # This is only ever called in contrived circumstances - because __set_name__ replaces this object with the + # decorated method, this __get__ method is never called when accessing the decorated method. + # Without this __get__ method, though, PyCharm doesn't understand that the decorated methods are still callable. return self if instance is None else partial(self.func, instance) @@ -111,35 +127,56 @@ class AddVisitedChild(Generic[AC]): __slots__ = ('child_cls', 'list_attr') def __init__(self, child_cls: Type[AC], attr: str): + """ + :param child_cls: The child class that should be used when adding a child entry to the parent AstCallable + instance in which this descriptor is an attribute. + :param attr: The name of the parent instance's attribute used to store a list of children. + """ self.child_cls = child_cls self.list_attr = attr def __set_name__(self, owner: Type[ArgCollection], name: str): owner._add_visit_func(name) - def __get__(self, instance: ArgCollection, owner) -> Callable[[InitNode, Call, TrackedRefMap], AC]: + @overload + def __get__(self, instance: Literal[None], owner: Any) -> Self: ... + + @overload + def __get__(self, instance: ArgCollection, owner: Any) -> VisitFunc: ... + + def __get__(self, instance: ArgCollection | None, owner: Any) -> Self | VisitFunc: if instance is None: - return self # noqa - return partial(instance._add_child, self.child_cls, getattr(instance, self.list_attr)) # noqa + return self + return partial(instance._add_child, self.child_cls, getattr(instance, self.list_attr)) # endregion -class AstCallable: - represents: RepresentedCallable - visit_funcs = set() +class AstCallable(ABC): + """ + Base class for classes that act as stand-ins for real classes, for tracking instances of those classes and methods + that were called on those instances while visiting AST nodes. + + Methods that should be tracked / should be called while visiting AST nodes must be registered with the + :class:`visit_func` decorator. + """ + + represents: ClassVar[RepresentedCallable] + visit_funcs: set[str] = set() _sig: Signature | None = None @classmethod def _add_visit_func(cls, name: str) -> bool: + """Register that this class has a method with the provided *name* that is a visitable function.""" try: - parent_visit_funcs = cls.__base__.visit_funcs # noqa + parent_visit_funcs = cls.__base__.visit_funcs # type: ignore[union-attr] except AttributeError: pass else: # Note: __init_subclass__ is called after __set_name__ is called for members if parent_visit_funcs is cls.visit_funcs: cls.visit_funcs = cls.visit_funcs.copy() + cls.visit_funcs.add(name) return True @@ -148,23 +185,24 @@ def __init_subclass__(cls, represents: RepresentedCallable | None = None, **kwar if represents: cls.represents = represents cls._sig = None + elif ABC not in cls.__bases__: + raise NotImplementedError(f'Missing required "represents" class param for {cls.__name__}') - def __init__( - self, node: InitNode, parent: AstCallable | Script, tracked_refs: TrackedRefMap, call: Call | None = None - ): + def __init__(self, node: InitNode, parent: AstCallable | Script, tracked_refs: TrackedRefMap, call: OptCall = None): self.init_node = node - if not call: - call = node.value if isinstance(node, Assign) else node # type: Call + self._init_call(call if call else node.value if isinstance(node, Assign) else node) # type: ignore[arg-type] + self._tracked_refs = tracked_refs + self.parent = parent + + def _init_call(self, call: Call): self.call_node = call self.call_args = call.args self.call_kwargs = call.keywords - self._tracked_refs = tracked_refs - self.parent = parent def __repr__(self) -> str: return f'<{self.__class__.__name__}[{self.init_call_repr()}]>' - def get_tracked_refs(self, module: str, name: str, default: D = _NotSet) -> set[str] | D: + def get_tracked_refs(self, module: str, name: str, default: D | _NotSetType = _NotSet) -> set[str] | D: for tracked_ref, refs in self._tracked_refs.items(): if tracked_ref.module == module and tracked_ref.name == name: return refs @@ -192,14 +230,14 @@ def init_func_name(self) -> str: @cached_property def _init_func_bound(self) -> BoundArguments: args = self.call_args if isinstance(self.represents, type) else ('self', *self.call_args) - return self.signature.bind(*args, **{kw.arg: kw.value for kw in self.call_kwargs}) + return self.signature.bind(*args, **{kw.arg: kw.value for kw in self.call_kwargs if kw.arg is not None}) @cached_property def init_func_args(self) -> list[str]: try: args = self._init_func_bound.args[1:] except (TypeError, AttributeError): # No represents func - args = self.call_args + args = self.call_args # type: ignore[assignment] return [unparse(arg) for arg in args] @cached_property @@ -207,7 +245,7 @@ def init_func_raw_kwargs(self) -> dict[str, AST]: try: kwargs = self._init_func_bound.arguments except (TypeError, AttributeError): # No represents func - return {kw.arg: kw.value for kw in self.call_kwargs} + return {kw.arg: kw.value for kw in self.call_kwargs if kw.arg is not None} else: kwargs = kwargs.copy() kwargs.pop('self', None) @@ -243,21 +281,19 @@ class ParserArg(AstCallable, represents=ArgumentParser.add_argument): parent: ArgCollection -class ArgCollection(AstCallable): +class ArgCollection(AstCallable, ABC): parent: ArgCollection | Script - _children = ('args', 'groups') + _children: tuple[str, ...] = ('args', 'groups') args: list[ParserArg] - groups: list[ArgGroup] - add_argument = AddVisitedChild(ParserArg, 'args') + groups: list[ArgGroup | MutuallyExclusiveGroup] + add_argument: AddVisitedChild[ParserArg] = AddVisitedChild(ParserArg, 'args') def __init_subclass__(cls, children: Collection[str] = (), **kwargs): super().__init_subclass__(**kwargs) if children: cls._children = (*cls._children, *children) - def __init__( - self, node: InitNode, parent: AstCallable | Script, tracked_refs: TrackedRefMap, call: Call | None = None - ): + def __init__(self, node: InitNode, parent: AstCallable | Script, tracked_refs: TrackedRefMap, call: OptCall = None): super().__init__(node, parent, tracked_refs, call) self.args = [] self.groups = [] @@ -265,22 +301,26 @@ def __init__( def __repr__(self) -> str: return f'<{self.__class__.__name__}: ``{self.init_call_repr()}``>' - def _add_child(self, cls: Type[AC], container: list[AC], node: InitNode, call: Call, refs: TrackedRefMap) -> AC: + def _add_child(self, cls: Type[AC], container: list[AC], node: InitNode, call: OptCall, refs: TrackedRefMap) -> AC: child = cls(node, self, refs, call) container.append(child) return child @visit_func - def add_mutually_exclusive_group(self, node: InitNode, call: Call, tracked_refs: TrackedRefMap): - return self._add_child(MutuallyExclusiveGroup, self.groups, node, call, tracked_refs) + def add_mutually_exclusive_group( + self, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap + ) -> MutuallyExclusiveGroup: + return self._add_child( # type: ignore[return-value] + MutuallyExclusiveGroup, self.groups, node, call, tracked_refs + ) @visit_func - def add_argument_group(self, node: InitNode, call: Call, tracked_refs: TrackedRefMap): + def add_argument_group(self, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap) -> ArgGroup: return self._add_child(ArgGroup, self.groups, node, call, tracked_refs) def grouped_children(self) -> Iterator[tuple[Type[AC], list[AC]]]: - yield ParserArg, self.args - yield ArgGroup, self.groups + yield ParserArg, self.args # type: ignore[misc] + yield ArgGroup, self.groups # type: ignore[misc] # region Output Methods @@ -296,18 +336,24 @@ def pprint(self, indent: int = 0): class ArgGroup(ArgCollection, represents=_ArgumentParser.add_argument_group): - pass + """A group containing zero or more arguments or other argument groups""" class MutuallyExclusiveGroup(ArgGroup, represents=_ArgumentParser.add_mutually_exclusive_group): - pass + """A mutually exclusive argument group""" class SubparsersAction(AstCallable, represents=_ArgumentParser.add_subparsers): - parent: ParserObj + """ + Represents a subparsers action obtained by calling ``parser.add_subparsers()``. Contrary to the way the represented + class behaves, when :meth:`.add_parser` is called, the subparser is stored directly on the parent parser rather + than within this instance. + """ + + parent: AstArgumentParser @visit_func - def add_parser(self, node: InitNode, call: Call, tracked_refs: TrackedRefMap): + def add_parser(self, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap): sub_parser = self.parent._add_subparser(node, call, tracked_refs) sub_parser.sp_parent = self return sub_parser @@ -315,14 +361,13 @@ def add_parser(self, node: InitNode, call: Call, tracked_refs: TrackedRefMap): @Script.register_parser class AstArgumentParser(ArgCollection, represents=ArgumentParser, children=('sub_parsers',)): + parent: Script | AstArgumentParser sub_parsers: list[SubParser] - add_subparsers = AddVisitedChild(SubparsersAction, '_subparsers_actions') + add_subparsers: AddVisitedChild[SubparsersAction] = AddVisitedChild(SubparsersAction, '_subparsers_actions') - def __init__( - self, node: InitNode, parent: AstCallable | Script, tracked_refs: TrackedRefMap, call: Call | None = None - ): + def __init__(self, node: InitNode, parent: Script | ParserObj, tracked_refs: TrackedRefMap, call: OptCall = None): super().__init__(node, parent, tracked_refs, call) - self._subparsers_actions = [] + self._subparsers_actions: list[SubparsersAction] = [] # Note: sub_parsers aren't included in grouped_children since they need different handling during conversion self.sub_parsers = [] @@ -330,14 +375,28 @@ def __repr__(self) -> str: sub_parsers = len(self.sub_parsers) return f'<{self.__class__.__name__}[{sub_parsers=}]: ``{self.init_call_repr()}``>' + @overload + def _add_subparser( + self, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap, sub_parser_cls: Literal[None] = None + ) -> SubParser: ... + + @overload + def _add_subparser( + self, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap, sub_parser_cls: Type[ParserObj] + ) -> ParserObj: ... + def _add_subparser( - self, node: InitNode, call: Call, tracked_refs: TrackedRefMap, sub_parser_cls: ParserCls | None = None - ): + self, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap, sub_parser_cls: Type[ParserObj] | None = None + ) -> SubParser | ParserObj: + """Add a subparser to this parser. Only meant to be called by :class:`SubparsersAction`.""" # Using default of None since the class hasn't been defined at the time it would need to be set as default - return self._add_child(sub_parser_cls or SubParser, self.sub_parsers, node, call, tracked_refs) + return self._add_child( # type: ignore[misc] + sub_parser_cls or SubParser, self.sub_parsers, node, call, tracked_refs + ) class SubParser(AstArgumentParser, represents=_SubParsersAction.add_parser): + parent: AstArgumentParser | SubParser sp_parent: SubparsersAction @cached_property diff --git a/lib/cli_command_parser/conversion/cli.py b/lib/cli_command_parser/conversion/cli.py index f226439c..0524e7b2 100644 --- a/lib/cli_command_parser/conversion/cli.py +++ b/lib/cli_command_parser/conversion/cli.py @@ -10,10 +10,11 @@ log = logging.getLogger(__name__) arg_parser = 'argparse.ArgumentParser' -cli_cp_cmd = 'cli-command-parser Command' -class ParserConverter(Command, description=f'Tool to convert an {arg_parser} into a {cli_cp_cmd}'): +class ParserConverter(Command): + """Tool to convert an argparse.ArgumentParser into a cli-command-parser Command""" + action = SubCommand() input: Path no_smart_for = Flag('-S', help='Disable "smart" for loop handling that attempts to dedupe common subparser params') @@ -35,6 +36,8 @@ def script(self): class Convert(ParserConverter): + """Print the cli-command-parser Commands that are equivalent to the discovered argparse ArgumentParsers""" + input: Path = Positional(type=IPath(type='file', exists=True), help=f'A file containing an {arg_parser}') add_methods = Flag('--no-methods', '-M', default=True, help='Do not include boilerplate methods in Commands') @@ -45,6 +48,8 @@ def main(self): class Pprint(ParserConverter): + """Print a tiered internal representation of the discovered argparse ArgumentParsers and their groups/arguments""" + input: Path = Positional(type=IPath(type='file', exists=True), help=f'A file containing an {arg_parser}') def main(self): diff --git a/lib/cli_command_parser/conversion/utils.py b/lib/cli_command_parser/conversion/utils.py index 277b5665..057fcf94 100644 --- a/lib/cli_command_parser/conversion/utils.py +++ b/lib/cli_command_parser/conversion/utils.py @@ -8,19 +8,28 @@ def get_name_repr(node: AST | expr) -> str: if isinstance(node, Call): + # Call nodes include the arguments passed to the func/callable being called - we want the name of the callable node = node.func - if isinstance(node, Name): - return node.id - elif isinstance(node, Attribute): - return f'{get_name_repr(node.value)}.{node.attr}' # noqa - elif isinstance(node, AST): - return unparse(node) - else: - raise TypeError(f'Only AST nodes are supported - found {node.__class__.__name__}') + match node: + case Name(): + # Name = a variable + return node.id # the name of the variable + case Attribute(): + # `foo.bar.baz` -> Attribute(value=Attribute(value=Name(id='foo'), attr='bar'), attr='baz') + return f'{get_name_repr(node.value)}.{node.attr}' + case AST(): + return unparse(node) # returns the original source code for the provided AST object + case _: + raise TypeError(f'Only AST nodes are supported - found {type(node).__name__}') def iter_module_parents(module: str) -> Iterator[str]: + """ + Given a nested module name, yields parent package names in ascending order. + + I.e., given ``foo.bar.baz``, this function will yield ``foo.bar`` and then ``foo``. + """ while True: parent = module.rsplit('.', 1)[0] if parent == module: @@ -30,9 +39,20 @@ def iter_module_parents(module: str) -> Iterator[str]: def collection_contents(node: AST) -> list[str]: - if isinstance(node, Dict): - return [unparse(key) for key in node.keys] # type: ignore[arg-type] - elif isinstance(node, (List, Set, Tuple)): - return [unparse(ele) for ele in node.elts] # noqa - else: - raise TypeError(f'Unexpected AST node type={node.__class__.__name__}') + """ + Returns a list of individually unparsed (original source code strings) elements that would be processed when + iterating over the specified node. + + Silently ignores any dictionaries that were expanded within a dict literal. + + :param node: An AST node representing a dict/list/set/tuple literal. + :return: List of elements as strings of source code. + """ + match node: + case Dict(): + # Dict expansion like `{'a': 1, **some_mapping}` results in key=None for each expanded mapping + return [unparse(key) for key in node.keys if key is not None] + case List() | Set() | Tuple(): + return [unparse(ele) for ele in node.elts] + case _: + raise TypeError(f'Unexpected AST node type={node.__class__.__name__}') diff --git a/lib/cli_command_parser/conversion/visitor.py b/lib/cli_command_parser/conversion/visitor.py index a831c104..bfc15fde 100644 --- a/lib/cli_command_parser/conversion/visitor.py +++ b/lib/cli_command_parser/conversion/visitor.py @@ -2,21 +2,35 @@ import logging import re -from ast import AST, Assign, Attribute, Call, For, Import, ImportFrom, Name, NodeVisitor, expr +from ast import AST, Assign, Attribute, Call, For, Import, ImportFrom, Name, NodeVisitor, withitem from collections import ChainMap, defaultdict +from enum import Enum from functools import partial, wraps -from typing import TYPE_CHECKING, Callable, Collection, Iterator +from typing import TYPE_CHECKING, Callable, Collection, Iterator, Literal, Union, overload -from .argparse_ast import AstArgumentParser +from .argparse_ast import AstArgumentParser, AstCallable, VisitFunc from .utils import get_name_repr if TYPE_CHECKING: TrackedRefMap = dict['TrackedRef', set[str]] + NameTrackedMap = dict[str, Union[Callable, 'TrackedRef']] + TrackedValue = Union['TrackedRef', VisitFunc, AstCallable] + RefName = str | AST __all__ = ['ScriptVisitor', 'TrackedRef'] log = logging.getLogger(__name__) -_NoCall = object() + +class _NoCallType(Enum): + """Provides the sentinel value for _NoCall in a way that is fully compatible with type checkers.""" + + _NoCall = '_NoCall' + + def __str__(self) -> str: + return self.name + + +_NoCall = _NoCallType._NoCall def scoped(func): @@ -39,6 +53,8 @@ def __get__(self, instance: ScriptVisitor, owner): class ScriptVisitor(NodeVisitor): + scopes: ChainMap[str, TrackedValue] + visit_FunctionDef = visit_AsyncFunctionDef = ScopedVisit() visit_Lambda = ScopedVisit() visit_ClassDef = ScopedVisit() @@ -47,8 +63,8 @@ class ScriptVisitor(NodeVisitor): def __init__(self, smart_loop_handling: bool = True, track_refs: Collection[TrackedRef] = ()): self.smart_loop_handling = smart_loop_handling self.scopes = ChainMap() # ChainMap that tracks the var/class/func/etc names available in a given scope - self._tracked_refs = set() - self._mod_name_tracked_map = defaultdict(dict) + self._tracked_refs: set[TrackedRef] = set() # References that are tracked, but not meant to be called + self._mod_name_tracked_map: dict[str, NameTrackedMap] = defaultdict(dict) # All tracked items by source module for ref in track_refs: self.track_refs_to(ref) @@ -56,6 +72,7 @@ def track_callable(self, module: str, name: str, cb: Callable): self._mod_name_tracked_map[module][name] = cb def track_refs_to(self, ref: TrackedRef): + """Register a reference that should be tracked.""" self._tracked_refs.add(ref) self._mod_name_tracked_map[ref.module][ref.name] = ref @@ -64,14 +81,21 @@ def get_tracked_refs(self) -> TrackedRefMap: for key, val in self.scopes.items(): if val in self._tracked_refs: tracked_refs[val].add(key) - tracked_refs.default_factory = None - return tracked_refs + + tracked_refs.default_factory = None # disable creation of new sets on future key misses + return tracked_refs # type: ignore[return-value] # region Imports def visit_Import(self, node: Import): + """ + Processes a ``import some_module`` or ``import some_module as other_name`` statement. + + :param node: The AST object representing the import statement. + """ for module_name, as_name in imp_names(node): if name_tracked_map := self._mod_name_tracked_map.get(module_name): + # One or more items in the specified module were registered to be tracked log.debug(f'Found module import: {module_name} as {as_name}') for name, tracked in name_tracked_map.items(): self.scopes[f'{as_name}.{name}'] = tracked @@ -87,18 +111,31 @@ def visit_ImportFrom(self, node: ImportFrom): for the source file or to resolve what the relative import's fully qualified module name would be. This may result in incorrect items being tracked if the name matched a tracked name in the matched tracked module. """ + # `from foo.bar import x,y,z` -> level = 0 - absolute import, so node.module = 'foo.bar' + # `from . import x,y,z` -> level = 1 - relative import with no module, so node.module = None + # `from ..foo.bar import x,y,z` -> level 2 - relative import with a module, so node.module = 'foo.bar' + if not node.module: + # Only relative imports (of any level) with no specific module name result in node.module = None + return # Tracking relative imports with no explicit module name is not supported + if level := node.level: - # For absolute imports, the level is 0 - # For relative imports, the level is a positive int, representing the number of relative package levels + # It's a relative import - attempt to fuzzily match tracked modules with the same depth that end with the + # specified module name matches = re.compile(r'[^.]+\.' * level + re.escape(node.module) + '$').search for module, name_tracked_map in self._mod_name_tracked_map.items(): if matches(module): log.debug(f'Found fuzzy relative name match for {"." * level + node.module!r} to {module=}') self._maybe_track_import_from(node, name_tracked_map) - elif name_tracked_map := self._mod_name_tracked_map.get(node.module): + elif name_tracked_map := self._mod_name_tracked_map.get(node.module): # type: ignore[assignment] + # It's an absolute import and the module name matches an item that is being tracked self._maybe_track_import_from(node, name_tracked_map) - def _maybe_track_import_from(self, node: ImportFrom, name_tracked_map): + def _maybe_track_import_from(self, node: ImportFrom, name_tracked_map: NameTrackedMap): + """ + If any of the items imported from the specified module match an item that is being tracked, add the local name + that is being used for it (the ``as other_name`` value, if specified, otherwise the original name) to the + current scope. + """ for name, as_name in imp_names(node): if tracked := name_tracked_map.get(name): log.debug(f'Found tracked import: {node.module}.{name} as {as_name}') @@ -110,11 +147,17 @@ def _maybe_track_import_from(self, node: ImportFrom, name_tracked_map): @scoped def visit_For(self, node: For): + """Visit a *for* loop.""" + # Given the loop, `for x in y:`, `For.target` -> `x` (the loop variable), and `For.iter` -> `y` (the iterable) if isinstance(node.target, Name): + # `For.target` -> the loop variable(s); if the target is a Name, then it is a single variable. + # For loops with multiple loop variables result in the target being a tuple of Names. try: - ele_names = [get_name_repr(ele) for ele in node.iter.elts] # noqa + # When the iterable is an in-line collection literal such as a tuple, `node.iter` will be the AST + # object representing that collection literal. The `elts` attr of Tuple/List/Set contains its elements. + ele_names = [get_name_repr(ele) for ele in node.iter.elts] # type: ignore[attr-defined] except (AttributeError, TypeError): - ele_names = () + ele_names = [] if ele_names and self.smart_loop_handling: self._visit_for_smart(node, node.target.id, ele_names) @@ -126,17 +169,36 @@ def visit_For(self, node: For): visit_AsyncFor = visit_For def _visit_for_smart(self, node: For, loop_var: str, ele_names: list[str]): + """ + Processes *for* loops that iterate over tracked parser objects for the purpose of registering common + arguments, etc. If not all of the identified elements are parsers, then this method falls back to the + generic :meth:`._visit_for_elements` handler instead. + + :param node: An AST node representing a *for* loop over an in-line tuple/list/set literal. + :param loop_var: Given ``for x in y:``, ``x``. + :param ele_names: The names of the items in the tuple/list/set literal. + """ log.debug(f'Attempting smart for loop visit for {loop_var=} in {ele_names=}') - refs = [ref for ref in (self.scopes.get(name) for name in ele_names) if ref] + refs: list[AstArgumentParser] = [ + ref # type: ignore[misc] # mypy doesn't seem to recognize the isinstance part of the condition + for name in ele_names + if (ref := self.scopes.get(name)) and isinstance(ref, AstArgumentParser) + ] # log.debug(f' > Found {len(refs)=}, {len(ele_names)=}') - if len(refs) == len(ele_names) and all(isinstance(ref, AstArgumentParser) for ref in refs): + if len(refs) == len(ele_names): # ele_names is confirmed to be non-empty before this method is called + # All elements are AstArgumentParser or SubParser (or subclasses thereof) objects parents = set(ref.parent for ref in refs) log.debug(f' > Found parents={len(parents)}') if len(parents) == 1: + # They all have the same parent parser or script parent = next(iter(parents)) if parent and set(getattr(parent, 'sub_parsers', ())) == set(refs): - self.scopes[loop_var] = parent + # They are all subparsers with the same parent parser, and the parent parser does not have any + # other subparsers that are not in scope for this loop. + # Pretend the parent is the target - ignore the subparsers when evaluating the loop, and add the + # common items to the parent parser. + self.scopes[loop_var] = parent # type: ignore[assignment] self.generic_visit(node) return @@ -156,55 +218,101 @@ def _visit_for_elements(self, node: For, loop_var: str, ele_names: list[str]): # endregion - def resolve_ref(self, name: str | AST | Attribute | Name | expr): - if isinstance(name, Attribute) and isinstance(name.value, Call): - obj = self.visit_Call(name.value) - attr = name.attr - else: - if not isinstance(name, str): - name = get_name_repr(name) - try: - return self.scopes[name] - except KeyError: - pass - try: - obj_name, attr = name.rsplit('.', 1) - obj = self.scopes[obj_name] - except (ValueError, KeyError): + # region Resolve Tracked References + + @overload + def resolve_ref(self, name: RefName, only_visitable: Literal[False] = False) -> VisitFunc | AstCallable | None: ... + + @overload + def resolve_ref(self, name: RefName, only_visitable: Literal[True]) -> VisitFunc | None: ... + + def resolve_ref(self, name: RefName, only_visitable: bool = False) -> VisitFunc | AstCallable | None: + """ + Resolve the given reference to a tracked item in the current scope. + + :param name: The name of a reference or an AST node that may contain the name of a reference + :param only_visitable: If True, then only return visit functions or None, otherwise include AstCallable objects + :return: The resolved reference + """ + obj, attr = self._resolve(name) + match obj: + case AstCallable(): + if attr: + return getattr(obj, attr) if attr in obj.visit_funcs else None + return None if only_visitable else obj + case None | TrackedRef(): return None + case _: + return obj if attr is None else None + + def _resolve(self, name: RefName) -> tuple[TrackedValue | None, str | None]: + """Resolves the given reference, but does not handle final attr lookup or type checking.""" + obj: TrackedRef | VisitFunc | AstCallable | None | _NoCallType + if isinstance(name, Attribute) and isinstance(name.value, Call): + if (obj := self.visit_Call(name.value)) is _NoCall: + return None, None + return obj, name.attr + + if not isinstance(name, str): + name = get_name_repr(name) + + if obj := self.scopes.get(name): + return obj, None try: - can_call = attr in obj.visit_funcs - except (AttributeError, TypeError): - return None - return getattr(obj, attr) if can_call else None + obj_name, attr = name.rsplit('.', 1) + obj = self.scopes[obj_name] + except (ValueError, KeyError): + return None, None + + return obj, attr - def visit_withitem(self, item): + # endregion + + def visit_withitem(self, item: withitem): + """ + Visit a single ``withitem`` / context expression within a ``with ...:`` statement that may include one or more + ``withitem``s / content expressions. + """ context_expr = item.context_expr - if func := self.resolve_ref(context_expr): + if func := self.resolve_ref(context_expr, True): + # Found a ``with foo(...):`` statement where *foo* is being tracked call = context_expr if isinstance(context_expr, Call) else None result = func(item, call, self.get_tracked_refs()) if as_name := item.optional_vars: self.scopes[get_name_repr(as_name)] = result def visit_Assign(self, node: Assign): - value = node.value - if isinstance(value, (Attribute, Name)): # Assigning an alias to a variable - if ref := self.resolve_ref(value): - for target in node.targets: - self.scopes[get_name_repr(target)] = ref - elif isinstance(value, Call): - if (result := self.visit_Call(value)) is not _NoCall: - for target in node.targets: - self.scopes[get_name_repr(target)] = result # noq - - def visit_Call(self, node: Call): - if func := self.resolve_ref(node.func): + """ + Visit an assignment statement where one or more variables (stored in ``Assign.targets``) are being assigned one + or more values (stored in ``Assign.value``). + """ + match node.value: + case Attribute() | Name(): + # Assigning an alias to a variable; e.g., `foo = bar` or `foo = bar.baz` + if ref := self.resolve_ref(node.value): + # The value was singular and referenced something being tracked + for target in node.targets: + self.scopes[get_name_repr(target)] = ref + # Not handled here: cases like `a = (1, 2); x, y = a` or `x, y = a, b` + case Call(): + # Storing the result of a function/similar call; e.g., `foo = bar()` or `foo = bar.baz()` + if (result := self.visit_Call(node.value)) is not _NoCall: + for target in node.targets: + self.scopes[get_name_repr(target)] = result + + def visit_Call(self, node: Call) -> AstCallable | _NoCallType: + if func := self.resolve_ref(node.func, True): return func(node, node, self.get_tracked_refs()) return _NoCall class TrackedRef: + """ + Represents any class/function/object/variable/etc. that may be imported from a specific module with a specific + name in that module. Used to track references to that item. + """ + __slots__ = ('module', 'name') def __init__(self, name: str): @@ -216,8 +324,8 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash(self.__class__) ^ hash(self.module) ^ hash(self.name) - def __eq__(self, other: TrackedRef) -> bool: - return self.name == other.name and self.module == other.module + def __eq__(self, other) -> bool: + return self.__class__ == other.__class__ and self.name == other.name and self.module == other.module def imp_names(imp: Import | ImportFrom) -> Iterator[tuple[str, str]]: diff --git a/lib/cli_command_parser/exceptions.py b/lib/cli_command_parser/exceptions.py index b5f27c6a..711ccc14 100644 --- a/lib/cli_command_parser/exceptions.py +++ b/lib/cli_command_parser/exceptions.py @@ -3,7 +3,6 @@ :author: Doug Skrypa """ -# pylint: disable=W0231 from __future__ import annotations diff --git a/lib/cli_command_parser/formatting/commands.py b/lib/cli_command_parser/formatting/commands.py index 57230031..1f6fc4bf 100644 --- a/lib/cli_command_parser/formatting/commands.py +++ b/lib/cli_command_parser/formatting/commands.py @@ -24,7 +24,7 @@ from ..core import CommandMeta from ..metadata import ProgramMetadata from ..parameters import BaseOption, BasePositional, Parameter, PassThru, SubCommand - from ..typing import Bool, CommandAny, CommandCls, CommandType, OptStr + from ..typing import Bool, CommandAny, CommandCls, OptStr __all__ = ['CommandHelpFormatter', 'get_formatter'] @@ -32,7 +32,7 @@ class CommandHelpFormatter: - def __init__(self, command: CommandType, params: CommandParameters): + def __init__(self, command: CommandMeta, params: CommandParameters): self.command = command self.params = params self.pos_group = ParamGroup(description='Positional arguments') @@ -122,13 +122,21 @@ def format_help(self, allow_sys_argv: Bool = True) -> str: # region RST Formatting def format_rst( - self, fix_name: Bool = True, fix_name_func: NameFunc = None, init_level: int = 1, allow_sys_argv: Bool = False + self, + fix_name: Bool = True, + fix_name_func: NameFunc | None = None, + init_level: int = 1, + allow_sys_argv: Bool = False, ) -> str: """Generate the RST content for the Command associated with this formatter and all of its subcommands""" return '\n'.join(self._format_rst(fix_name, fix_name_func, init_level, allow_sys_argv)) def _format_rst( - self, fix_name: Bool = True, fix_name_func: NameFunc = None, init_level: int = 1, allow_sys_argv: Bool = False + self, + fix_name: Bool = True, + fix_name_func: NameFunc | None = None, + init_level: int = 1, + allow_sys_argv: Bool = False, ) -> Iterator[str]: name = self._meta.doc_name if fix_name: diff --git a/lib/cli_command_parser/formatting/restructured_text.py b/lib/cli_command_parser/formatting/restructured_text.py index d0ca22bd..2d2ae6d6 100644 --- a/lib/cli_command_parser/formatting/restructured_text.py +++ b/lib/cli_command_parser/formatting/restructured_text.py @@ -6,15 +6,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, Sequence, TypeVar +from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, Sequence if TYPE_CHECKING: from ..typing import Bool, OptStr, Strings + RowMaps = Sequence[Mapping[OptStr, OptStr]] + __all__ = ['rst_bar', 'rst_list_table', 'RstTable'] -T = TypeVar('T') -RowMaps = Sequence[Mapping[T, 'OptStr']] # region Constants & Templates @@ -40,6 +40,9 @@ # endregion +# region RST Formatting Helpers + + def rst_bar(text: str | int, level: int = 1) -> str: bar_len = text if isinstance(text, int) else len(text) c = BAR_CHAR_ORDER[level] @@ -105,6 +108,9 @@ def rst_list_table(data: dict[str, str], value_pad: int = 20) -> str: return LIST_TABLE_TMPL.format(widths=widths, entries=entries) +# endregion + + class RstTable: """ :param title: The title for this table. Only displayed if ``show_title`` is True. @@ -131,7 +137,7 @@ def __init__( self.subtitle = subtitle self.show_title = show_title self.use_table_directive = use_table_directive - self._rows = [] + self._rows: list[Row] = [] self._widths = () self._updated = False if headers: @@ -139,7 +145,7 @@ def __init__( @classmethod def from_dicts( - cls, rows: RowMaps, columns: Sequence[T] | None = None, auto_headers: Bool = False, **kwargs + cls, rows: RowMaps, columns: Sequence[OptStr] | None = None, auto_headers: Bool = False, **kwargs ) -> RstTable: """ Initialize a RstTable using the given keyword arguments, and populate its rows using the given dicts and @@ -170,7 +176,7 @@ def widths(self) -> tuple[int, ...]: self._updated = False return self._widths - def add_dict_rows(self, rows: RowMaps, columns: Sequence[T] | None = None, add_header: Bool = False): + def add_dict_rows(self, rows: RowMaps, columns: Sequence[OptStr] | None = None, add_header: Bool = False): """Add a row for each dict in the given sequence of rows, where the keys represent the columns.""" if not columns: columns = list(rows[0]) diff --git a/lib/cli_command_parser/inputs/__init__.py b/lib/cli_command_parser/inputs/__init__.py index 494dd643..76a134f4 100644 --- a/lib/cli_command_parser/inputs/__init__.py +++ b/lib/cli_command_parser/inputs/__init__.py @@ -51,15 +51,15 @@ def normalize_input_type(type_func: InputTypeFunc, param_choices: ChoicesType) - match type_func: case None: - return Choices(param_choices) if choices_provided else type_func + return Choices(param_choices) if choices_provided else type_func # type: ignore[arg-type] case range(): return Range(type_func) case _Pattern(): return Regex(type_func) case _EnumMeta(): - enum_choices = EnumChoices(type_func) + enum_choices: EnumChoices = EnumChoices(type_func) if choices_provided: - return Choices(param_choices, enum_choices) + return Choices(param_choices, enum_choices) # type: ignore[arg-type] return enum_choices - return Choices(param_choices, type_func) if choices_provided else type_func + return Choices(param_choices, type_func) if choices_provided else type_func # type: ignore[arg-type] diff --git a/lib/cli_command_parser/inputs/base.py b/lib/cli_command_parser/inputs/base.py index 08e79a0e..1348d6f6 100644 --- a/lib/cli_command_parser/inputs/base.py +++ b/lib/cli_command_parser/inputs/base.py @@ -37,7 +37,7 @@ def is_valid_type(self, value: str) -> bool: # pylint: disable=W0613 """ return True - def fix_default(self, value: Any) -> T | None: + def fix_default(self, value: Any) -> T | str | None: return value def format_metavar(self, choice_delim: str = ',', sort_choices: bool = False) -> str: diff --git a/lib/cli_command_parser/inputs/choices.py b/lib/cli_command_parser/inputs/choices.py index 89f231be..231fc2a1 100644 --- a/lib/cli_command_parser/inputs/choices.py +++ b/lib/cli_command_parser/inputs/choices.py @@ -6,11 +6,11 @@ from __future__ import annotations +import sys from abc import ABC, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Collection, Iterator, Mapping, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Collection, Iterator, Mapping, Type, TypeVar -from ..typing import T, TypeFunc from .base import InputType from .exceptions import InvalidChoiceError @@ -19,13 +19,18 @@ __all__ = ['Choices', 'ChoiceMap', 'EnumChoices'] +if sys.version_info >= (3, 13): + T = TypeVar('T', default=str) +else: + T = TypeVar('T') + EnumT = TypeVar('EnumT', bound=Enum) class _ChoicesBase(InputType[T], ABC): __slots__ = ('choices', 'type', 'case_sensitive') - choices: Collection[T] - type: TypeFunc | None + choices: Collection + type: Callable[[str], T] | None case_sensitive: bool def __contains__(self, value: str) -> bool: @@ -54,21 +59,22 @@ def _choices_repr(self, delim: str = ',') -> str: def _normalize(self, value: str) -> T: if self.type is not None: try: - return self.type(value) # pylint: disable=E1102 + return self.type(value) except (ValueError, TypeError) as e: raise InvalidChoiceError(value, self.choices) from e - return value + return value # type: ignore[return-value] - def _iter_normalized(self, value: Any, choices: Collection | None = None) -> Iterator[T]: + def _iter_normalized(self, value: str | T, choices: Collection | None = None) -> Iterator[str | T]: yield value if not self.case_sensitive and (choices is None or isinstance(choices, (set, Mapping))): - yield value.lower() - yield value.upper() + # Choices validates that all members of `choices` are strings when not case_sensitive + yield value.lower() # type: ignore[misc,union-attr] + yield value.upper() # type: ignore[misc,union-attr] def _case_insensitive_map_choice(self, value: Any) -> T: if not self.case_sensitive: norm_value = value.casefold() - for choice, val in self.choices.items(): # noqa + for choice, val in self.choices.items(): # type: ignore[attr-defined] if norm_value == choice.casefold(): return val @@ -77,7 +83,7 @@ def _case_insensitive_map_choice(self, value: Any) -> T: def format_metavar(self, choice_delim: str = ',', sort_choices: bool = False) -> str: choices = map(str, self.choices) if sort_choices: - choices = sorted(choices) + choices = sorted(choices) # type: ignore[assignment] return f'{{{choice_delim.join(choices)}}}' @@ -93,11 +99,17 @@ class Choices(_ChoicesBase[T]): """ __slots__ = () + choices: Collection[T] - def __init__(self, choices: Collection[T], type: TypeFunc | None = None, case_sensitive: Bool = True): # noqa + def __init__( + self, + choices: Collection[T], + type: Callable[[str], T] | None = None, # noqa + case_sensitive: Bool = True, + ): if not case_sensitive and not all(isinstance(c, str) for c in choices): raise TypeError(f'Cannot combine case_sensitive=False with non-str {choices=}') - elif isinstance(type, EnumChoices) and not any(isinstance(c, type.type) for c in choices): + if isinstance(type, EnumChoices) and not any(isinstance(c, type.type) for c in choices): raise TypeError(f'Invalid {choices=} for {type=}') super().__init__() # fix_default is not implemented here, so it's not necessary to expose self.choices = choices @@ -106,21 +118,22 @@ def __init__(self, choices: Collection[T], type: TypeFunc | None = None, case_se def _choices_repr(self, delim: str = ',') -> str: try: - return delim.join(map(repr, sorted(self.choices))) + return delim.join(map(repr, sorted(self.choices))) # type: ignore[type-var] except TypeError: # The choice values are not sortable return delim.join(sorted(map(repr, self.choices))) def __call__(self, value: str) -> T: choices = self.choices - value = self._normalize(value) + value = self._normalize(value) # type: ignore[assignment] for val in self._iter_normalized(value, choices): if val in choices: - return value + return value # type: ignore[return-value] if not self.case_sensitive: - norm_value = value.casefold() + # choices/value are confirmed to be str in init when case_sensitive=False + norm_value = value.casefold() # type: ignore[attr-defined] for choice in choices: - if norm_value == choice.casefold(): + if norm_value == choice.casefold(): # type: ignore[attr-defined] return choice raise InvalidChoiceError(value, choices) @@ -146,7 +159,7 @@ def __init__(self, choices: Mapping[Any, T], *args, **kwargs): # TODO: Alternate ChoiceMap where values are used as help text, similar to SubCommand with local_choices def __call__(self, value: str) -> T: - value = self._normalize(value) + value = self._normalize(value) # type: ignore[assignment] for val in self._iter_normalized(value): try: return self.choices[val] @@ -172,12 +185,13 @@ class EnumChoices(_ChoicesBase[EnumT]): __slots__ = () type: Type[EnumT] + choices: Mapping[str, EnumT] def __init__(self, enum: Type[EnumT], case_sensitive: Bool = False): super().__init__() # fix_default is not implemented here, so it's not necessary to expose self.type = enum self.case_sensitive = case_sensitive - self.choices = enum._member_map_ + self.choices = enum._member_map_ # type: ignore[assignment] def _type_str(self) -> str: return f'type={self.type.__name__}, ' @@ -189,7 +203,7 @@ def __call__(self, value: str) -> EnumT: enum = self.type for val in self._iter_normalized(value): try: - return enum[val] + return enum[val] # type: ignore[index] except KeyError: pass try: diff --git a/lib/cli_command_parser/inputs/files.py b/lib/cli_command_parser/inputs/files.py index 0dfc4f2f..694fac23 100644 --- a/lib/cli_command_parser/inputs/files.py +++ b/lib/cli_command_parser/inputs/files.py @@ -9,8 +9,9 @@ import os from abc import ABC from pathlib import Path as _Path +from typing import IO -from ..typing import FP, Bool, Converter, OptStr, PathLike, T +from ..typing import Bool, Converter, OptStr, PathLike, T from .base import InputType from .exceptions import InputValidationError from .utils import FileWrapper, InputParam, StatMode, allows_write, fix_windows_path @@ -19,14 +20,14 @@ class FileInput(InputType[T], ABC): - exists: bool = InputParam(None) - expand: bool = InputParam(True) - resolve: bool = InputParam(False) - type: StatMode = InputParam(StatMode.ANY) - readable: bool = InputParam(False) - writable: bool = InputParam(False) - allow_dash: bool = InputParam(False) - use_windows_fix: bool = InputParam(True) + exists: InputParam[bool | None] = InputParam(None) + expand: InputParam[bool] = InputParam(True) + resolve: InputParam[bool] = InputParam(False) + type: InputParam[StatMode] = InputParam(StatMode.ANY) + readable: InputParam[bool] = InputParam(False) + writable: InputParam[bool] = InputParam(False) + allow_dash: InputParam[bool] = InputParam(False) + use_windows_fix: InputParam[bool] = InputParam(True) def __init__( self, @@ -45,7 +46,7 @@ def __init__( self.exists = exists self.expand = expand self.resolve = resolve - self.type = StatMode(type) # pylint: disable=E1120 + self.type = StatMode(type) self.readable = readable self.writable = writable self.allow_dash = allow_dash @@ -55,33 +56,38 @@ def __repr__(self) -> str: non_defaults = ', '.join(f'{k}={v!r}' for k, v in self.__dict__.items()) return f'<{self.__class__.__name__}({non_defaults})>' - def fix_default(self, value: T | None) -> T | None: + def fix_default(self, value: T | str | None) -> T | str | None: """ Fixes the default value to conform to the expected return type for this input. Allows the default value for a path to be provided as a string, for example. """ if value is None or not self._fix_default: return value - return self(value) + return self(value) # type: ignore[arg-type] def validated_path(self, path: PathLike) -> _Path: if not isinstance(path, _Path): if not (path := path.strip()): raise InputValidationError('A valid path is required') path = _Path(path) + if path.parts == ('-',): if not self.allow_dash: raise InputValidationError('Dash (-) is not supported for this parameter') return path + if self.use_windows_fix and os.name == 'nt': try: path = fix_windows_path(path) except OSError: pass + if self.expand: path = path.expanduser() + if self.resolve: path = path.resolve() + if self.exists is not None: if self.exists and not path.exists(): raise InputValidationError('the provided path does not exist') @@ -94,8 +100,10 @@ def validated_path(self, path: PathLike) -> _Path: if self.readable and not os.access(path, os.R_OK): raise InputValidationError('the provided path is not readable') + if self.writable and not os.access(path, os.W_OK): raise InputValidationError('the provided path is not writable') + return path @@ -132,12 +140,12 @@ class File(FileInput[FileWrapper | str | bytes]): :param kwargs: Additional keyword arguments to pass to :class:`.Path`. """ - mode: str = InputParam('r') - type: StatMode = InputParam(StatMode.FILE) - encoding: str = InputParam(None) - errors: str = InputParam(None) - lazy: bool = InputParam(True) - parents: bool = InputParam(False) + mode: InputParam[str] = InputParam('r') + type: InputParam[StatMode] = InputParam(StatMode.FILE) + encoding: InputParam[str | None] = InputParam(None) + errors: InputParam[str | None] = InputParam(None) + lazy: InputParam[bool] = InputParam(True) + parents: InputParam[bool] = InputParam(False) def __init__( self, @@ -182,8 +190,8 @@ class Serialized(File): :param kwargs: Additional keyword arguments to pass to :class:`.File` """ - converter: Converter = InputParam(None) - pass_file: bool = InputParam(False) + converter: InputParam[Converter | None] = InputParam(None) + pass_file: InputParam[bool] = InputParam(False) def __init__(self, converter: Converter, *, pass_file: Bool = False, **kwargs): super().__init__(**kwargs) @@ -212,7 +220,7 @@ def __init__(self, *, mode: str = 'rb', wrap_errors: bool = True, **kwargs): super().__init__(json.dump if write else self._load_json, mode=mode, **kwargs) self.wrap_errors = wrap_errors - def _load_json(self, f: FP): + def _load_json(self, f: IO): from json import JSONDecodeError, load try: diff --git a/lib/cli_command_parser/inputs/numeric.py b/lib/cli_command_parser/inputs/numeric.py index 314fa709..1608f6d0 100644 --- a/lib/cli_command_parser/inputs/numeric.py +++ b/lib/cli_command_parser/inputs/numeric.py @@ -3,7 +3,6 @@ :author: Doug Skrypa """ -# pylint: disable=W0622 from __future__ import annotations @@ -66,21 +65,29 @@ class Range(_RangeInput[NT]): """ type: NumType = int - range: _range | None + range: _range snap: bool - def __init__(self, range: RngType, snap: Bool = False, type: NumType = None, fix_default: Bool = True): # noqa + def __init__( + self, + range: RngType, # noqa + snap: Bool = False, + type: NumType | None = None, # noqa + fix_default: Bool = True, + ): super().__init__(fix_default) self.snap = snap - if isinstance(range, int): - self.range = _range(range) - elif not isinstance(range, _range): - self.range = _range(*range) # noqa - else: - self.range = range if type is not None: self.type = type + match range: + case int(): + self.range = _range(range) + case _range(): + self.range = range + case _: + self.range = _range(*range) + def __repr__(self) -> str: return f'<{self.__class__.__name__}({self.range!r}, snap={self.snap!r}, type={self.type!r})>' @@ -91,13 +98,12 @@ def _range_str(self, var: str = 'N') -> str: return base if step == 1 else f'{base}, {step=}' def __call__(self, value: str) -> NT: - value = self.type(value) - if value in self.range: - return value + num_val = self.type(value) + if num_val in self.range: + return num_val elif self.snap: - if (rng_min := min(self.range)) > value: - return rng_min - return max(self.range) + snapped = rng_min if (rng_min := min(self.range)) > num_val else max(self.range) + return self.type(snapped) if self.type is not int else snapped # type: ignore[return-value] raise InputValidationError(f'expected a value in the range {self._range_str()}') @@ -148,13 +154,15 @@ def __init__( if snap: if self.type is float: raise TypeError('Unable to snap to extrema with type=float') - real_min = min if include_min else min + 1 - real_max = max if include_max else max - 1 - if real_min >= real_max: - raise ValueError( - f'Invalid {min=} >= {max=} with snap=True, {include_min=},' - f' {include_max=} - snap would produce invalid values' - ) + + if min is not None and max is not None: + real_min = min if include_min else min + 1 + real_max = max if include_max else max - 1 + if real_min >= real_max: + raise ValueError( + f'Invalid {min=} >= {max=} with snap=True, {include_min=},' + f' {include_max=} - snap would produce invalid values' + ) self.snap = snap self.min = self.type(min) if min is not None else min # for floats especially, such as a range like 0~1, this @@ -168,7 +176,7 @@ def __repr__(self) -> str: def _range_str(self, var: str = 'N') -> str: return range_str(self.min, self.max, self.include_min, self.include_max, var) - def handle_invalid(self, bound: Number, inclusive: bool, snap_dir: int) -> Number: + def handle_invalid(self, bound: Number, inclusive: bool, snap_dir: int) -> NT: """ Handle calculating / returning a snap value or raise an exception if snapping to the bound is not allowed. @@ -179,21 +187,22 @@ def handle_invalid(self, bound: Number, inclusive: bool, snap_dir: int) -> Numbe :param snap_dir: The direction to adjust the bound if it is exclusive as ``+1`` or ``-1`` :return: The snap value if :attr:`.snap` is True, otherwise a :class:`python:ValueError` is raised """ - if self.snap: + if self.snap and bound is not None: return bound if inclusive else (bound + snap_dir) raise InputValidationError(f'expected a value in the range {self._range_str()}') def __call__(self, value: str) -> NT: - value = self.type(value) - if self.value_lt_min(value): + num_val = self.type(value) + # Note: if snap is enabled, it is applied by `handle_invalid` + if self.value_lt_min(num_val): return self.handle_invalid(self.min, self.include_min, 1) - elif self.value_gt_max(value): + elif self.value_gt_max(num_val): return self.handle_invalid(self.max, self.include_max, -1) else: - return value + return num_val -class Bytes(NumericInput[NT]): +class Bytes(NumericInput[int | float]): # type: ignore[type-var] """ A byte count/size. @@ -270,9 +279,9 @@ def _type_desc(self) -> str: parts.append('byte count/size') return ' '.join(parts) - def __call__(self, value: str) -> NT: + def __call__(self, value: str) -> int | float: try: - num, unit = self._pattern.match(value.strip()).groups() + num, unit = self._pattern.match(value.strip()).groups() # type: ignore[union-attr] except (TypeError, AttributeError): raise InputValidationError(f'expected {self._type_desc()} with optional unit') from None diff --git a/lib/cli_command_parser/inputs/patterns.py b/lib/cli_command_parser/inputs/patterns.py index ba172b9c..3501d9dd 100644 --- a/lib/cli_command_parser/inputs/patterns.py +++ b/lib/cli_command_parser/inputs/patterns.py @@ -97,9 +97,9 @@ class Regex(PatternInput[RegexResult]): def __init__( self, *patterns: str | Pattern, - group: str | int = None, - groups: Collection[str | int] = None, - mode: RegexMode | str = None, + group: str | int | None = None, + groups: Collection[str | int] | None = None, + mode: RegexMode | str | None = None, ): if not patterns: raise TypeError('At least one regex pattern is required') @@ -120,17 +120,17 @@ def __call__(self, value: str) -> RegexResult: raise InputValidationError(f'expected a value matching {self._describe_patterns()}') if (mode := self.mode) == RegexMode.STRING: - return value + return value # type: ignore[return-value] elif mode == RegexMode.MATCH: - return m + return m # type: ignore[return-value] elif mode == RegexMode.GROUP: - return m.group(self.group) + return m.group(self.group) # type: ignore[arg-type,return-value] elif mode == RegexMode.GROUPS: if self.groups: - return tuple(m.group(g) for g in self.groups) - return m.groups() + return tuple(m.group(g) for g in self.groups) # type: ignore[return-value] + return m.groups() # type: ignore[return-value] else: # mode == RegexMode.DICT - return m.groupdict() + return m.groupdict() # type: ignore[return-value] class Glob(PatternInput[str]): diff --git a/lib/cli_command_parser/inputs/time.py b/lib/cli_command_parser/inputs/time.py index 511b33f1..f9899581 100644 --- a/lib/cli_command_parser/inputs/time.py +++ b/lib/cli_command_parser/inputs/time.py @@ -22,7 +22,7 @@ from enum import Enum from locale import LC_ALL, setlocale from threading import RLock -from typing import Collection, Iterator, Literal, Sequence, Type, TypeVar, overload +from typing import Collection, Iterator, Literal, NoReturn, Sequence, Type, TypeVar, overload from ..typing import Bool, Locale, Number, OptStr, T, TimeBound from ..utils import MissingMixin @@ -32,7 +32,7 @@ __all__ = ['DTFormatMode', 'Day', 'Month', 'TimeDelta', 'DateTime', 'Date', 'Time'] -DT = TypeVar('DT') +DT = TypeVar('DT', datetime, date, time) TimeUnit = Literal['microseconds', 'milliseconds', 'seconds', 'minutes', 'hours', 'days', 'weeks'] _TIMEDELTA_UNITS = {'microseconds', 'milliseconds', 'seconds', 'minutes', 'hours', 'days', 'weeks'} DEFAULT_DATE_FMT = '%Y-%m-%d' @@ -75,7 +75,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): class DTInput(_FixedInputType[T], ABC): __slots__ = ('locale',) - dt_type: OptStr + dt_type: str locale: Locale | None def __init_subclass__(cls, dt_type: OptStr = None, **kwargs): @@ -83,9 +83,10 @@ def __init_subclass__(cls, dt_type: OptStr = None, **kwargs): :param dt_type: Used in InvalidChoiceError / ValueError messages """ super().__init_subclass__(**kwargs) - cls.dt_type = dt_type + if dt_type: + cls.dt_type = dt_type - def __init__(self, locale: Locale = None, fix_default: Bool = True): + def __init__(self, locale: Locale | None = None, fix_default: Bool = True): super().__init__(fix_default) self.locale = locale @@ -256,6 +257,9 @@ def __init__( fix_default: Bool = True, ): ... + @overload + def __init__(self: NoReturn) -> NoReturn: ... # type: ignore[misc] # Workaround for single overload def error + def __init__(self, *, iso: Bool = False, **kwargs): super().__init__(**kwargs) self.iso = iso @@ -311,6 +315,9 @@ def __init__( fix_default: Bool = True, ): ... + @overload + def __init__(self: NoReturn) -> NoReturn: ... # type: ignore[misc] # Workaround for single overload def error + def __init__(self, *, numeric: Bool = True, **kwargs): super().__init__(numeric=numeric, **kwargs) @@ -347,7 +354,7 @@ def __init__( int_only: Bool = False, fix_default: Bool = True, ): - unit = unit.lower() + unit = unit.lower() # type: ignore[assignment] if unit not in _TIMEDELTA_UNITS: raise TypeError(f'Invalid {unit=} - expected one of: {", ".join(sorted(_TIMEDELTA_UNITS))}') elif min is not None and max is not None and min >= max: @@ -374,7 +381,7 @@ def __call__(self, value: str | int | float) -> timedelta: elif self.int_only and int(value) != value: raise self._invalid(value, f'expected an integer, not a {value.__class__.__name__}') - return timedelta(**{self.unit: value}) + return timedelta(**{self.unit: value}) # type: ignore[misc] def _invalid(self, value: Number, message: str) -> InputValidationError: return InputValidationError(f'Invalid numeric {self.unit}={value!r} - {message}') @@ -384,7 +391,7 @@ def _range_str(self) -> str: def fix_default(self, value: int | float | timedelta | None) -> timedelta | None: if value is None or isinstance(value, timedelta) or not self._fix_default: - return value + return value # type: ignore[return-value] return self(value) def format_metavar(self, choice_delim: str = ',', sort_choices: bool = False) -> str: @@ -419,11 +426,11 @@ def __init__( self.latest = latest @classmethod - def _fix_type(cls, dt: datetime) -> DT: + def _fix_type(cls, dt: datetime | None) -> DT | None: try: return getattr(dt, cls.dt_type)() except AttributeError: - return dt + return dt # type: ignore[return-value] @property def earliest(self) -> DT | None: @@ -466,7 +473,7 @@ def parse_dt(self, value: str) -> datetime: ) def parse(self, value: str) -> DT: - return self._fix_type(self.parse_dt(value)) + return self._fix_type(self.parse_dt(value)) # type: ignore[return-value] def choice_str(self, choice_delim: str = ' | ', sort_choices: bool = False) -> str: return choice_delim.join(sorted(self.formats) if sort_choices else self.formats) @@ -484,11 +491,13 @@ def _validate_bounds(self, dt: DT): check_latest = latest is not None if not (check_earliest or check_latest): return - elif (check_earliest and dt < earliest) or (check_latest and dt > latest): + elif (check_earliest and dt < earliest) or (check_latest and dt > latest): # type: ignore[operator] if check_earliest and check_latest: - msg = f'between {dt_repr(earliest)} and {dt_repr(latest)} (inclusive)' + msg = f'between {dt_repr(earliest)} and {dt_repr(latest)} (inclusive)' # type: ignore[arg-type] + elif check_earliest: + msg = f'after {dt_repr(earliest)}' # type: ignore[arg-type] else: - msg = f'after {dt_repr(earliest)}' if check_earliest else f'before {dt_repr(latest)}' + msg = f'before {dt_repr(latest)}' # type: ignore[arg-type] raise InputValidationError(f'Invalid {self.dt_type}={dt_repr(dt)} - a {self.dt_type} {msg} is required') def __call__(self, value: str) -> DT: @@ -580,7 +589,7 @@ def __init__( def dt_repr(dt: datetime | date | time, use_repr: bool = True) -> str: try: - dt_str = dt.isoformat(' ') + dt_str = dt.isoformat(' ') # type: ignore[call-arg] except (TypeError, ValueError): # TypeError for date objects, ValueError for time objects dt_str = dt.isoformat() return repr(dt_str) if use_repr else dt_str diff --git a/lib/cli_command_parser/inputs/utils.py b/lib/cli_command_parser/inputs/utils.py index b4868ffb..830b7e2e 100644 --- a/lib/cli_command_parser/inputs/utils.py +++ b/lib/cli_command_parser/inputs/utils.py @@ -11,28 +11,41 @@ from contextlib import contextmanager from pathlib import Path from stat import S_IFBLK, S_IFCHR, S_IFDIR, S_IFIFO, S_IFLNK, S_IFMT, S_IFREG, S_IFSOCK -from typing import TYPE_CHECKING, Any, BinaryIO, Iterator, TextIO +from typing import IO, TYPE_CHECKING, Any, Generic, Iterator, Literal, TypeVar, overload from weakref import finalize +try: + from typing import Self +except ImportError: # added in 3.11 + Self = TypeVar('Self') # type: ignore[misc,assignment] + from ..utils import FixedFlag from .exceptions import InputValidationError if TYPE_CHECKING: - from ..typing import FP, Bool, Converter, Number, OptStr + from ..typing import Bool, Converter, Number, OptStr __all__ = ['InputParam', 'StatMode', 'FileWrapper', 'fix_windows_path', 'range_str', 'RangeMixin'] +_T = TypeVar('_T') + -class InputParam: +class InputParam(Generic[_T]): __slots__ = ('default', 'name') - def __init__(self, default: Any): + def __init__(self, default: _T): self.default = default def __set_name__(self, owner, name: str): self.name = name - def __get__(self, instance, owner) -> Any: + @overload + def __get__(self, instance: Literal[None], owner: Any) -> Self: ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T: ... + + def __get__(self, instance: object, owner: Any) -> Self | _T: try: return instance.__dict__[self.name] except AttributeError: # instance is None @@ -40,17 +53,23 @@ def __get__(self, instance, owner) -> Any: except KeyError: return self.default - def __set__(self, instance, value: Any): + def __set__(self, instance: object, value: _T): if value != self.default: instance.__dict__[self.name] = value class StatMode(FixedFlag): - def __new__(cls, mode, friendly_name: OptStr = None): + if TYPE_CHECKING: + mode: int | None # type: ignore[misc] + friendly_name: str + + def __init__(self, mode: int | str | StatMode | None): ... + + def __new__(cls, mode: int | None, friendly_name: OptStr = None): # Defined __new__ to avoid juggling dicts for the stat mode values and names obj = object.__new__(cls) if friendly_name: - obj.mode = mode + obj.mode = mode # type: ignore[misc] obj.friendly_name = friendly_name if mode is None: # ANY obj._value_ = sum(m._value_ for m in cls.__members__.values()) @@ -76,11 +95,14 @@ def __str__(self) -> str: name = self.friendly_name except AttributeError: # Combined flags name = None + if name: return name + names = [part.friendly_name for part in self._decompose()] if len(names) == 2: return '{} or {}'.format(*names) + names[-1] = f'or {names[-1]}' return ', '.join(names) @@ -104,20 +126,20 @@ def __init__( self.converter = converter self.pass_file = pass_file self.parents = parents - self._fp: TextIO | BinaryIO | None = None - self._finalizer = None + self._fp: IO | None = None + self._finalizer: finalize | None = None - def __eq__(self, other: FileWrapper) -> bool: + def __eq__(self, other) -> bool: attrs = ('path', 'mode', 'binary', 'encoding', 'errors', 'converter', 'pass_file', 'parents') try: return all(getattr(self, a) == getattr(other, a) for a in attrs) - except AttributeError: + except AttributeError: # not a FileWrapper return NotImplemented def read(self) -> Any: with self._file() as f: if self.converter is not None: - return self.converter(f if self.pass_file else f.read()) + return self.converter(f if self.pass_file else f.read()) # type: ignore[call-arg] else: return f.read() @@ -125,13 +147,13 @@ def write(self, data: Any): with self._file() as f: if self.converter is not None: if self.pass_file: - self.converter(data, f) + self.converter(data, f) # type: ignore[call-arg] else: - f.write(self.converter(data)) + f.write(self.converter(data)) # type: ignore[call-arg] else: f.write(data) - def _open(self) -> FP: + def _open(self) -> IO: if self.path == Path('-'): stream = sys.stdin if 'r' in self.mode else sys.stdout return stream.buffer if self.binary else stream @@ -148,7 +170,7 @@ def _open(self) -> FP: return fp @classmethod - def _cleanup(cls, fp: FP, warn_msg: str): + def _cleanup(cls, fp: IO, warn_msg: str): fp.close() warnings.warn(warn_msg, ResourceWarning) @@ -165,17 +187,18 @@ def close(self): do_close = self._finalizer.detach() except AttributeError: do_close = False + if do_close: self._close() @contextmanager - def _file(self) -> Iterator[FP]: + def _file(self) -> Iterator[IO]: try: yield self._open() finally: self.close() - def __enter__(self) -> FP | FileWrapper: + def __enter__(self) -> IO | FileWrapper: if self.converter is not None: return self return self._open() diff --git a/lib/cli_command_parser/parameters/base.py b/lib/cli_command_parser/parameters/base.py index 041c02f0..cea419e8 100644 --- a/lib/cli_command_parser/parameters/base.py +++ b/lib/cli_command_parser/parameters/base.py @@ -7,12 +7,12 @@ from __future__ import annotations import re +import sys from abc import ABC, abstractmethod -from collections.abc import Collection from contextvars import ContextVar from functools import cached_property from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Literal, NoReturn, Type, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Type, TypeAlias, TypeVar, overload try: from typing import Self @@ -28,21 +28,35 @@ from ..inputs.exceptions import InputValidationError, InvalidChoiceError from ..inputs.numeric import NumericInput from ..nargs import REMAINDER, Nargs -from ..typing import CommandMethod, DefaultFunc, T_co -from ..utils import _NotSet +from ..utils import _NotSet, _NotSetType from .option_strings import OptionStrings if TYPE_CHECKING: - from ..core import CommandMeta + from collections.abc import Collection + from typing import Literal, NoReturn + + from ..commands import Command from ..formatting.params import ParamHelpFormatter - from ..typing import Bool, CommandAny, CommandCls, CommandObj, LeadingDash, OptStr, OptStrs, Strings + from ..typing import Bool, LeadingDash, OptStr, OptStrs, Strings from .actions import ParamAction from .groups import ParamGroup + _CmdCls = Type[Command] + _CmdObjOrCls: TypeAlias = Command | _CmdCls + __all__ = ['Parameter', 'BasePositional', 'BaseOption'] -_group_stack = ContextVar('cli_command_parser.parameters.base.group_stack') +_group_stack: ContextVar[list[ParamGroup]] = ContextVar('cli_command_parser.parameters.base.group_stack') _is_numeric = re.compile(r'^-\d+$|^-\d*\.\d+?$').match + +if sys.version_info >= (3, 13): + T = TypeVar('T', default=str) +else: + T = TypeVar('T') + +CommandMethod = Callable[['Command'], T] +DefaultFunc = Callable[[], T] | CommandMethod + TD = TypeVar('TD') @@ -64,9 +78,9 @@ class ParamBase(ABC): _attr_name: OptStr = None #: Always the name of the attr that points to this object _name: OptStr = None #: An explicitly provided name, or the name of the attr that points to this obj group: ParamGroup | None = None #: The group this object is a member of, if any - command: CommandMeta | None = None #: The :class:`.Command` this object is a member of + command: _CmdCls | None = None #: The :class:`.Command` this object is a member of required: Bool #: Whether this param/group is required - help: str #: The description for this param/group that will appear in ``--help`` text + help: OptStr #: The description for this param/group that will appear in ``--help`` text hide: Bool #: Whether this param/group should be hidden in ``--help`` text # fmt: on @@ -102,7 +116,7 @@ def name(self, value: OptStr): def _default_name(self) -> str: return f'{self.__class__.__name__}#{id(self)}' - def __set_name__(self, command: CommandCls, name: str): + def __set_name__(self, command: _CmdCls, name: str): self.command = command if self._name is None: self._name = name @@ -121,22 +135,30 @@ def __eq__(self, other) -> bool: def __hash__(self) -> int: return hash(self.__class__) ^ hash(self._attr_name) ^ hash(self._name) ^ hash(self.command) - def _ctx(self, command: CommandAny | None = None) -> Context | None: + @overload + def _ctx(self, command: _CmdObjOrCls) -> Context: ... + + @overload + def _ctx(self, command: Any) -> Context | None: ... + + def _ctx(self, command: _CmdObjOrCls | Any = None) -> Context | None: if context := get_current_context(True): return context + if command is None: command = self.command + try: - return command._Command__ctx + return command._Command__ctx # type: ignore[union-attr] # noqa except AttributeError: return None - def _config(self, command: CommandAny | None = None) -> CommandConfig: + def _config(self, command: _CmdCls | None = None) -> CommandConfig: if context := self._ctx(command): return context.config if command is None: command = self.command - return command.__class__.config(command, DEFAULT_CONFIG) + return command.__class__.config(command, DEFAULT_CONFIG) # type: ignore[union-attr] # region Usage / Help Text @@ -167,7 +189,7 @@ def format_help(self, *args, **kwargs) -> str: # endregion -class Parameter(ParamBase, Generic[T_co], ABC): +class Parameter(ParamBase, Generic[T], ABC): """ Base class for all other parameters. It is not meant to be used directly. @@ -205,7 +227,7 @@ class Parameter(ParamBase, Generic[T_co], ABC): # Instance attributes with class defaults metavar: OptStr = None nargs: Nargs # Expected to be set in subclasses - type: Callable[[str], T_co] | None = None # Expected to be set in subclasses + type: Callable[[str], T] | None = None # Expected to be set in subclasses allow_leading_dash: AllowLeadingDash = AllowLeadingDash.NUMERIC # Set in some subclasses default = _NotSet default_cb: DefaultCallback | None = None @@ -276,7 +298,7 @@ def _handle_bad_action(self, action: str) -> NoReturn: f'Invalid {action=} for {self.__class__.__name__} - valid actions: {sorted(self._action_map)}' ) - def __set_name__(self, command: CommandCls, name: str): + def __set_name__(self, command: _CmdCls, name: str): super().__set_name__(command, name) # If self.type is None, a type may still be inferred from an annotation, which happens in this method. if untyped_choices := self.type is not None: @@ -290,13 +312,13 @@ def __set_name__(self, command: CommandCls, name: str): if (annotated_type := get_descriptor_value_type(command, name)) is None: return elif untyped_choices: - self.type.type = annotated_type + self.type.type = annotated_type # type: ignore[union-attr] else: # self.type must be None # Choices present earlier would have already been converted self.type = normalize_input_type(annotated_type, None) @property - def has_choices(self) -> bool: + def has_choices(self) -> Bool: if self.type: return isinstance(self.type, _ChoicesBase) and self.type.choices return False @@ -330,7 +352,7 @@ def register_default_cb(self, method: CommandMethod) -> CommandMethod: def __repr__(self) -> str: names = ('action', 'const', 'default', 'default_cb', 'type', 'choices', 'required', 'hide', 'help') if self._repr_attrs: - names = chain(names, self._repr_attrs) + names = chain(names, self._repr_attrs) # type: ignore[assignment] skip = (None, _NotSet) attrs = ( @@ -346,12 +368,13 @@ def __repr__(self) -> str: def get_const(self, opt_str: OptStr = None): return _NotSet - def get_env_const(self, value: str, env_var: str) -> tuple[T_co, bool]: + def get_env_const(self, value: str, env_var: str) -> tuple[T | _NotSetType, bool]: return _NotSet, False - def prepare_value(self, value: str, short_combo: Bool = False, env_var: OptStr = None) -> T_co: + def prepare_value(self, value: str, short_combo: Bool = False, env_var: OptStr = None) -> T | str: if self.type is None: return value + try: return self.type(value) except InvalidChoiceError as e: @@ -366,13 +389,13 @@ def prepare_value(self, value: str, short_combo: Bool = False, env_var: OptStr = suffix = f' from env var={env_var!r}' if env_var else '' raise BadArgument(self, f'unable to cast {value=} to type={self.type!r}{suffix}') from e - def prepare_validation_value(self, value: str, short_combo: Bool = False) -> T_co: + def prepare_validation_value(self, value: str, short_combo: Bool = False) -> T | str: if self.type is None or (isinstance(self.type, InputType) and self.type.is_valid_type(value)): return value - else: - return self.prepare_value(value, short_combo) - def validate(self, value: T_co | None, joined: Bool = False): + return self.prepare_value(value, short_combo) + + def validate(self, value: T | None, joined: Bool = False): if not isinstance(value, str) or not value or not value[0] == '-': return elif self.allow_leading_dash == AllowLeadingDash.NUMERIC: @@ -399,27 +422,33 @@ def is_valid_arg(self, value: Any) -> bool: # region Parse Results / Argument Value Handling @overload - def __get__(self, command: Literal[None], owner: Type[object]) -> Self: ... + def __get__(self, command: Literal[None], owner: Any) -> Self: ... @overload - def __get__(self, command: object, owner: Type[object]) -> T_co | None: ... + def __get__(self, command: object, owner: Any) -> T | None: ... - def __get__(self, command, owner): + def __get__(self, command: object | None, owner: Any) -> Self | T | None: if command is None: return self - with self._ctx(command): - value = self.result(command) + if context := self._ctx(command): + with context: + value = self.result(command) + else: + # This would only ever happen if this Parameter was created in a non-Command class, but it makes mypy happy + value = None if self._attr_name: command.__dict__[self._attr_name] = value # Skip __get__ on subsequent accesses + return value - def result(self, command: CommandObj | None = None, missing_default: TD = _NotSet) -> T_co | TD | None: + def result(self, command: Command | Any = None, missing_default: TD | _NotSetType = _NotSet) -> T | TD | None: """The final result / parsed value for this Parameter that is returned upon access as a descriptor.""" if (value := ctx.get_parsed_value(self)) is not _NotSet: return self.action.finalize_value(value) - elif self.required: + + if self.required: if missing_default is _NotSet: raise MissingArgument(self) return missing_default @@ -447,7 +476,7 @@ def show_in_help(self) -> bool: # endregion -class BasePositional(Parameter[T_co], ABC): +class BasePositional(Parameter[T], ABC): """ Base class for :class:`.Positional`, :class:`.SubCommand`, :class:`.Action`, and any other parameters that are provided positionally, without prefixes. It is not meant to be used directly. @@ -474,7 +503,13 @@ def __init_subclass__(cls, default_ok: bool | None = None, **kwargs): # pylint: cls._default_ok = default_ok def __init__( - self, action: str, *, required: Bool = True, default: Any = _NotSet, default_cb: DefaultFunc = None, **kwargs + self, + action: str, + *, + required: Bool = True, + default: Any = _NotSet, + default_cb: DefaultFunc | None = None, + **kwargs, ): if not (self._default_ok and 0 in self.nargs): # Indicates that having a default is bad if not required: @@ -486,7 +521,7 @@ def __init__( super().__init__(action, default=default, required=required, default_cb=default_cb, **kwargs) -class BaseOption(Parameter[T_co], ABC): +class BaseOption(Parameter[T], ABC): """ Base class for :class:`.Option`, :class:`.Flag`, :class:`.Counter`, and any other keyword-like parameters that have ``--long`` and ``-short`` prefixes before values. @@ -530,7 +565,7 @@ def __init__( self, *option_strs: str, action: str, - name_mode: OptionNameMode | OptStr = _NotSet, + name_mode: OptionNameMode | OptStr | _NotSetType = _NotSet, env_var: OptStrs = None, strict_env: bool = True, use_env_value: Bool = None, @@ -552,7 +587,7 @@ def _handle_bad_action(self, action: str) -> NoReturn: raise ParameterDefinitionError(f'Invalid {action=} for {self.__class__.__name__} - did you mean {fixed!r}?') super()._handle_bad_action(action) - def __set_name__(self, command: CommandCls, name: str): + def __set_name__(self, command: _CmdCls, name: str): super().__set_name__(command, name) if not self.option_strs.name_mode: self.option_strs.name_mode = self._config(command).option_name_mode @@ -605,7 +640,7 @@ def __set__(self, instance: Parameter, value: LeadingDash): instance.__dict__[self.name] = value -class DefaultCallback: +class DefaultCallback(Generic[T]): __slots__ = ('func', 'use_cmd') def __init__(self, func: CommandMethod | DefaultFunc, use_cmd: bool = False): @@ -615,7 +650,7 @@ def __init__(self, func: CommandMethod | DefaultFunc, use_cmd: bool = False): def __repr__(self) -> str: return f'<{self.__class__.__name__}({self.func!r}, use_cmd={self.use_cmd})>' - def __call__(self, command: CommandObj | None) -> T_co: + def __call__(self, command: Command) -> T: # If the func isn't a method / doesn't accept the command, then `command` must not be None, but the default # callback is intentionally not called by ParamAction.get_default (and its subclasses) when command is None. - return self.func(command) if self.use_cmd else self.func() + return self.func(command) if self.use_cmd else self.func() # type: ignore[call-arg] diff --git a/lib/cli_command_parser/parameters/options.py b/lib/cli_command_parser/parameters/options.py index bd5488ea..a1b72e3f 100644 --- a/lib/cli_command_parser/parameters/options.py +++ b/lib/cli_command_parser/parameters/options.py @@ -13,14 +13,14 @@ from ..exceptions import BadArgument, CommandDefinitionError, ParameterDefinitionError, ParamUsageError, ParserExit from ..inputs import normalize_input_type from ..nargs import Nargs, NargsValue -from ..typing import T_co, TypeFunc +from ..typing import TypeFunc from ..utils import _NotSet, str_to_bool from .actions import Append, AppendConst, Count, Store, StoreConst -from .base import AllowLeadingDashProperty, BaseOption +from .base import AllowLeadingDashProperty, BaseOption, CommandMethod from .option_strings import TriFlagOptionStrings if TYPE_CHECKING: - from ..typing import Bool, ChoicesType, CommandCls, CommandMethod, CommandObj, InputTypeFunc, LeadingDash, OptStr + from ..typing import Bool, ChoicesType, CommandCls, CommandObj, InputTypeFunc, LeadingDash, OptStr __all__ = [ 'Option', @@ -35,13 +35,14 @@ ] log = logging.getLogger(__name__) +T = TypeVar('T') TD = TypeVar('TD') TC = TypeVar('TC') TA = TypeVar('TA') ConstAct = Literal['store_const', 'append_const'] -class Option(BaseOption[T_co | TD], actions=(Store, Append)): +class Option(BaseOption[T | TD], actions=(Store, Append)): """ A generic option that can be specified as ``--foo bar`` or by using other similar forms. diff --git a/lib/cli_command_parser/parameters/positionals.py b/lib/cli_command_parser/parameters/positionals.py index 76b7b0df..c734a9bf 100644 --- a/lib/cli_command_parser/parameters/positionals.py +++ b/lib/cli_command_parser/parameters/positionals.py @@ -13,10 +13,10 @@ from ..nargs import Nargs, NargsValue from ..utils import _NotSet from .actions import Append, Store -from .base import AllowLeadingDashProperty, BasePositional +from .base import AllowLeadingDashProperty, BasePositional, DefaultFunc if TYPE_CHECKING: - from ..typing import ChoicesType, DefaultFunc, InputTypeFunc, LeadingDash + from ..typing import ChoicesType, InputTypeFunc, LeadingDash __all__ = ['Positional'] diff --git a/lib/cli_command_parser/parser.py b/lib/cli_command_parser/parser.py index ecf040bd..40136029 100644 --- a/lib/cli_command_parser/parser.py +++ b/lib/cli_command_parser/parser.py @@ -28,8 +28,10 @@ if TYPE_CHECKING: from .command_parameters import CommandParameters + from .commands import Command from .config import CommandConfig - from .typing import Bool, CommandType, OptStr + from .core import CommandMeta + from .typing import Bool, OptStr __all__ = ['CommandParser', 'parse_args_and_get_next_cmd'] log = logging.getLogger(__name__) @@ -62,7 +64,7 @@ def __init__(self, ctx: Context, params: CommandParameters, config: CommandConfi PosNode.build_tree(ctx.command_cls) @classmethod - def parse_args_and_get_next_cmd(cls, ctx: Context) -> CommandType | None: + def parse_args_and_get_next_cmd(cls, ctx: Context) -> CommandMeta | None: try: return cls(ctx, ctx.params, ctx.config).get_next_cmd(ctx) except UsageError: @@ -70,7 +72,7 @@ def parse_args_and_get_next_cmd(cls, ctx: Context) -> CommandType | None: raise return None - def get_next_cmd(self, ctx: Context) -> CommandType | None: + def get_next_cmd(self, ctx: Context) -> CommandMeta | None: self._parse_args(ctx) self._validate_groups() missing = ctx.get_missing() diff --git a/lib/cli_command_parser/testing.py b/lib/cli_command_parser/testing.py index 58c86ed0..d5f967be 100644 --- a/lib/cli_command_parser/testing.py +++ b/lib/cli_command_parser/testing.py @@ -3,7 +3,6 @@ :author: Doug Skrypa """ -# pylint: disable=R0913,C0103 from __future__ import annotations @@ -25,7 +24,7 @@ from .parameters import help_action if TYPE_CHECKING: - from .typing import CommandCls, OptStr + from .typing import OptStr __all__ = [ 'ParserTest', @@ -50,6 +49,8 @@ ExcCases = Iterable[ExceptionCase] CallExceptionCase = tuple[Kwargs, ExcType] | tuple[Kwargs, ExcType, str] CallExceptionCases = Iterable[CallExceptionCase] +MaybeCallExcCases = Iterable[ExceptionCase | CallExceptionCase] +CommandCls = Type[Command] OPT_ENV_MOD = 'cli_command_parser.parser.environ' EXCLUDE_ACTIONS = (help_action,) @@ -84,13 +85,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): class ParserTest(TestCase): - # def setUp(self): - # print() - # - # def subTest(self, *args, **kwargs): - # print() - # return super().subTest(*args, **kwargs) - def assert_dict_equal(self, d1, d2, msg: OptStr = None): self.assertIsInstance(d1, dict, 'First argument is not a dictionary') self.assertIsInstance(d2, dict, 'Second argument is not a dictionary') @@ -139,7 +133,9 @@ def assert_parse_fails( with AssertRaisesWithStringContext(self, expected_exc, expected_pattern, msg): cmd_cls.parse(argv) - def assert_parse_fails_cases(self, cmd_cls: CommandCls, cases: ExcCases, exc: ExcType = None, msg: OptStr = None): + def assert_parse_fails_cases( + self, cmd_cls: CommandCls, cases: ExcCases, exc: ExcType | None = None, msg: OptStr = None + ): for argv, exc, pat in _iter_exc_cases(cases, exc): with self.subTest(expected='exception', argv=argv): with AssertRaisesWithStringContext(self, exc, pat, msg): @@ -171,16 +167,24 @@ def assert_call_fails_cases(self, func: Callable, cases: Iterable[CallExceptionC def assert_strings_equal( self, expected: str, actual: str, message: OptStr = None, diff_lines: int = 3, trim: bool = False ): + """ + Assert that two strings are equal. Primarily intended for cases involving multi-line text. + + :param expected: Expected string + :param actual: Actual string that was produced by code that is being tested + :param message: Message to be displayed if the strings do not match (default: a multi-line colored diff is + displayed) + :param diff_lines: Number of lines of context before/after differing lines that should be displayed in the diff + :param trim: Whether the expected and actual strings should be normalized by removing trailing whitespace + """ if trim: expected = expected.rstrip() actual = '\n'.join(line.rstrip() for line in actual.splitlines()) + if message: self.assertEqual(expected, actual, message) elif expected != actual: diff = format_diff(expected, actual, n=diff_lines) - # if not diff.strip(): - # self.assertEqual(expected, actual) - # else: self.fail('Strings did not match:\n' + diff) def assert_str_starts_with_line(self, prefix: str, text: str): @@ -198,14 +202,14 @@ def env_vars(self, case: str, **env_vars): yield -def _iter_exc_cases(cases: ExcCases | CallExceptionCases, exc: ExcType | None = None): +def _iter_exc_cases(cases: MaybeCallExcCases, exc: ExcType | None = None): if exc is not None: for args in cases: yield args, exc, None else: for case in cases: try: - args, exc = case + args, exc = case # type: ignore[assignment,misc] except ValueError: yield case # Assume it is a 3-tuple of ([argv|kwargs], exc, pattern) else: @@ -215,7 +219,7 @@ def _iter_exc_cases(cases: ExcCases | CallExceptionCases, exc: ExcType | None = # region Formatting -def _colored(text: str, color: int, end: str = '\n'): +def _colored(text: str, color: int, end: str = '\n') -> str: return f'\x1b[38;5;{color}m{text}\x1b[0m{end}' @@ -268,19 +272,21 @@ def format_dict_diff(a: dict[str, Any], b: dict[str, Any]) -> str: class RedirectStreams(AbstractContextManager): - _stdin: IO | str | bytes | None = None + _stdin: IO | None = None def __init__(self, stdin: IO | str | bytes | None = None): - self._old = {} + self._old: dict[str, IO] = {} if stdin is not None: if isinstance(stdin, bytes): self._stdin = BytesIO(stdin) - self._stdin.buffer = self._stdin # pretend to be the underlying buffer as well + # pretend to be the underlying buffer as well + self._stdin.buffer = self._stdin # type: ignore[attr-defined] elif isinstance(stdin, str): self._stdin = StringIO(stdin) - self._stdin.buffer = BytesIO(stdin.encode('utf-8')) + self._stdin.buffer = BytesIO(stdin.encode('utf-8')) # type: ignore[misc] else: self._stdin = stdin + self._stdout = StringIO() self._stderr = StringIO() @@ -293,12 +299,14 @@ def stderr(self) -> str: return self._stderr.getvalue() def __enter__(self) -> RedirectStreams: - streams = {'stdout': self._stdout, 'stderr': self._stderr} + streams: dict[str, IO] = {'stdout': self._stdout, 'stderr': self._stderr} if self._stdin is not None: - streams['stdin'] = self._stdin + streams['stdin'] = self._stdin # type: ignore[assignment] + for name, io in streams.items(): self._old[name] = getattr(sys, name) setattr(sys, name, io) + return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -310,12 +318,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): # region Help / Usage / RST Text -def get_usage_text(cmd: Type[Command]) -> str: +def get_usage_text(cmd: CommandCls) -> str: with cmd().ctx: return get_params(cmd).formatter.format_usage() -def get_help_text(cmd: Type[Command] | Command, terminal_width: int = 199) -> str: +def get_help_text(cmd: CommandCls | Command, terminal_width: int = 199) -> str: if not isinstance(cmd, Command): cmd = cmd() @@ -324,7 +332,7 @@ def get_help_text(cmd: Type[Command] | Command, terminal_width: int = 199) -> st return get_params(cmd).formatter.format_help() -def get_rst_text(cmd: Type[Command] | Command) -> str: +def get_rst_text(cmd: CommandCls | Command) -> str: if not isinstance(cmd, Command): cmd = cmd() @@ -336,7 +344,7 @@ def get_rst_text(cmd: Type[Command] | Command) -> str: # endregion -def sealed_mock(*args, **kwargs): +def sealed_mock(*args, **kwargs) -> Mock: kwargs.setdefault('return_value', None) mock = Mock(*args, **kwargs) seal(mock) diff --git a/lib/cli_command_parser/typing.py b/lib/cli_command_parser/typing.py index beaf9cfe..eab33203 100644 --- a/lib/cli_command_parser/typing.py +++ b/lib/cli_command_parser/typing.py @@ -7,18 +7,16 @@ from __future__ import annotations from typing import ( + IO, TYPE_CHECKING, Any, - BinaryIO, Callable, Collection, - Dict, Iterable, - List, Pattern, Sequence, - TextIO, Type, + TypeAlias, TypeVar, Union, ) @@ -26,6 +24,7 @@ if TYPE_CHECKING: from datetime import date, datetime, time, timedelta from enum import Enum + from numbers import Number as _Number from pathlib import Path from .commands import Command @@ -35,47 +34,41 @@ from .parameters import Parameter, ParamGroup T = TypeVar('T') -T_co = TypeVar('T_co', covariant=True) -TypeFunc = Callable[[str], T_co] +TypeFunc = Callable[[str], T] -NT = TypeVar('NT', bound=float, covariant=True) -Number = Union[NT, None] -NumType = Callable[[Union[str, float, int]], NT] -RngType = Union[range, int, Sequence[int]] +NT = TypeVar('NT', bound='_Number') +Number: TypeAlias = NT | None +NumType = Callable[[Any], NT] +RngType = range | int | Sequence[int] InputTypeFunc = Union[None, TypeFunc, 'InputType', range, Type['Enum'], Pattern] -ChoicesType = Union[Collection[Any], None] +ChoicesType = Collection[Any] | None -Bool = Union[bool, Any] +Bool = bool | Any StrSeq = Sequence[str] -Strs = Union[str, StrSeq] +Strs = str | StrSeq StrIter = Iterable[str] -IStrs = Union[str, StrIter] -OptStr = Union[str, None] -OptStrs = Union[Strs, None] +IStrs = str | StrIter +OptStr = str | None +OptStrs = Strs | None Strings = Collection[str] PathLike = Union[str, 'Path'] Locale = str | tuple[OptStr, OptStr] TimeBound = Union['datetime', 'date', 'time', 'timedelta', None] -FP = Union[TextIO, BinaryIO] -Deserializer = Callable[[Union[str, bytes, FP]], Any] -Serializer = Callable[..., Union[str, bytes, None]] -Converter = Union[Deserializer, Serializer] +Deserializer = Callable[[str | bytes | IO], Any] +Serializer = Callable[[Any, IO], None] | Callable[[Any], str | bytes] +Converter = Deserializer | Serializer Config = Union['CommandConfig', None] -AnyConfig = Union[Config, Dict[str, Any]] +AnyConfig = Config | dict[str, Any] LeadingDash = Union['AllowLeadingDash', str, bool] Param = TypeVar('Param', bound='Parameter') -ParamList = List[Param] +ParamList = list[Param] ParamOrGroup = Union[Param, 'ParamGroup'] CommandObj = TypeVar('CommandObj', bound='Command') -CommandType = TypeVar('CommandType', bound='CommandMeta') -CommandCls = Union[CommandType, Type[CommandObj]] -CommandAny = Union[CommandCls, CommandObj] - -CommandMethod = Callable[[CommandObj], T_co] -DefaultFunc = Union[Callable[[], T_co], CommandMethod] +CommandCls: TypeAlias = Type[CommandObj] +CommandAny: TypeAlias = CommandCls | CommandObj diff --git a/lib/cli_command_parser/utils.py b/lib/cli_command_parser/utils.py index ec50307b..84bc7b54 100644 --- a/lib/cli_command_parser/utils.py +++ b/lib/cli_command_parser/utils.py @@ -23,7 +23,7 @@ Self = TypeVar('Self') # type: ignore[misc,assignment] try: - from wcwidth import wcwidth + from wcwidth import wcwidth # type: ignore[import-untyped] except ImportError: wcwidth = len diff --git a/tests/test_conversion/test_convert_argparse.py b/tests/test_conversion/test_convert_argparse.py index a08a3268..dee67542 100755 --- a/tests/test_conversion/test_convert_argparse.py +++ b/tests/test_conversion/test_convert_argparse.py @@ -626,7 +626,7 @@ def tearDownClass(cls): ac_converter_map = Converter._ac_converter_map del ac_converter_map[next(ac for ac in ac_converter_map if ac.__name__ == 'ParserConstant')] for module in ('foo', 'foo.bar', 'foo.bar.baz'): - del Script._parser_classes[module] + del Script.mod_cls_to_ast_cls_map[module] def test_custom_parser_subclass(self): code = """ diff --git a/tests/test_inputs/test_numeric_inputs.py b/tests/test_inputs/test_numeric_inputs.py index bf4c220b..a5e4ea08 100755 --- a/tests/test_inputs/test_numeric_inputs.py +++ b/tests/test_inputs/test_numeric_inputs.py @@ -68,7 +68,7 @@ def test_num_range_str(self): def test_num_range_repr(self): self.assertEqual(", snap=False)[0 <= N]>", repr(NumRange(min=0))) - def test_num_range_requires_min_max(self): + def test_num_range_requires_min_or_max(self): with self.assert_raises_contains_str(ValueError, 'at least one of min and/or max values'): NumRange() @@ -84,6 +84,13 @@ def test_snap_rejects_float(self): with self.assert_raises_contains_str(TypeError, 'Unable to snap to extrema with type=float'): NumRange(snap=True, type=float, min=1) + def test_snap_no_min(self): + self.assertEqual(9, NumRange(snap=True, max=10)('20')) + self.assertEqual(10, NumRange(snap=True, max=10, include_max=True)('20')) + + def test_snap_no_max(self): + self.assertEqual(0, NumRange(snap=True, min=0)('-5')) + def test_num_range_auto_type(self): self.assertIs(int, NumRange(min=1, max=10).type) self.assertIs(float, NumRange(min=1.5, max=10).type)