diff --git a/lib/cli_command_parser/__init__.py b/lib/cli_command_parser/__init__.py index acfff218..b0116d10 100644 --- a/lib/cli_command_parser/__init__.py +++ b/lib/cli_command_parser/__init__.py @@ -42,6 +42,7 @@ Counter, Flag, Option, + Param, Parameter, ParamGroup, PassThru, @@ -52,4 +53,4 @@ after_main, before_main, ) -from .typing import Param, ParamOrGroup +from .typing import ParamOrGroup diff --git a/lib/cli_command_parser/command_parameters.py b/lib/cli_command_parser/command_parameters.py index fa2ec8ff..a47aa998 100644 --- a/lib/cli_command_parser/command_parameters.py +++ b/lib/cli_command_parser/command_parameters.py @@ -12,20 +12,25 @@ from collections import defaultdict from functools import cached_property -from typing import TYPE_CHECKING, Collection, Iterator +from typing import TYPE_CHECKING, Any, Collection, Iterator, Type, TypeAlias from .config import AmbiguousComboMode, CommandConfig from .exceptions import AmbiguousCombo, AmbiguousShortForm, CommandDefinitionError, ParameterDefinitionError -from .parameters import Action, ActionFlag, ParamGroup, PassThru, SubCommand, help_action +from .parameters import ActionFlag, ParamGroup, PassThru, help_action from .parameters.base import BaseOption, BasePositional, ParamBase, Parameter +from .parameters.choice_map import Action, SubCommand if TYPE_CHECKING: + from .commands import Command from .context import Context + from .core import CommandMeta from .formatting.commands import CommandHelpFormatter - from .typing import CommandCls, Strings + from .typing import Bool, Strings + CommandCls: TypeAlias = Type[Command] | CommandMeta OptionMap = dict[str, BaseOption] ActionFlags = list[ActionFlag] + Positionals = list[BasePositional] | tuple[()] __all__ = ['CommandParameters'] @@ -33,18 +38,14 @@ class CommandParameters: # fmt: off command: CommandCls #: The Command associated with this CommandParameters object - formatter: CommandHelpFormatter #: The formatter used for this Command's help text parent: CommandParameters | None #: The parent Command's CommandParameters - action: Action | None = None #: An Action Parameter, if specified - _pass_thru: PassThru | None = None #: A PassThru Parameter, if specified - sub_command: SubCommand | None = None #: A SubCommand Parameter, if specified action_flags: ActionFlags #: List of action flags split_action_flags: tuple[ActionFlags, ActionFlags] #: Action flags split by before/after main options: list[BaseOption] #: List of optional Parameters combo_option_map: OptionMap #: Mapping of {short opt: Parameter} (no dash characters) groups: list[ParamGroup] #: List of ParamGroup objects positionals: list[BasePositional] #: List of positional Parameters - _deferred_positionals: list[BasePositional] = () #: Positional Parameters that are deferred to sub commands + _deferred_positionals: Positionals = () #: Positional Parameters that are deferred to sub commands option_map: OptionMap #: Mapping of {--opt / -opt: Parameter} # fmt: on @@ -52,6 +53,12 @@ def __init__(self, command: CommandCls, parent_params: CommandParameters | None, self.command = command self.parent = parent_params self.config = config + # fmt: off + # These are annotated here because mypy thinks they're invoked as descriptors when annotated at the class level + self.action: Action | None = None #: An Action Parameter, if specified + self.sub_command: SubCommand | None = None #: A SubCommand Parameter, if specified + self._pass_thru: PassThru | None = None #: A PassThru Parameter, if specified + # fmt: on self._process_parameters() def __repr__(self) -> str: @@ -77,11 +84,8 @@ def has_nested_pass_thru(self) -> bool: @cached_property def all_positionals(self) -> list[BasePositional]: - try: - if not self.parent.sub_command: - return self.parent.all_positionals + self.positionals - except AttributeError: - pass + if self.parent and not self.parent.sub_command: + return self.parent.all_positionals + self.positionals return self.positionals def get_positionals_to_parse(self, ctx: Context) -> list[BasePositional]: @@ -94,6 +98,7 @@ def get_positionals_to_parse(self, ctx: Context) -> list[BasePositional]: @cached_property def formatter(self) -> CommandHelpFormatter: + """The formatter used for this Command's help text.""" from .formatting.commands import CommandHelpFormatter formatter_factory = self.config.command_formatter or CommandHelpFormatter @@ -105,13 +110,13 @@ def formatter(self) -> CommandHelpFormatter: return formatter @cached_property - def _has_help(self) -> bool: + def _has_help(self) -> Bool: return help_action in self.action_flags or (self.parent and self.parent._has_help) # region Initialization def _iter_parameters(self) -> Iterator[ParamBase]: - name_param_map = {} # Allow subclasses to override names, but not within a given command + name_param_map: dict[str, Any] = {} # Allow subclasses to override names, but not within a given command for item in self.command.__dict__.items(): attr, param = item if attr.startswith('__') or not isinstance(param, ParamBase): # Name mangled Parameters are still processed @@ -173,7 +178,9 @@ def _process_groups(self, groups: set[ParamGroup]): self.groups = sorted(groups) if groups else [] def _process_positionals(self, params: list[BasePositional]): - unfollowable = action_or_sub_cmd = split_index = None + unfollowable: BasePositional | None = None + action_or_sub_cmd: SubCommand | Action | None = None + split_index: int = 0 if self.parent and (deferred := self.parent._deferred_positionals): params = deferred + params @@ -186,26 +193,28 @@ def _process_positionals(self, params: list[BasePositional]): raise CommandDefinitionError( f'Additional Positional parameters cannot follow {unfollowable} {why} - {param=} is invalid' ) - elif isinstance(param, (SubCommand, Action)): + + if isinstance(param, (SubCommand, Action)): if action_or_sub_cmd: raise CommandDefinitionError( f'Only 1 Action xor SubCommand is allowed in a given Command - {self.command.__name__} cannot' f' contain both {action_or_sub_cmd} and {param}' ) - elif isinstance(param, SubCommand): + + if isinstance(param, SubCommand): self.sub_command = action_or_sub_cmd = param split_index = i + 1 if param.has_choices and 0 in param.nargs: # It has local choices or is not required unfollowable = param else: # It's an Action - self.action = action_or_sub_cmd = param # type: ignore + self.action = action_or_sub_cmd = param if not param.has_choices: raise CommandDefinitionError(f'No choices were registered for {self.action}') elif 0 in param.nargs or (param.nargs.variable and not param.has_choices): unfollowable = param if split_index: - if self.sub_command.has_local_choices: + if self.sub_command.has_local_choices: # type: ignore[union-attr] self._deferred_positionals = params[split_index:] else: params, self._deferred_positionals = params[:split_index], params[split_index:] @@ -252,19 +261,22 @@ def _process_option_strs(self, param: BaseOption, opt_type: str, opt_strs: Strin f'{opt_type} {option=} conflict for command={self.command!r} between {existing} and {param}' ) - def _process_action_flags(self): - action_flags = sorted(p for p in self.options if isinstance(p, ActionFlag)) - grouped_ordered_flags = {True: defaultdict(list), False: defaultdict(list)} + def _process_action_flags(self) -> None: + action_flags: list[ActionFlag] = sorted(p for p in self.options if isinstance(p, ActionFlag)) + grouped_ordered_flags: dict[bool, dict[int | float, ActionFlags]] = { + True: defaultdict(list), + False: defaultdict(list), + } for param in action_flags: if param.func is None: raise ParameterDefinitionError(f'No function was registered for {param=}') grouped_ordered_flags[param.before_main][param.order].append(param) found_non_always = False - invalid = {} + invalid: dict[tuple[bool, int | float], ActionFlags | ActionFlag] = {} for before_main, prio_params in grouped_ordered_flags.items(): for prio, params in prio_params.items(): - param: ActionFlag = params[0] # Don't pop and check `if params` - all are needed for the group check + param = params[0] # Don't pop and check `if params` - all are needed for the group check if found_non_always and param.always_available: invalid[(before_main, prio)] = param elif not param.always_available: @@ -296,7 +308,7 @@ def _process_action_flags(self): @cached_property def _classified_combo_options(self) -> tuple[OptionMap, OptionMap]: """Tuple of (single char short:Option map, multi-char short:Option map) for options available in this command""" - multi_char_combos = {} + multi_char_combos: OptionMap = {} items = self.combo_option_map.items() for combo, param in items: if len(combo) == 1: # combo_option_map is sorted in reverse length order, so all following will be 1 char @@ -368,7 +380,7 @@ def short_option_to_param_value_pairs(self, option: str) -> tuple[list[tuple[str # Note: if the option is not in this Command's option_map, the KeyError is handled by CommandParser return [(option, self.option_map[option], value)], True else: - value = None + value = None # type: ignore[assignment] try: param = self.option_map[option] diff --git a/lib/cli_command_parser/commands.py b/lib/cli_command_parser/commands.py index 723aaf7c..c7bb5acd 100644 --- a/lib/cli_command_parser/commands.py +++ b/lib/cli_command_parser/commands.py @@ -9,10 +9,10 @@ import logging from abc import ABC from contextlib import ExitStack -from typing import TYPE_CHECKING, Sequence, TextIO, Type, overload +from typing import TYPE_CHECKING, Sequence, TextIO, Type from .context import ActionPhase, Context, get_or_create_context -from .core import CommandMeta, get_params, get_top_level_commands +from .core import CommandMeta, get_metadata, get_params, get_top_level_commands from .exceptions import ParamConflict, ParserExit from .parser import parse_args_and_get_next_cmd from .utils import maybe_await @@ -32,34 +32,25 @@ class Command(ABC, metaclass=CommandMeta): #: The parsing Context used for this Command. Provided here for convenience - this reference to it is not used by #: any CLI Command Parser internals, so it is safe for subclasses to redefine / overwrite it. ctx: Context + __ctx: Context - def __new__(cls): + def __new__(cls) -> Command: # By storing the Context here instead of __init__, every single subclass won't need to # call super().__init__(...) from their own __init__ for this step self = super().__new__(cls) self.__ctx = ctx = get_or_create_context(cls, command=self) if not hasattr(self, 'ctx'): - self.ctx: Context = ctx # noqa # PyCharm complains this is invalid, but doesn't understand it without it + self.ctx = ctx # noqa # PyCharm complains this is invalid, but doesn't understand it without it return self def __repr__(self) -> str: cls = self.__class__ - return f'<{cls.__name__} in prog={cls.__class__.meta(cls).prog!r}>' + return f'<{cls.__name__} in prog={get_metadata(cls).prog!r}>' # region Parse & Run @classmethod - @overload - def parse_and_run(cls: Type[CommandObj], argv: Argv = None, **kwargs) -> CommandObj | None: - # These overloads indicate that an instance of the same type or another may be returned - ... - - @classmethod - @overload - def parse_and_run(cls, argv: Argv = None, **kwargs) -> CommandObj | None: ... - - @classmethod - def parse_and_run(cls, argv=None, **kwargs): + def parse_and_run(cls: Type[CommandObj], argv: Argv | None = None, **kwargs) -> CommandObj | None: """ Primary entry point for parsing arguments, resolving subcommands, and running a command. @@ -90,15 +81,7 @@ def parse_and_run(cls, argv=None, **kwargs): # region Parse @classmethod - @overload - def parse(cls: Type[CommandObj], argv: Argv = None) -> CommandObj: ... - - @classmethod - @overload - def parse(cls, argv: Argv = None) -> CommandObj: ... - - @classmethod - def parse(cls, argv=None): + def parse(cls: Type[CommandObj], argv: Argv | None = None) -> CommandObj: """ Parses the specified arguments (or :data:`sys.argv`), and resolves the final subcommand class based on the parsed arguments, if necessary. Initializes the Command, but does not call any of its other methods. @@ -111,7 +94,7 @@ def parse(cls, argv=None): with ExitStack() as stack: stack.enter_context(ctx) while sub_cmd := parse_args_and_get_next_cmd(ctx): - cmd_cls = sub_cmd + cmd_cls = sub_cmd # type: ignore[assignment] ctx = stack.enter_context(ctx._sub_context(cmd_cls)) return cmd_cls() @@ -308,14 +291,14 @@ async def parse_and_await(cls, argv=None, **kwargs): await maybe_await(self(**kwargs)) return self - async def __call__(self, *args, **kwargs) -> int: + async def __call__(self, *args, **kwargs) -> int: # type: ignore[override] """Asynchronous version of :meth:`Command.__call__`.""" - with self._Command__ctx as ctx, ctx.get_error_handler(): # noqa + with self._Command__ctx as ctx, ctx.get_error_handler(): # type: ignore[attr-defined] await maybe_await(self._pre_init_actions_(*args, **kwargs)) await maybe_await(self._init_command_(*args, **kwargs)) await maybe_await(self._before_main_(*args, **kwargs)) try: - await maybe_await(self.main(*args, **kwargs)) + await maybe_await(self.main(*args, **kwargs)) # type: ignore[arg-type] except BaseException: if ctx.config.always_run_after_main: log.debug('Caught exception - running _after_main_ before propagating', exc_info=True) @@ -328,7 +311,7 @@ async def __call__(self, *args, **kwargs) -> int: async def _run_actions_(self, phase: ActionPhase, args: tuple, kwargs: dict): """Asynchronous version of :meth:`Command._run_actions_`.""" - for param in self._Command__ctx.iter_action_flags(phase): # noqa + for param in self._Command__ctx.iter_action_flags(phase): # type: ignore[attr-defined] await maybe_await(param.func(self, *args, **kwargs)) async def _pre_init_actions_(self, *args, **kwargs): @@ -340,9 +323,9 @@ async def _before_main_(self, *args, **kwargs): """Asynchronous version of :meth:`Command._before_main_`.""" await self._run_actions_(ActionPhase.BEFORE_MAIN, args, kwargs) - async def main(self, *args, **kwargs) -> int | None: + async def main(self, *args, **kwargs) -> int | None: # type: ignore[override] """Asynchronous version of :meth:`Command.main`.""" - with self._Command__ctx as ctx: # noqa + with self._Command__ctx as ctx: # type: ignore[attr-defined] action = get_params(self).action if action is not None and (ctx.actions_taken == 0 or ctx.config.action_after_action_flags): ctx.actions_taken += 1 @@ -355,7 +338,7 @@ async def _after_main_(self, *args, **kwargs): await self._run_actions_(ActionPhase.AFTER_MAIN, args, kwargs) -def main(argv: Argv = None, return_command: Bool = False, **kwargs) -> CommandObj | None: +def main(argv: Argv | None = None, return_command: Bool = False, **kwargs) -> CommandObj | None: """ Convenience function that can be used as the main entry point for a program. diff --git a/lib/cli_command_parser/context.py b/lib/cli_command_parser/context.py index 794b5d94..3d27cd1d 100644 --- a/lib/cli_command_parser/context.py +++ b/lib/cli_command_parser/context.py @@ -3,18 +3,18 @@ :author: Doug Skrypa """ -# pylint: disable=R0801 from __future__ import annotations import sys from collections import defaultdict +from collections.abc import Collection from contextlib import AbstractContextManager from contextvars import ContextVar from enum import Enum from functools import cached_property from inspect import Parameter as _Parameter, Signature -from typing import TYPE_CHECKING, Any, Callable, Collection, Iterator, Sequence, cast +from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal, Sequence, Type, TypeAlias, cast, overload from .config import DEFAULT_CONFIG, CommandConfig from .error_handling import ErrorHandler, NullErrorHandler, extended_error_handler @@ -26,13 +26,14 @@ from .commands import Command from .core import CommandMeta from .parameters import ActionFlag, BaseOption, Parameter - from .typing import AnyConfig, Bool, CommandCls, CommandObj, OptStr, ParamOrGroup, PathLike, StrSeq + from .typing import AnyConfig, Bool, OptStr, ParamOrGroup, PathLike, StrSeq Argv = StrSeq | None + CommandCls: TypeAlias = Type[Command] | CommandMeta __all__ = ['Context', 'ctx', 'get_current_context', 'get_or_create_context', 'get_context', 'get_parsed', 'get_raw_arg'] -_context_stack = ContextVar('cli_command_parser.context.stack') +_context_stack: ContextVar[list[Context]] = ContextVar('cli_command_parser.context.stack') _TERMINAL = Terminal() @@ -47,9 +48,10 @@ class Context(AbstractContextManager): # Extending AbstractContextManager to ma config: CommandConfig prog: OptStr = None allow_argv_prog: Bool = True - _command_obj: CommandObj | None = None + _command_obj: Command | None = None _terminal_width: int | None _provided: dict[ParamOrGroup, int] + _parsed: dict[ParamOrGroup, Any] def __init__( self, @@ -60,7 +62,7 @@ def __init__( config: AnyConfig | None = None, terminal_width: int | None = None, allow_argv_prog: Bool = None, - command: CommandObj | None = None, + command: Command | None = None, **kwargs, ): self.command_cls = command_cls @@ -87,7 +89,7 @@ def __init__( @classmethod def for_prog(cls, prog: PathLike, *args, **kwargs) -> Context: self = cls(*args, **kwargs) - self.prog = getattr(prog, 'name', prog) + self.prog = getattr(prog, 'name', prog) # type: ignore[arg-type] return self def _set_argv(self, prog: OptStr, argv: Argv): @@ -101,7 +103,7 @@ def _set_argv(self, prog: OptStr, argv: Argv): self.remaining = list(self.argv) def _sub_context( - self, command_cls: CommandCls, argv: Argv = None, command: CommandObj | None = None, **kwargs + self, command_cls: CommandCls, argv: Argv = None, command: Command | None = None, **kwargs ) -> Context: return self.__class__( self.remaining if argv is None else argv, @@ -249,13 +251,16 @@ def num_provided(self, param: ParamOrGroup) -> int: def get_missing(self) -> list[Parameter]: """Not intended to be called by users. Used during parsing to determine if any Parameters are missing.""" - return [p for p in self.params.required_check_params() if not self._provided[p]] + if self.params: + return [p for p in self.params.required_check_params() if not self._provided[p]] + return [] def missing_options_with_env_var(self) -> Iterator[BaseOption]: """Yields Option parameters that have an environment variable configured, and did not have any CLI values.""" - for param in self.params.options: - if param.env_var and not self._provided[param]: - yield param + if self.params: + for param in self.params.options: + if param.env_var and not self._provided[param]: + yield param # endregion @@ -268,7 +273,8 @@ def _parsed_action_flags(self) -> tuple[int, list[ActionFlag], list[ActionFlag]] action flags to run before main, and the action flags to run after main. """ try: - before_main, after_main = self.params.split_action_flags # Each part is already sorted + # Each part is already sorted + before_main, after_main = self.params.split_action_flags # type: ignore[union-attr] except AttributeError: # self.command_cls is None return 0, [], [] @@ -427,6 +433,14 @@ def config(self) -> CommandConfig: # region Public / Semi-Public Functions +@overload +def get_current_context(silent: Literal[False] = False) -> Context: ... + + +@overload +def get_current_context(silent: Literal[True]) -> Context | None: ... + + def get_current_context(silent: bool = False) -> Context | None: """ Get the currently active parsing context. @@ -444,7 +458,7 @@ def get_current_context(silent: bool = False) -> Context | None: def get_or_create_context( - command_cls: CommandCls, argv: Argv = None, *, command: CommandObj | None = None, **kwargs + command_cls: CommandCls, argv: Argv = None, *, command: Command | None = None, **kwargs ) -> Context: """ Used internally by Commands to re-use an existing user-activated Context, or to create a new Context if there was @@ -464,7 +478,7 @@ def get_context(command: Command) -> Context: :return: The Context associated with the given Command """ try: - return command._Command__ctx # noqa + return command._Command__ctx # type: ignore[attr-defined] except AttributeError as e: raise TypeError('get_context only supports Command objects') from e diff --git a/lib/cli_command_parser/conversion/argparse_ast.py b/lib/cli_command_parser/conversion/argparse_ast.py index 3c448ad6..ae321d11 100644 --- a/lib/cli_command_parser/conversion/argparse_ast.py +++ b/lib/cli_command_parser/conversion/argparse_ast.py @@ -9,7 +9,7 @@ from functools import cached_property, partial from inspect import BoundArguments, Signature from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterator, Literal, Type, TypeVar, overload +from typing import TYPE_CHECKING, Callable, ClassVar, Generic, Literal, Type, TypeAlias, TypeVar, overload try: from typing import Self @@ -22,8 +22,9 @@ if TYPE_CHECKING: from collections.abc import Collection + from typing import Any, Iterator - from cli_command_parser.typing import PathLike + from cli_command_parser.typing import OptStr, PathLike from .visitor import TrackedRef, TrackedRefMap @@ -36,6 +37,7 @@ ParserObj = TypeVar('ParserObj', bound='AstArgumentParser') RepresentedCallable = Callable AC = TypeVar('AC', bound='AstCallable') +ACGroup: TypeAlias = tuple[Type[AC], list[AC]] D = TypeVar('D') VisitFunc = Callable[[InitNode, OptCall, 'TrackedRefMap'], AC] @@ -255,11 +257,11 @@ def init_func_raw_kwargs(self) -> dict[str, AST]: kwargs.update(kwargs.pop('kwargs')) return kwargs - def _init_func_kwargs(self) -> dict[str, str]: + def _init_func_kwargs(self) -> dict[str, OptStr]: return {key: unparse(val) for key, val in self.init_func_raw_kwargs.items()} @cached_property - def init_func_kwargs(self) -> dict[str, str]: + def init_func_kwargs(self) -> dict[str, OptStr]: return self._init_func_kwargs() def init_call_repr(self) -> str: @@ -318,9 +320,9 @@ def add_mutually_exclusive_group( def add_argument_group(self, node: InitNode, call: OptCall, tracked_refs: TrackedRefMap) -> ArgGroup: return self._add_child(ArgGroup, self.groups, node, call, tracked_refs) - def grouped_children(self) -> Iterator[tuple[Type[AC], list[AC]]]: - yield ParserArg, self.args # type: ignore[misc] - yield ArgGroup, self.groups # type: ignore[misc] + def grouped_children(self) -> Iterator[ACGroup]: + yield ParserArg, self.args + yield ArgGroup, self.groups # region Output Methods @@ -400,7 +402,7 @@ class SubParser(AstArgumentParser, represents=_SubParsersAction.add_parser): sp_parent: SubparsersAction @cached_property - def init_func_kwargs(self) -> dict[str, str]: + def init_func_kwargs(self) -> dict[str, OptStr]: kwargs = self.sp_parent.init_func_kwargs.copy() kwargs.update(self._init_func_kwargs()) return kwargs diff --git a/lib/cli_command_parser/conversion/cli.py b/lib/cli_command_parser/conversion/cli.py index 0524e7b2..81a8943a 100644 --- a/lib/cli_command_parser/conversion/cli.py +++ b/lib/cli_command_parser/conversion/cli.py @@ -4,7 +4,7 @@ from functools import cached_property from pathlib import Path -from cli_command_parser import Command, Counter, Flag, ParamGroup, Positional, SubCommand, main +from cli_command_parser import Command, Counter, Flag, Param, ParamGroup, Positional, SubCommand, main from cli_command_parser.inputs import Path as IPath log = logging.getLogger(__name__) @@ -16,11 +16,13 @@ class ParserConverter(Command): """Tool to convert an argparse.ArgumentParser into a cli-command-parser Command""" action = SubCommand() - input: Path - no_smart_for = Flag('-S', help='Disable "smart" for loop handling that attempts to dedupe common subparser params') + input: Param[Path] + no_smart_for: Param[bool] = Flag( + '-S', help='Disable "smart" for loop handling that attempts to dedupe common subparser params' + ) with ParamGroup('Common'): verbose = Counter('-v', help='Increase logging verbosity (can specify multiple times)') - dry_run = Flag('-D', help='Print the actions that would be taken instead of taking them') + dry_run: Flag[bool] = Flag('-D', help='Print the actions that would be taken instead of taking them') def _init_command_(self): log_fmt = '%(asctime)s %(levelname)s %(name)s %(lineno)d %(message)s' if self.verbose > 1 else '%(message)s' @@ -38,7 +40,7 @@ def script(self): class Convert(ParserConverter): """Print the cli-command-parser Commands that are equivalent to the discovered argparse ArgumentParsers""" - input: Path = Positional(type=IPath(type='file', exists=True), help=f'A file containing an {arg_parser}') + input: Param[Path] = Positional(type=IPath(type='file', exists=True), help=f'A file containing an {arg_parser}') add_methods = Flag('--no-methods', '-M', default=True, help='Do not include boilerplate methods in Commands') def main(self): @@ -50,7 +52,7 @@ def main(self): class Pprint(ParserConverter): """Print a tiered internal representation of the discovered argparse ArgumentParsers and their groups/arguments""" - input: Path = Positional(type=IPath(type='file', exists=True), help=f'A file containing an {arg_parser}') + input: Param[Path] = Positional(type=IPath(type='file', exists=True), help=f'A file containing an {arg_parser}') def main(self): for parser in self.script.parsers: diff --git a/lib/cli_command_parser/conversion/command_builder.py b/lib/cli_command_parser/conversion/command_builder.py index 491de05b..34100da3 100644 --- a/lib/cli_command_parser/conversion/command_builder.py +++ b/lib/cli_command_parser/conversion/command_builder.py @@ -7,24 +7,26 @@ from dataclasses import dataclass, fields from functools import cached_property from itertools import count -from typing import TYPE_CHECKING, Generic, Iterable, Iterator, Type, TypeVar +from typing import TYPE_CHECKING, Generic, Iterable, Iterator, MutableMapping, Type, TypeVar from cli_command_parser.nargs import Nargs -from .argparse_ast import AC, ArgGroup, AstArgumentParser, MutuallyExclusiveGroup, ParserArg, Script +from .argparse_ast import ArgGroup, AstArgumentParser, AstCallable, MutuallyExclusiveGroup, ParserArg, Script from .utils import collection_contents, unparse if TYPE_CHECKING: - from cli_command_parser.typing import OptStr + from cli_command_parser.typing import Bool, OptStr from .argparse_ast import ArgCollection __all__ = ['convert_script'] log = logging.getLogger(__name__) +AC = TypeVar('AC', bound=AstCallable | Script) +ACol = TypeVar('ACol', bound='ArgCollection') C = TypeVar('C', bound='Converter') -RESERVED = set(keyword.kwlist) | set(getattr(keyword, 'softkwlist', ('_', 'case', 'match'))) # soft was added in 3.9 +RESERVED = set(keyword.kwlist) | set(keyword.softkwlist) # TODO: Handle argparse.SUPPRESS ('==SUPPRESS==') @@ -35,7 +37,7 @@ def convert_script(script: Script, add_methods: bool = False) -> str: class Converter(Generic[AC], ABC): converts: Type[AC] | None = None newline_between_members: bool = False - _ac_converter_map = {} + _ac_converter_map = {} # type: ignore[var-annotated] def __init_subclass__(cls, converts: Type[AC] | None = None, newline_between_members: bool | None = None, **kwargs): super().__init_subclass__(**kwargs) @@ -51,16 +53,15 @@ def __init__(self, ast_obj: AC | Script, parent: Converter | None = None): @classmethod def for_ast_callable(cls, ast_obj: AC | Type[AC]) -> Type[Converter[AC]]: - if not isinstance(ast_obj, type): - ast_obj = ast_obj.__class__ + ast_cls = ast_obj if isinstance(ast_obj, type) else ast_obj.__class__ try: - return cls._ac_converter_map[ast_obj] + return cls._ac_converter_map[ast_cls] except KeyError: pass for converts_cls, converter_cls in cls._ac_converter_map.items(): - if issubclass(ast_obj, converts_cls): + if issubclass(ast_cls, converts_cls): return converter_cls - raise TypeError(f'No Converter is registered for {ast_obj.__class__.__name__} objects') + raise TypeError(f'No Converter is registered for {ast_cls.__class__.__name__} objects') @classmethod def init_for_ast_callable(cls, ast_obj: AC, *args, **kwargs) -> Converter[AC]: @@ -119,10 +120,10 @@ def format_lines(self, indent: int = 0) -> Iterator[str]: yield from ParserConverter(parser, counter=counter, add_methods=self.add_methods).format_lines() -class CollectionConverter(Converter[AC], ABC): - ast_obj: ArgCollection +class CollectionConverter(Converter[ACol], ABC): + ast_obj: ACol parent: CollectionConverter | None - _name_mode = None + _name_mode: OptStr = None @cached_property def name_mode(self) -> str | None: @@ -130,7 +131,10 @@ def name_mode(self) -> str | None: @cached_property def grouped_children(self) -> list[ConverterGroup[ParamConverter | GroupConverter | Converter]]: - return [self.for_ast_callable(cg_cls).init_group(self, cg) for cg_cls, cg in self.ast_obj.grouped_children()] + return [ + self.for_ast_callable(cg_cls).init_group(self, cg) # type: ignore[misc] + for cg_cls, cg in self.ast_obj.grouped_children() + ] def descendant_args(self) -> Iterator[ParamConverter]: for child_group in self.grouped_children: @@ -138,9 +142,9 @@ def descendant_args(self) -> Iterator[ParamConverter]: continue elif hasattr(child_group[0], 'descendant_args'): for child in child_group: - yield from child.descendant_args() + yield from child.descendant_args() # type: ignore[union-attr] elif isinstance(child_group[0], ParamConverter): - yield from child_group + yield from child_group # type: ignore[misc] def format_members(self, prefix: str, indent: int = 4) -> Iterator[str]: last = False @@ -284,7 +288,7 @@ def name_mode(self) -> str | None: return self._name_mode or (self.parent.name_mode if self.parent else None) @cached_property - def _name_mode(self) -> str | None: + def _name_mode(self) -> str | None: # type: ignore[override] if self.parent and self.parent._name_mode: return None name_modes = {pc._name_mode for pc in self.descendant_args() if pc.is_option and '_' in pc.attr_name} @@ -330,8 +334,8 @@ def __init__(self, arg: ParserArg, parent: CollectionConverter, num: int): super().__init__(arg, parent) self.num = num - def __eq__(self, other: ParamConverter) -> bool: - return self.ast_obj == other.ast_obj and self.num == other.num + def __eq__(self, other) -> bool: + return isinstance(other, ParamConverter) and self.ast_obj == other.ast_obj and self.num == other.num def __lt__(self, other: ParamConverter) -> bool: if self.is_positional and not other.is_positional: @@ -360,7 +364,7 @@ def attr_name(self) -> str: @cached_property def name_mode(self) -> str | None: - return None if self.parent.name_mode else self._name_mode + return None if self.parent and self.parent.name_mode else self._name_mode @cached_property def _name_mode(self) -> str | None: @@ -452,12 +456,12 @@ def is_pass_thru(self) -> bool: return nargs in self.ast_obj.get_tracked_refs('argparse', 'REMAINDER', ()) @cached_property - def is_positional(self) -> bool: + def is_positional(self) -> Bool: long, short, plain = self._grouped_opt_strs return plain and not long and not short @cached_property - def is_option(self) -> bool: + def is_option(self) -> Bool: long, short, plain = self._grouped_opt_strs return (long or short) and not plain @@ -635,10 +639,10 @@ class FlagArgs(OptionArgs): const: OptStr = None @classmethod - def init_flag(cls, action: str, const: OptStr = None, default: OptStr = None, **kwargs): + def init_flag(cls, action: OptStr, const: OptStr = None, default: OptStr = None, **kwargs): values = {'store_true': ('True', 'False'), 'store_false': ('False', 'True')} try: - value, opposite = values[action] + value, opposite = values[action] # type: ignore[index] except KeyError: if action == 'store_const': action = None @@ -671,10 +675,10 @@ def init_counter(cls, const: OptStr = None, default: OptStr = None, **kwargs): # endregion -def literal_eval_or_none(expr: str) -> str | None: +def literal_eval_or_none(expr: str | None) -> str | None: try: - return literal_eval(expr) - except ValueError: + return literal_eval(expr) # type: ignore[arg-type] + except ValueError: # expr could not be evaluated, or it was None return None diff --git a/lib/cli_command_parser/core.py b/lib/cli_command_parser/core.py index dad4839f..38420741 100644 --- a/lib/cli_command_parser/core.py +++ b/lib/cli_command_parser/core.py @@ -19,11 +19,9 @@ from .utils import _NotSet, _NotSetType if TYPE_CHECKING: - from .commands import Command - from .typing import AnyConfig, CommandCls, Config, OptStr + from .typing import AnyConfig, CommandAny, CommandCls, Config, OptStr Bases = tuple[type, ...] - CommandAny = Union['CommandMeta', Command] Choice = str | None | _NotSetType Choices = Mapping[str, str | None] | Collection[str] OptChoices = Choices | None diff --git a/lib/cli_command_parser/documentation.py b/lib/cli_command_parser/documentation.py index 2693d531..81021f11 100644 --- a/lib/cli_command_parser/documentation.py +++ b/lib/cli_command_parser/documentation.py @@ -13,7 +13,7 @@ from fnmatch import fnmatch from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable, Mapping +from typing import TYPE_CHECKING, Any, Iterable, Mapping, Type from .commands import Command from .context import Context @@ -22,8 +22,9 @@ from .formatting.restructured_text import MODULE_TEMPLATE, rst_header, rst_toc_tree if TYPE_CHECKING: - from .typing import Bool, CommandCls, OptStr, PathLike, Strings + from .typing import Bool, OptStr, PathLike, Strings + CommandCls = Type[Command] Commands = dict[str, CommandCls] __all__ = ['render_script_rst', 'render_command_rst', 'load_commands', 'RstWriter'] @@ -79,7 +80,7 @@ def load_commands(path: PathLike, top_only: Bool = False, include_abc: Bool = Fa Load all of the commands from the file with the given path and return them as a dict of ``{name: Command}``. If an :class:`python:OSError` or a subclass thereof is encountered while attempting to load the file (due to the - path not existing, or a permission error, etc), it will be allowed to propagate. An :class:`python:ImportError` + path not existing, or a permission error, etc.), it will be allowed to propagate. An :class:`python:ImportError` may be raised by :func:`import_module` if the specified path cannot be imported. :param path: The path to a file containing one or more :class:`.Command` classes @@ -136,7 +137,12 @@ def import_module(path: PathLike): if path.is_dir(): path /= '__init__.py' - spec = spec_from_file_location(name, path) + if not (spec := spec_from_file_location(name, path)): + path_str = path.as_posix() + raise ImportError( + f'Unable to find module={name!r} at path={path_str!r} - are you sure it is a Python module?', path=path_str + ) + try: module = module_from_spec(spec) except AttributeError as e: @@ -145,7 +151,7 @@ def import_module(path: PathLike): sys.modules[spec.name] = module # This is required for the program metadata introspection try: - spec.loader.exec_module(module) + spec.loader.exec_module(module) # type: ignore[union-attr] except Exception: del sys.modules[spec.name] raise @@ -264,7 +270,7 @@ def document_scripts( ): names = [self.document_script(path, subdir, top_only=top_only, **kwargs) for path in paths] if index_name or index_header or index_subdir: - name = index_name or subdir + name: str = index_name or subdir or index_header or index_subdir # type: ignore[assignment] self.write_index( name, index_header or name.title(), names, content_subdir=subdir, caption=caption, subdir=index_subdir ) @@ -310,7 +316,7 @@ def document_package( index_subdir = content_subdir = f'{subdir}/{name}' if subdir else name else: index_subdir = None - content_subdir = subdir + content_subdir = subdir # type: ignore[assignment] # TODO: This needs improvement for multi-package repos contents = self._generate_code_rsts(pkg_name, pkg_path, content_subdir, max_depth=max_depth) diff --git a/lib/cli_command_parser/formatting/commands.py b/lib/cli_command_parser/formatting/commands.py index 1f6fc4bf..84650f52 100644 --- a/lib/cli_command_parser/formatting/commands.py +++ b/lib/cli_command_parser/formatting/commands.py @@ -8,7 +8,7 @@ from functools import cached_property from textwrap import TextWrapper -from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Type +from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Type, TypeAlias from ..context import NoActiveContext, ctx from ..core import get_metadata, get_params @@ -20,11 +20,14 @@ if TYPE_CHECKING: from ..command_parameters import CommandParameters + from ..commands import Command from ..config import CommandConfig from ..core import CommandMeta from ..metadata import ProgramMetadata from ..parameters import BaseOption, BasePositional, Parameter, PassThru, SubCommand - from ..typing import Bool, CommandAny, CommandCls, OptStr + from ..typing import Bool, OptStr + + CommandCls: TypeAlias = Type[Command] | CommandMeta __all__ = ['CommandHelpFormatter', 'get_formatter'] @@ -32,7 +35,7 @@ class CommandHelpFormatter: - def __init__(self, command: CommandMeta, params: CommandParameters): + def __init__(self, command: CommandCls, params: CommandParameters): self.command = command self.params = params self.pos_group = ParamGroup(description='Positional arguments') @@ -229,21 +232,22 @@ def _fix_name(name: str) -> str: return camel_to_snake_case(name).replace('_', ' ').title() -def get_formatter(command: CommandAny) -> CommandHelpFormatter: +def get_formatter(command: CommandCls | Command) -> CommandHelpFormatter: """Get the :class:`CommandHelpFormatter` for the given Command""" return get_params(command).formatter -def get_usage_sub_cmds(command: CommandCls): - cmd_mcs: Type[CommandMeta] = command.__class__ # Using metaclass to avoid potentially overwritten attrs +def get_usage_sub_cmds(command: CommandCls | CommandMeta): + # Using metaclass to avoid potentially overwritten attrs + cmd_mcs: Type[CommandMeta] = command.__class__ # type: ignore[assignment] - parent: CommandMeta + parent: CommandMeta | None if not (parent := cmd_mcs.parent(command, False)): return yield from get_usage_sub_cmds(parent) - sub_cmd_param: SubCommand + sub_cmd_param: SubCommand | None if not (sub_cmd_param := cmd_mcs.params(parent).sub_command): return diff --git a/lib/cli_command_parser/formatting/params.py b/lib/cli_command_parser/formatting/params.py index 9e43624f..297d1b0a 100644 --- a/lib/cli_command_parser/formatting/params.py +++ b/lib/cli_command_parser/formatting/params.py @@ -3,18 +3,18 @@ :author: Doug Skrypa """ -# pylint: disable=W0613 from __future__ import annotations +from abc import ABC, abstractmethod from functools import cached_property -from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Type +from typing import TYPE_CHECKING, Callable, ClassVar, Generic, Iterable, Iterator, Type, TypeVar from ..config import CmdAliasMode, SubcommandAliasHelpMode from ..context import ctx from ..core import get_config from ..parameters import ParamGroup, PassThru, TriFlag -from ..parameters.base import BaseOption, BasePositional +from ..parameters.base import BaseOption, BasePositional, ParamBase, Parameter from ..parameters.choice_map import Choice, ChoiceMap from .restructured_text import Cell, Row, RstTable from .utils import _should_add_default, format_help_entry @@ -22,23 +22,30 @@ if TYPE_CHECKING: from ..nargs import Nargs from ..parameters.option_strings import TriFlagOptionStrings - from ..typing import Bool, OptStr, ParamOrGroup + from ..typing import Bool, OptStr BoolFormatterMap = dict[bool, Callable[[str], str]] -class ParamHelpFormatter: +BaseP = TypeVar('BaseP', bound=ParamBase) +ParamP = TypeVar('ParamP', bound=Parameter) +PosP = TypeVar('PosP', bound=BasePositional) +OptP = TypeVar('OptP', bound=BaseOption) + + +class ParamHelpFormatter(ABC, Generic[BaseP]): __slots__ = ('param',) - _param_cls_fmt_cls_map = {} + + _param_cls_fmt_cls_map: ClassVar[dict[Type[ParamBase], Type[ParamHelpFormatter]]] = {} required_formatter_map: BoolFormatterMap = {False: '[{}]'.format} - def __init_subclass__(cls, param_cls: Type[ParamOrGroup] = None, **kwargs): + def __init_subclass__(cls, param_cls: Type[BaseP] | None = None, **kwargs): super().__init_subclass__(**kwargs) if param_cls is not None: cls._param_cls_fmt_cls_map[param_cls] = cls @classmethod - def for_param_cls(cls, param_cls: Type[ParamOrGroup]): + def for_param_cls(cls, param_cls: Type[BaseP]) -> Type[ParamHelpFormatter[BaseP]]: try: return cls._param_cls_fmt_cls_map[param_cls] except KeyError: @@ -48,13 +55,13 @@ def for_param_cls(cls, param_cls: Type[ParamOrGroup]): if issubclass(param_cls, p_cls): return f_cls - return ParamHelpFormatter + raise ValueError(f'No help formatter is available for {param_cls.__name__} objects') - def __new__(cls, param: ParamOrGroup): + def __new__(cls, param): return super().__new__(cls.for_param_cls(param.__class__) if cls is ParamHelpFormatter else cls) - def __init__(self, param: ParamOrGroup): - self.param = param + def __init__(self, param: BaseP): + self.param: BaseP = param def __getnewargs__(self): return (self.param,) @@ -69,6 +76,45 @@ def maybe_wrap_usage(self, text: str) -> str: except KeyError: return text + def format_basic_usage(self) -> str: + """Format the Parameter for use in the ``usage:`` line""" + return self.maybe_wrap_usage(self.format_usage(True)) + + @abstractmethod + def format_usage(self, include_meta: Bool = False, full: Bool = False, delim: str = ', ') -> str: + """Format the Parameter for use in both the ``usage:`` line and in the list of Parameters""" + raise NotImplementedError + + def iter_usage_parts(self, include_meta: Bool = False, full: Bool = False) -> Iterator[str]: + """Format the Parameter for use in the list of Parameters with their ``help='...'`` descriptions""" + yield self.format_usage(include_meta=include_meta, full=full) + + @abstractmethod + def format_description(self, rst: Bool = False, *, description: OptStr = None) -> str: + raise NotImplementedError + + def format_help(self, prefix: str = '') -> str: + usage_iter = self.iter_usage_parts(include_meta=True, full=True) + return format_help_entry(usage_iter, self.format_description(), prefix) + + # region RST + + def rst_usage(self) -> str: + return '``' + self.format_usage(include_meta=True, full=True) + '``' + + def rst_row(self) -> tuple[str, str]: + """Returns a tuple of (usage, description)""" + return self.rst_usage(), self.format_description(rst=True) + + def rst_rows(self) -> Iterator[tuple[str, str]]: + yield self.rst_row() + + # endregion + + +class ParameterHelpFormatter(ParamHelpFormatter[ParamP], param_cls=Parameter): + __slots__ = () + def format_metavar(self) -> str: param = self.param if param.metavar and param.action.accepts_values: @@ -77,7 +123,7 @@ def format_metavar(self) -> str: config = ctx.config if (t := param.type) is not None: try: - metavar = t.format_metavar(config.choice_delim, config.sort_choices) # noqa + metavar = t.format_metavar(config.choice_delim, config.sort_choices) # type: ignore[attr-defined] except Exception: # noqa # pylint: disable=W0703 pass else: @@ -108,19 +154,11 @@ def _format_usage_metavar(self, full: Bool = True) -> str: return f'{metavar} [{metavar} ...]' return metavar - def format_basic_usage(self) -> str: - """Format the Parameter for use in the ``usage:`` line""" - return self.maybe_wrap_usage(self.format_usage(True)) - def format_usage(self, include_meta: Bool = False, full: Bool = False, delim: str = ', ') -> str: """Format the Parameter for use in both the ``usage:`` line and in the list of Parameters""" return self.format_metavar() - def iter_usage_parts(self, include_meta: Bool = False, full: Bool = False) -> Iterator[str]: - """Format the Parameter for use in the list of Parameters with their ``help='...'`` descriptions""" - yield self.format_usage(include_meta=include_meta, full=full) - - def format_description(self, rst: Bool = False, description: OptStr = None) -> str: + def format_description(self, rst: Bool = False, *, description: OptStr = None) -> str: param = self.param if description is None: description = param.help or '' @@ -130,34 +168,16 @@ def format_description(self, rst: Bool = False, description: OptStr = None) -> s return description - def format_help(self, prefix: str = '') -> str: - usage_iter = self.iter_usage_parts(include_meta=True, full=True) - return format_help_entry(usage_iter, self.format_description(), prefix) - - # region RST - - def rst_usage(self) -> str: - return '``' + self.format_usage(include_meta=True, full=True) + '``' - - def rst_row(self) -> tuple[str, str]: - """Returns a tuple of (usage, description)""" - return self.rst_usage(), self.format_description(rst=True) - - def rst_rows(self) -> Iterator[tuple[str, str]]: - yield self.rst_row() - - # endregion - -class PositionalHelpFormatter(ParamHelpFormatter, param_cls=BasePositional): - param: BasePositional +class PositionalHelpFormatter(ParameterHelpFormatter[PosP], param_cls=BasePositional): + __slots__ = () def format_usage(self, include_meta: Bool = False, full: Bool = False, delim: str = ', ') -> str: return self._format_usage_metavar(full) -class OptionHelpFormatter(ParamHelpFormatter, param_cls=BaseOption): - param: BaseOption +class OptionHelpFormatter(ParameterHelpFormatter[OptP], param_cls=BaseOption): + __slots__ = () def iter_usage_parts(self, include_meta: Bool = False, full: Bool = False) -> Iterator[str]: opts = self.param.option_strs @@ -170,8 +190,8 @@ def iter_usage_parts(self, include_meta: Bool = False, full: Bool = False) -> It metavar = self._format_usage_metavar() yield from (f'{opt} {metavar}' for opt in opts.option_strs()) - def format_description(self, rst: Bool = False, description: OptStr = None) -> str: - description = super().format_description(rst, description) + def format_description(self, rst: Bool = False, *, description: OptStr = None) -> str: + description = super().format_description(rst, description=description) param: BaseOption = self.param if param.env_var and (param.show_env_var or (param.show_env_var is None and ctx.config.show_env_vars)): pad, quote = _pad_and_quote(description, rst) @@ -196,7 +216,9 @@ def rst_usage(self) -> str: return ', '.join(f'``{part}``' for part in self.iter_usage_parts()) -class TriFlagHelpFormatter(OptionHelpFormatter, param_cls=TriFlag): +class TriFlagHelpFormatter(OptionHelpFormatter[TriFlag], param_cls=TriFlag): + __slots__ = () + def format_usage(self, include_meta: Bool = False, full: Bool = False, delim: str = ', ') -> str: opts: TriFlagOptionStrings = self.param.option_strs if full: @@ -204,11 +226,11 @@ def format_usage(self, include_meta: Bool = False, full: Bool = False, delim: st else: return f'{opts.get_usage_opt(False)} | {opts.get_usage_opt(True)}' - def format_description(self, rst: Bool = False, alt: bool = False) -> str: + def format_description(self, rst: Bool = False, *, description: OptStr = None, alt: bool = False) -> str: if not alt: - return super().format_description(rst=rst) + return super().format_description(rst=rst, description=description) elif self.param.alt_help: - return super().format_description(rst=rst, description=self.param.alt_help) + return super().format_description(rst=rst, description=description or self.param.alt_help) return '' def format_help(self, prefix: str = '') -> str: @@ -225,11 +247,9 @@ def rst_rows(self) -> Iterator[tuple[str, str]]: yield usage, self.format_description(rst=True, alt=alt) -class ChoiceMapHelpFormatter(ParamHelpFormatter, param_cls=ChoiceMap): +class ChoiceMapHelpFormatter(ParameterHelpFormatter[ChoiceMap], param_cls=ChoiceMap): """Formatter for :class:`SubCommand` and :class:`Action` parameters (and any other params that extend ChoiceMap)""" - param: ChoiceMap - @cached_property def choice_groups(self) -> Iterable[ChoiceGroup]: return ChoiceGroup.group_choices(self.param.choices.values()) @@ -239,7 +259,7 @@ def format_metavar(self) -> str: config = ctx.config choices = (str(c) for c in (c.choice for cg in self.choice_groups for c in cg.choices) if c is not None) if config.sort_choices: - choices = sorted(choices) + choices = sorted(choices) # type: ignore[assignment] return f'{{{config.choice_delim.join(choices)}}}' else: return self.param.metavar or self.param.name.upper() @@ -248,7 +268,7 @@ def format_help(self, prefix: str = '') -> str: help_entry = format_help_entry(self.iter_usage_parts(), self.param.description, prefix, lpad=2) choices = self._format_choices(prefix) if ctx.config.sort_choices: - choices = sorted(choices) + choices = sorted(choices) # type: ignore[assignment] parts = ( f'{prefix}{self.param.title or self.param._default_title}:', @@ -266,7 +286,7 @@ def _format_choices(self, prefix: str = '') -> Iterator[str]: def rst_table(self) -> RstTable: rows = self._format_rst_rows() if ctx.config.sort_choices: - rows = sorted(rows) + rows = sorted(rows) # type: ignore[assignment] table = RstTable(self.param.title or self.param._default_title, self.param.description) table.add_rows(rows) @@ -306,7 +326,7 @@ def group_choices(cls, choices: Iterable[Choice]) -> Iterable[ChoiceGroup]: :param choices: The :class:`.Choice` objects that may contain aliases of each other. :return: The :class:`.ChoiceGroup` objects containing the grouped Choices. """ - target_choice_map = {} + target_choice_map = {} # type: ignore[var-annotated] for n, choice in enumerate(choices): key = (choice.target, n if choice.local else None, choice.help) try: @@ -329,7 +349,8 @@ def format(self, default_mode: CmdAliasMode, prefix: str = '') -> Iterator[str]: :return: Generator that yields formatted help text entries (strings) for the Choices in this group. """ for choice, usage, description in self.prepare(default_mode): - yield format_help_entry((usage,), description, lpad=4, prefix=prefix) + if usage is not None: + yield format_help_entry((usage,), description, lpad=4, prefix=prefix) def prepare(self, default_mode: CmdAliasMode) -> Iterator[tuple[Choice, OptStr, OptStr]]: """ @@ -349,14 +370,16 @@ def prepare(self, default_mode: CmdAliasMode) -> Iterator[tuple[Choice, OptStr, else: mode = default_mode - if mode == SubcommandAliasHelpMode.ALIAS: - yield from self.prepare_aliases() - elif mode == SubcommandAliasHelpMode.REPEAT: - yield from self.prepare_repeated() - elif mode == SubcommandAliasHelpMode.COMBINE: - yield self.prepare_combined() - else: # Treat as a format string - yield from self.prepare_aliases(mode) + match mode: + case SubcommandAliasHelpMode.ALIAS: + yield from self.prepare_aliases() + case SubcommandAliasHelpMode.REPEAT: + yield from self.prepare_repeated() + case SubcommandAliasHelpMode.COMBINE: + yield self.prepare_combined() + case str(): + # Treat it as a format string + yield from self.prepare_aliases(mode) def prepare_combined(self) -> tuple[Choice, OptStr, OptStr]: """ @@ -415,12 +438,13 @@ def prepare_repeated(self) -> Iterator[tuple[Choice, OptStr, OptStr]]: yield choice, choice.format_usage(), choice.help -class PassThruHelpFormatter(ParamHelpFormatter, param_cls=PassThru): +class PassThruHelpFormatter(ParameterHelpFormatter[PassThru], param_cls=PassThru): + __slots__ = () required_formatter_map = {True: '-- {}'.format, False: '[-- {}]'.format} -class GroupHelpFormatter(ParamHelpFormatter, param_cls=ParamGroup): # noqa # pylint: disable=W0223 - param: ParamGroup +class GroupHelpFormatter(ParamHelpFormatter[ParamGroup], param_cls=ParamGroup): + __slots__ = () required_formatter_map: BoolFormatterMap = {True: '{{{}}}'.format, False: '[{}]'.format} def _get_choice_delim(self) -> str: diff --git a/lib/cli_command_parser/parameters/__init__.py b/lib/cli_command_parser/parameters/__init__.py index d132b313..cd5466db 100644 --- a/lib/cli_command_parser/parameters/__init__.py +++ b/lib/cli_command_parser/parameters/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseOption, BasePositional, Parameter +from .base import BaseOption, BasePositional, Param, Parameter from .choice_map import Action, SubCommand from .groups import ParamGroup from .options import ActionFlag, Counter, Flag, Option, TriFlag, action_flag, after_main, before_main, help_action diff --git a/lib/cli_command_parser/parameters/actions.py b/lib/cli_command_parser/parameters/actions.py index 7c16a121..4835989a 100644 --- a/lib/cli_command_parser/parameters/actions.py +++ b/lib/cli_command_parser/parameters/actions.py @@ -6,16 +6,18 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Callable, Generic, Iterable, Iterator, NoReturn, Sequence, TypeVar, Union +from enum import Enum +from typing import TYPE_CHECKING, ClassVar, Generic, NoReturn, TypeVar, Union from ..context import ctx from ..exceptions import BadArgument, InvalidChoice, MissingArgument, ParamConflict, ParamUsageError, TooManyArguments from ..inputs import InputType from ..nargs import Nargs -from ..utils import _NotSet, camel_to_snake_case +from ..utils import _NotSet, _NotSetType, camel_to_snake_case if TYPE_CHECKING: from .base import Parameter # noqa + from .options import Counter from ..typing import Bool, CommandObj, OptStr __all__ = [ @@ -30,16 +32,26 @@ 'StoreAll', ] -_PANotSet = object() -Param = TypeVar('Param', bound='Parameter') +class _PANotSetType(Enum): + """Provides the sentinel value for _PANotSet in a way that is fully compatible with type checkers.""" + + _PANotSet = '_PANotSet' + + def __str__(self) -> str: + return self.name + + +_PANotSet = _PANotSetType._PANotSet + +P = TypeVar('P', bound='Parameter') Found = Union[int, NoReturn] -class ParamAction(ABC, Generic[Param]): +class ParamAction(ABC, Generic[P]): __slots__ = ('param',) name: str - param: Param + # param: P default = _NotSet accepts_values: bool = False accepts_consts: bool = False @@ -56,8 +68,8 @@ def __init_subclass__( if accepts_consts is not None: cls.accepts_consts = accepts_consts - def __init__(self, param: Param): - self.param = param + def __init__(self, param: P): + self.param: P = param def __str__(self) -> str: return self.name @@ -171,10 +183,8 @@ def finalize_value(self, value): # region Mixins -class ValueMixin: +class _ValueAction(ParamAction[P], ABC): __slots__ = () - param: Param # noqa - get_default: Callable def set_value(self, value): if (prev := ctx.get_parsed_value(self.param)) is not _NotSet: @@ -204,12 +214,10 @@ def append_value(self, value): # parsed.extend(values) -class ConstMixin: +class _ConstAction(ParamAction[P], ABC): __slots__ = () - param: Param # noqa - get_default: Callable - add_const: Callable - add_value: Callable + + _append: ClassVar[bool] def __init_subclass__(cls, append: bool = False, **kwargs): super().__init_subclass__(**kwargs) @@ -249,15 +257,16 @@ def add_env_value(self, value: str, env_var: str) -> Found: self.append_const(const) else: self.set_const(const) - return 1 elif const: return self.add_const() + return 1 + # endregion -class Store(ValueMixin, ParamAction, default=None, accepts_values=True): +class Store(_ValueAction, default=None, accepts_values=True): __slots__ = () default_nargs = Nargs(1) @@ -293,7 +302,7 @@ def would_accept(self, value: str, combo: bool = False) -> bool: # endregion -class Append(ValueMixin, ParamAction, accepts_values=True): +class Append(_ValueAction, accepts_values=True): __slots__ = () default_nargs = Nargs('+') @@ -388,7 +397,7 @@ def finalize_value(self, value): # endregion -class BasicConstAction(ConstMixin, ParamAction, ABC, accepts_consts=True): +class BasicConstAction(_ConstAction, ABC, accepts_consts=True): __slots__ = () default_nargs = Nargs(0) @@ -485,9 +494,9 @@ def add_const(self, *, opt: OptStr = None, combo: bool = False) -> Found: def add_value(self, value: str, *, combo: bool = False, joined: Bool = False, env_var: OptStr = None) -> Found: ctx.record_action(self.param) - value = self.param.prepare_value(value, combo, env_var) - self.param.validate(value, joined) - self._add(value) + int_val = self.param.prepare_value(value, combo, env_var) + self.param.validate(int_val, joined) + self._add(int_val) return 1 # endregion diff --git a/lib/cli_command_parser/parameters/base.py b/lib/cli_command_parser/parameters/base.py index cea419e8..354ed702 100644 --- a/lib/cli_command_parser/parameters/base.py +++ b/lib/cli_command_parser/parameters/base.py @@ -44,7 +44,7 @@ _CmdCls = Type[Command] _CmdObjOrCls: TypeAlias = Command | _CmdCls -__all__ = ['Parameter', 'BasePositional', 'BaseOption'] +__all__ = ['Param', 'Parameter', 'BasePositional', 'BaseOption'] _group_stack: ContextVar[list[ParamGroup]] = ContextVar('cli_command_parser.parameters.base.group_stack') _is_numeric = re.compile(r'^-\d+$|^-\d*\.\d+?$').match @@ -60,6 +60,20 @@ TD = TypeVar('TD') +class Param(Generic[T]): + __slots__ = () + + if TYPE_CHECKING: + + @overload + def __get__(self, command: Literal[None], owner: Any = None) -> Self: ... + + @overload + def __get__(self, command: object, owner: Any = None) -> T | None: ... + + def __get__(self, command: object | None, owner: Any = None) -> Self | T | None: ... + + class ParamBase(ABC): """ Base class for :class:`Parameter` and :class:`.ParamGroup`. @@ -95,7 +109,6 @@ def __init__( self.required = required self.help = help self.hide = hide - # TODO: Make the --help flag a counter and allow some `hide=True` params to be shown with `-hh` or similar? self.name = name if param_groups := _group_stack.get(None): # If truthy, there's at least 1 active ParamGroup param_groups[-1].register(self) # This sets self.group = group @@ -171,7 +184,7 @@ def formatter(self) -> ParamHelpFormatter: except AttributeError: # self.command is None formatter_factory = ParamHelpFormatter - return formatter_factory(self) # noqa + return formatter_factory(self) # type: ignore[abstract] # __new__ automatically handles child selection @property @abstractmethod @@ -189,7 +202,7 @@ def format_help(self, *args, **kwargs) -> str: # endregion -class Parameter(ParamBase, Generic[T], ABC): +class Parameter(ParamBase, Param[T], ABC): """ Base class for all other parameters. It is not meant to be used directly. @@ -223,13 +236,13 @@ class Parameter(ParamBase, Generic[T], ABC): # fmt: off # Class attributes _action_map: dict[str, Type[ParamAction]] = {} - _repr_attrs: Strings | None = None #: Attributes to include in ``repr()`` output + _repr_attrs: Strings = () #: Attributes to include in ``repr()`` output # Instance attributes with class defaults metavar: OptStr = None nargs: Nargs # Expected to be set in subclasses type: Callable[[str], T] | None = None # Expected to be set in subclasses allow_leading_dash: AllowLeadingDash = AllowLeadingDash.NUMERIC # Set in some subclasses - default = _NotSet + default: T | _NotSetType = _NotSet default_cb: DefaultCallback | None = None show_default: Bool = None strict_default: Bool = False @@ -259,7 +272,7 @@ def __init__( # pylint: disable=R0913 metavar: OptStr = None, name: OptStr = None, required: Bool = False, - default: Any = _NotSet, + default: T | _NotSetType = _NotSet, default_cb: DefaultFunc | None = None, cb_with_cmd: Bool = False, show_default: Bool = None, @@ -356,7 +369,7 @@ def __repr__(self) -> str: skip = (None, _NotSet) attrs = ( - (a, str(v) if a == 'action' else v) + (a, str(v) if a == 'action' else tuple(v) if a == 'choices' else v) # type: ignore[arg-type] for a in names if (v := getattr(self, a, None)) not in skip and not (a == 'hide' and not v) ) @@ -395,7 +408,7 @@ def prepare_validation_value(self, value: str, short_combo: Bool = False) -> T | return self.prepare_value(value, short_combo) - def validate(self, value: T | None, joined: Bool = False): + def validate(self, value: Any, joined: Bool = False): if not isinstance(value, str) or not value or not value[0] == '-': return elif self.allow_leading_dash == AllowLeadingDash.NUMERIC: @@ -422,12 +435,12 @@ def is_valid_arg(self, value: Any) -> bool: # region Parse Results / Argument Value Handling @overload - def __get__(self, command: Literal[None], owner: Any) -> Self: ... + def __get__(self, command: Literal[None], owner: Any = None) -> Self: ... @overload - def __get__(self, command: object, owner: Any) -> T | None: ... + def __get__(self, command: object, owner: Any = None) -> T | None: ... - def __get__(self, command: object | None, owner: Any) -> Self | T | None: + def __get__(self, command: object | None, owner: Any = None) -> Self | T | None: if command is None: return self @@ -559,7 +572,7 @@ class - it is not meant to be used directly. show_env_var: Bool = None strict_env: Bool use_env_value: Bool - const = _NotSet + const: T | _NotSetType = _NotSet def __init__( self, @@ -617,12 +630,18 @@ def __init__(self, default: AllowLeadingDash = AllowLeadingDash.NUMERIC): def __set_name__(self, owner, name: str): self.name = name - def __get__(self, instance: Parameter | None, owner) -> AllowLeadingDash | AllowLeadingDashProperty: + @overload + def __get__(self, instance: None, owner: Any) -> AllowLeadingDashProperty: ... + + @overload + def __get__(self, instance: Parameter, owner: Any) -> AllowLeadingDash: ... + + def __get__(self, instance: Parameter | None, owner: Any) -> AllowLeadingDash | AllowLeadingDashProperty: if instance is None: return self return instance.__dict__.get(self.name, self.default) - def __set__(self, instance: Parameter, value: LeadingDash): + def __set__(self, instance: Parameter, value: LeadingDash | None): if value is not None: value = AllowLeadingDash(value) diff --git a/lib/cli_command_parser/parameters/choice_map.py b/lib/cli_command_parser/parameters/choice_map.py index 41f38b9a..979a95c4 100644 --- a/lib/cli_command_parser/parameters/choice_map.py +++ b/lib/cli_command_parser/parameters/choice_map.py @@ -43,15 +43,9 @@ class Choice(Generic[T]): target: T - def __init__( - self, - choice: OptStr, - target: T | _NotSetType = _NotSet, - help: str | None = None, # noqa - local: bool = False, - ): + def __init__(self, choice: OptStr, target: T, help: OptStr = None, local: bool = False): # noqa self.choice = choice - self.target = choice if target is _NotSet else target + self.target = target self.help = help self.local = local @@ -68,7 +62,7 @@ def format_help(self, lpad: int = 4, prefix: str = '') -> str: return format_help_entry((self.format_usage(),), self.help, prefix, lpad=lpad) -class ChoiceMap(BasePositional[str], Generic[T], actions=(Concatenate,)): +class ChoiceMap(BasePositional[str | None], Generic[T], actions=(Concatenate,)): """ Base class for :class:`SubCommand` and :class:`Action`. It is not meant to be used directly. @@ -87,7 +81,7 @@ class ChoiceMap(BasePositional[str], Generic[T], actions=(Concatenate,)): :param kwargs: Additional keyword arguments to pass to :class:`.BasePositional`. """ - _choice_validation_exc = ParameterDefinitionError + _choice_validation_exc: Type[Exception] = ParameterDefinitionError _default_title: str = 'Choices' nargs = Nargs('+') choices: dict[OptStr, Choice[T]] @@ -95,9 +89,7 @@ class ChoiceMap(BasePositional[str], Generic[T], actions=(Concatenate,)): description: OptStr formatter: ChoiceMapHelpFormatter - def __init_subclass__( # pylint: disable=W0222 - cls, title: OptStr = None, choice_validation_exc: Type[Exception] = None, **kwargs - ): + def __init_subclass__(cls, title: OptStr = None, choice_validation_exc: Type[Exception] | None = None, **kwargs): """ :param title: Default title to use for help text sections containing the choices for this parameter. :param choice_validation_exc: The type of exception to raise when validating defined choices. @@ -143,17 +135,11 @@ def _validate_positional(cls, value: str): if bad := {c for c in value if (c in whitespace and c != ' ') or c not in printable}: raise cls._choice_validation_exc(f'Invalid {cls.__name__} choice={value!r} - invalid characters: {bad}') - def register_choice(self, choice: str, target: T = _NotSet, help: OptStr = None): # noqa + def register_choice(self, choice: str, target: T, help: OptStr = None): # noqa self._validate_positional(choice) self._register_choice(choice, target, help) - def _register_choice( - self, - choice: OptStr, - target: T | None | _NotSetType = _NotSet, - help: OptStr = None, # noqa - local: bool = False, - ): + def _register_choice(self, choice: OptStr, target: T, help: OptStr = None, local: bool = False): # noqa try: existing = self.choices[choice] except KeyError: @@ -182,14 +168,14 @@ def validate(self, value: str | Sequence[str], joined: Bool = False): if (choice := ' '.join(values)) in self.choices: return - elif len(values) > self.nargs.max: + elif len(values) > self.nargs.max: # type: ignore[operator] # it's guaranteed to be bound / have a max here raise BadArgument(self, 'too many values') prefix = choice + ' ' if not any(c.startswith(prefix) for c in self.choices if c): raise InvalidChoice(self, prefix[:-1], self.choices) - def result(self, command: CommandObj | None = None, missing_default: TD = _NotSet) -> OptStr | TD: + def result(self, command: CommandObj | None = None, missing_default: TD | _NotSetType = _NotSet) -> OptStr | TD: if not self.choices: self._no_choices_error() return super().result(command, missing_default) @@ -208,7 +194,7 @@ def show_in_help(self) -> bool: # endregion -class SubCommand(ChoiceMap[CommandCls], title='Subcommands', choice_validation_exc=CommandDefinitionError): +class SubCommand(ChoiceMap[CommandCls | None], title='Subcommands', choice_validation_exc=CommandDefinitionError): """ Used to indicate the position where a choice that results in delegating execution of the program to a sub-command should be provided. @@ -252,7 +238,7 @@ def has_local_choices(self) -> bool: def _register_local_choices(self, local_choices: Mapping[str, str] | Collection[str]): try: - choice_help_iter = local_choices.items() + choice_help_iter = local_choices.items() # type: ignore[union-attr] except AttributeError: choice_help_iter = ((choice, None) for choice in local_choices) @@ -267,7 +253,7 @@ def register_command(self, choice: OptStr, command: CommandCls, help: OptStr) -> if help is None: # This approach was used because importing get_metadata from core would result in a circular dependency - meta: ProgramMetadata = command.__class__.meta(command) + meta: ProgramMetadata = command.__class__.meta(command) # type: ignore[attr-defined] # print(f'Registering {choice=} -> {command=} w/ {meta.description=}, {meta.parent=}') if meta.description and (not meta.parent or meta.parent.description != meta.description): help = meta.description # noqa diff --git a/lib/cli_command_parser/parameters/options.py b/lib/cli_command_parser/parameters/options.py index a1b72e3f..e7c5ac8c 100644 --- a/lib/cli_command_parser/parameters/options.py +++ b/lib/cli_command_parser/parameters/options.py @@ -7,14 +7,15 @@ from __future__ import annotations import logging +import sys from functools import partial, update_wrapper -from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, TypeVar, overload from ..exceptions import BadArgument, CommandDefinitionError, ParameterDefinitionError, ParamUsageError, ParserExit from ..inputs import normalize_input_type from ..nargs import Nargs, NargsValue from ..typing import TypeFunc -from ..utils import _NotSet, str_to_bool +from ..utils import _NotSet, _NotSetType, str_to_bool from .actions import Append, AppendConst, Count, Store, StoreConst from .base import AllowLeadingDashProperty, BaseOption, CommandMethod from .option_strings import TriFlagOptionStrings @@ -35,14 +36,19 @@ ] log = logging.getLogger(__name__) -T = TypeVar('T') -TD = TypeVar('TD') -TC = TypeVar('TC') -TA = TypeVar('TA') +if sys.version_info >= (3, 13): + T = TypeVar('T', default=str) + B = TypeVar('B', default=bool) + TF = TypeVar('TF', default=bool | None) +else: + T = TypeVar('T') + B = TypeVar('B') + TF = TypeVar('TF') + ConstAct = Literal['store_const', 'append_const'] -class Option(BaseOption[T | TD], actions=(Store, Append)): +class Option(BaseOption[T | None], actions=(Store, Append)): """ A generic option that can be specified as ``--foo bar`` or by using other similar forms. @@ -72,7 +78,7 @@ class Option(BaseOption[T | TD], actions=(Store, Append)): :param kwargs: Additional keyword arguments to pass to :class:`.BaseOption`. """ - default: TD + default: T allow_leading_dash = AllowLeadingDashProperty() def __init__( @@ -80,36 +86,24 @@ def __init__( *option_strs: str, nargs: NargsValue | None = None, action: Literal['store', 'append'] | None = None, - default: TD = _NotSet, + default: T | _NotSetType = _NotSet, required: Bool = False, type: InputTypeFunc = None, # noqa choices: ChoicesType = None, allow_leading_dash: LeadingDash | None = None, **kwargs, ): - if nargs_provided := nargs is not None: - nargs = Nargs(nargs) - if 0 in nargs: - nargs = nargs._orig - details = 'use Flag or Counter for Options that can be specified without a value' - if isinstance(nargs, range) and nargs.start == 0 and nargs.step != nargs.stop: - suffix = f', {nargs.step}' if nargs.step != 1 else '' - details = f'try using range({nargs.step}, {nargs.stop}{suffix}) instead, or {details}' - raise ParameterDefinitionError(f'Invalid {nargs=} - {details}') - + _nargs: Nargs | None = _validate_option_nargs(nargs) if not action: - if nargs_provided: - action = 'store' if nargs == 1 else 'append' + if _nargs is not None: + action = 'store' if _nargs == 1 else 'append' else: action = 'store' - elif nargs_provided and action == 'store' and nargs != 1: - raise ParameterDefinitionError(f'Invalid {nargs=} for {action=}') + elif _nargs is not None and action == 'store' and _nargs != 1: + raise ParameterDefinitionError(f'Invalid nargs={_nargs} for {action=}') super().__init__(*option_strs, action=action, default=default, required=required, **kwargs) - if not nargs_provided: - nargs = self.action.default_nargs - - self.nargs = nargs + self.nargs = self.action.default_nargs if _nargs is None else _nargs self.type = normalize_input_type(type, choices) self.allow_leading_dash = allow_leading_dash @@ -119,7 +113,22 @@ def _handle_bad_action(self, action: str) -> NoReturn: super()._handle_bad_action(action) -class Flag(BaseOption[TD | TC], actions=(StoreConst, AppendConst)): +def _validate_option_nargs(nargs_val: NargsValue | None) -> Nargs | None: + if nargs_val is None: + return None + + nargs = Nargs(nargs_val) + if 0 not in nargs: + return nargs + + details = 'use Flag or Counter for Options that can be specified without a value' + if isinstance(nargs_val, range) and nargs_val.start == 0 and nargs_val.step != nargs_val.stop: + suffix = f', {nargs_val.step}' if nargs_val.step != 1 else '' + details = f'try using range({nargs_val.step}, {nargs_val.stop}{suffix}) instead, or {details}' + raise ParameterDefinitionError(f'Invalid nargs={nargs_val} - {details}') + + +class Flag(BaseOption[B], actions=(StoreConst, AppendConst)): """ A (typically boolean) option that does not accept any values. @@ -150,45 +159,47 @@ class Flag(BaseOption[TD | TC], actions=(StoreConst, AppendConst)): """ nargs = Nargs(0) - type = staticmethod(str_to_bool) # Without staticmethod, this would be interpreted as a normal method + type: TypeFunc = staticmethod(str_to_bool) # Without staticmethod, this would be interpreted as a normal method use_env_value: bool = False __default_const_map = {True: False, False: True, _NotSet: True} - default: TD - const: TC + default: B + const: B def __init__( self, *option_strs: str, action: ConstAct = 'store_const', - default: TD = _NotSet, + default: B | _NotSetType = _NotSet, default_cb=_NotSet, - const: TC = _NotSet, + const: B | _NotSetType = _NotSet, type: TypeFunc | None = None, # noqa **kwargs, ): if const is _NotSet: try: - const = self.__default_const_map[default] + const = self.__default_const_map[default] # type: ignore[assignment] except KeyError as e: raise ParameterDefinitionError( f"A 'const' value is required for {self.__class__.__name__} since {default=} is not True or False" ) from e + if default_cb is not _NotSet: cls_name = self.__class__.__name__ raise ParameterDefinitionError(f"The 'default_cb' arg is not supported for {cls_name} parameters") + if default is _NotSet: - default = self.__default_const_map.get(const, _NotSet) # will be True or False + default = self.__default_const_map.get(const, _NotSet) # type: ignore[assignment] # will be True or False if default is False: # Avoid surprises for custom non-truthy values kwargs.setdefault('show_default', False) super().__init__(*option_strs, action=action, default=default, **kwargs) - self.const = const + self.const = const # type: ignore[assignment] if type is not None: self.type = type def register_default_cb(self, method): raise ParameterDefinitionError(f'{self.__class__.__name__}s do not support default callback methods') - def get_env_const(self, value: str, env_var: str) -> tuple[TC | TD, bool]: + def get_env_const(self, value: str, env_var: str) -> tuple[B, bool]: try: parsed = self.type(value) except Exception as e: @@ -198,7 +209,7 @@ def get_env_const(self, value: str, env_var: str) -> tuple[TC | TD, bool]: return parsed, self.use_env_value -class TriFlag(BaseOption[TD | TC | TA], actions=(StoreConst, AppendConst)): +class TriFlag(BaseOption[TF], actions=(StoreConst, AppendConst)): """ A trinary / ternary Flag. While :class:`.Flag` only supports 1 constant when provided, with 1 default if not provided, this class accepts a pair of constants for the primary and alternate values to store, along with a @@ -234,26 +245,26 @@ class TriFlag(BaseOption[TD | TC | TA], actions=(StoreConst, AppendConst)): """ nargs = Nargs(0) - type = staticmethod(str_to_bool) # Without staticmethod, this would be interpreted as a normal method + type: TypeFunc = staticmethod(str_to_bool) # Without staticmethod, this would be interpreted as a normal method use_env_value: bool = False _default_cb_ok = True _opt_str_cls = TriFlagOptionStrings option_strs: TriFlagOptionStrings alt_help: OptStr = None - default: TD - consts: tuple[TC, TA] + default: TF | _NotSetType + consts: tuple[TF, TF] def __init__( self, *option_strs: str, - consts: tuple[TC, TA] = (True, False), + consts: tuple[TF, TF] = (True, False), # type: ignore[assignment] alt_prefix: OptStr = None, alt_long: OptStr = None, alt_short: OptStr = None, alt_help: OptStr = None, action: ConstAct = 'store_const', - default: TD = _NotSet, - default_cb: Callable[[], TD] | None = None, + default: TF | _NotSetType = _NotSet, + default_cb: Callable[[], TF] | None = None, type: TypeFunc | None = None, # noqa **kwargs, ): @@ -272,9 +283,10 @@ def __init__( if default is _NotSet and default_cb is None: if not kwargs.get('required', False): - default = None + default = None # type: ignore[assignment] else: self._default_cb_ok = False + if default in consts: raise ParameterDefinitionError( f'Invalid {default=} with {consts=} - the default must not match either value' @@ -298,13 +310,13 @@ def register_default_cb(self, method: CommandMethod) -> CommandMethod: self.default = _NotSet # The default was set by __init__ - remove it so the method can be registered return super().register_default_cb(method) - def get_const(self, opt_str: OptStr = None) -> TC | TA: + def get_const(self, opt_str: OptStr = None) -> TF: if opt_str in self.option_strs.alt_allowed: return self.consts[1] else: return self.consts[0] - def get_env_const(self, value: str, env_var: str) -> tuple[TC | TA | TD, bool]: + def get_env_const(self, value: str, env_var: str) -> tuple[TF, bool]: try: parsed = self.type(value) except Exception as e: @@ -321,7 +333,7 @@ def get_env_const(self, value: str, env_var: str) -> tuple[TC | TA | TD, bool]: # region Action Flag -class ActionFlag(Flag, repr_attrs=('order', 'before_main')): +class ActionFlag(Flag[bool], repr_attrs=('order', 'before_main')): """ A :class:`.Flag` that triggers the execution of a function / method / other callable when specified. @@ -380,7 +392,7 @@ def __hash__(self) -> int: result ^= hash(attr) return result - def __eq__(self, other: ActionFlag) -> bool: + def __eq__(self, other) -> bool: if not isinstance(other, ActionFlag): return NotImplemented return all(getattr(self, a) == getattr(other, a) for a in ('name', 'func', 'command', 'order', 'before_main')) @@ -403,7 +415,13 @@ def __call__(self, func: Callable) -> ActionFlag: self.func = func return self - def __get__(self, command: CommandObj | None, owner: CommandCls) -> ActionFlag | Callable: + @overload # type: ignore[override] + def __get__(self, command: None, owner: Any = None) -> ActionFlag: ... + + @overload + def __get__(self, command: object, owner: Any = None) -> Callable: ... + + def __get__(self, command: object | None, owner: Any = None) -> ActionFlag | Callable: # Allow the method to be called, regardless of whether it was specified if command is None: return self @@ -477,7 +495,7 @@ def __init__( action: str = 'count', init: int = 0, const: int = 1, - default: int = _NotSet, + default: int | _NotSetType = _NotSet, default_cb: Callable[[], int] | None = None, required: bool = False, **kwargs, @@ -500,13 +518,14 @@ def register_default_cb(self, method: CommandMethod) -> CommandMethod: self.default_cb = None return super().register_default_cb(method) - def prepare_value(self, value: str | None, short_combo: bool = False, env_var: OptStr = None) -> int: + def prepare_value(self, value: OptStr, short_combo: Bool = False, env_var: OptStr = None) -> int: try: - return self.type(value) + return self.type(value) # type: ignore[arg-type] except (ValueError, TypeError) as e: combinable = self.option_strs.combinable - if short_combo and combinable and all(c in combinable for c in value): - return len(value) + 1 # +1 for the -short that preceded this value + if short_combo and combinable and all(c in combinable for c in value): # type: ignore[union-attr] + return len(value) + 1 # type: ignore[arg-type] # +1 for the -short that preceded this value + suffix = f' from env var={env_var!r}' if env_var else '' raise BadArgument(self, f'bad counter {value=}{suffix}') from e diff --git a/lib/cli_command_parser/parameters/positionals.py b/lib/cli_command_parser/parameters/positionals.py index c734a9bf..f63664ab 100644 --- a/lib/cli_command_parser/parameters/positionals.py +++ b/lib/cli_command_parser/parameters/positionals.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, TypeVar from ..exceptions import ParameterDefinitionError from ..inputs import normalize_input_type @@ -20,8 +20,10 @@ __all__ = ['Positional'] +T = TypeVar('T') -class Positional(BasePositional, default_ok=True, actions=(Store, Append)): + +class Positional(BasePositional[T], default_ok=True, actions=(Store, Append)): """ A parameter that must be provided positionally. @@ -65,23 +67,23 @@ def __init__( **kwargs, ): if nargs_provided := nargs is not None: - self.nargs = nargs = Nargs(nargs) - if nargs == 0: + self.nargs = Nargs(nargs) + if self.nargs == 0: raise ParameterDefinitionError( f'Invalid {nargs=} - {self.__class__.__name__} must allow at least 1 value' ) else: - self.nargs = nargs = Nargs(1) + self.nargs = Nargs(1) if not action: if nargs_provided: - action = 'store' if nargs == 1 or nargs == Nargs('?') else 'append' + action = 'store' if self.nargs == 1 or self.nargs == Nargs('?') else 'append' else: action = 'store' - elif nargs_provided and action == 'store' and nargs.max != 1: - raise ParameterDefinitionError(f'Invalid {action=} for {nargs=}') + elif nargs_provided and action == 'store' and self.nargs.max != 1: + raise ParameterDefinitionError(f'Invalid {action=} for nargs={self.nargs}') - if (required := 0 not in nargs) and (default is not _NotSet or default_cb is not None): + if (required := 0 not in self.nargs) and (default is not _NotSet or default_cb is not None): raise ParameterDefinitionError( f'Invalid {default=} or {default_cb=} - only allowed for Positional parameters when nargs=? or nargs=*' ) diff --git a/lib/cli_command_parser/parse_tree.py b/lib/cli_command_parser/parse_tree.py index bf1dea5d..d33e9335 100644 --- a/lib/cli_command_parser/parse_tree.py +++ b/lib/cli_command_parser/parse_tree.py @@ -4,19 +4,22 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Collection, Iterable, Iterator, MutableMapping, TypeAlias +from typing import TYPE_CHECKING, Collection, Iterable, Iterator, MutableMapping, Type, TypeAlias +from .core import get_params from .exceptions import AmbiguousParseTree from .utils import _parse_tree_target_repr if TYPE_CHECKING: + from .commands import Command from .core import CommandMeta from .nargs import Nargs from .parameters.base import BasePositional from .parameters.choice_map import Choice, ChoiceMap from .typing import OptStr - Target: TypeAlias = BasePositional | CommandMeta | None + CommandCls: TypeAlias = Type[Command] | CommandMeta + Target: TypeAlias = BasePositional | CommandCls | None __all__ = ['PosNode'] @@ -258,9 +261,9 @@ def __iter__(self) -> Iterator[Word]: # region Build Tree @classmethod - def build_tree(cls, command: CommandMeta) -> PosNode: + def build_tree(cls, command: CommandCls) -> PosNode: root = cls(None, None, target=command) - _process_params(command, [root], command.__class__.params(command).all_positionals) + _process_params(command, [root], get_params(command).all_positionals) return root def update_node(self, word: Word, param: BasePositional, target: Target) -> PosNode: @@ -335,7 +338,7 @@ def _has_upper_bound(node) -> bool: def _process_params( - command: CommandMeta, nodes: Iterable[PosNode], params: Iterable[BasePositional] + command: CommandCls, nodes: Iterable[PosNode], params: Iterable[BasePositional] ) -> Iterable[PosNode]: for param in params: nodes = _process_param(command, nodes, param) @@ -343,15 +346,13 @@ def _process_params( return nodes -def _process_param(command: CommandMeta, nodes: Iterable[PosNode], param: BasePositional | ChoiceMap) -> set[PosNode]: +def _process_param(command: CommandCls, nodes: Iterable[PosNode], param: BasePositional | ChoiceMap) -> set[PosNode]: # At each step, the number of branches grows try: - choices: dict[OptStr, Choice[CommandMeta]] = param.choices # type: ignore[union-attr] + choices: dict[OptStr, Choice[CommandCls]] = param.choices # type: ignore[union-attr] except AttributeError: # It was not a ChoiceMap param pass else: - get_params = command.__class__.params - new_nodes: set[PosNode] = set() for choice in choices.values(): target = choice.target diff --git a/lib/cli_command_parser/parser.py b/lib/cli_command_parser/parser.py index 40136029..083d7a32 100644 --- a/lib/cli_command_parser/parser.py +++ b/lib/cli_command_parser/parser.py @@ -9,7 +9,7 @@ import logging from collections import deque from os import environ -from typing import TYPE_CHECKING, Deque, Sequence +from typing import TYPE_CHECKING, Deque, Sequence, Type, TypeAlias from .context import ActionPhase, Context from .core import get_parent @@ -30,9 +30,11 @@ from .command_parameters import CommandParameters from .commands import Command from .config import CommandConfig - from .core import CommandMeta from .typing import Bool, OptStr + CommandCls: TypeAlias = Type[Command] + Positionals = Sequence[BasePositional] + __all__ = ['CommandParser', 'parse_args_and_get_next_cmd'] log = logging.getLogger(__name__) @@ -47,32 +49,31 @@ class CommandParser: __slots__ = ('_last', 'arg_deque', 'ctx', 'config', 'deferred', 'params', 'positionals') - arg_deque: Deque[str] | None + arg_deque: Deque[str] config: CommandConfig - deferred: list[str] | None + deferred: list[str] params: CommandParameters positionals: list[BasePositional] - _last: Parameter | None def __init__(self, ctx: Context, params: CommandParameters, config: CommandConfig): - self._last = None + self._last: Parameter | None = None self.ctx = ctx self.params = params self.positionals = params.get_positionals_to_parse(ctx) self.config = config if config.reject_ambiguous_pos_combos: - PosNode.build_tree(ctx.command_cls) + PosNode.build_tree(ctx.command_cls) # type: ignore[arg-type] @classmethod - def parse_args_and_get_next_cmd(cls, ctx: Context) -> CommandMeta | None: + def parse_args_and_get_next_cmd(cls, ctx: Context) -> CommandCls | None: try: - return cls(ctx, ctx.params, ctx.config).get_next_cmd(ctx) + return cls(ctx, ctx.params, ctx.config).get_next_cmd(ctx) # type: ignore[arg-type] except UsageError: if not ctx.categorized_action_flags[_PRE_INIT]: raise return None - def get_next_cmd(self, ctx: Context) -> CommandMeta | None: + def get_next_cmd(self, ctx: Context) -> CommandCls | None: self._parse_args(ctx) self._validate_groups() missing = ctx.get_missing() @@ -115,7 +116,8 @@ def _parse_args(self, ctx: Context): self._parse_env_vars(ctx) - def _parse_env_vars(self, ctx: Context): + @classmethod + def _parse_env_vars(cls, ctx: Context): for param in ctx.missing_options_with_env_var(): for env_var in param.env_vars(): try: @@ -298,16 +300,12 @@ def _maybe_backtrack(self, param: Parameter, found: int) -> int: else: return found - def _get_backtrack_count( - self, param: Parameter, extras: Sequence[str] = (), positionals: Sequence[BasePositional] = () - ) -> int: + def _get_backtrack_count(self, param: Parameter, extras: Sequence[str] = (), positionals: Positionals = ()) -> int: if poppable_groups := param.action.get_maybe_poppable_values(): return next((len(g) for g in poppable_groups if self._should_backtrack(g, extras, positionals)), 0) return 0 - def _should_backtrack( - self, group: list[str], extras: Sequence[str] = (), positionals: Sequence[BasePositional] = () - ) -> bool: + def _should_backtrack(self, group: list[str], extras: Sequence[str] = (), positionals: Positionals = ()) -> bool: args = [*group, *extras, *self.arg_deque] for pos_param in positionals or self.positionals: n = pos_param.nargs.min @@ -329,7 +327,7 @@ def _maybe_backtrack_last_positional(self, param: BasePositional): By the time this method is called, it has already been discovered that `found` does not satisfy `param`'s nargs requirements. """ - if not self.config.allow_backtrack: + if not self.config.allow_backtrack or not self._last: # This method is called extremely rarely & it's cleaner to have this check here than in _finalize_consume return diff --git a/lib/cli_command_parser/typing.py b/lib/cli_command_parser/typing.py index eab33203..a57bd706 100644 --- a/lib/cli_command_parser/typing.py +++ b/lib/cli_command_parser/typing.py @@ -65,9 +65,9 @@ AnyConfig = Config | dict[str, Any] LeadingDash = Union['AllowLeadingDash', str, bool] -Param = TypeVar('Param', bound='Parameter') -ParamList = list[Param] -ParamOrGroup = Union[Param, 'ParamGroup'] +P = TypeVar('P', bound='Parameter') +ParamList = list[P] +ParamOrGroup = Union[P, 'ParamGroup'] CommandObj = TypeVar('CommandObj', bound='Command') CommandCls: TypeAlias = Type[CommandObj] diff --git a/mypy.ini b/mypy.ini index 456e7374..f8b1844b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1 @@ [mypy] -disable_error_code = import-untyped - -# Equivalent to --follow-imports=silent -# Mypy will check imported files but won't report errors in them -follow_imports = silent diff --git a/tests/test_documentation/test_help_text.py b/tests/test_documentation/test_help_text.py index a44092a2..36094cf4 100755 --- a/tests/test_documentation/test_help_text.py +++ b/tests/test_documentation/test_help_text.py @@ -14,7 +14,12 @@ from cli_command_parser.core import CommandMeta from cli_command_parser.exceptions import MissingArgument from cli_command_parser.formatting.commands import CommandHelpFormatter, get_usage_sub_cmds -from cli_command_parser.formatting.params import ChoiceGroup, ParamHelpFormatter, PositionalHelpFormatter +from cli_command_parser.formatting.params import ( + ChoiceGroup, + ParameterHelpFormatter, + ParamHelpFormatter, + PositionalHelpFormatter, +) from cli_command_parser.formatting.restructured_text import RstTable from cli_command_parser.inputs import Date, Day from cli_command_parser.parameters import Counter, Flag, Option, ParamGroup, PassThru, Positional, TriFlag, action_flag @@ -924,9 +929,9 @@ def test_non_base_formatter_cls_does_not_lookup_subclass(self): # Would be ChoiceMapHelpFormatter if it looked up the subclass self.assertIs(formatter.__class__, PositionalHelpFormatter) - def test_default_formatter_class_returned(self): - formatter = ParamHelpFormatter.for_param_cls(int) # noqa - self.assertIs(formatter, ParamHelpFormatter) + def test_invalid_param_class(self): + with self.assertRaisesRegex(ValueError, 'No help formatter is available for int objects'): + ParamHelpFormatter.for_param_cls(int) # type: ignore def test_formatter_uses_cmd_ctx(self): class Foo(Command): @@ -937,7 +942,7 @@ class Foo(Command): foo.bar # noqa def test_custom_formatter(self): - class CustomFormatter(ParamHelpFormatter): + class CustomFormatter(ParameterHelpFormatter): def format_help(self, *args, **kwargs): return 'test help' @@ -954,8 +959,8 @@ class Foo(ABC, metaclass=CommandMeta): self.assertIsInstance(CommandMeta.params(Foo).formatter, CommandHelpFormatter) def test_choice_group_add_no_str(self): - group = ChoiceGroup(Choice('')) - group.add(Choice(None)) + group = ChoiceGroup(Choice('', None)) + group.add(Choice(None, None)) self.assertEqual(0, len(group.choice_strs)) def test_group_desc_override(self): diff --git a/tests/test_parameters/test_choice_maps.py b/tests/test_parameters/test_choice_maps.py index b24e1e07..f94de66c 100755 --- a/tests/test_parameters/test_choice_maps.py +++ b/tests/test_parameters/test_choice_maps.py @@ -1,7 +1,6 @@ #!/usr/bin/env python from unittest import main -from unittest.mock import Mock from cli_command_parser import Command, Context from cli_command_parser.exceptions import BadArgument, CommandDefinitionError, InvalidChoice, ParameterDefinitionError @@ -173,7 +172,7 @@ def test_allow_leading_dash_not_allowed_action(self): # endregion def test_choice_format_help(self): - choice = Choice('test', help='Example choice') + choice = Choice('test', None, help='Example choice') self.assertEqual(' test Example choice', choice.format_help()) def test_default_when_missing(self): diff --git a/tests/test_parameters/test_misc.py b/tests/test_parameters/test_misc.py index 4385be18..4015017e 100755 --- a/tests/test_parameters/test_misc.py +++ b/tests/test_parameters/test_misc.py @@ -10,6 +10,7 @@ ChoiceMapHelpFormatter, GroupHelpFormatter, OptionHelpFormatter, + ParameterHelpFormatter, ParamHelpFormatter, PassThruHelpFormatter, PositionalHelpFormatter, @@ -63,7 +64,7 @@ class Foo(Command): def test_formatter_class(self): param_fmt_cls_map = { - Parameter: ParamHelpFormatter, + Parameter: ParameterHelpFormatter, PassThru: PassThruHelpFormatter, BasePositional: PositionalHelpFormatter, Positional: PositionalHelpFormatter,