Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/cli_command_parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
Counter,
Flag,
Option,
Param,
Parameter,
ParamGroup,
PassThru,
Expand All @@ -52,4 +53,4 @@
after_main,
before_main,
)
from .typing import Param, ParamOrGroup
from .typing import ParamOrGroup
66 changes: 39 additions & 27 deletions lib/cli_command_parser/command_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,53 @@

from collections import defaultdict
from functools import cached_property
from typing import TYPE_CHECKING, Collection, Iterator
from typing import TYPE_CHECKING, Any, Collection, Iterator, Type, TypeAlias

from .config import AmbiguousComboMode, CommandConfig
from .exceptions import AmbiguousCombo, AmbiguousShortForm, CommandDefinitionError, ParameterDefinitionError
from .parameters import Action, ActionFlag, ParamGroup, PassThru, SubCommand, help_action
from .parameters import ActionFlag, ParamGroup, PassThru, help_action
from .parameters.base import BaseOption, BasePositional, ParamBase, Parameter
from .parameters.choice_map import Action, SubCommand

if TYPE_CHECKING:
from .commands import Command
from .context import Context
from .core import CommandMeta
from .formatting.commands import CommandHelpFormatter
from .typing import CommandCls, Strings
from .typing import Bool, Strings

CommandCls: TypeAlias = Type[Command] | CommandMeta
OptionMap = dict[str, BaseOption]
ActionFlags = list[ActionFlag]
Positionals = list[BasePositional] | tuple[()]

__all__ = ['CommandParameters']


class CommandParameters:
# fmt: off
command: CommandCls #: The Command associated with this CommandParameters object
formatter: CommandHelpFormatter #: The formatter used for this Command's help text
parent: CommandParameters | None #: The parent Command's CommandParameters
action: Action | None = None #: An Action Parameter, if specified
_pass_thru: PassThru | None = None #: A PassThru Parameter, if specified
sub_command: SubCommand | None = None #: A SubCommand Parameter, if specified
action_flags: ActionFlags #: List of action flags
split_action_flags: tuple[ActionFlags, ActionFlags] #: Action flags split by before/after main
options: list[BaseOption] #: List of optional Parameters
combo_option_map: OptionMap #: Mapping of {short opt: Parameter} (no dash characters)
groups: list[ParamGroup] #: List of ParamGroup objects
positionals: list[BasePositional] #: List of positional Parameters
_deferred_positionals: list[BasePositional] = () #: Positional Parameters that are deferred to sub commands
_deferred_positionals: Positionals = () #: Positional Parameters that are deferred to sub commands
option_map: OptionMap #: Mapping of {--opt / -opt: Parameter}
# fmt: on

def __init__(self, command: CommandCls, parent_params: CommandParameters | None, config: CommandConfig):
self.command = command
self.parent = parent_params
self.config = config
# fmt: off
# These are annotated here because mypy thinks they're invoked as descriptors when annotated at the class level
self.action: Action | None = None #: An Action Parameter, if specified
self.sub_command: SubCommand | None = None #: A SubCommand Parameter, if specified
self._pass_thru: PassThru | None = None #: A PassThru Parameter, if specified
# fmt: on
self._process_parameters()

def __repr__(self) -> str:
Expand All @@ -77,11 +84,8 @@ def has_nested_pass_thru(self) -> bool:

@cached_property
def all_positionals(self) -> list[BasePositional]:
try:
if not self.parent.sub_command:
return self.parent.all_positionals + self.positionals
except AttributeError:
pass
if self.parent and not self.parent.sub_command:
return self.parent.all_positionals + self.positionals
return self.positionals

def get_positionals_to_parse(self, ctx: Context) -> list[BasePositional]:
Expand All @@ -94,6 +98,7 @@ def get_positionals_to_parse(self, ctx: Context) -> list[BasePositional]:

@cached_property
def formatter(self) -> CommandHelpFormatter:
"""The formatter used for this Command's help text."""
from .formatting.commands import CommandHelpFormatter

formatter_factory = self.config.command_formatter or CommandHelpFormatter
Expand All @@ -105,13 +110,13 @@ def formatter(self) -> CommandHelpFormatter:
return formatter

@cached_property
def _has_help(self) -> bool:
def _has_help(self) -> Bool:
return help_action in self.action_flags or (self.parent and self.parent._has_help)

# region Initialization

def _iter_parameters(self) -> Iterator[ParamBase]:
name_param_map = {} # Allow subclasses to override names, but not within a given command
name_param_map: dict[str, Any] = {} # Allow subclasses to override names, but not within a given command
for item in self.command.__dict__.items():
attr, param = item
if attr.startswith('__') or not isinstance(param, ParamBase): # Name mangled Parameters are still processed
Expand Down Expand Up @@ -173,7 +178,9 @@ def _process_groups(self, groups: set[ParamGroup]):
self.groups = sorted(groups) if groups else []

