diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 19f11841..ed5942af 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -264,6 +264,25 @@ def from_galsim(cls, galsim_bounds): else: return _cls() + def to_galsim(self): + """Create a galsim `BoundsD/I` from a `jax_galsim.BoundsD/I` object.""" + if isinstance(self, BoundsI): + gs_class = _galsim.bounds.BoundsI + cast = int + else: + gs_class = _galsim.bounds.BoundsD + cast = float + + if self.isDefined(): + return gs_class( + cast(self.xmin), + cast(self.xmax), + cast(self.ymin), + cast(self.ymax), + ) + else: + return gs_class() + @implements(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class diff --git a/jax_galsim/position.py b/jax_galsim/position.py index 0b8271a3..822797b8 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -168,6 +168,20 @@ def from_galsim(cls, galsim_position): ) return _cls(galsim_position.x, galsim_position.y) + def to_galsim(self): + """Create a galsim `PositionD/I` from a `jax_galsim.PositionD/I` object.""" + if isinstance(self, PositionI): + gs_class = _galsim.bounds.PositionI + cast = int + else: + gs_class = _galsim.bounds.PositionD + cast = float + + return gs_class( + cast(self.x), + cast(self.y), + ) + @implements(_galsim.PositionD) @register_pytree_node_class diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 2d914b19..6095a81e 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -126,6 +126,10 @@ def _run_object_checks(obj, cls, kind): # check that we can hash the object hash(obj) + elif kind == "to-from-galsim": + gs_obj = obj.to_galsim() + jgs_obj = obj.from_galsim(gs_obj) + assert jgs_obj == obj elif kind == "pickle-eval-repr-img" or kind == "pickle-eval-repr-nohash": from numpy import array # noqa: F401 @@ -342,6 +346,7 @@ def _reg_fun(p): "tree_flatten", "tree_unflatten", "from_galsim", + "to_galsim", ]: # this deprecated method doesn't have consistent doc strings in galsim if ( @@ -497,6 +502,7 @@ def _reg_sfun(g1): def test_api_bounds(obj): _run_object_checks(obj, obj.__class__, "docs-methods") _run_object_checks(obj, obj.__class__, "pickle-eval-repr") + _run_object_checks(obj, obj.__class__, "to-from-galsim") # JAX tracing should be an identity assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj @@ -550,6 +556,7 @@ def _reg_sfun(g1): def test_api_position(obj): _run_object_checks(obj, obj.__class__, "docs-methods") _run_object_checks(obj, obj.__class__, "pickle-eval-repr") + _run_object_checks(obj, obj.__class__, "to-from-galsim") # JAX tracing should be an identity assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj