From 6be7e462a5d7c0688ca486555d069134e60c77a2 Mon Sep 17 00:00:00 2001 From: James Mitchell Date: Tue, 17 Jun 2025 14:34:57 +0100 Subject: [PATCH] Add __call__ to CxxWrapper --- .../detail/cxx_wrapper.py | 53 ++++++++++++++----- tests/test_sims.py | 6 +++ 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/src/libsemigroups_pybind11/detail/cxx_wrapper.py b/src/libsemigroups_pybind11/detail/cxx_wrapper.py index b38a4e1a3..e9d590900 100644 --- a/src/libsemigroups_pybind11/detail/cxx_wrapper.py +++ b/src/libsemigroups_pybind11/detail/cxx_wrapper.py @@ -66,13 +66,21 @@ def __init__( optional_kwargs=(), **kwargs, ) -> None: - if len(args) == 1 and len(kwargs) == 0 and type(args[0]) in self._all_wrapped_cxx_types: + if ( + len(args) == 1 + and len(kwargs) == 0 + and type(args[0]) in self._all_wrapped_cxx_types + ): # Copy constructor like construction directly from cxx object self._cxx_obj = args[0] self.py_template_params = self.py_template_params_from_cxx_obj() return - if not len(required_kwargs) <= len(kwargs) <= len(required_kwargs) + len(optional_kwargs): + if ( + not len(required_kwargs) + <= len(kwargs) + <= len(required_kwargs) + len(optional_kwargs) + ): raise TypeError( f"expected between {len(required_kwargs)} and " f"{len(required_kwargs) + len(optional_kwargs)} " @@ -111,7 +119,9 @@ def __getattr__(self: Self, name: str): def cxx_fn_wrapper(*args) -> Any: if len(args) == 1 and isinstance(args[0], list): args = args[0] - return getattr(self._cxx_obj, name)([to_cxx(x) for x in args]) + return getattr(self._cxx_obj, name)( + [to_cxx(x) for x in args] + ) return getattr(self._cxx_obj, name)(*(to_cxx(x) for x in args)) return cxx_fn_wrapper @@ -126,14 +136,27 @@ def __copy__(self: Self) -> Self: if self._cxx_obj is not None: if hasattr(self._cxx_obj, "__copy__"): return to_py(self._cxx_obj.__copy__()) - raise NotImplementedError(f"{type(self._cxx_obj)} has no member named __copy__") + raise NotImplementedError( + f"{type(self._cxx_obj)} has no member named __copy__" + ) raise NameError("_cxx_obj has not been defined") - def __eq__(self: Self, that) -> bool: + def __eq__(self: Self, that: Self) -> bool: if self._cxx_obj is not None: if hasattr(self._cxx_obj, "__eq__"): return self._cxx_obj.__eq__(that._cxx_obj) - raise NotImplementedError(f"{type(self._cxx_obj)} has no member named __eq__") + raise NotImplementedError( + f"{type(self._cxx_obj)} has no member named __eq__" + ) + raise NameError("_cxx_obj has not been defined") + + def __call__(self: Self, *args) -> Any: + if self._cxx_obj is not None: + if callable(self._cxx_obj): + return self._cxx_obj.__call__(*args) + raise NotImplementedError( + f"{type(self._cxx_obj)} has no member named __call__" + ) raise NameError("_cxx_obj has not been defined") def py_template_params_from_cxx_obj(self: Self) -> tuple: @@ -152,9 +175,9 @@ def init_cxx_obj(self: Self, *args) -> None: defined. """ assert self.py_template_params is not None - self._cxx_obj = self._py_template_params_to_cxx_type[self.py_template_params]( - *(to_cxx(x) for x in args) - ) + self._cxx_obj = self._py_template_params_to_cxx_type[ + self.py_template_params + ](*(to_cxx(x) for x in args)) # TODO proper annotations @@ -170,13 +193,17 @@ def cxx_mem_fn_wrapper(self, *args): # TODO move the first if-clause into to_cxx? if len(args) == 1 and isinstance(args[0], list): args = [[to_cxx(x) for x in args[0]]] - result = getattr(to_cxx(self), cxx_mem_fn.__name__)(*(to_cxx(x) for x in args)) + result = getattr(to_cxx(self), cxx_mem_fn.__name__)( + *(to_cxx(x) for x in args) + ) if result is to_cxx(self): return self if type(result) in _CXX_WRAPPED_TYPE_TO_PY_TYPE: cached_val = f"_cached_return_value_{cxx_mem_fn.__name__}" # TODO use args too in cached_val? - if hasattr(self, cached_val) and result is to_cxx(getattr(self, cached_val)): + if hasattr(self, cached_val) and result is to_cxx( + getattr(self, cached_val) + ): return getattr(self, cached_val) result = _CXX_WRAPPED_TYPE_TO_PY_TYPE[type(result)](result) setattr(self, cached_val, result) @@ -210,7 +237,9 @@ def copy_cxx_mem_fns(cxx_class: pybind11_type, py_class: CxxWrapper) -> None: that call the cxx member function on the _cxx_obj. """ for py_meth_name in dir(cxx_class): - if (not py_meth_name.startswith("_")) and py_meth_name not in dir(py_class): + if (not py_meth_name.startswith("_")) and py_meth_name not in dir( + py_class + ): setattr( py_class, py_meth_name, diff --git a/tests/test_sims.py b/tests/test_sims.py index 8f0bfad34..141b7bed4 100644 --- a/tests/test_sims.py +++ b/tests/test_sims.py @@ -597,6 +597,12 @@ def test_sims_refiner_faithful_return_policy(): assert srf.init([[0, 1], [0]]) is srf +def test_sims_refiner_faithful_call(): + srf = SimsRefinerFaithful() + wg = WordGraph(2, [[0, 1]]) + assert srf(wg) + + def test_sims_refiner_ideals_return_policy(): sri = SimsRefinerIdeals() assert sri.presentation() is sri.presentation()