def _process_positionals(self, params: list[BasePositional]):
unfollowable = action_or_sub_cmd = split_index = None
unfollowable: BasePositional | None = None
action_or_sub_cmd: SubCommand | Action | None = None
split_index: int = 0
if self.parent and (deferred := self.parent._deferred_positionals):
params = deferred + params

Expand All @@ -186,26 +193,28 @@ def _process_positionals(self, params: list[BasePositional]):
raise CommandDefinitionError(
f'Additional Positional parameters cannot follow {unfollowable} {why} - {param=} is invalid'
)
elif isinstance(param, (SubCommand, Action)):

if isinstance(param, (SubCommand, Action)):
if action_or_sub_cmd:
raise CommandDefinitionError(
f'Only 1 Action xor SubCommand is allowed in a given Command - {self.command.__name__} cannot'
f' contain both {action_or_sub_cmd} and {param}'
)
elif isinstance(param, SubCommand):

if isinstance(param, SubCommand):
self.sub_command = action_or_sub_cmd = param
split_index = i + 1
if param.has_choices and 0 in param.nargs: # It has local choices or is not required
unfollowable = param
else: # It's an Action
self.action = action_or_sub_cmd = param # type: ignore
self.action = action_or_sub_cmd = param
if not param.has_choices:
raise CommandDefinitionError(f'No choices were registered for {self.action}')
elif 0 in param.nargs or (param.nargs.variable and not param.has_choices):
unfollowable = param

if split_index:
if self.sub_command.has_local_choices:
if self.sub_command.has_local_choices: # type: ignore[union-attr]
self._deferred_positionals = params[split_index:]
else:
params, self._deferred_positionals = params[:split_index], params[split_index:]
Expand Down Expand Up @@ -252,19 +261,22 @@ def _process_option_strs(self, param: BaseOption, opt_type: str, opt_strs: Strin
f'{opt_type} {option=} conflict for command={self.command!r} between {existing} and {param}'
)

def _process_action_flags(self):
action_flags = sorted(p for p in self.options if isinstance(p, ActionFlag))
grouped_ordered_flags = {True: defaultdict(list), False: defaultdict(list)}
def _process_action_flags(self) -> None:
action_flags: list[ActionFlag] = sorted(p for p in self.options if isinstance(p, ActionFlag))
grouped_ordered_flags: dict[bool, dict[int | float, ActionFlags]] = {
True: defaultdict(list),
False: defaultdict(list),
}
for param in action_flags:
if param.func is None:
raise ParameterDefinitionError(f'No function was registered for {param=}')
grouped_ordered_flags[param.before_main][param.order].append(param)

