From 80071879a828faf74a04751d9dd3ca566ea94e11 Mon Sep 17 00:00:00 2001 From: dskrypa Date: Tue, 17 Mar 2026 07:15:02 -0400 Subject: [PATCH 1/4] simplified AstCallable init by removing the semi-redundant call param --- .../conversion/argparse_ast.py | 73 +++++++++++-------- lib/cli_command_parser/conversion/visitor.py | 27 ++++--- .../test_conversion/test_convert_argparse.py | 27 ++++++- 3 files changed, 81 insertions(+), 46 deletions(-) diff --git a/lib/cli_command_parser/conversion/argparse_ast.py b/lib/cli_command_parser/conversion/argparse_ast.py index ae321d11..a552ef30 100644 --- a/lib/cli_command_parser/conversion/argparse_ast.py +++ b/lib/cli_command_parser/conversion/argparse_ast.py @@ -39,7 +39,7 @@ AC = TypeVar('AC', bound='AstCallable') ACGroup: TypeAlias = tuple[Type[AC], list[AC]] D = TypeVar('D') -VisitFunc = Callable[[InitNode, OptCall, 'TrackedRefMap'], AC] +VisitFunc = Callable[[InitNode, 'TrackedRefMap'], AC] class Script: @@ -80,10 +80,8 @@ def register_parser(cls, ast_cls: ParserCls) -> ParserCls: cls._register_parser(real_cls.__module__, real_cls.__name__, ast_cls) # type: ignore[union-attr] return ast_cls - def add_parser( - self, ast_cls: Type[ParserObj], node: InitNode, call: OptCall, tracked_refs: TrackedRefMap - ) -> ParserObj: - parser = ast_cls(node, self, tracked_refs, call) + def add_parser(self, ast_cls: Type[ParserObj], node: InitNode, tracked_refs: TrackedRefMap) -> ParserObj: + parser = ast_cls(node, self, tracked_refs) self._parsers.append(parser) return parser @@ -190,17 +188,18 @@ def __init_subclass__(cls, represents: RepresentedCallable | None = None, **kwar elif ABC not in cls.__bases__: raise NotImplementedError(f'Missing required "represents" class param for {cls.__name__}') - def __init__(self, node: InitNode, parent: AstCallable | Script, tracked_refs: TrackedRefMap, call: OptCall = None): + def __init__(self, node: InitNode, parent: AstCallable | Script, tracked_refs: TrackedRefMap): self.init_node = node - self._init_call(call if call else node.value if isinstance(node, Assign) else node) # type: ignore[arg-type] + if call := _get_call(node): + self.call_node = call + self.call_args = call.args + self.call_kwargs = call.keywords + else: + raise ValueError(f'Unexpected {node=}') + self._tracked_refs = tracked_refs self.parent = parent - def _init_call(self, call: Call): - self.call_node = call - self.call_args = call.args - self.call_kwargs = call.keywords - def __repr__(self) -> str: return f'<{self.__class__.__name__}[{self.init_call_repr()}]>' @@ -276,6 +275,20 @@ def pprint(self, indent: int = 0): print(f'{" " * indent} - {self!r}') +def _get_call(node: InitNode) -> Call | None: + match node: + case Call(): + return node + case withitem(): + if isinstance(node.context_expr, Call): + return node.context_expr + case Assign(): + if isinstance(node.value, Call): + return node.value + + return None + + # region Stdlib Argparse Wrappers @@ -295,30 +308,26 @@ def __init_subclass__(cls, children: Collection[str] = (), **kwargs): if children: cls._children = (*cls._children, *children) - def __init__(self, node: InitNode, parent: AstCallable | Script, tracked_refs: TrackedRefMap, call: OptCall = None): - super().__init__(node, parent, tracked_refs, call) + def __init__(self, node: InitNode, parent: AstCallable | Script, tracked_refs: TrackedRefMap): + super().__init__(node, parent, tracked_refs) self.args = [] self.groups = [] def __repr__(self) -> str: return f'<{self.__class__.__name__}: ``{self.init_call_repr()}``>' - def _add_child(self, cls: Type[AC], container: list[AC], node: InitNode, call: OptCall, refs: TrackedRefMap) -> AC: - child = cls(node, self, refs, call) + def _add_child(self, cls: Type[AC], container: list[AC], node: InitNode, refs: TrackedRefMap) -> AC: + child = cls(node, self, refs) container.append(child) return child @visit_func - def add_mutually_exclusive_group( - self, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap - ) -> MutuallyExclusiveGroup: - return self._add_child( # type: ignore[return-value] - MutuallyExclusiveGroup, self.groups, node, call, tracked_refs - ) + def add_mutually_exclusive_group(self, node: InitNode, tracked_refs: TrackedRefMap) -> MutuallyExclusiveGroup: + return self._add_child(MutuallyExclusiveGroup, self.groups, node, tracked_refs) # type: ignore[return-value] @visit_func - def add_argument_group(self, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap) -> ArgGroup: - return self._add_child(ArgGroup, self.groups, node, call, tracked_refs) + def add_argument_group(self, node: InitNode, tracked_refs: TrackedRefMap) -> ArgGroup: + return self._add_child(ArgGroup, self.groups, node, tracked_refs) def grouped_children(self) -> Iterator[ACGroup]: yield ParserArg, self.args @@ -355,8 +364,8 @@ class behaves, when :meth:`.add_parser` is called, the subparser is stored direc parent: AstArgumentParser @visit_func - def add_parser(self, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap): - sub_parser = self.parent._add_subparser(node, call, tracked_refs) + def add_parser(self, node: InitNode, tracked_refs: TrackedRefMap): + sub_parser = self.parent._add_subparser(node, tracked_refs) sub_parser.sp_parent = self return sub_parser @@ -367,8 +376,8 @@ class AstArgumentParser(ArgCollection, represents=ArgumentParser, children=('sub sub_parsers: list[SubParser] add_subparsers: AddVisitedChild[SubparsersAction] = AddVisitedChild(SubparsersAction, '_subparsers_actions') - def __init__(self, node: InitNode, parent: Script | ParserObj, tracked_refs: TrackedRefMap, call: OptCall = None): - super().__init__(node, parent, tracked_refs, call) + def __init__(self, node: InitNode, parent: Script | ParserObj, tracked_refs: TrackedRefMap): + super().__init__(node, parent, tracked_refs) self._subparsers_actions: list[SubparsersAction] = [] # Note: sub_parsers aren't included in grouped_children since they need different handling during conversion self.sub_parsers = [] @@ -379,21 +388,21 @@ def __repr__(self) -> str: @overload def _add_subparser( - self, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap, sub_parser_cls: Literal[None] = None + self, node: InitNode, tracked_refs: TrackedRefMap, sub_parser_cls: Literal[None] = None ) -> SubParser: ... @overload def _add_subparser( - self, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap, sub_parser_cls: Type[ParserObj] + self, node: InitNode, tracked_refs: TrackedRefMap, sub_parser_cls: Type[ParserObj] ) -> ParserObj: ... def _add_subparser( - self, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap, sub_parser_cls: Type[ParserObj] | None = None + self, node: InitNode, tracked_refs: TrackedRefMap, sub_parser_cls: Type[ParserObj] | None = None ) -> SubParser | ParserObj: """Add a subparser to this parser. Only meant to be called by :class:`SubparsersAction`.""" # Using default of None since the class hasn't been defined at the time it would need to be set as default return self._add_child( # type: ignore[misc] - sub_parser_cls or SubParser, self.sub_parsers, node, call, tracked_refs + sub_parser_cls or SubParser, self.sub_parsers, node, tracked_refs ) diff --git a/lib/cli_command_parser/conversion/visitor.py b/lib/cli_command_parser/conversion/visitor.py index bfc15fde..2078bcb5 100644 --- a/lib/cli_command_parser/conversion/visitor.py +++ b/lib/cli_command_parser/conversion/visitor.py @@ -6,15 +6,15 @@ from collections import ChainMap, defaultdict from enum import Enum from functools import partial, wraps -from typing import TYPE_CHECKING, Callable, Collection, Iterator, Literal, Union, overload +from typing import TYPE_CHECKING, Callable, Collection, Iterator, Literal, Type, Union, overload from .argparse_ast import AstArgumentParser, AstCallable, VisitFunc from .utils import get_name_repr if TYPE_CHECKING: TrackedRefMap = dict['TrackedRef', set[str]] - NameTrackedMap = dict[str, Union[Callable, 'TrackedRef']] TrackedValue = Union['TrackedRef', VisitFunc, AstCallable] + NameTrackedMap = dict[str, TrackedValue] RefName = str | AST __all__ = ['ScriptVisitor', 'TrackedRef'] @@ -48,8 +48,10 @@ def _scoped_method(self: ScriptVisitor, *args, **kwargs): class ScopedVisit: __slots__ = () - def __get__(self, instance: ScriptVisitor, owner): - return self if instance is None else partial(scoped(owner.generic_visit), instance) + def __get__(self, instance: ScriptVisitor | None, owner: Type[ScriptVisitor]): + if instance is None: + return self + return partial(scoped(owner.generic_visit), instance) class ScriptVisitor(NodeVisitor): @@ -64,6 +66,7 @@ def __init__(self, smart_loop_handling: bool = True, track_refs: Collection[Trac self.smart_loop_handling = smart_loop_handling self.scopes = ChainMap() # ChainMap that tracks the var/class/func/etc names available in a given scope self._tracked_refs: set[TrackedRef] = set() # References that are tracked, but not meant to be called + self._mod_name_tracked_map: dict[str, NameTrackedMap] = defaultdict(dict) # All tracked items by source module for ref in track_refs: self.track_refs_to(ref) @@ -206,6 +209,12 @@ def _visit_for_smart(self, node: For, loop_var: str, ele_names: list[str]): self._visit_for_elements(node, loop_var, ele_names) def _visit_for_elements(self, node: For, loop_var: str, ele_names: list[str]): + """ + Iterates over the discovered elements, calling :meth:`.generic_visit` for each iteration where one of the + elements is an item that is being tracked. + + When visiting a *for* loop with :meth:`.generic_visit`, the body of the loop is only visited once. + """ visited_any = False for name in ele_names: if ref := self.scopes.get(name): @@ -274,11 +283,9 @@ def visit_withitem(self, item: withitem): Visit a single ``withitem`` / context expression within a ``with ...:`` statement that may include one or more ``withitem``s / content expressions. """ - context_expr = item.context_expr - if func := self.resolve_ref(context_expr, True): - # Found a ``with foo(...):`` statement where *foo* is being tracked - call = context_expr if isinstance(context_expr, Call) else None - result = func(item, call, self.get_tracked_refs()) + if func := self.resolve_ref(item.context_expr, True): + # Found a `with foo(...):` statement where `foo` is being tracked or a `with bar:` where `bar = foo(...)` + result = func(item, self.get_tracked_refs()) if as_name := item.optional_vars: self.scopes[get_name_repr(as_name)] = result @@ -303,7 +310,7 @@ def visit_Assign(self, node: Assign): def visit_Call(self, node: Call) -> AstCallable | _NoCallType: if func := self.resolve_ref(node.func, True): - return func(node, node, self.get_tracked_refs()) + return func(node, self.get_tracked_refs()) return _NoCall diff --git a/tests/test_conversion/test_convert_argparse.py b/tests/test_conversion/test_convert_argparse.py index dee67542..3159d19d 100755 --- a/tests/test_conversion/test_convert_argparse.py +++ b/tests/test_conversion/test_convert_argparse.py @@ -484,8 +484,9 @@ def test_add_visit_func_attr_error(self): AstCallable.visit_funcs = original def test_ast_callable_misc(self): - ac = AstCallable(Mock(args=123, keywords=456), Mock(), {}) - self.assertEqual(123, ac.call_args) + call = ast.parse('foo(123, bar=456)').body[0].value + ac = AstCallable(call, Mock(), {}) + self.assertEqual(123, ac.call_args[0].value) self.assertIsNone(ac.get_tracked_refs('foo', 'bar', None)) with self.assertRaises(KeyError): self.assertIsNone(ac.get_tracked_refs('foo', 'bar')) @@ -582,8 +583,8 @@ def __init__(self, *args, **kwargs): self.constants = [] @visit_func - def add_subparser(self, node: InitNode, call: ast.Call, tracked_refs: TrackedRefMap): - return self._add_subparser(node, call, tracked_refs, SubParserShortcut) + def add_subparser(self, node: InitNode, tracked_refs: TrackedRefMap): + return self._add_subparser(node, tracked_refs, SubParserShortcut) def grouped_children(self): yield ParserConstant, self.constants @@ -645,6 +646,24 @@ def test_custom_parser_subclass(self): ] self.assert_strings_equal('\n\n\n'.join(cmds), convert_script(Script(code))) + def test_custom_parser_subclass_with_no_call(self): + code = """ +from argparse import SUPPRESS as hide +from foo.bar import ArgParser +parser = ArgParser(description='Parse args') +parser.add_constant('abc', 123) +sp1 = parser.add_subparser('action', 'one') +with sp1: + sp1.add_argument('--foo', help=hide) +sp2 = parser.add_subparser('action', 'two') + """ + cmds = [ + prep_expected('abc = 123', 'action = SubCommand()', description="'Parse args'"), + prep_cmd('foo = Option(hide=True)', name='One', base=CMD0), + prep_cmd('pass', name='Two', base=CMD0), + ] + self.assert_strings_equal('\n\n\n'.join(cmds), convert_script(Script(code))) + def test_converter_for_ast_callable_subclass(self): code = "from foo import ArgParser\np = ArgParser()\nsp = p.add_subparser(name='one')\nsp.add_argument('--foo')" self.assertEqual(Converter.for_ast_callable(Script(code).parsers[0]), ParserConverter) From d072c939a5a3dc8c8f23b24aaa33a7b70137cfb7 Mon Sep 17 00:00:00 2001 From: dskrypa Date: Sat, 21 Mar 2026 14:53:53 -0400 Subject: [PATCH 2/4] removed import shim in bin for argparse_to_command; cleaned up tag.py; added py.typed indicator --- MANIFEST.in | 1 + bin/argparse_to_command.py | 6 ------ bin/tag.py | 36 ++++++++++++++++----------------- lib/cli_command_parser/py.typed | 0 4 files changed, 18 insertions(+), 25 deletions(-) delete mode 100755 bin/argparse_to_command.py create mode 100644 lib/cli_command_parser/py.typed diff --git a/MANIFEST.in b/MANIFEST.in index 7dc909a6..a8824931 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include *.rst include *.txt +include *.typed include LICENSE include MANIFEST.in diff --git a/bin/argparse_to_command.py b/bin/argparse_to_command.py deleted file mode 100755 index 6f2f537c..00000000 --- a/bin/argparse_to_command.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env python - -from cli_command_parser.conversion.cli import main - -if __name__ == '__main__': - main() diff --git a/bin/tag.py b/bin/tag.py index 4796dc47..c04c8f22 100755 --- a/bin/tag.py +++ b/bin/tag.py @@ -2,7 +2,7 @@ import logging import re -from datetime import datetime +from datetime import UTC, datetime from pathlib import Path from subprocess import check_call, check_output from tempfile import TemporaryDirectory @@ -14,8 +14,8 @@ class TagUpdater(Command): - version_file_path: Path = Option( - '-p', metavar='PATH', default=DEFAULT_PATH, help='Path to the __version__.py file to update' + version_file_path = Option( + '-p', metavar='PATH', type=Path, default=DEFAULT_PATH, help='Path to the __version__.py file to update' ) verbose = Counter('-v', help='Increase logging verbosity (can specify multiple times)') dry_run = Flag('-D', help='Print the actions that would be taken instead of taking them') @@ -23,7 +23,7 @@ class TagUpdater(Command): '-S', help='Always include a suffix (default: only when multiple versions are created on the same day)' ) - def main(self): + def main(self) -> None: log_fmt = '%(asctime)s %(levelname)s %(name)s %(lineno)d %(message)s' if self.verbose > 1 else '%(message)s' logging.basicConfig(level=logging.DEBUG if self.verbose else logging.INFO, format=log_fmt) @@ -44,37 +44,35 @@ def main(self): check_call(['git', 'tag', next_version]) check_call(['git', 'push', '--tags']) - def update_version(self) -> str | None: + def update_version(self) -> str: version_pat = re.compile(r'^(\s*__version__\s?=\s?)(["\'])(\d{4}\.\d{2}\.\d{2}(?:-\d+)?)\2$') path = self.version_file_path - found = False - new_ver = None + new_ver: str | None = None with TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir).joinpath('tmp.txt') log.debug(f'Writing updated file to temp file={tmp_path}') with path.open('r', encoding='utf-8') as f_in, tmp_path.open('w', encoding='utf-8', newline='\n') as f_out: for line in f_in: - if found: + if new_ver: f_out.write(line) elif m := version_pat.match(line): - found = True new_ver, new_line = self._updated_version_line(m.groups()) f_out.write(new_line) else: f_out.write(line) - if found: + if new_ver: if self.dry_run: log.info(f'[DRY RUN] Would replace original file={path.as_posix()} with modified version') else: log.info(f'Replacing original file={path.as_posix()} with modified version') tmp_path.replace(path) - else: - raise RuntimeError(f'No valid version was found in {path.as_posix()}') - return new_ver + return new_ver - def _updated_version_line(self, groups): + raise RuntimeError(f'No valid version was found in {path.as_posix()}') + + def _updated_version_line(self, groups) -> tuple[str, str]: var, quote, old_ver = groups new_ver = get_next_version(old_ver, self.force_suffix) prefix = '[DRY RUN] Would replace' if self.dry_run else 'Replacing' @@ -83,25 +81,25 @@ def _updated_version_line(self, groups): return new_ver, new_line -def get_latest_tag(): +def get_latest_tag() -> str: stdout: str = check_output(['git', 'tag', '--list'], text=True) versions = [] for line in stdout.splitlines(): try: - date, suffix = line.split('-') + date, str_suffix = line.split('-') except ValueError: date = line suffix = 0 else: - suffix = int(suffix) + suffix = int(str_suffix) versions.append((date, suffix)) date, suffix = max(versions) return f'{date}-{suffix}' -def get_next_version(old_ver: str, force_suffix: bool = False): +def get_next_version(old_ver: str, force_suffix: bool = False) -> str: try: old_date_str, old_suffix = old_ver.split('-') except ValueError: @@ -111,7 +109,7 @@ def get_next_version(old_ver: str, force_suffix: bool = False): new_suffix = int(old_suffix) + 1 old_date = datetime.strptime(old_date_str, '%Y.%m.%d').date() - today = datetime.utcnow().date() + today = datetime.now(UTC).date() today_str = today.strftime('%Y.%m.%d') if old_date < today: if force_suffix: diff --git a/lib/cli_command_parser/py.typed b/lib/cli_command_parser/py.typed new file mode 100644 index 00000000..e69de29b From 5478206dcc5a6660139f3d9ce390d679c42f4bfe Mon Sep 17 00:00:00 2001 From: dskrypa Date: Sat, 21 Mar 2026 15:00:01 -0400 Subject: [PATCH 3/4] added stricted nargs str typing; added tuple expansion support for argparse conversion; misc minor typing cleanup --- lib/cli_command_parser/command_parameters.py | 2 +- .../conversion/command_builder.py | 4 ++-- lib/cli_command_parser/conversion/visitor.py | 20 ++++++++++++------- lib/cli_command_parser/core.py | 10 +++++++--- lib/cli_command_parser/nargs.py | 7 +++++-- lib/cli_command_parser/parse_tree.py | 6 ++---- .../test_conversion/test_convert_argparse.py | 11 ++++++---- 7 files changed, 37 insertions(+), 23 deletions(-) diff --git a/lib/cli_command_parser/command_parameters.py b/lib/cli_command_parser/command_parameters.py index a47aa998..2dbd2a1f 100644 --- a/lib/cli_command_parser/command_parameters.py +++ b/lib/cli_command_parser/command_parameters.py @@ -262,7 +262,7 @@ def _process_option_strs(self, param: BaseOption, opt_type: str, opt_strs: Strin ) def _process_action_flags(self) -> None: - action_flags: list[ActionFlag] = sorted(p for p in self.options if isinstance(p, ActionFlag)) + action_flags: ActionFlags = sorted(p for p in self.options if isinstance(p, ActionFlag)) # type: ignore[misc] grouped_ordered_flags: dict[bool, dict[int | float, ActionFlags]] = { True: defaultdict(list), False: defaultdict(list), diff --git a/lib/cli_command_parser/conversion/command_builder.py b/lib/cli_command_parser/conversion/command_builder.py index 34100da3..391cf557 100644 --- a/lib/cli_command_parser/conversion/command_builder.py +++ b/lib/cli_command_parser/conversion/command_builder.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, fields from functools import cached_property from itertools import count -from typing import TYPE_CHECKING, Generic, Iterable, Iterator, MutableMapping, Type, TypeVar +from typing import TYPE_CHECKING, Generic, Iterable, Iterator, Type, TypeVar from cli_command_parser.nargs import Nargs @@ -597,7 +597,7 @@ class ParamArgs(ParamBaseArgs): def init_positional(cls, action: OptStr = None, nargs: OptStr = None, **kwargs): if nargs is not None: if (parsed := literal_eval_or_none(nargs)) is not None: - nargs_obj = Nargs(parsed) + nargs_obj = Nargs(parsed) # type: ignore[arg-type] if action in ('store', None) and nargs_obj == 1: action = nargs = None else: diff --git a/lib/cli_command_parser/conversion/visitor.py b/lib/cli_command_parser/conversion/visitor.py index 2078bcb5..6329c1ef 100644 --- a/lib/cli_command_parser/conversion/visitor.py +++ b/lib/cli_command_parser/conversion/visitor.py @@ -2,11 +2,11 @@ import logging import re -from ast import AST, Assign, Attribute, Call, For, Import, ImportFrom, Name, NodeVisitor, withitem +from ast import AST, Assign, Attribute, Call, For, Import, ImportFrom, List, Name, NodeVisitor, Tuple, withitem from collections import ChainMap, defaultdict from enum import Enum from functools import partial, wraps -from typing import TYPE_CHECKING, Callable, Collection, Iterator, Literal, Type, Union, overload +from typing import TYPE_CHECKING, Collection, Iterator, Literal, Type, Union, overload from .argparse_ast import AstArgumentParser, AstCallable, VisitFunc from .utils import get_name_repr @@ -66,12 +66,11 @@ def __init__(self, smart_loop_handling: bool = True, track_refs: Collection[Trac self.smart_loop_handling = smart_loop_handling self.scopes = ChainMap() # ChainMap that tracks the var/class/func/etc names available in a given scope self._tracked_refs: set[TrackedRef] = set() # References that are tracked, but not meant to be called - self._mod_name_tracked_map: dict[str, NameTrackedMap] = defaultdict(dict) # All tracked items by source module for ref in track_refs: self.track_refs_to(ref) - def track_callable(self, module: str, name: str, cb: Callable): + def track_callable(self, module: str, name: str, cb: VisitFunc | AstCallable): self._mod_name_tracked_map[module][name] = cb def track_refs_to(self, ref: TrackedRef): @@ -235,7 +234,7 @@ def resolve_ref(self, name: RefName, only_visitable: Literal[False] = False) -> @overload def resolve_ref(self, name: RefName, only_visitable: Literal[True]) -> VisitFunc | None: ... - def resolve_ref(self, name: RefName, only_visitable: bool = False) -> VisitFunc | AstCallable | None: + def resolve_ref(self, name: RefName, only_visitable: bool = False) -> TrackedValue | None: """ Resolve the given reference to a tracked item in the current scope. @@ -250,7 +249,7 @@ def resolve_ref(self, name: RefName, only_visitable: bool = False) -> VisitFunc return getattr(obj, attr) if attr in obj.visit_funcs else None return None if only_visitable else obj case None | TrackedRef(): - return None + return None if only_visitable else obj case _: return obj if attr is None else None @@ -294,7 +293,8 @@ def visit_Assign(self, node: Assign): Visit an assignment statement where one or more variables (stored in ``Assign.targets``) are being assigned one or more values (stored in ``Assign.value``). """ - match node.value: + # Note: node.targets only contains multiple elements for chained assignments like `a = b = c` + match node.value: # The value on the right side of `=` case Attribute() | Name(): # Assigning an alias to a variable; e.g., `foo = bar` or `foo = bar.baz` if ref := self.resolve_ref(node.value): @@ -307,6 +307,12 @@ def visit_Assign(self, node: Assign): if (result := self.visit_Call(node.value)) is not _NoCall: for target in node.targets: self.scopes[get_name_repr(target)] = result + case List() | Tuple(): + for target in node.targets: + if isinstance(target, (List, Tuple)) and len(target.elts) == len(node.value.elts): + for target_var, value in zip(target.elts, node.value.elts): + if ref := self.resolve_ref(value): + self.scopes[get_name_repr(target_var)] = ref def visit_Call(self, node: Call) -> AstCallable | _NoCallType: if func := self.resolve_ref(node.func, True): diff --git a/lib/cli_command_parser/core.py b/lib/cli_command_parser/core.py index 38420741..9c636eee 100644 --- a/lib/cli_command_parser/core.py +++ b/lib/cli_command_parser/core.py @@ -8,7 +8,7 @@ from __future__ import annotations from abc import ABC, ABCMeta -from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Iterator, Mapping, TypeVar, Union, overload +from typing import TYPE_CHECKING, overload from warnings import warn from weakref import WeakSet @@ -19,8 +19,12 @@ from .utils import _NotSet, _NotSetType if TYPE_CHECKING: - from .typing import AnyConfig, CommandAny, CommandCls, Config, OptStr + from collections.abc import Collection + from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar + from .typing import CommandAny, CommandCls, OptStr + + AnyConfig = CommandConfig | dict[str, Any] | None Bases = tuple[type, ...] Choice = str | None | _NotSetType Choices = Mapping[str, str | None] | Collection[str] @@ -148,7 +152,7 @@ def _from_parent(mcs, meth: Callable[[CommandMeta], T], bases: Bases | Iterable[ # region Config Methods @classmethod - def _prepare_config(mcs, bases: Bases, config: AnyConfig, kwargs: dict[str, Any]) -> Config: + def _prepare_config(mcs, bases: Bases, config: AnyConfig, kwargs: dict[str, Any]) -> CommandConfig | None: if config is not None: if kwargs: raise CommandDefinitionError(f'Cannot combine {config=} with keyword config arguments={kwargs}') diff --git a/lib/cli_command_parser/nargs.py b/lib/cli_command_parser/nargs.py index d6d2dd4a..c7b0679c 100644 --- a/lib/cli_command_parser/nargs.py +++ b/lib/cli_command_parser/nargs.py @@ -8,7 +8,7 @@ from collections.abc import Sequence from enum import Enum -from typing import Any, Collection, FrozenSet, TypeAlias +from typing import Any, Collection, FrozenSet, Literal, TypeAlias __all__ = ['Nargs', 'NargsValue', 'REMAINDER'] @@ -29,7 +29,10 @@ def __str__(self) -> str: SET_ERROR_FMT = 'Invalid nargs={!r} set - expected non-empty set where all values are integers >= 0' SEQ_ERROR_FMT = 'Invalid nargs={!r} sequence - expected 2 ints where 0 <= a <= b or b is None' -NargsValue: TypeAlias = str | int | tuple[int, _Max] | Sequence[int] | set[int] | FrozenSet[int] | range | _Remainder +NargsStr = Literal['?', '*', '+', 'REMAINDER'] +NargsValue: TypeAlias = ( + NargsStr | int | tuple[int, _Max] | Sequence[int] | set[int] | FrozenSet[int] | range | _Remainder +) class Nargs: diff --git a/lib/cli_command_parser/parse_tree.py b/lib/cli_command_parser/parse_tree.py index d33e9335..760ed78f 100644 --- a/lib/cli_command_parser/parse_tree.py +++ b/lib/cli_command_parser/parse_tree.py @@ -65,17 +65,15 @@ class PosNode(MutableMapping[Word, 'PosNode']): __slots__ = ('links', 'param', 'parent', 'target', 'word', '_any_word', '_any_node') links: dict[Word, PosNode] - param: BasePositional | None parent: PosNode | None - target: Target word: Word _any_word: AnyWord def __init__(self, word: Word, param: BasePositional | None, target: Target = None, parent: PosNode | None = None): self.links = {} - self.param = param + self.param: BasePositional | None = param self.parent = parent - self.target = target + self.target: Target = target self.word = word if parent is not None: parent[word] = self diff --git a/tests/test_conversion/test_convert_argparse.py b/tests/test_conversion/test_convert_argparse.py index 3159d19d..0368ae08 100755 --- a/tests/test_conversion/test_convert_argparse.py +++ b/tests/test_conversion/test_convert_argparse.py @@ -486,7 +486,7 @@ def test_add_visit_func_attr_error(self): def test_ast_callable_misc(self): call = ast.parse('foo(123, bar=456)').body[0].value ac = AstCallable(call, Mock(), {}) - self.assertEqual(123, ac.call_args[0].value) + self.assertEqual(123, ac.call.args[0].value) self.assertIsNone(ac.get_tracked_refs('foo', 'bar', None)) with self.assertRaises(KeyError): self.assertIsNone(ac.get_tracked_refs('foo', 'bar')) @@ -633,15 +633,17 @@ def test_custom_parser_subclass(self): code = """ from argparse import SUPPRESS as hide from foo.bar import ArgParser +a = hide parser = ArgParser(description='Parse args') parser.add_constant('abc', 123) with parser.add_subparser('action', 'one') as sp1: - sp1.add_argument('--foo', help=hide) + sp1.add_argument('--foo', help=a) + sp1.add_argument('--bar', help=hide) sp2 = parser.add_subparser('action', 'two') """ cmds = [ prep_expected('abc = 123', 'action = SubCommand()', description="'Parse args'"), - prep_cmd('foo = Option(hide=True)', name='One', base=CMD0), + prep_cmd('foo = Option(hide=True)', 'bar = Option(hide=True)', name='One', base=CMD0), prep_cmd('pass', name='Two', base=CMD0), ] self.assert_strings_equal('\n\n\n'.join(cmds), convert_script(Script(code))) @@ -650,11 +652,12 @@ def test_custom_parser_subclass_with_no_call(self): code = """ from argparse import SUPPRESS as hide from foo.bar import ArgParser +a, b = hide, 123 parser = ArgParser(description='Parse args') parser.add_constant('abc', 123) sp1 = parser.add_subparser('action', 'one') with sp1: - sp1.add_argument('--foo', help=hide) + sp1.add_argument('--foo', help=a) sp2 = parser.add_subparser('action', 'two') """ cmds = [ From d885bb5a2a96100268353f66f7305c803ad233d2 Mon Sep 17 00:00:00 2001 From: dskrypa Date: Sat, 21 Mar 2026 15:10:18 -0400 Subject: [PATCH 4/4] #151 refactored Flag types to have a new BaseFlag base; overhauled internal type helper modules; fixed param value type detection for mypy for 3.10~3.14 --- lib/cli_command_parser/commands.py | 8 +- lib/cli_command_parser/config.py | 7 +- lib/cli_command_parser/context.py | 13 +- .../conversion/argparse_ast.py | 42 ++- lib/cli_command_parser/conversion/cli.py | 10 +- lib/cli_command_parser/formatting/params.py | 16 +- lib/cli_command_parser/inputs/__init__.py | 46 ++- lib/cli_command_parser/inputs/_typing.py | 20 ++ lib/cli_command_parser/inputs/choices.py | 22 +- lib/cli_command_parser/inputs/files.py | 8 +- lib/cli_command_parser/inputs/numeric.py | 21 +- lib/cli_command_parser/inputs/patterns.py | 87 +++++- lib/cli_command_parser/inputs/time.py | 10 +- lib/cli_command_parser/inputs/utils.py | 8 +- lib/cli_command_parser/metadata.py | 7 +- lib/cli_command_parser/parameters/_typing.py | 15 + lib/cli_command_parser/parameters/actions.py | 17 +- lib/cli_command_parser/parameters/base.py | 292 ++++++++++++------ .../parameters/choice_map.py | 22 +- lib/cli_command_parser/parameters/groups.py | 8 +- lib/cli_command_parser/parameters/options.py | 212 ++++++++++--- .../parameters/pass_thru.py | 6 +- .../parameters/positionals.py | 24 +- lib/cli_command_parser/typing.py | 78 +++-- lib/cli_command_parser/utils.py | 10 +- tests/test_parameters/test_counters.py | 4 + tests/test_parameters/test_positionals.py | 6 - 27 files changed, 689 insertions(+), 330 deletions(-) create mode 100644 lib/cli_command_parser/inputs/_typing.py create mode 100644 lib/cli_command_parser/parameters/_typing.py diff --git a/lib/cli_command_parser/commands.py b/lib/cli_command_parser/commands.py index c7bb5acd..87237f52 100644 --- a/lib/cli_command_parser/commands.py +++ b/lib/cli_command_parser/commands.py @@ -18,7 +18,7 @@ from .utils import maybe_await if TYPE_CHECKING: - from .typing import Bool, CommandObj + from .typing import Bool, Self __all__ = ['Command', 'AsyncCommand', 'main', 'print_help'] log = logging.getLogger(__name__) @@ -50,7 +50,7 @@ def __repr__(self) -> str: # region Parse & Run @classmethod - def parse_and_run(cls: Type[CommandObj], argv: Argv | None = None, **kwargs) -> CommandObj | None: + def parse_and_run(cls, argv: Argv | None = None, **kwargs) -> Self | None: """ Primary entry point for parsing arguments, resolving subcommands, and running a command. @@ -81,7 +81,7 @@ def parse_and_run(cls: Type[CommandObj], argv: Argv | None = None, **kwargs) -> # region Parse @classmethod - def parse(cls: Type[CommandObj], argv: Argv | None = None) -> CommandObj: + def parse(cls, argv: Argv | None = None) -> Self: """ Parses the specified arguments (or :data:`sys.argv`), and resolves the final subcommand class based on the parsed arguments, if necessary. Initializes the Command, but does not call any of its other methods. @@ -338,7 +338,7 @@ async def _after_main_(self, *args, **kwargs): await self._run_actions_(ActionPhase.AFTER_MAIN, args, kwargs) -def main(argv: Argv | None = None, return_command: Bool = False, **kwargs) -> CommandObj | None: +def main(argv: Argv | None = None, return_command: Bool = False, **kwargs) -> Command | None: """ Convenience function that can be used as the main entry point for a program. diff --git a/lib/cli_command_parser/config.py b/lib/cli_command_parser/config.py index fa65e958..2b43485d 100644 --- a/lib/cli_command_parser/config.py +++ b/lib/cli_command_parser/config.py @@ -11,11 +11,6 @@ from string import whitespace from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Sequence, Type, TypeAlias, TypeVar, overload -try: - from typing import Self -except ImportError: # added in 3.11 - Self = TypeVar('Self') # type: ignore[misc,assignment] - from .exceptions import CommandDefinitionError from .utils import FixedFlag, MissingMixin, _NotSet, _NotSetType, positive_int @@ -25,7 +20,7 @@ from .error_handling import ErrorHandler from .formatting.commands import CommandHelpFormatter from .formatting.params import ParamHelpFormatter - from .typing import Bool, ParamOrGroup + from .typing import Bool, ParamOrGroup, Self _CmdHelpFormatter: TypeAlias = Callable[[CommandMeta, CommandParameters], CommandHelpFormatter] _ParamHelpFormatter: TypeAlias = Callable[[ParamOrGroup], ParamHelpFormatter] diff --git a/lib/cli_command_parser/context.py b/lib/cli_command_parser/context.py index 3d27cd1d..a1025de5 100644 --- a/lib/cli_command_parser/context.py +++ b/lib/cli_command_parser/context.py @@ -26,9 +26,10 @@ from .commands import Command from .core import CommandMeta from .parameters import ActionFlag, BaseOption, Parameter - from .typing import AnyConfig, Bool, OptStr, ParamOrGroup, PathLike, StrSeq + from .typing import Bool, OptStr, ParamOrGroup, PathLike, StrSeq Argv = StrSeq | None + AnyConfig = CommandConfig | dict[str, Any] | None CommandCls: TypeAlias = Type[Command] | CommandMeta __all__ = ['Context', 'ctx', 'get_current_context', 'get_or_create_context', 'get_context', 'get_parsed', 'get_raw_arg'] @@ -59,7 +60,7 @@ def __init__( command_cls: CommandCls | None = None, *, parent: Context | None = None, - config: AnyConfig | None = None, + config: AnyConfig = None, terminal_width: int | None = None, allow_argv_prog: Bool = None, command: Command | None = None, @@ -130,7 +131,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def __contains__(self, param: ParamOrGroup | str | Any) -> bool: try: - self._parsed[param] + self._parsed[param] # type: ignore[index] except KeyError: if isinstance(param, str): try: @@ -156,7 +157,7 @@ def get_parsed( self, command: Command | None = None, *, - exclude: Collection[Parameter] = (), + exclude: Collection[Parameter[Any, Any]] = (), recursive: Bool = True, default: Any = None, include_defaults: Bool = True, @@ -441,6 +442,10 @@ def get_current_context(silent: Literal[False] = False) -> Context: ... def get_current_context(silent: Literal[True]) -> Context | None: ... +@overload +def get_current_context(silent: bool) -> Context | None: ... + + def get_current_context(silent: bool = False) -> Context | None: """ Get the currently active parsing context. diff --git a/lib/cli_command_parser/conversion/argparse_ast.py b/lib/cli_command_parser/conversion/argparse_ast.py index a552ef30..3697c275 100644 --- a/lib/cli_command_parser/conversion/argparse_ast.py +++ b/lib/cli_command_parser/conversion/argparse_ast.py @@ -11,11 +11,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, ClassVar, Generic, Literal, Type, TypeAlias, TypeVar, overload -try: - from typing import Self -except ImportError: # added in 3.11 - Self = TypeVar('Self') # type: ignore[misc,assignment] - from ..utils import _NotSet, _NotSetType from .argparse_utils import ArgumentParser as _ArgumentParser, SubParsersAction as _SubParsersAction from .utils import get_name_repr, iter_module_parents, unparse @@ -24,8 +19,7 @@ from collections.abc import Collection from typing import Any, Iterator - from cli_command_parser.typing import OptStr, PathLike - + from ..typing import OptStr, PathLike, Self from .visitor import TrackedRef, TrackedRefMap __all__ = ['ParserArg', 'ArgGroup', 'MutuallyExclusiveGroup', 'AstArgumentParser', 'SubParser', 'Script'] @@ -189,14 +183,13 @@ def __init_subclass__(cls, represents: RepresentedCallable | None = None, **kwar raise NotImplementedError(f'Missing required "represents" class param for {cls.__name__}') def __init__(self, node: InitNode, parent: AstCallable | Script, tracked_refs: TrackedRefMap): + """ + :param node: The AST node that this object represents. + :param parent: The parent script or AstCallable in which the node that this object represents exists. + :param tracked_refs: Mapping of :class:`~TrackedRef` objects to the set of variable names that are references + to the tracked item. + """ self.init_node = node - if call := _get_call(node): - self.call_node = call - self.call_args = call.args - self.call_kwargs = call.keywords - else: - raise ValueError(f'Unexpected {node=}') - self._tracked_refs = tracked_refs self.parent = parent @@ -204,6 +197,7 @@ def __repr__(self) -> str: return f'<{self.__class__.__name__}[{self.init_call_repr()}]>' def get_tracked_refs(self, module: str, name: str, default: D | _NotSetType = _NotSet) -> set[str] | D: + """Get the set of variable names that are references to the specified tracked item.""" for tracked_ref, refs in self._tracked_refs.items(): if tracked_ref.module == module and tracked_ref.name == name: return refs @@ -223,22 +217,28 @@ def _signature(cls) -> Signature: def signature(self) -> Signature: return self._signature() + @cached_property + def call(self) -> Call: + if call := _get_call(self.init_node): + return call + raise AttributeError(f'Unable to determine call for node={self.init_node}') + @cached_property def init_func_name(self) -> str: """The name or alias of the function/callable that was used to initialize this object""" - return get_name_repr(self.call_node.func) + return get_name_repr(self.call.func) @cached_property def _init_func_bound(self) -> BoundArguments: - args = self.call_args if isinstance(self.represents, type) else ('self', *self.call_args) - return self.signature.bind(*args, **{kw.arg: kw.value for kw in self.call_kwargs if kw.arg is not None}) + args = self.call.args if isinstance(self.represents, type) else ('self', *self.call.args) + return self.signature.bind(*args, **{kw.arg: kw.value for kw in self.call.keywords if kw.arg is not None}) @cached_property def init_func_args(self) -> list[str]: try: args = self._init_func_bound.args[1:] except (TypeError, AttributeError): # No represents func - args = self.call_args # type: ignore[assignment] + args = self.call.args # type: ignore[assignment] return [unparse(arg) for arg in args] @cached_property @@ -246,7 +246,7 @@ def init_func_raw_kwargs(self) -> dict[str, AST]: try: kwargs = self._init_func_bound.arguments except (TypeError, AttributeError): # No represents func - return {kw.arg: kw.value for kw in self.call_kwargs if kw.arg is not None} + return {kw.arg: kw.value for kw in self.call.keywords if kw.arg is not None} else: kwargs = kwargs.copy() kwargs.pop('self', None) @@ -401,9 +401,7 @@ def _add_subparser( ) -> SubParser | ParserObj: """Add a subparser to this parser. Only meant to be called by :class:`SubparsersAction`.""" # Using default of None since the class hasn't been defined at the time it would need to be set as default - return self._add_child( # type: ignore[misc] - sub_parser_cls or SubParser, self.sub_parsers, node, tracked_refs - ) + return self._add_child(sub_parser_cls or SubParser, self.sub_parsers, node, tracked_refs) # type: ignore[misc] class SubParser(AstArgumentParser, represents=_SubParsersAction.add_parser): diff --git a/lib/cli_command_parser/conversion/cli.py b/lib/cli_command_parser/conversion/cli.py index 81a8943a..b1d05560 100644 --- a/lib/cli_command_parser/conversion/cli.py +++ b/lib/cli_command_parser/conversion/cli.py @@ -17,12 +17,10 @@ class ParserConverter(Command): action = SubCommand() input: Param[Path] - no_smart_for: Param[bool] = Flag( - '-S', help='Disable "smart" for loop handling that attempts to dedupe common subparser params' - ) + no_smart_for = Flag('-S', help='Disable "smart" for loop handling that attempts to dedupe common subparser params') with ParamGroup('Common'): verbose = Counter('-v', help='Increase logging verbosity (can specify multiple times)') - dry_run: Flag[bool] = Flag('-D', help='Print the actions that would be taken instead of taking them') + dry_run = Flag('-D', help='Print the actions that would be taken instead of taking them') def _init_command_(self): log_fmt = '%(asctime)s %(levelname)s %(name)s %(lineno)d %(message)s' if self.verbose > 1 else '%(message)s' @@ -40,7 +38,7 @@ def script(self): class Convert(ParserConverter): """Print the cli-command-parser Commands that are equivalent to the discovered argparse ArgumentParsers""" - input: Param[Path] = Positional(type=IPath(type='file', exists=True), help=f'A file containing an {arg_parser}') + input = Positional(type=IPath(type='file', exists=True), help=f'A file containing an {arg_parser}') add_methods = Flag('--no-methods', '-M', default=True, help='Do not include boilerplate methods in Commands') def main(self): @@ -52,7 +50,7 @@ def main(self): class Pprint(ParserConverter): """Print a tiered internal representation of the discovered argparse ArgumentParsers and their groups/arguments""" - input: Param[Path] = Positional(type=IPath(type='file', exists=True), help=f'A file containing an {arg_parser}') + input = Positional(type=IPath(type='file', exists=True), help=f'A file containing an {arg_parser}') def main(self): for parser in self.script.parsers: diff --git a/lib/cli_command_parser/formatting/params.py b/lib/cli_command_parser/formatting/params.py index 297d1b0a..f74cf08c 100644 --- a/lib/cli_command_parser/formatting/params.py +++ b/lib/cli_command_parser/formatting/params.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from functools import cached_property -from typing import TYPE_CHECKING, Callable, ClassVar, Generic, Iterable, Iterator, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterable, Iterator, Type, TypeVar from ..config import CmdAliasMode, SubcommandAliasHelpMode from ..context import ctx @@ -28,9 +28,9 @@ BaseP = TypeVar('BaseP', bound=ParamBase) -ParamP = TypeVar('ParamP', bound=Parameter) -PosP = TypeVar('PosP', bound=BasePositional) -OptP = TypeVar('OptP', bound=BaseOption) +ParamP = TypeVar('ParamP', bound=Parameter[Any, Any]) +PosP = TypeVar('PosP', bound=BasePositional[Any, Any]) +OptP = TypeVar('OptP', bound=BaseOption[Any, Any]) class ParamHelpFormatter(ABC, Generic[BaseP]): @@ -123,7 +123,9 @@ def format_metavar(self) -> str: config = ctx.config if (t := param.type) is not None: try: - metavar = t.format_metavar(config.choice_delim, config.sort_choices) # type: ignore[attr-defined] + metavar = t.format_metavar( # type: ignore[union-attr,attr-defined] + config.choice_delim, config.sort_choices + ) except Exception: # noqa # pylint: disable=W0703 pass else: @@ -132,7 +134,7 @@ def format_metavar(self) -> str: if config.use_type_metavar and t is not None: try: - name = t.__name__ + name = t.__name__ # type: ignore[union-attr,attr-defined] except AttributeError: pass else: @@ -527,7 +529,7 @@ def rst_table(self) -> RstTable: formatter = member.formatter try: - sub_table: RstTable = formatter.rst_table() # noqa + sub_table: RstTable = formatter.rst_table() # type: ignore[attr-defined] except AttributeError: table.add_rows(formatter.rst_rows()) else: diff --git a/lib/cli_command_parser/inputs/__init__.py b/lib/cli_command_parser/inputs/__init__.py index 76a134f4..184f9152 100644 --- a/lib/cli_command_parser/inputs/__init__.py +++ b/lib/cli_command_parser/inputs/__init__.py @@ -21,7 +21,9 @@ from .utils import FileWrapper, StatMode if _t.TYPE_CHECKING: - from ..typing import ChoicesType, InputTypeFunc, TypeFunc + from ..typing import ChoicesType, InputTypeFunc, NormalizedType, T, TypeFunc + + TypeT: _t.TypeAlias = _t.Union[_t.Type[T], TypeFunc[T], InputType[T]] # fmt: off __all__ = [ @@ -38,8 +40,32 @@ _INVALID_TYPES_WITH_CHOICES = (Range, range, Regex, _Pattern, Glob) -def normalize_input_type(type_func: InputTypeFunc, param_choices: ChoicesType) -> TypeFunc | None: - if choices_provided := param_choices is not None: +if _t.TYPE_CHECKING: + + @_t.overload + def normalize_input_type(type_func: None, param_choices: None) -> None: ... + + @_t.overload + def normalize_input_type(type_func: InputType[T], param_choices: None) -> InputType[T]: ... + + @_t.overload + def normalize_input_type(type_func: TypeFunc[T], param_choices: None) -> TypeFunc[T]: ... + + @_t.overload + def normalize_input_type(type_func: _t.Type[T], param_choices: None) -> _t.Type[T]: ... + + @_t.overload + def normalize_input_type(type_func: TypeT[T] | None, param_choices: _t.Collection[T]) -> Choices[T]: ... + + @_t.overload + def normalize_input_type(type_func: range, param_choices: _t.Any) -> Range[int]: ... + + @_t.overload + def normalize_input_type(type_func: _Pattern, param_choices: _t.Any) -> Regex[str]: ... + + +def normalize_input_type(type_func: InputTypeFunc[T], param_choices: ChoicesType[T]) -> NormalizedType[T]: + if param_choices is not None: if not param_choices: raise _ParamDefinitionError(f'Invalid choices={param_choices!r} - when specified, choices cannot be empty') elif isinstance(param_choices, range): @@ -51,15 +77,19 @@ def normalize_input_type(type_func: InputTypeFunc, param_choices: ChoicesType) - match type_func: case None: - return Choices(param_choices) if choices_provided else type_func # type: ignore[arg-type] + if param_choices is None: + return type_func + return Choices(param_choices) case range(): return Range(type_func) case _Pattern(): return Regex(type_func) case _EnumMeta(): enum_choices: EnumChoices = EnumChoices(type_func) - if choices_provided: - return Choices(param_choices, enum_choices) # type: ignore[arg-type] - return enum_choices + if param_choices is None: + return enum_choices + return Choices(param_choices, enum_choices) - return Choices(param_choices, type_func) if choices_provided else type_func # type: ignore[arg-type] + if param_choices is None: + return type_func + return Choices(param_choices, type_func) diff --git a/lib/cli_command_parser/inputs/_typing.py b/lib/cli_command_parser/inputs/_typing.py new file mode 100644 index 00000000..fe73095c --- /dev/null +++ b/lib/cli_command_parser/inputs/_typing.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import IO, TYPE_CHECKING, Any, Callable, Sequence, TypeAlias, TypeVar, Union + +if TYPE_CHECKING: + from datetime import date, datetime, time, timedelta + from numbers import Number as _Number + + +Deserializer: TypeAlias = Callable[[str | bytes | IO], Any] +Serializer: TypeAlias = Callable[[Any, IO], None] | Callable[[Any], str | bytes] +Converter: TypeAlias = Deserializer | Serializer + +Locale = str | tuple[str | None, str | None] +TimeBound: TypeAlias = Union['datetime', 'date', 'time', 'timedelta', None] + +N = TypeVar('N', bound=Union['_Number', int, float]) +Number: TypeAlias = N | None +NumType: TypeAlias = Callable[[str | Any], N] +RngType: TypeAlias = range | int | Sequence[int] diff --git a/lib/cli_command_parser/inputs/choices.py b/lib/cli_command_parser/inputs/choices.py index 231fc2a1..3ab8bdd7 100644 --- a/lib/cli_command_parser/inputs/choices.py +++ b/lib/cli_command_parser/inputs/choices.py @@ -6,25 +6,21 @@ from __future__ import annotations -import sys from abc import ABC, abstractmethod -from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Collection, Iterator, Mapping, Type, TypeVar +from ..typing import T from .base import InputType from .exceptions import InvalidChoiceError if TYPE_CHECKING: + from enum import Enum + from ..typing import Bool __all__ = ['Choices', 'ChoiceMap', 'EnumChoices'] -if sys.version_info >= (3, 13): - T = TypeVar('T', default=str) -else: - T = TypeVar('T') - -EnumT = TypeVar('EnumT', bound=Enum) +E = TypeVar('E', bound='Enum') class _ChoicesBase(InputType[T], ABC): @@ -175,7 +171,7 @@ def fix_default(self, value: Any) -> T: return self(value) -class EnumChoices(_ChoicesBase[EnumT]): +class EnumChoices(_ChoicesBase[E]): """ Similar to :class:`ChoiceMap`, but uses an Enum to validate / normalize input instead of the keys in a dict. @@ -184,10 +180,10 @@ class EnumChoices(_ChoicesBase[EnumT]): """ __slots__ = () - type: Type[EnumT] - choices: Mapping[str, EnumT] + type: Type[E] + choices: Mapping[str, E] - def __init__(self, enum: Type[EnumT], case_sensitive: Bool = False): + def __init__(self, enum: Type[E], case_sensitive: Bool = False): super().__init__() # fix_default is not implemented here, so it's not necessary to expose self.type = enum self.case_sensitive = case_sensitive @@ -199,7 +195,7 @@ def _type_str(self) -> str: def _choices_repr(self, delim: str = ',') -> str: return delim.join(self.type._member_map_) - def __call__(self, value: str) -> EnumT: + def __call__(self, value: str) -> E: enum = self.type for val in self._iter_normalized(value): try: diff --git a/lib/cli_command_parser/inputs/files.py b/lib/cli_command_parser/inputs/files.py index 694fac23..bbd39436 100644 --- a/lib/cli_command_parser/inputs/files.py +++ b/lib/cli_command_parser/inputs/files.py @@ -9,13 +9,17 @@ import os from abc import ABC from pathlib import Path as _Path -from typing import IO +from typing import IO, TYPE_CHECKING -from ..typing import Bool, Converter, OptStr, PathLike, T +from ..typing import T from .base import InputType from .exceptions import InputValidationError from .utils import FileWrapper, InputParam, StatMode, allows_write, fix_windows_path +if TYPE_CHECKING: + from ..typing import Bool, OptStr, PathLike + from ._typing import Converter + __all__ = ['Path', 'File', 'Serialized', 'Json', 'Pickle'] diff --git a/lib/cli_command_parser/inputs/numeric.py b/lib/cli_command_parser/inputs/numeric.py index 1608f6d0..f7a90e67 100644 --- a/lib/cli_command_parser/inputs/numeric.py +++ b/lib/cli_command_parser/inputs/numeric.py @@ -8,23 +8,26 @@ import re from abc import ABC, abstractmethod -from typing import Literal +from typing import TYPE_CHECKING, Literal -from ..typing import NT, Bool, Number, NumType, RngType +from ._typing import N, Number, NumType, RngType from .base import _FixedInputType from .exceptions import InputValidationError from .utils import RangeMixin, range_str +if TYPE_CHECKING: + from ..typing import Bool + __all__ = ['NumericInput', 'Range', 'NumRange', 'Bytes'] _range = range -class NumericInput(_FixedInputType[NT], ABC): +class NumericInput(_FixedInputType[N], ABC): __slots__ = () -class _RangeInput(NumericInput[NT], ABC): +class _RangeInput(NumericInput[N], ABC): type: NumType def is_valid_type(self, value: str) -> bool: @@ -50,7 +53,7 @@ def format_metavar(self, choice_delim: str = ',', sort_choices: bool = False) -> return f'{{{self._range_str()}}}' -class Range(_RangeInput[NT]): +class Range(_RangeInput[N]): """ A range of integers that uses the builtin :class:`python:range`. If a range object is passed to a :class:`.Parameter` as the ``type=`` value, it will automatically be wrapped by this class. @@ -97,7 +100,7 @@ def _range_str(self, var: str = 'N') -> str: base = f'{rng_min} <= {var} <= {rng_max}' return base if step == 1 else f'{base}, {step=}' - def __call__(self, value: str) -> NT: + def __call__(self, value: str) -> N: num_val = self.type(value) if num_val in self.range: return num_val @@ -108,7 +111,7 @@ def __call__(self, value: str) -> NT: raise InputValidationError(f'expected a value in the range {self._range_str()}') -class NumRange(RangeMixin, _RangeInput[NT]): +class NumRange(RangeMixin, _RangeInput[N]): """ A range of integers or floats, optionally only bounded on one side. @@ -176,7 +179,7 @@ def __repr__(self) -> str: def _range_str(self, var: str = 'N') -> str: return range_str(self.min, self.max, self.include_min, self.include_max, var) - def handle_invalid(self, bound: Number, inclusive: bool, snap_dir: int) -> NT: + def handle_invalid(self, bound: Number, inclusive: bool, snap_dir: int) -> N: """ Handle calculating / returning a snap value or raise an exception if snapping to the bound is not allowed. @@ -191,7 +194,7 @@ def handle_invalid(self, bound: Number, inclusive: bool, snap_dir: int) -> NT: return bound if inclusive else (bound + snap_dir) raise InputValidationError(f'expected a value in the range {self._range_str()}') - def __call__(self, value: str) -> NT: + def __call__(self, value: str) -> N: num_val = self.type(value) # Note: if snap is enabled, it is applied by `handle_invalid` if self.value_lt_min(num_val): diff --git a/lib/cli_command_parser/inputs/patterns.py b/lib/cli_command_parser/inputs/patterns.py index 3501d9dd..2b9cac61 100644 --- a/lib/cli_command_parser/inputs/patterns.py +++ b/lib/cli_command_parser/inputs/patterns.py @@ -11,7 +11,7 @@ from abc import ABC from enum import Enum from fnmatch import translate -from typing import Collection, Match, Pattern, Sequence, TypeVar +from typing import TYPE_CHECKING, Collection, Literal, Match, Pattern, Sequence, TypeVar, overload from ..utils import MissingMixin from .base import InputType, T @@ -19,7 +19,10 @@ __all__ = ['Regex', 'RegexMode', 'Glob'] -RegexResult = TypeVar('RegexResult', str, Match, tuple[str, ...], dict[str, str]) +_Pat = str | Pattern +GroupsResult = tuple[str, ...] +DictResult = dict[str, str] +RegexResult = TypeVar('RegexResult', str, Match, GroupsResult, DictResult) class PatternInput(InputType[T], ABC): @@ -94,9 +97,60 @@ class Regex(PatternInput[RegexResult]): __slots__ = ('mode', 'group', 'groups') + # region Init Overloads + + if TYPE_CHECKING: + + @overload + def __init__( + self: Regex[str], + *patterns: _Pat, + group: str | int, + groups: Collection[str | int] | None = None, + mode: Literal['group', RegexMode.GROUP] | None = None, + ): ... + + @overload + def __init__( + self: Regex[GroupsResult], + *patterns: _Pat, + group: None = None, + groups: Collection[str | int], + mode: Literal['groups', RegexMode.GROUPS] | None = None, + ): ... + + @overload + def __init__( + self: Regex[str], + *patterns: _Pat, + group: None = None, + groups: None = None, + mode: Literal['string', RegexMode.STRING] | None = None, + ): ... + + @overload + def __init__( + self: Regex[Match], + *patterns: _Pat, + group: None = None, + groups: None = None, + mode: Literal['match', RegexMode.MATCH], + ): ... + + @overload + def __init__( + self: Regex[DictResult], + *patterns: _Pat, + group: None = None, + groups: None = None, + mode: Literal['dict', RegexMode.DICT], + ): ... + + # endregion + def __init__( self, - *patterns: str | Pattern, + *patterns: _Pat, group: str | int | None = None, groups: Collection[str | int] | None = None, mode: RegexMode | str | None = None, @@ -119,18 +173,19 @@ def __call__(self, value: str) -> RegexResult: if not (m := next((pm for p in self.patterns if (pm := p.search(value))), None)): raise InputValidationError(f'expected a value matching {self._describe_patterns()}') - if (mode := self.mode) == RegexMode.STRING: - return value # type: ignore[return-value] - elif mode == RegexMode.MATCH: - return m # type: ignore[return-value] - elif mode == RegexMode.GROUP: - return m.group(self.group) # type: ignore[arg-type,return-value] - elif mode == RegexMode.GROUPS: - if self.groups: - return tuple(m.group(g) for g in self.groups) # type: ignore[return-value] - return m.groups() # type: ignore[return-value] - else: # mode == RegexMode.DICT - return m.groupdict() # type: ignore[return-value] + match self.mode: + case RegexMode.STRING: + return value # type: ignore[return-value] + case RegexMode.MATCH: + return m # type: ignore[return-value] + case RegexMode.GROUP: + return m.group(self.group) # type: ignore[return-value,arg-type] + case RegexMode.GROUPS: + if self.groups: + return tuple(m.group(g) for g in self.groups) # type: ignore[return-value] + return m.groups() # type: ignore[return-value] + case _: # mode == RegexMode.DICT + return m.groupdict() # type: ignore[return-value] class Glob(PatternInput[str]): @@ -138,7 +193,7 @@ class Glob(PatternInput[str]): Validates that values match one of the provided glob / :doc:`fnmatch ` patterns. :param patterns: One or more glob pattern strings. - :param match_case: Whether matches should be case sensitive or not (default: False). + :param match_case: Whether matches should be case-sensitive or not (default: False). :param normcase: Whether :func:`python:os.path.normcase` should be called on patterns and values (default: False). """ diff --git a/lib/cli_command_parser/inputs/time.py b/lib/cli_command_parser/inputs/time.py index f9899581..19a1f396 100644 --- a/lib/cli_command_parser/inputs/time.py +++ b/lib/cli_command_parser/inputs/time.py @@ -22,14 +22,18 @@ from enum import Enum from locale import LC_ALL, setlocale from threading import RLock -from typing import Collection, Iterator, Literal, NoReturn, Sequence, Type, TypeVar, overload +from typing import TYPE_CHECKING, ClassVar, Collection, Iterator, Literal, NoReturn, Sequence, Type, TypeVar, overload -from ..typing import Bool, Locale, Number, OptStr, T, TimeBound +from ..typing import T from ..utils import MissingMixin from .base import InputType, _FixedInputType from .exceptions import InputValidationError, InvalidChoiceError from .utils import RangeMixin, range_str +if TYPE_CHECKING: + from ..typing import Bool, OptStr + from ._typing import Locale, Number, TimeBound + __all__ = ['DTFormatMode', 'Day', 'Month', 'TimeDelta', 'DateTime', 'Date', 'Time'] DT = TypeVar('DT', datetime, date, time) @@ -403,7 +407,7 @@ def format_metavar(self, choice_delim: str = ',', sort_choices: bool = False) -> class DateTimeInput(DTInput[DT], ABC): formats: Collection[str] - _type: Type[DT] + _type: ClassVar[Type[DT]] _earliest: TimeBound = None _latest: TimeBound = None # TODO: Add usage examples to the more user-friendly docs diff --git a/lib/cli_command_parser/inputs/utils.py b/lib/cli_command_parser/inputs/utils.py index 830b7e2e..2d3cc4fa 100644 --- a/lib/cli_command_parser/inputs/utils.py +++ b/lib/cli_command_parser/inputs/utils.py @@ -14,16 +14,12 @@ from typing import IO, TYPE_CHECKING, Any, Generic, Iterator, Literal, TypeVar, overload from weakref import finalize -try: - from typing import Self -except ImportError: # added in 3.11 - Self = TypeVar('Self') # type: ignore[misc,assignment] - from ..utils import FixedFlag from .exceptions import InputValidationError if TYPE_CHECKING: - from ..typing import Bool, Converter, Number, OptStr + from ..typing import Bool, OptStr, Self + from ._typing import Converter, Number __all__ = ['InputParam', 'StatMode', 'FileWrapper', 'fix_windows_path', 'range_str', 'RangeMixin'] diff --git a/lib/cli_command_parser/metadata.py b/lib/cli_command_parser/metadata.py index 0914b45f..e5829447 100644 --- a/lib/cli_command_parser/metadata.py +++ b/lib/cli_command_parser/metadata.py @@ -18,16 +18,11 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Literal, Type, TypeVar, overload from urllib.parse import urlparse -try: - from typing import Self -except ImportError: # added in 3.11 - Self = TypeVar('Self') # type: ignore[misc,assignment] - from .context import NoActiveContext, ctx if TYPE_CHECKING: from .core import CommandMeta - from .typing import Bool, OptStr + from .typing import Bool, OptStr, Self __all__ = ['ProgramMetadata'] diff --git a/lib/cli_command_parser/parameters/_typing.py b/lib/cli_command_parser/parameters/_typing.py new file mode 100644 index 00000000..3fef1190 --- /dev/null +++ b/lib/cli_command_parser/parameters/_typing.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, TypeAlias, Union + +from ..typing import D + +if TYPE_CHECKING: + from cli_command_parser.commands import Command + from cli_command_parser.config import AllowLeadingDash + + +CommandMethod: TypeAlias = Callable[['Command'], D] +DefaultFunc: TypeAlias = Callable[[], D] | CommandMethod[D] + +LeadingDash: TypeAlias = Union['AllowLeadingDash', str, bool] diff --git a/lib/cli_command_parser/parameters/actions.py b/lib/cli_command_parser/parameters/actions.py index 4835989a..b52df027 100644 --- a/lib/cli_command_parser/parameters/actions.py +++ b/lib/cli_command_parser/parameters/actions.py @@ -13,12 +13,12 @@ from ..exceptions import BadArgument, InvalidChoice, MissingArgument, ParamConflict, ParamUsageError, TooManyArguments from ..inputs import InputType from ..nargs import Nargs -from ..utils import _NotSet, _NotSetType, camel_to_snake_case +from ..utils import _NotSet, camel_to_snake_case if TYPE_CHECKING: - from .base import Parameter # noqa - from .options import Counter - from ..typing import Bool, CommandObj, OptStr + from ..commands import Command + from ..typing import Bool, OptStr + from .base import BaseFlag, Parameter __all__ = [ 'ParamAction', @@ -45,6 +45,7 @@ def __str__(self) -> str: _PANotSet = _PANotSetType._PANotSet P = TypeVar('P', bound='Parameter') +F = TypeVar('F', bound='BaseFlag') Found = Union[int, NoReturn] @@ -160,7 +161,7 @@ def can_reset(self) -> bool: # region Parsed Value / Default Finalization - def get_default(self, command: CommandObj | None = None, missing_default=_NotSet): + def get_default(self, command: Command | None = None, missing_default=_NotSet): if (default := self.param.default) is not _NotSet: return self.finalize_default(default) elif (default_cb := self.param.default_cb) and command is not None: @@ -214,7 +215,7 @@ def append_value(self, value): # parsed.extend(values) -class _ConstAction(ParamAction[P], ABC): +class _ConstAction(ParamAction[F], ABC): __slots__ = () _append: ClassVar[bool] @@ -359,7 +360,7 @@ def can_reset(self) -> bool: # region Parsed Value / Default Finalization - def get_default(self, command: CommandObj | None = None, missing_default=_NotSet): + def get_default(self, command: Command | None = None, missing_default=_NotSet): if (default := self.param.default) is not _NotSet: return self.finalize_default(default) elif (default_cb := self.param.default_cb) and command is not None: @@ -448,7 +449,7 @@ def add_const(self, *, opt: OptStr = None, combo: bool = False) -> Found: # region Parsed Value / Default Finalization - def get_default(self, command: CommandObj | None = None, missing_default=_NotSet): + def get_default(self, command: Command | None = None, missing_default=_NotSet): return [] # endregion diff --git a/lib/cli_command_parser/parameters/base.py b/lib/cli_command_parser/parameters/base.py index 354ed702..29443217 100644 --- a/lib/cli_command_parser/parameters/base.py +++ b/lib/cli_command_parser/parameters/base.py @@ -7,17 +7,11 @@ from __future__ import annotations import re -import sys from abc import ABC, abstractmethod from contextvars import ContextVar from functools import cached_property from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Type, TypeAlias, TypeVar, overload - -try: - from typing import Self -except ImportError: # added in 3.11 - Self = TypeVar('Self') # type: ignore[misc,assignment] +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Iterator, Type, TypeVar, overload from ..annotations import get_descriptor_value_type from ..config import DEFAULT_CONFIG, AllowLeadingDash, CommandConfig, OptionNameMode @@ -28,35 +22,30 @@ from ..inputs.exceptions import InputValidationError, InvalidChoiceError from ..inputs.numeric import NumericInput from ..nargs import REMAINDER, Nargs +from ..typing import D, T from ..utils import _NotSet, _NotSetType from .option_strings import OptionStrings if TYPE_CHECKING: from collections.abc import Collection - from typing import Literal, NoReturn + from typing import Literal, NoReturn, TypeAlias from ..commands import Command from ..formatting.params import ParamHelpFormatter - from ..typing import Bool, LeadingDash, OptStr, OptStrs, Strings + from ..typing import Bool, NormalizedType, OptStr, OptStrs, Self, Strings + from ._typing import CommandMethod, DefaultFunc, LeadingDash from .actions import ParamAction from .groups import ParamGroup + _ActCls = Type[ParamAction] _CmdCls = Type[Command] _CmdObjOrCls: TypeAlias = Command | _CmdCls -__all__ = ['Param', 'Parameter', 'BasePositional', 'BaseOption'] +__all__ = ['Param', 'Parameter', 'BasePositional', 'BaseOption', 'BaseFlag'] _group_stack: ContextVar[list[ParamGroup]] = ContextVar('cli_command_parser.parameters.base.group_stack') _is_numeric = re.compile(r'^-\d+$|^-\d*\.\d+?$').match -if sys.version_info >= (3, 13): - T = TypeVar('T', default=str) -else: - T = TypeVar('T') - -CommandMethod = Callable[['Command'], T] -DefaultFunc = Callable[[], T] | CommandMethod - TD = TypeVar('TD') @@ -69,9 +58,12 @@ class Param(Generic[T]): def __get__(self, command: Literal[None], owner: Any = None) -> Self: ... @overload - def __get__(self, command: object, owner: Any = None) -> T | None: ... + def __get__(self, command: Command, owner: Any = None) -> T | None: ... - def __get__(self, command: object | None, owner: Any = None) -> Self | T | None: ... + @overload + def __get__(self, command: object, owner: Any = None) -> Self: ... + + def __get__(self, command: Command | object | None, owner: Any = None) -> Self | T | None: ... class ParamBase(ABC): @@ -85,10 +77,11 @@ class ParamBase(ABC): :param hide: If ``True``, this parameter will not be included in usage / help messages. Defaults to ``False``. """ + missing_hint: ClassVar[OptStr] = None #: Hint to provide in exceptions if this param/group is missing + + # region Instance Attributes + # fmt: off - # Class Attributes - missing_hint: OptStr = None #: Hint to provide if this param/group is missing - # Instance Attributes _attr_name: OptStr = None #: Always the name of the attr that points to this object _name: OptStr = None #: An explicitly provided name, or the name of the attr that points to this obj group: ParamGroup | None = None #: The group this object is a member of, if any @@ -98,13 +91,9 @@ class ParamBase(ABC): hide: Bool #: Whether this param/group should be hidden in ``--help`` text # fmt: on - def __init__( - self, - name: OptStr = None, - required: Bool = False, - help: OptStr = None, # noqa - hide: Bool = False, - ): + # endregion + + def __init__(self, name: OptStr = None, required: Bool = False, help: OptStr = None, hide: Bool = False): # noqa self.__doc__ = help # Prevent this class's docstring from showing up for params in generated documentation self.required = required self.help = help @@ -137,24 +126,18 @@ def __set_name__(self, command: _CmdCls, name: str): # endregion - def __eq__(self, other) -> bool: - return ( - self.__class__ == other.__class__ - and self._attr_name == other._attr_name - and self._name == other._name - and self.command == other.command - ) + # region Internal Context & Config - def __hash__(self) -> int: - return hash(self.__class__) ^ hash(self._attr_name) ^ hash(self._name) ^ hash(self.command) + @overload + def _ctx(self, command: _CmdObjOrCls, strict: bool = False) -> Context: ... @overload - def _ctx(self, command: _CmdObjOrCls) -> Context: ... + def _ctx(self, command: Any, strict: Literal[True]) -> Context: ... @overload - def _ctx(self, command: Any) -> Context | None: ... + def _ctx(self, command: Any, strict: bool = False) -> Context | None: ... - def _ctx(self, command: _CmdObjOrCls | Any = None) -> Context | None: + def _ctx(self, command: _CmdObjOrCls | Any = None, strict: bool = False) -> Context | None: if context := get_current_context(True): return context @@ -164,6 +147,8 @@ def _ctx(self, command: _CmdObjOrCls | Any = None) -> Context | None: try: return command._Command__ctx # type: ignore[union-attr] # noqa except AttributeError: + if strict: + return get_current_context() return None def _config(self, command: _CmdCls | None = None) -> CommandConfig: @@ -173,6 +158,8 @@ def _config(self, command: _CmdCls | None = None) -> CommandConfig: command = self.command return command.__class__.config(command, DEFAULT_CONFIG) # type: ignore[union-attr] + # endregion + # region Usage / Help Text @cached_property @@ -202,7 +189,7 @@ def format_help(self, *args, **kwargs) -> str: # endregion -class Parameter(ParamBase, Param[T], ABC): +class Parameter(ParamBase, Param[T | D], ABC): """ Base class for all other parameters. It is not meant to be used directly. @@ -231,26 +218,12 @@ class Parameter(ParamBase, Param[T], ABC): :param hide: If ``True``, this parameter will not be included in usage / help messages. Defaults to ``False``. """ - # region Attributes & Initialization + # region Class Attributes & Initialization - # fmt: off - # Class attributes - _action_map: dict[str, Type[ParamAction]] = {} - _repr_attrs: Strings = () #: Attributes to include in ``repr()`` output - # Instance attributes with class defaults - metavar: OptStr = None - nargs: Nargs # Expected to be set in subclasses - type: Callable[[str], T] | None = None # Expected to be set in subclasses - allow_leading_dash: AllowLeadingDash = AllowLeadingDash.NUMERIC # Set in some subclasses - default: T | _NotSetType = _NotSet - default_cb: DefaultCallback | None = None - show_default: Bool = None - strict_default: Bool = False - # fmt: on + _action_map: ClassVar[dict[str, _ActCls]] = {} + _repr_attrs: ClassVar[Strings] = () #: Attributes to include in ``repr()`` output - def __init_subclass__( - cls, repr_attrs: Strings | None = None, actions: Collection[Type[ParamAction]] | None = None, **kwargs - ): + def __init_subclass__(cls, repr_attrs: Strings | None = None, actions: Collection[_ActCls] | None = None, **kwargs): """ :param repr_attrs: Additional attributes to include in the repr. :param actions: Collection of ParamAction classes that this type of Parameter supports @@ -263,7 +236,66 @@ def __init_subclass__( if repr_attrs: cls._repr_attrs = repr_attrs - def __init__( # pylint: disable=R0913 + # endregion + + # region Instance Attributes & Initialization + + # fmt: off + metavar: OptStr = None + allow_leading_dash: AllowLeadingDash = AllowLeadingDash.NUMERIC # Set in some subclasses + nargs: Nargs # Expected to be set in subclasses + + type: NormalizedType[T] = None + + # Expected to be set in subclasses + + default: D | _NotSetType = _NotSet #: Default value when no arg is provided + default_cb: DefaultCallback | None = None #: Callback that provides the default value + show_default: Bool = None #: Whether the default should be shown in help + strict_default: Bool = False #: Whether the default should be fixed by type + # fmt: on + + # region Init Overloads + + if TYPE_CHECKING: + + @overload + def __init__( + self: Parameter[T, _NotSetType], + action: str, + *, + help: OptStr = None, # noqa + hide: Bool = False, + metavar: OptStr = None, + name: OptStr = None, + required: Literal[True], + default: _NotSetType = _NotSet, + default_cb: None = None, + cb_with_cmd: Bool = False, + show_default: Bool = None, + strict_default: Bool = False, + ): ... + + @overload + def __init__( + self: Parameter[T, D], + action: str, + *, + help: OptStr = None, # noqa + hide: Bool = False, + metavar: OptStr = None, + name: OptStr = None, + required: Bool = False, + default: D | _NotSetType = _NotSet, + default_cb: DefaultFunc[D] | None = None, + cb_with_cmd: Bool = False, + show_default: Bool = None, + strict_default: Bool = False, + ): ... + + # endregion + + def __init__( self, action: str, *, @@ -272,8 +304,8 @@ def __init__( # pylint: disable=R0913 metavar: OptStr = None, name: OptStr = None, required: Bool = False, - default: T | _NotSetType = _NotSet, - default_cb: DefaultFunc | None = None, + default: D | _NotSetType = _NotSet, + default_cb: DefaultFunc[D] | None = None, cb_with_cmd: Bool = False, show_default: Bool = None, strict_default: Bool = False, @@ -336,7 +368,7 @@ def has_choices(self) -> Bool: return isinstance(self.type, _ChoicesBase) and self.type.choices return False - def register_default_cb(self, method: CommandMethod) -> CommandMethod: + def register_default_cb(self, method: CommandMethod[D]) -> CommandMethod[D]: """ Intended to be used as a decorator to register a method in a Command to be used as the default callback for this Parameter. The method will only be called during parsing if no value was explicitly provided for this @@ -357,11 +389,14 @@ def register_default_cb(self, method: CommandMethod) -> CommandMethod: raise ParameterDefinitionError( f'Cannot register a default callback method for {self} because it already has {problem}' ) + self.default_cb = DefaultCallback(method, True) return method # endregion + # region Boilerplate Methods + def __repr__(self) -> str: names = ('action', 'const', 'default', 'default_cb', 'type', 'choices', 'required', 'hide', 'help') if self._repr_attrs: @@ -376,13 +411,20 @@ def __repr__(self) -> str: kwargs = ', '.join(f'{a}={v!r}' for a, v in attrs) return f'{self.__class__.__name__}({self.name!r}, {kwargs})' - # region Parsing / Argument Handling + def __eq__(self, other) -> bool: + return ( + self.__class__ == other.__class__ + and self._attr_name == other._attr_name + and self._name == other._name + and self.command == other.command + ) - def get_const(self, opt_str: OptStr = None): - return _NotSet + def __hash__(self) -> int: + return hash(self.__class__) ^ hash(self._attr_name) ^ hash(self._name) ^ hash(self.command) - def get_env_const(self, value: str, env_var: str) -> tuple[T | _NotSetType, bool]: - return _NotSet, False + # endregion + + # region Parsing / Argument Handling def prepare_value(self, value: str, short_combo: Bool = False, env_var: OptStr = None) -> T | str: if self.type is None: @@ -434,29 +476,36 @@ def is_valid_arg(self, value: Any) -> bool: # region Parse Results / Argument Value Handling - @overload - def __get__(self, command: Literal[None], owner: Any = None) -> Self: ... + if TYPE_CHECKING: - @overload - def __get__(self, command: object, owner: Any = None) -> T | None: ... + @overload # type: ignore[override] + def __get__(self, command: Literal[None], owner: Any = None) -> Self: ... + + @overload + def __get__(self: Parameter[T, _NotSetType], command: Command, owner: Any = None) -> T: ... - def __get__(self, command: object | None, owner: Any = None) -> Self | T | None: + @overload + def __get__(self: Parameter[T, D], command: Command, owner: Any = None) -> T | D: ... + + @overload + def __get__(self, command: object, owner: Any = None) -> Self: ... + + def __get__(self, command: Command | object | None, owner: Any = None) -> Self | T | D: if command is None: return self - if context := self._ctx(command): - with context: - value = self.result(command) - else: - # This would only ever happen if this Parameter was created in a non-Command class, but it makes mypy happy - value = None + with self._ctx(command, True): + value = self.result(command) + # If `_attr_name` is set, it indicates that this parameter was present when the Command was initially defined. + # If it was not set, it means this parameter was added to the class late. Such cases are supported, but they + # do not benefit from parsed value caching within the command instance's `__dict__`. if self._attr_name: command.__dict__[self._attr_name] = value # Skip __get__ on subsequent accesses return value - def result(self, command: Command | Any = None, missing_default: TD | _NotSetType = _NotSet) -> T | TD | None: + def result(self, command: Command | Any = None, missing_default: TD | _NotSetType = _NotSet) -> T | D | TD: """The final result / parsed value for this Parameter that is returned upon access as a descriptor.""" if (value := ctx.get_parsed_value(self)) is not _NotSet: return self.action.finalize_value(value) @@ -489,7 +538,7 @@ def show_in_help(self) -> bool: # endregion -class BasePositional(Parameter[T], ABC): +class BasePositional(Parameter[T, D], ABC): """ Base class for :class:`.Positional`, :class:`.SubCommand`, :class:`.Action`, and any other parameters that are provided positionally, without prefixes. It is not meant to be used directly. @@ -504,7 +553,7 @@ class BasePositional(Parameter[T], ABC): :param kwargs: Additional keyword arguments to pass to :class:`Parameter`. """ - _default_ok: bool = False + _default_ok: ClassVar[bool] = False def __init_subclass__(cls, default_ok: bool | None = None, **kwargs): # pylint: disable=W0222 """ @@ -521,20 +570,22 @@ def __init__( *, required: Bool = True, default: Any = _NotSet, - default_cb: DefaultFunc | None = None, + default_cb: Any = None, **kwargs, ): if not (self._default_ok and 0 in self.nargs): # Indicates that having a default is bad if not required: cls_name = self.__class__.__name__ raise ParameterDefinitionError(f'All {cls_name} parameters must be required - invalid {required=}') - elif kw := ('default' if default is not _NotSet else 'default_cb' if default_cb is not None else None): + + if kw := ('default' if default is not _NotSet else 'default_cb' if default_cb is not None else None): cls_name = self.__class__.__name__ raise ParameterDefinitionError(f'The {kw!r} arg is not supported for {cls_name} parameters') + super().__init__(action, default=default, required=required, default_cb=default_cb, **kwargs) -class BaseOption(Parameter[T], ABC): +class BaseOption(Parameter[T, D], ABC): """ Base class for :class:`.Option`, :class:`.Flag`, :class:`.Counter`, and any other keyword-like parameters that have ``--long`` and ``-short`` prefixes before values. @@ -566,13 +617,63 @@ class - it is not meant to be used directly. :param kwargs: Additional keyword arguments to pass to :class:`Parameter`. """ - _opt_str_cls: Type[OptionStrings] = OptionStrings + _opt_str_cls: ClassVar[Type[OptionStrings]] = OptionStrings + option_strs: OptionStrings env_var: OptStrs = None show_env_var: Bool = None strict_env: Bool use_env_value: Bool - const: T | _NotSetType = _NotSet + + # region Init Overloads + + if TYPE_CHECKING: + + @overload + def __init__( + self: BaseOption[T, _NotSetType], + *option_strs: str, + action: str, + help: OptStr = None, # noqa + hide: Bool = False, + metavar: OptStr = None, + name: OptStr = None, + name_mode: OptionNameMode | OptStr | _NotSetType = _NotSet, + required: Literal[True], + default: _NotSetType = _NotSet, + default_cb: None = None, + cb_with_cmd: Bool = False, + show_default: Bool = None, + strict_default: Bool = False, + env_var: OptStrs = None, + strict_env: bool = True, + use_env_value: Bool = None, + show_env_var: Bool = None, + ): ... + + @overload + def __init__( + self: BaseOption[T, D], + *option_strs: str, + action: str, + help: OptStr = None, # noqa + hide: Bool = False, + metavar: OptStr = None, + name: OptStr = None, + name_mode: OptionNameMode | OptStr | _NotSetType = _NotSet, + required: Bool = False, + default: D | _NotSetType = _NotSet, + default_cb: DefaultFunc[D] | None = None, + cb_with_cmd: Bool = False, + show_default: Bool = None, + strict_default: Bool = False, + env_var: OptStrs = None, + strict_env: bool = True, + use_env_value: Bool = None, + show_env_var: Bool = None, + ): ... + + # endregion def __init__( self, @@ -613,9 +714,18 @@ def env_vars(self) -> Iterator[str]: else: yield from self.env_var + +class BaseFlag(BaseOption[T, D], ABC): + const: T | _NotSetType = _NotSet + def get_const(self, opt_str: OptStr = None): return self.const + def get_env_const(self, value: str, env_var: str) -> tuple[T | _NotSetType, bool]: + # Counter is the only flag-like param that doesn't override this / use this for handling env variables, but + # making this abstract would complicate things elsewhere + return _NotSet, False + class AllowLeadingDashProperty: """ @@ -662,7 +772,9 @@ def __set__(self, instance: Parameter, value: LeadingDash | None): class DefaultCallback(Generic[T]): __slots__ = ('func', 'use_cmd') - def __init__(self, func: CommandMethod | DefaultFunc, use_cmd: bool = False): + func: DefaultFunc[T] + + def __init__(self, func: DefaultFunc[T], use_cmd: bool = False): self.func = func self.use_cmd = use_cmd diff --git a/lib/cli_command_parser/parameters/choice_map.py b/lib/cli_command_parser/parameters/choice_map.py index 979a95c4..5ab39222 100644 --- a/lib/cli_command_parser/parameters/choice_map.py +++ b/lib/cli_command_parser/parameters/choice_map.py @@ -8,8 +8,7 @@ from functools import partial from string import printable, whitespace -from types import MethodType -from typing import TYPE_CHECKING, Callable, Collection, Generic, Mapping, NoReturn, Sequence, Type, TypeVar +from typing import TYPE_CHECKING, Callable, Collection, Generic, Mapping, NoReturn, ParamSpec, Sequence, Type, TypeVar from ..context import ctx from ..exceptions import BadArgument, CommandDefinitionError, InvalidChoice, ParameterDefinitionError @@ -21,14 +20,17 @@ from .base import BasePositional if TYPE_CHECKING: + from ..commands import Command from ..formatting.params import ChoiceMapHelpFormatter from ..metadata import ProgramMetadata - from ..typing import Bool, CommandObj, OptStr + from ..typing import Bool, OptStr __all__ = ['SubCommand', 'Action', 'Choice', 'ChoiceMap'] T = TypeVar('T') TD = TypeVar('TD') +P = ParamSpec('P') +Method = Callable[P, T] # TODO: Combine SubCommand and Action, replacing `local_choices` with stackable decorators on the target method, # optionally injecting the selected choice into positional args for the decorated method, which may be main? @@ -62,7 +64,7 @@ def format_help(self, lpad: int = 4, prefix: str = '') -> str: return format_help_entry((self.format_usage(),), self.help, prefix, lpad=lpad) -class ChoiceMap(BasePositional[str | None], Generic[T], actions=(Concatenate,)): +class ChoiceMap(BasePositional[str, None], Generic[T], actions=(Concatenate,)): """ Base class for :class:`SubCommand` and :class:`Action`. It is not meant to be used directly. @@ -175,7 +177,7 @@ def validate(self, value: str | Sequence[str], joined: Bool = False): if not any(c.startswith(prefix) for c in self.choices if c): raise InvalidChoice(self, prefix[:-1], self.choices) - def result(self, command: CommandObj | None = None, missing_default: TD | _NotSetType = _NotSet) -> OptStr | TD: + def result(self, command: Command | None = None, missing_default: TD | _NotSetType = _NotSet) -> OptStr | TD: if not self.choices: self._no_choices_error() return super().result(command, missing_default) @@ -308,7 +310,7 @@ def _no_choices_error(self) -> NoReturn: raise CommandDefinitionError(f'{ctx.command_cls}.{self.name} = {self} has no sub Commands') -class Action(ChoiceMap[MethodType], title='Actions'): +class Action(ChoiceMap[Method], title='Actions'): """ Actions are similar to :class:`.SubCommand` parameters, but allow methods in :class:`.Command` classes to be registered as a callable to be executed based on a user's choice instead of separate sub Commands. @@ -321,10 +323,10 @@ class Action(ChoiceMap[MethodType], title='Actions'): def register_action( self, choice: OptStr, - method: MethodType, + method: Method, help: OptStr = None, # noqa default: Bool = False, - ) -> MethodType: + ) -> Method: if help is None: try: help = method.__doc__ # noqa @@ -346,12 +348,12 @@ def register_action( def register( self, - method_or_choice: str | MethodType | None = None, + method_or_choice: str | Method | None = None, *, choice: OptStr = None, help: OptStr = None, # noqa default: Bool = False, - ) -> MethodType | Callable[[MethodType], MethodType]: + ) -> Method | Callable[[Method], Method]: """ Decorator that registers the wrapped method to be called when the given choice is specified for this parameter. Methods may also be registered by decorating them with the instantiated Action parameter directly - doing so diff --git a/lib/cli_command_parser/parameters/groups.py b/lib/cli_command_parser/parameters/groups.py index dfa37e66..78bc05f0 100644 --- a/lib/cli_command_parser/parameters/groups.py +++ b/lib/cli_command_parser/parameters/groups.py @@ -16,7 +16,9 @@ if TYPE_CHECKING: from ..formatting.params import GroupHelpFormatter - from ..typing import Bool, ParamList, ParamOrGroup + from ..typing import Bool, ParamOrGroup + + ParamList = list[ParamOrGroup] __all__ = ['ParamGroup'] @@ -187,8 +189,8 @@ def register_all(self, params: Iterable[ParamOrGroup]): def _categorize_params(self) -> tuple[ParamList, ParamList]: """Called after parsing to group this group's members by whether they were provided or not.""" - provided = [] - missing = [] + provided: ParamList = [] + missing: ParamList = [] for obj in self.members: if ctx.num_provided(obj): provided.append(obj) diff --git a/lib/cli_command_parser/parameters/options.py b/lib/cli_command_parser/parameters/options.py index e7c5ac8c..c76d425c 100644 --- a/lib/cli_command_parser/parameters/options.py +++ b/lib/cli_command_parser/parameters/options.py @@ -7,21 +7,23 @@ from __future__ import annotations import logging -import sys from functools import partial, update_wrapper -from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, Type, overload from ..exceptions import BadArgument, CommandDefinitionError, ParameterDefinitionError, ParamUsageError, ParserExit from ..inputs import normalize_input_type from ..nargs import Nargs, NargsValue -from ..typing import TypeFunc +from ..typing import B, D, T from ..utils import _NotSet, _NotSetType, str_to_bool from .actions import Append, AppendConst, Count, Store, StoreConst -from .base import AllowLeadingDashProperty, BaseOption, CommandMethod +from .base import AllowLeadingDashProperty, BaseFlag, BaseOption from .option_strings import TriFlagOptionStrings if TYPE_CHECKING: - from ..typing import Bool, ChoicesType, CommandCls, CommandObj, InputTypeFunc, LeadingDash, OptStr + from ..commands import Command + from ..config import OptionNameMode + from ..typing import Bool, ChoicesType, InputTypeFunc, OptStr, OptStrs, TypeFunc + from ._typing import CommandMethod, DefaultFunc, LeadingDash __all__ = [ 'Option', @@ -36,19 +38,11 @@ ] log = logging.getLogger(__name__) -if sys.version_info >= (3, 13): - T = TypeVar('T', default=str) - B = TypeVar('B', default=bool) - TF = TypeVar('TF', default=bool | None) -else: - T = TypeVar('T') - B = TypeVar('B') - TF = TypeVar('TF') - +OptAct = Literal['store', 'append'] | None ConstAct = Literal['store_const', 'append_const'] -class Option(BaseOption[T | None], actions=(Store, Append)): +class Option(BaseOption[T, D], actions=(Store, Append)): """ A generic option that can be specified as ``--foo bar`` or by using other similar forms. @@ -78,18 +72,76 @@ class Option(BaseOption[T | None], actions=(Store, Append)): :param kwargs: Additional keyword arguments to pass to :class:`.BaseOption`. """ - default: T + default: D allow_leading_dash = AllowLeadingDashProperty() + # region Init Overloads + + if TYPE_CHECKING: + + @overload + def __init__( + self: Option[T, _NotSetType], + *option_strs: str, + required: Literal[True], + type: InputTypeFunc[T] = None, # noqa + choices: ChoicesType[T] = None, + default: _NotSetType = _NotSet, + default_cb: None = None, + nargs: NargsValue | None = None, + action: OptAct = None, + help: OptStr = None, # noqa + hide: Bool = False, + metavar: OptStr = None, + name: OptStr = None, + name_mode: OptionNameMode | OptStr | _NotSetType = _NotSet, + allow_leading_dash: LeadingDash | None = None, + cb_with_cmd: Bool = False, + show_default: Bool = None, + strict_default: Bool = False, + env_var: OptStrs = None, + strict_env: bool = True, + use_env_value: Bool = None, + show_env_var: Bool = None, + ): ... + + @overload + def __init__( + self: Option[T, D], + *option_strs: str, + required: Bool = False, + type: InputTypeFunc[T] = None, # noqa + choices: ChoicesType[T] = None, + default: D | _NotSetType = _NotSet, + default_cb: DefaultFunc[D] | None = None, + nargs: NargsValue | None = None, + action: OptAct = None, + help: OptStr = None, # noqa + hide: Bool = False, + metavar: OptStr = None, + name: OptStr = None, + name_mode: OptionNameMode | OptStr | _NotSetType = _NotSet, + allow_leading_dash: LeadingDash | None = None, + cb_with_cmd: Bool = False, + show_default: Bool = None, + strict_default: Bool = False, + env_var: OptStrs = None, + strict_env: bool = True, + use_env_value: Bool = None, + show_env_var: Bool = None, + ): ... + + # endregion + def __init__( self, *option_strs: str, nargs: NargsValue | None = None, - action: Literal['store', 'append'] | None = None, - default: T | _NotSetType = _NotSet, + action: OptAct = None, + default: D | _NotSetType = _NotSet, required: Bool = False, - type: InputTypeFunc = None, # noqa - choices: ChoicesType = None, + type: InputTypeFunc[T] = None, # noqa + choices: ChoicesType[T] = None, allow_leading_dash: LeadingDash | None = None, **kwargs, ): @@ -104,7 +156,7 @@ def __init__( super().__init__(*option_strs, action=action, default=default, required=required, **kwargs) self.nargs = self.action.default_nargs if _nargs is None else _nargs - self.type = normalize_input_type(type, choices) + self.type = normalize_input_type(type, choices) # type: ignore[assignment] self.allow_leading_dash = allow_leading_dash def _handle_bad_action(self, action: str) -> NoReturn: @@ -128,7 +180,7 @@ def _validate_option_nargs(nargs_val: NargsValue | None) -> Nargs | None: raise ParameterDefinitionError(f'Invalid nargs={nargs_val} - {details}') -class Flag(BaseOption[B], actions=(StoreConst, AppendConst)): +class Flag(BaseFlag[B, B], actions=(StoreConst, AppendConst)): """ A (typically boolean) option that does not accept any values. @@ -159,20 +211,96 @@ class Flag(BaseOption[B], actions=(StoreConst, AppendConst)): """ nargs = Nargs(0) - type: TypeFunc = staticmethod(str_to_bool) # Without staticmethod, this would be interpreted as a normal method + # Without staticmethod, this would be interpreted as a normal method + type: TypeFunc[B] = staticmethod(str_to_bool) # type: ignore[arg-type] use_env_value: bool = False __default_const_map = {True: False, False: True, _NotSet: True} default: B const: B + # region Init Overloads + + if TYPE_CHECKING: + + @overload + def __init__( + self: Flag[bool], + *option_strs: str, + action: ConstAct = 'store_const', + default: _NotSetType = _NotSet, + default_cb: _NotSetType = _NotSet, + const: _NotSetType = _NotSet, + type: None = None, # noqa + name_mode: OptionNameMode | OptStr | _NotSetType = _NotSet, + env_var: OptStrs = None, + strict_env: bool = True, + use_env_value: Bool = None, + show_env_var: Bool = None, + help: OptStr = None, # noqa + hide: Bool = False, + metavar: OptStr = None, + name: OptStr = None, + required: Bool = False, + show_default: Bool = None, + strict_default: Bool = False, + ): ... + + @overload + def __init__( + self: Flag[B], + *option_strs: str, + action: ConstAct = 'store_const', + default: B, + default_cb: _NotSetType = _NotSet, + const: B | _NotSetType = _NotSet, + type: TypeFunc[B] | None = None, # noqa + name_mode: OptionNameMode | OptStr | _NotSetType = _NotSet, + env_var: OptStrs = None, + strict_env: bool = True, + use_env_value: Bool = None, + show_env_var: Bool = None, + help: OptStr = None, # noqa + hide: Bool = False, + metavar: OptStr = None, + name: OptStr = None, + required: Bool = False, + show_default: Bool = None, + strict_default: Bool = False, + ): ... + + @overload + def __init__( + self: Flag[B], + *option_strs: str, + action: ConstAct = 'store_const', + default: B | _NotSetType = _NotSet, + default_cb: _NotSetType = _NotSet, + const: B, + type: TypeFunc[B] | None = None, # noqa + name_mode: OptionNameMode | OptStr | _NotSetType = _NotSet, + env_var: OptStrs = None, + strict_env: bool = True, + use_env_value: Bool = None, + show_env_var: Bool = None, + help: OptStr = None, # noqa + hide: Bool = False, + metavar: OptStr = None, + name: OptStr = None, + required: Bool = False, + show_default: Bool = None, + strict_default: Bool = False, + ): ... + + # endregion + def __init__( self, *option_strs: str, action: ConstAct = 'store_const', default: B | _NotSetType = _NotSet, - default_cb=_NotSet, + default_cb: _NotSetType = _NotSet, const: B | _NotSetType = _NotSet, - type: TypeFunc | None = None, # noqa + type: TypeFunc[B] | None = None, # noqa **kwargs, ): if const is _NotSet: @@ -209,7 +337,7 @@ def get_env_const(self, value: str, env_var: str) -> tuple[B, bool]: return parsed, self.use_env_value -class TriFlag(BaseOption[TF], actions=(StoreConst, AppendConst)): +class TriFlag(BaseFlag[B, D], actions=(StoreConst, AppendConst)): """ A trinary / ternary Flag. While :class:`.Flag` only supports 1 constant when provided, with 1 default if not provided, this class accepts a pair of constants for the primary and alternate values to store, along with a @@ -245,27 +373,28 @@ class TriFlag(BaseOption[TF], actions=(StoreConst, AppendConst)): """ nargs = Nargs(0) - type: TypeFunc = staticmethod(str_to_bool) # Without staticmethod, this would be interpreted as a normal method + # Without staticmethod, this would be interpreted as a normal method + type: TypeFunc[B] = staticmethod(str_to_bool) # type: ignore[arg-type] use_env_value: bool = False _default_cb_ok = True _opt_str_cls = TriFlagOptionStrings option_strs: TriFlagOptionStrings alt_help: OptStr = None - default: TF | _NotSetType - consts: tuple[TF, TF] + default: D | _NotSetType + consts: tuple[B, B] def __init__( self, *option_strs: str, - consts: tuple[TF, TF] = (True, False), # type: ignore[assignment] + consts: tuple[B, B] = (True, False), # type: ignore[assignment] alt_prefix: OptStr = None, alt_long: OptStr = None, alt_short: OptStr = None, alt_help: OptStr = None, action: ConstAct = 'store_const', - default: TF | _NotSetType = _NotSet, - default_cb: Callable[[], TF] | None = None, - type: TypeFunc | None = None, # noqa + default: D | _NotSetType = _NotSet, + default_cb: Callable[[], D] | None = None, + type: TypeFunc[B] | None = None, # noqa **kwargs, ): if alt_short and '-' in alt_short[1:]: @@ -301,26 +430,27 @@ def __init__( if type is not None: self.type = type - def __set_name__(self, command: CommandCls, name: str): + def __set_name__(self, command: Type[Command], name: str): super().__set_name__(command, name) self.option_strs.update_alts(name) - def register_default_cb(self, method: CommandMethod) -> CommandMethod: + def register_default_cb(self, method: CommandMethod[D]) -> CommandMethod[D]: if self._default_cb_ok and self.default is not _NotSet: self.default = _NotSet # The default was set by __init__ - remove it so the method can be registered return super().register_default_cb(method) - def get_const(self, opt_str: OptStr = None) -> TF: + def get_const(self, opt_str: OptStr = None) -> B: if opt_str in self.option_strs.alt_allowed: return self.consts[1] else: return self.consts[0] - def get_env_const(self, value: str, env_var: str) -> tuple[TF, bool]: + def get_env_const(self, value: str, env_var: str) -> tuple[B, bool]: try: parsed = self.type(value) except Exception as e: raise ParamUsageError(self, f'unable to parse {value=} from env var={env_var!r}: {e}') from e + if self.use_env_value: if parsed not in self.consts and parsed != self.default: raise BadArgument(self, f'invalid value={parsed!r} from env var={env_var!r}') @@ -465,7 +595,7 @@ def help_action(self): # endregion -class Counter(BaseOption[int], actions=(Count,)): +class Counter(BaseFlag[int, int], actions=(Count,)): """ A :class:`.Flag`-like option that counts the number of times it was specified. Supports an optional integer value to explicitly increase the stored value by that amount. @@ -496,7 +626,7 @@ def __init__( init: int = 0, const: int = 1, default: int | _NotSetType = _NotSet, - default_cb: Callable[[], int] | None = None, + default_cb: DefaultFunc[int] | None = None, required: bool = False, **kwargs, ): @@ -513,7 +643,7 @@ def __init__( self.init = init self.const = const - def register_default_cb(self, method: CommandMethod) -> CommandMethod: + def register_default_cb(self, method: CommandMethod[int]) -> CommandMethod[int]: if self.default_cb and self.default_cb.func is _counter_default: self.default_cb = None return super().register_default_cb(method) @@ -542,5 +672,5 @@ def validate(self, value: Any, joined: Bool = False): return -def _counter_default(): +def _counter_default() -> int: return 0 diff --git a/lib/cli_command_parser/parameters/pass_thru.py b/lib/cli_command_parser/parameters/pass_thru.py index 3c1e6309..639a3b66 100644 --- a/lib/cli_command_parser/parameters/pass_thru.py +++ b/lib/cli_command_parser/parameters/pass_thru.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Literal +from typing import ClassVar, Literal from ..nargs import Nargs from .actions import StoreAll @@ -24,8 +24,10 @@ class PassThru(Parameter, actions=(StoreAll,)): :param kwargs: Additional keyword arguments to pass to :class:`.Parameter`. """ + action: StoreAll nargs = Nargs('REMAINDER') - missing_hint: str = " (missing pass thru args separated from others with '--')" # leading space is intentional + # missing_hint: Hint to provide in exceptions if this param/group is missing; the leading space is intentional + missing_hint: ClassVar[str] = " (missing pass thru args separated from others with '--')" def __init__(self, action: Literal['store_all'] = 'store_all', **kwargs): super().__init__(action=action, **kwargs) diff --git a/lib/cli_command_parser/parameters/positionals.py b/lib/cli_command_parser/parameters/positionals.py index f63664ab..28975f4a 100644 --- a/lib/cli_command_parser/parameters/positionals.py +++ b/lib/cli_command_parser/parameters/positionals.py @@ -6,24 +6,24 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, TypeVar +from typing import TYPE_CHECKING, Literal from ..exceptions import ParameterDefinitionError from ..inputs import normalize_input_type from ..nargs import Nargs, NargsValue -from ..utils import _NotSet +from ..typing import D, T +from ..utils import _NotSet, _NotSetType from .actions import Append, Store -from .base import AllowLeadingDashProperty, BasePositional, DefaultFunc +from .base import AllowLeadingDashProperty, BasePositional if TYPE_CHECKING: - from ..typing import ChoicesType, InputTypeFunc, LeadingDash + from ..typing import ChoicesType, InputTypeFunc + from ._typing import DefaultFunc, LeadingDash __all__ = ['Positional'] -T = TypeVar('T') - -class Positional(BasePositional[T], default_ok=True, actions=(Store, Append)): +class Positional(BasePositional[T, D], default_ok=True, actions=(Store, Append)): """ A parameter that must be provided positionally. @@ -58,11 +58,11 @@ def __init__( self, nargs: NargsValue | None = None, action: Literal['store', 'append'] | None = None, - type: InputTypeFunc = None, # noqa - default: Any = _NotSet, + type: InputTypeFunc[T] = None, # noqa + default: D | _NotSetType = _NotSet, *, - default_cb: DefaultFunc | None = None, - choices: ChoicesType = None, + default_cb: DefaultFunc[D] | None = None, + choices: ChoicesType[T] = None, allow_leading_dash: LeadingDash | None = None, **kwargs, ): @@ -89,5 +89,5 @@ def __init__( ) kwargs.setdefault('required', required) super().__init__(action=action, default=default, default_cb=default_cb, **kwargs) - self.type = normalize_input_type(type, choices) + self.type = normalize_input_type(type, choices) # type: ignore[assignment] self.allow_leading_dash = allow_leading_dash diff --git a/lib/cli_command_parser/typing.py b/lib/cli_command_parser/typing.py index a57bd706..5b76c4e9 100644 --- a/lib/cli_command_parser/typing.py +++ b/lib/cli_command_parser/typing.py @@ -6,14 +6,14 @@ from __future__ import annotations +import sys +from collections.abc import Collection from typing import ( - IO, TYPE_CHECKING, Any, - Callable, - Collection, Iterable, Pattern, + Protocol, Sequence, Type, TypeAlias, @@ -21,54 +21,52 @@ Union, ) +try: + from typing import Self +except ImportError: # added in 3.11 + Self = TypeVar('Self') # type: ignore[misc,assignment] + if TYPE_CHECKING: - from datetime import date, datetime, time, timedelta - from enum import Enum - from numbers import Number as _Number from pathlib import Path from .commands import Command - from .config import AllowLeadingDash, CommandConfig - from .core import CommandMeta - from .inputs import InputType + from .inputs import InputType, Regex from .parameters import Parameter, ParamGroup + from .parameters.base import ParamBase + -T = TypeVar('T') -TypeFunc = Callable[[str], T] +Bool: TypeAlias = bool | Any +StrSeq: TypeAlias = Sequence[str] +Strs: TypeAlias = str | StrSeq +StrIter: TypeAlias = Iterable[str] +IStrs: TypeAlias = str | StrIter +OptStr: TypeAlias = str | None +OptStrs: TypeAlias = Strs | None +Strings: TypeAlias = Collection[str] +PathLike: TypeAlias = Union[str, 'Path'] -NT = TypeVar('NT', bound='_Number') -Number: TypeAlias = NT | None -NumType = Callable[[Any], NT] -RngType = range | int | Sequence[int] +CommandObj = TypeVar('CommandObj', bound='Command') +CommandCls: TypeAlias = Type[CommandObj] +CommandAny: TypeAlias = CommandCls | CommandObj -InputTypeFunc = Union[None, TypeFunc, 'InputType', range, Type['Enum'], Pattern] -ChoicesType = Collection[Any] | None +ParamOrGroup: TypeAlias = Union['Parameter', 'ParamGroup', 'ParamBase'] -Bool = bool | Any -StrSeq = Sequence[str] -Strs = str | StrSeq -StrIter = Iterable[str] -IStrs = str | StrIter -OptStr = str | None -OptStrs = Strs | None -Strings = Collection[str] -PathLike = Union[str, 'Path'] -Locale = str | tuple[OptStr, OptStr] -TimeBound = Union['datetime', 'date', 'time', 'timedelta', None] +if sys.version_info >= (3, 13): + T = TypeVar('T', default=str, covariant=True, bound=Any) + D = TypeVar('D', default=None, covariant=True, bound=Any) + B = TypeVar('B', default=bool, covariant=True, bound=Any) +else: + T = TypeVar('T', covariant=True, bound=Any) + D = TypeVar('D', covariant=True, bound=Any) + B = TypeVar('B', covariant=True, bound=Any) -Deserializer = Callable[[str | bytes | IO], Any] -Serializer = Callable[[Any, IO], None] | Callable[[Any], str | bytes] -Converter = Deserializer | Serializer -Config = Union['CommandConfig', None] -AnyConfig = Config | dict[str, Any] -LeadingDash = Union['AllowLeadingDash', str, bool] +class TypeFunc(Protocol[T]): + def __call__(self, value: str, /) -> T: + pass -P = TypeVar('P', bound='Parameter') -ParamList = list[P] -ParamOrGroup = Union[P, 'ParamGroup'] -CommandObj = TypeVar('CommandObj', bound='Command') -CommandCls: TypeAlias = Type[CommandObj] -CommandAny: TypeAlias = CommandCls | CommandObj +ChoicesType: TypeAlias = Collection[T] | None +InputTypeFunc: TypeAlias = Union[Type[T], TypeFunc[T], 'InputType[T]', range, Pattern, None] +NormalizedType: TypeAlias = Union[Type[T], TypeFunc[T], 'InputType[T]', 'Regex[str]', None] diff --git a/lib/cli_command_parser/utils.py b/lib/cli_command_parser/utils.py index 84bc7b54..e20e57b8 100644 --- a/lib/cli_command_parser/utils.py +++ b/lib/cli_command_parser/utils.py @@ -10,23 +10,21 @@ from inspect import isawaitable from shutil import get_terminal_size from time import monotonic -from typing import Any, Awaitable, Callable, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar try: from enum import CONFORM except ImportError: # added in 3.11 CONFORM = None # type: ignore[misc,assignment] -try: - from typing import Self -except ImportError: # added in 3.11 - Self = TypeVar('Self') # type: ignore[misc,assignment] - try: from wcwidth import wcwidth # type: ignore[import-untyped] except ImportError: wcwidth = len +if TYPE_CHECKING: + from .typing import Self + # region Text Processing / Formatting diff --git a/tests/test_parameters/test_counters.py b/tests/test_parameters/test_counters.py index 45b08d8e..2f885e78 100755 --- a/tests/test_parameters/test_counters.py +++ b/tests/test_parameters/test_counters.py @@ -161,6 +161,10 @@ class Foo(Command): with self.assert_raises_contains_str(BadArgument, "bad counter value='foo' from env var='BAR'"): Foo.parse([]) + def test_default_unused_get_env_const(self): + # This is a contrived test - this method is not actually used during parsing of Counter env variables + self.assertEqual((_NotSet, False), Counter().get_env_const('123456', '')) + # endregion diff --git a/tests/test_parameters/test_positionals.py b/tests/test_parameters/test_positionals.py index 44b5d94f..4b3204cb 100755 --- a/tests/test_parameters/test_positionals.py +++ b/tests/test_parameters/test_positionals.py @@ -125,12 +125,6 @@ def test_bad_leading_dash_with_remainder_rejected(self): class Foo(Command): bar = Positional(nargs='REMAINDER', allow_leading_dash=allow_leading_dash) - def test_default_get_const(self): - self.assertIs(_NotSet, Positional().get_const()) - - def test_default_normalize_env_val(self): - self.assertEqual((_NotSet, False), Positional().get_env_const('123456', '')) - def test_too_many_arguments(self): with Context(): param = Positional(nargs=1, action='append')