Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions jax_galsim/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,25 @@ def from_galsim(cls, galsim_bounds):
else:
return _cls()

def to_galsim(self):
"""Create a galsim `BoundsD/I` from a `jax_galsim.BoundsD/I` object."""
if isinstance(self, BoundsI):
gs_class = _galsim.bounds.BoundsI
cast = int
else:
gs_class = _galsim.bounds.BoundsD
cast = float

if self.isDefined():
return gs_class(
cast(self.xmin),
cast(self.xmax),
cast(self.ymin),
cast(self.ymax),
)
else:
return gs_class()


@implements(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
Expand Down
14 changes: 14 additions & 0 deletions jax_galsim/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,20 @@ def from_galsim(cls, galsim_position):
)
return _cls(galsim_position.x, galsim_position.y)

def to_galsim(self):
"""Create a galsim `PositionD/I` from a `jax_galsim.PositionD/I` object."""
if isinstance(self, PositionI):
gs_class = _galsim.bounds.PositionI
cast = int
else:
gs_class = _galsim.bounds.PositionD
cast = float

return gs_class(
cast(self.x),
cast(self.y),
)


@implements(_galsim.PositionD)
@register_pytree_node_class
Expand Down
7 changes: 7 additions & 0 deletions tests/jax/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def _run_object_checks(obj, cls, kind):

# check that we can hash the object
hash(obj)
elif kind == "to-from-galsim":
gs_obj = obj.to_galsim()
jgs_obj = obj.from_galsim(gs_obj)
assert jgs_obj == obj
elif kind == "pickle-eval-repr-img" or kind == "pickle-eval-repr-nohash":
from numpy import array # noqa: F401

Expand Down Expand Up @@ -342,6 +346,7 @@ def _reg_fun(p):
"tree_flatten",
"tree_unflatten",
"from_galsim",
"to_galsim",
]:
# this deprecated method doesn't have consistent doc strings in galsim
if (
Expand Down Expand Up @@ -497,6 +502,7 @@ def _reg_sfun(g1):
def test_api_bounds(obj):
_run_object_checks(obj, obj.__class__, "docs-methods")
_run_object_checks(obj, obj.__class__, "pickle-eval-repr")
_run_object_checks(obj, obj.__class__, "to-from-galsim")

# JAX tracing should be an identity
assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj
Expand Down Expand Up @@ -550,6 +556,7 @@ def _reg_sfun(g1):
def test_api_position(obj):
_run_object_checks(obj, obj.__class__, "docs-methods")
_run_object_checks(obj, obj.__class__, "pickle-eval-repr")
_run_object_checks(obj, obj.__class__, "to-from-galsim")

# JAX tracing should be an identity
assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj
Expand Down