found_non_always = False
invalid = {}
invalid: dict[tuple[bool, int | float], ActionFlags | ActionFlag] = {}
for before_main, prio_params in grouped_ordered_flags.items():
for prio, params in prio_params.items():
param: ActionFlag = params[0] # Don't pop and check `if params` - all are needed for the group check
param = params[0] # Don't pop and check `if params` - all are needed for the group check
if found_non_always and param.always_available:
invalid[(before_main, prio)] = param
elif not param.always_available:
Expand Down Expand Up @@ -296,7 +308,7 @@ def _process_action_flags(self):
@cached_property
def _classified_combo_options(self) -> tuple[OptionMap, OptionMap]:
"""Tuple of (single char short:Option map, multi-char short:Option map) for options available in this command"""
multi_char_combos = {}
multi_char_combos: OptionMap = {}
items = self.combo_option_map.items()
for combo, param in items:
if len(combo) == 1: # combo_option_map is sorted in reverse length order, so all following will be 1 char
Expand Down Expand Up @@ -368,7 +380,7 @@ def short_option_to_param_value_pairs(self, option: str) -> tuple[list[tuple[str
# Note: if the option is not in this Command's option_map, the KeyError is handled by CommandParser
return [(option, self.option_map[option], value)], True
else:
value = None
value = None # type: ignore[assignment]

try:
param = self.option_map[option]
Expand Down
49 changes: 16 additions & 33 deletions lib/cli_command_parser/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import logging
from abc import ABC
from contextlib import ExitStack
from typing import TYPE_CHECKING, Sequence, TextIO, Type, overload
from typing import TYPE_CHECKING, Sequence, TextIO, Type

from .context import ActionPhase, Context, get_or_create_context
from .core import CommandMeta, get_params, get_top_level_commands
from .core import CommandMeta, get_metadata, get_params, get_top_level_commands
from .exceptions import ParamConflict, ParserExit
from .parser import parse_args_and_get_next_cmd
from .utils import maybe_await
Expand All @@ -32,34 +32,25 @@ class Command(ABC, metaclass=CommandMeta):
#: The parsing Context used for this Command. Provided here for convenience - this reference to it is not used by
#: any CLI Command Parser internals, so it is safe for subclasses to redefine / overwrite it.
ctx: Context
__ctx: Context

def __new__(cls):
def __new__(cls) -> Command:
# By storing the Context here instead of __init__, every single subclass won't need to
# call super().__init__(...) from their own __init__ for this step
self = super().__new__(cls)
self.__ctx = ctx = get_or_create_context(cls, command=self)
if not hasattr(self, 'ctx'):
self.ctx: Context = ctx # noqa # PyCharm complains this is invalid, but doesn't understand it without it
self.ctx = ctx # noqa # PyCharm complains this is invalid, but doesn't understand it without it
return self

def __repr__(self) -> str:
cls = self.__class__
return f'<{cls.__name__} in prog={cls.__class__.meta(cls).prog!r}>'
return f'<{cls.__name__} in prog={get_metadata(cls).prog!r}>'

# region Parse & Run

@classmethod
@overload
def parse_and_run(cls: Type[CommandObj], argv: Argv = None, **kwargs) -> CommandObj | None:
# These overloads indicate that an instance of the same type or another may be returned
...

@classmethod
@overload
def parse_and_run(cls, argv: Argv = None, **kwargs) -> CommandObj | None: ...

@classmethod
def parse_and_run(cls, argv=None, **kwargs):
def parse_and_run(cls: Type[CommandObj], argv: Argv | None = None, **kwargs) -> CommandObj | None:
"""
Primary entry point for parsing arguments, resolving subcommands, and running a command.

Expand Down Expand Up @@ -90,15 +81,7 @@ def parse_and_run(cls, argv=None, **kwargs):
# region Parse

@classmethod
@overload
def parse(cls: Type[CommandObj], argv: Argv = None) -> CommandObj: ...

@classmethod
@overload
def parse(cls, argv: Argv = None) -> CommandObj: ...

@classmethod
def parse(cls, argv=None):
def parse(cls: Type[CommandObj], argv: Argv | None = None) -> CommandObj:
"""
Parses the specified arguments (or :data:`sys.argv`), and resolves the final subcommand class based on the
parsed arguments, if necessary. Initializes the Command, but does not call any of its other methods.
Expand All @@ -111,7 +94,7 @@ def parse(cls, argv=None):
with ExitStack() as stack:
stack.enter_context(ctx)
while sub_cmd := parse_args_and_get_next_cmd(ctx):
cmd_cls = sub_cmd
cmd_cls = sub_cmd # type: ignore[assignment]
ctx = stack.enter_context(ctx._sub_context(cmd_cls))

return cmd_cls()
Expand Down Expand Up @@ -308,14 +291,14 @@ async def parse_and_await(cls, argv=None, **kwargs):
await maybe_await(self(**kwargs))
return self

async def __call__(self, *args, **kwargs) -> int:
async def __call__(self, *args, **kwargs) -> int: # type: ignore[override]
"""Asynchronous version of :meth:`Command.__call__`."""
with self._Command__ctx as ctx, ctx.get_error_handler(): # noqa
with self._Command__ctx as ctx, ctx.get_error_handler(): # type: ignore[attr-defined]
await maybe_await(self._pre_init_actions_(*args, **kwargs))
await maybe_await(self._init_command_(*args, **kwargs))
await maybe_await(self._before_main_(*args, **kwargs))
try:
await maybe_await(self.main(*args, **kwargs))
await maybe_await(self.main(*args, **kwargs)) # type: ignore[arg-type]
except BaseException:
if ctx.config.always_run_after_main:
log.debug('Caught exception - running _after_main_ before propagating', exc_info=True)
Expand All @@ -328,7 +311,7 @@ async def __call__(self, *args, **kwargs) -> int:

async def _run_actions_(self, phase: ActionPhase, args: tuple, kwargs: dict):
"""Asynchronous version of :meth:`Command._run_actions_`."""
for param in self._Command__ctx.iter_action_flags(phase): # noqa
for param in self._Command__ctx.iter_action_flags(phase): # type: ignore[attr-defined]
await maybe_await(param.func(self, *args, **kwargs))

async def _pre_init_actions_(self, *args, **kwargs):
Expand All @@ -340,9 +323,9 @@ async def _before_main_(self, *args, **kwargs):
"""Asynchronous version of :meth:`Command._before_main_`."""
await self._run_actions_(ActionPhase.BEFORE_MAIN, args, kwargs)

async def main(self, *args, **kwargs) -> int | None:
async def main(self, *args, **kwargs) -> int | None: # type: ignore[override]
"""Asynchronous version of :meth:`Command.main`."""
with self._Command__ctx as ctx: # noqa
with self._Command__ctx as ctx: # type: ignore[attr-defined]
action = get_params(self).action
if action is not None and (ctx.actions_taken == 0 or ctx.config.action_after_action_flags):
ctx.actions_taken += 1
Expand All @@ -355,7 +338,7 @@ async def _after_main_(self, *args, **kwargs):
await self._run_actions_(ActionPhase.AFTER_MAIN, args, kwargs)


def main(argv: Argv = None, return_command: Bool = False, **kwargs) -> CommandObj | None:
def main(argv: Argv | None = None, return_command: Bool = False, **kwargs) -> CommandObj | None:
"""
Convenience function that can be used as the main entry point for a program.

Expand Down
Loading
Loading