Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions src/libsemigroups_pybind11/detail/cxx_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,21 @@ def __init__(
optional_kwargs=(),
**kwargs,
) -> None:
if len(args) == 1 and len(kwargs) == 0 and type(args[0]) in self._all_wrapped_cxx_types:
if (
len(args) == 1
and len(kwargs) == 0
and type(args[0]) in self._all_wrapped_cxx_types
):
# Copy constructor like construction directly from cxx object
self._cxx_obj = args[0]
self.py_template_params = self.py_template_params_from_cxx_obj()
return

if not len(required_kwargs) <= len(kwargs) <= len(required_kwargs) + len(optional_kwargs):
if (
not len(required_kwargs)
<= len(kwargs)
<= len(required_kwargs) + len(optional_kwargs)
):
raise TypeError(
f"expected between {len(required_kwargs)} and "
f"{len(required_kwargs) + len(optional_kwargs)} "
Expand Down Expand Up @@ -111,7 +119,9 @@ def __getattr__(self: Self, name: str):
def cxx_fn_wrapper(*args) -> Any:
if len(args) == 1 and isinstance(args[0], list):
args = args[0]
return getattr(self._cxx_obj, name)([to_cxx(x) for x in args])
return getattr(self._cxx_obj, name)(
[to_cxx(x) for x in args]
)
return getattr(self._cxx_obj, name)(*(to_cxx(x) for x in args))

return cxx_fn_wrapper
Expand All @@ -126,14 +136,27 @@ def __copy__(self: Self) -> Self:
if self._cxx_obj is not None:
if hasattr(self._cxx_obj, "__copy__"):
return to_py(self._cxx_obj.__copy__())
raise NotImplementedError(f"{type(self._cxx_obj)} has no member named __copy__")
raise NotImplementedError(
f"{type(self._cxx_obj)} has no member named __copy__"
)
raise NameError("_cxx_obj has not been defined")

def __eq__(self: Self, that) -> bool:
def __eq__(self: Self, that: Self) -> bool:
if self._cxx_obj is not None:
if hasattr(self._cxx_obj, "__eq__"):
return self._cxx_obj.__eq__(that._cxx_obj)
raise NotImplementedError(f"{type(self._cxx_obj)} has no member named __eq__")
raise NotImplementedError(
f"{type(self._cxx_obj)} has no member named __eq__"
)
raise NameError("_cxx_obj has not been defined")

def __call__(self: Self, *args) -> Any:
if self._cxx_obj is not None:
if callable(self._cxx_obj):
return self._cxx_obj.__call__(*args)
raise NotImplementedError(
f"{type(self._cxx_obj)} has no member named __call__"
)
raise NameError("_cxx_obj has not been defined")

def py_template_params_from_cxx_obj(self: Self) -> tuple:
Expand All @@ -152,9 +175,9 @@ def init_cxx_obj(self: Self, *args) -> None:
defined.
"""
assert self.py_template_params is not None
self._cxx_obj = self._py_template_params_to_cxx_type[self.py_template_params](
*(to_cxx(x) for x in args)
)
self._cxx_obj = self._py_template_params_to_cxx_type[
self.py_template_params
](*(to_cxx(x) for x in args))


# TODO proper annotations
Expand All @@ -170,13 +193,17 @@ def cxx_mem_fn_wrapper(self, *args):
# TODO move the first if-clause into to_cxx?
if len(args) == 1 and isinstance(args[0], list):
args = [[to_cxx(x) for x in args[0]]]
result = getattr(to_cxx(self), cxx_mem_fn.__name__)(*(to_cxx(x) for x in args))
result = getattr(to_cxx(self), cxx_mem_fn.__name__)(
*(to_cxx(x) for x in args)
)
if result is to_cxx(self):
return self
if type(result) in _CXX_WRAPPED_TYPE_TO_PY_TYPE:
cached_val = f"_cached_return_value_{cxx_mem_fn.__name__}"
# TODO use args too in cached_val?
if hasattr(self, cached_val) and result is to_cxx(getattr(self, cached_val)):
if hasattr(self, cached_val) and result is to_cxx(
getattr(self, cached_val)
):
return getattr(self, cached_val)
result = _CXX_WRAPPED_TYPE_TO_PY_TYPE[type(result)](result)
setattr(self, cached_val, result)
Expand Down Expand Up @@ -210,7 +237,9 @@ def copy_cxx_mem_fns(cxx_class: pybind11_type, py_class: CxxWrapper) -> None:
that call the cxx member function on the _cxx_obj.
"""
for py_meth_name in dir(cxx_class):
if (not py_meth_name.startswith("_")) and py_meth_name not in dir(py_class):
if (not py_meth_name.startswith("_")) and py_meth_name not in dir(
py_class
):
setattr(
py_class,
py_meth_name,
Expand Down
6 changes: 6 additions & 0 deletions tests/test_sims.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,12 @@ def test_sims_refiner_faithful_return_policy():
assert srf.init([[0, 1], [0]]) is srf


def test_sims_refiner_faithful_call():
srf = SimsRefinerFaithful()
wg = WordGraph(2, [[0, 1]])
assert srf(wg)


def test_sims_refiner_ideals_return_policy():
sri = SimsRefinerIdeals()
assert sri.presentation() is sri.presentation()
Expand Down