Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 12 additions & 32 deletions src/libsemigroups_pybind11/detail/cxx_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,14 @@ def __init__(
required_kwargs=(),
optional_kwargs=(),
**kwargs,
):
if (
len(args) == 1
and len(kwargs) == 0
and type(args[0]) in self._all_wrapped_cxx_types
):
) -> None:
if len(args) == 1 and len(kwargs) == 0 and type(args[0]) in self._all_wrapped_cxx_types:
# Copy constructor like construction directly from cxx object
self._cxx_obj = args[0]
self.py_template_params = self.py_template_params_from_cxx_obj()
return

if (
not len(required_kwargs)
<= len(kwargs)
<= len(required_kwargs) + len(optional_kwargs)
):
if not len(required_kwargs) <= len(kwargs) <= len(required_kwargs) + len(optional_kwargs):
raise TypeError(
f"expected between {len(required_kwargs)} and "
f"{len(required_kwargs) + len(optional_kwargs)} "
Expand Down Expand Up @@ -119,9 +111,7 @@ def __getattr__(self: Self, name: str):
def cxx_fn_wrapper(*args) -> Any:
if len(args) == 1 and isinstance(args[0], list):
args = args[0]
return getattr(self._cxx_obj, name)(
[to_cxx(x) for x in args]
)
return getattr(self._cxx_obj, name)([to_cxx(x) for x in args])
return getattr(self._cxx_obj, name)(*(to_cxx(x) for x in args))

return cxx_fn_wrapper
Expand All @@ -136,18 +126,14 @@ def __copy__(self: Self) -> Self:
if self._cxx_obj is not None:
if hasattr(self._cxx_obj, "__copy__"):
return to_py(self._cxx_obj.__copy__())
raise NotImplementedError(
f"{type(self._cxx_obj)} has no member named __copy__"
)
raise NotImplementedError(f"{type(self._cxx_obj)} has no member named __copy__")
raise NameError("_cxx_obj has not been defined")

def __eq__(self: Self, that) -> bool:
if self._cxx_obj is not None:
if hasattr(self._cxx_obj, "__eq__"):
return self._cxx_obj.__eq__(that._cxx_obj)
raise NotImplementedError(
f"{type(self._cxx_obj)} has no member named __eq__"
)
raise NotImplementedError(f"{type(self._cxx_obj)} has no member named __eq__")
raise NameError("_cxx_obj has not been defined")

def py_template_params_from_cxx_obj(self: Self) -> tuple:
Expand All @@ -166,9 +152,9 @@ def init_cxx_obj(self: Self, *args) -> None:
defined.
"""
assert self.py_template_params is not None
self._cxx_obj = self._py_template_params_to_cxx_type[
self.py_template_params
](*(to_cxx(x) for x in args))
self._cxx_obj = self._py_template_params_to_cxx_type[self.py_template_params](
*(to_cxx(x) for x in args)
)


# TODO proper annotations
Expand All @@ -184,17 +170,13 @@ def cxx_mem_fn_wrapper(self, *args):
# TODO move the first if-clause into to_cxx?
if len(args) == 1 and isinstance(args[0], list):
args = [[to_cxx(x) for x in args[0]]]
result = getattr(to_cxx(self), cxx_mem_fn.__name__)(
*(to_cxx(x) for x in args)
)
result = getattr(to_cxx(self), cxx_mem_fn.__name__)(*(to_cxx(x) for x in args))
if result is to_cxx(self):
return self
if type(result) in _CXX_WRAPPED_TYPE_TO_PY_TYPE:
cached_val = f"_cached_return_value_{cxx_mem_fn.__name__}"
# TODO use args too in cached_val?
if hasattr(self, cached_val) and result is to_cxx(
getattr(self, cached_val)
):
if hasattr(self, cached_val) and result is to_cxx(getattr(self, cached_val)):
return getattr(self, cached_val)
result = _CXX_WRAPPED_TYPE_TO_PY_TYPE[type(result)](result)
setattr(self, cached_val, result)
Expand Down Expand Up @@ -228,9 +210,7 @@ def copy_cxx_mem_fns(cxx_class: pybind11_type, py_class: CxxWrapper) -> None:
that call the cxx member function on the _cxx_obj.
"""
for py_meth_name in dir(cxx_class):
if (not py_meth_name.startswith("_")) and py_meth_name not in dir(
py_class
):
if (not py_meth_name.startswith("_")) and py_meth_name not in dir(py_class):
setattr(
py_class,
py_meth_name,
Expand Down
2 changes: 1 addition & 1 deletion src/libsemigroups_pybind11/detail/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def copydoc(original):
for example:

@copydoc(Transf1.__init__)
def __init___(self):
def __init___(self) -> None:
pass
"""

Expand Down
27 changes: 7 additions & 20 deletions src/libsemigroups_pybind11/presentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class Presentation(_CxxWrapper): # pylint: disable=missing-class-docstring
def __eq__(self: Self, other: Self):
return _to_cxx(self) == _to_cxx(other)

def __init__(self: Self, *args, **kwargs):
def __init__(self: Self, *args, **kwargs) -> None:
"""
Construct a Presentation instance of the type specified by its argument.
"""
Expand All @@ -114,10 +114,7 @@ def __init__(self: Self, *args, **kwargs):
if len(args) == 1:
if not (
isinstance(args[0], (str, list))
or (
isinstance(self, InversePresentation)
and isinstance(self, Presentation)
)
or (isinstance(self, InversePresentation) and isinstance(self, Presentation))
):
extra = ""
if isinstance(self, InversePresentation):
Expand All @@ -126,12 +123,8 @@ def __init__(self: Self, *args, **kwargs):
f"expected the argument to have type one of (str, list{extra}) "
f"but found {type(args[0])}"
)
if isinstance(args[0], list) and not all(
isinstance(x, int) for x in args[0]
):
raise ValueError(
"expected the argument to consist of int values"
)
if isinstance(args[0], list) and not all(isinstance(x, int) for x in args[0]):
raise ValueError("expected the argument to consist of int values")
if isinstance(args[0], str):
self.py_template_params = (str,)
if isinstance(args[0], list):
Expand Down Expand Up @@ -171,9 +164,7 @@ class InversePresentation(Presentation):
_py_template_params_to_cxx_type = {
(List[int],): _InversePresentationWords,
(str,): _InversePresentationStrings,
(Presentation,): Union[
_InversePresentationWords, _InversePresentationStrings
],
(Presentation,): Union[_InversePresentationWords, _InversePresentationStrings],
}

_cxx_type_to_py_template_params = dict(
Expand Down Expand Up @@ -223,9 +214,7 @@ def __init__(self: Self, *args, **kwargs) -> None:
length = _wrap_cxx_free_fn(_length)
longest_rule = _wrap_cxx_free_fn(_longest_rule)
longest_rule_length = _wrap_cxx_free_fn(_longest_rule_length)
longest_subword_reducing_length = _wrap_cxx_free_fn(
_longest_subword_reducing_length
)
longest_subword_reducing_length = _wrap_cxx_free_fn(_longest_subword_reducing_length)
make_semigroup = _wrap_cxx_free_fn(_make_semigroup)
normalize_alphabet = _wrap_cxx_free_fn(_normalize_alphabet)
reduce_complements = _wrap_cxx_free_fn(_reduce_complements)
Expand All @@ -235,9 +224,7 @@ def __init__(self: Self, *args, **kwargs) -> None:
remove_trivial_rules = _wrap_cxx_free_fn(_remove_trivial_rules)
replace_subword = _wrap_cxx_free_fn(_replace_subword)
replace_word = _wrap_cxx_free_fn(_replace_word)
replace_word_with_new_generator = _wrap_cxx_free_fn(
_replace_word_with_new_generator
)
replace_word_with_new_generator = _wrap_cxx_free_fn(_replace_word_with_new_generator)
reverse = _wrap_cxx_free_fn(_reverse)
shortest_rule = _wrap_cxx_free_fn(_shortest_rule)
shortest_rule_length = _wrap_cxx_free_fn(_shortest_rule_length)
Expand Down
14 changes: 4 additions & 10 deletions src/libsemigroups_pybind11/transf.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,7 @@ def __repr__(self: Self) -> str:
result = str(self)
if len(result) < 72:
return result
return (
f"<transformation of degree {self.degree()} and rank {self.rank()}>"
)
return f"<transformation of degree {self.degree()} and rank {self.rank()}>"

# We retain a separate __repr__ so that we can distinguish the cxx objects
# and their python counterparts.
Expand Down Expand Up @@ -253,7 +251,7 @@ def _cxx_type_from_degree(n: int):
return _PPerm4

@_copydoc(_PPerm1.__init__)
def __init__(self: Self, *args):
def __init__(self: Self, *args) -> None:
if len(args) < 3:
super().__init__(*args)
return
Expand All @@ -269,9 +267,7 @@ def __repr__(self: Self) -> str:
result = str(self)
if len(result) < 72:
return result
return (
f"<partial perm of degree {self.degree()} and rank {self.rank()}>"
)
return f"<partial perm of degree {self.degree()} and rank {self.rank()}>"

# We retain a separate __str__ so that we can distinguish the cxx objects
# and their python counterparts.
Expand Down Expand Up @@ -353,9 +349,7 @@ def increase_degree_by(self: Self, n: int) -> Self:
@staticmethod
@_copydoc(_Perm1.one)
def one(n: int) -> Self:
result_type = Perm._py_template_params_to_cxx_type[
Perm._py_template_params_from_degree(n)
]
result_type = Perm._py_template_params_to_cxx_type[Perm._py_template_params_from_degree(n)]
return _to_py(result_type.one(n))


Expand Down