From af7bbc1e9ffd6e5835bab6d16efad640c7262e11 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 8 Feb 2026 09:17:38 -0600 Subject: [PATCH 1/4] feat: add to_galsim method on shear --- jax_galsim/shear.py | 4 ++++ tests/jax/test_api.py | 1 + 2 files changed, 5 insertions(+) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 5af34c41..f2b32a94 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -294,6 +294,10 @@ def tree_unflatten(cls, aux_data, children): def from_galsim(cls, galsim_shear): return cls(g1=galsim_shear.g1, g2=galsim_shear.g2) + def to_galsim(self): + """Create a galsim `Shear` from a `jax_galsim.Shear` object.""" + return _galsim.Shear(g1=float(self.g1), g2=float(self.g2)) + @implements(_galsim._Shear) def _Shear(shear): diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 6095a81e..8957abe1 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -458,6 +458,7 @@ def test_api_gsobject(kind): def test_api_shear(obj): _run_object_checks(obj, jax_galsim.Shear, "docs-methods") _run_object_checks(obj, jax_galsim.Shear, "pickle-eval-repr") + _run_object_checks(obj, obj.__class__, "to-from-galsim") def _reg_sfun(g1): return (jax_galsim.Shear(g1=g1, g2=0.2) + jax_galsim.Shear(g1=g1, g2=-0.1)).eta1 From 6831e79027020845255df76307abbfd8deee64f6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 8 Feb 2026 09:43:09 -0600 Subject: [PATCH 2/4] feat: add conversions for angle and celestial coords --- jax_galsim/angle.py | 9 +++++++++ jax_galsim/celestial.py | 7 +++++++ jax_galsim/shear.py | 1 + tests/jax/test_api.py | 2 ++ 4 files changed, 19 insertions(+) diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index a7d1230a..ed513315 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -391,6 +391,15 @@ def tree_unflatten(cls, aux_data, children): ret._rad = children[0] return ret + @staticmethod + def from_galsim(gs_angle): + """Create a jax_galsim `Angle` from a `galsim.Angle` object.""" + return _Angle(gs_angle._rad) + + def to_galsim(self): + """Create a galsim `Angle` from a `jax_galsim.Angle` object.""" + return _galsim.angle._Angle(float(self._rad)) + @implements(_galsim._Angle) def _Angle(theta): diff --git a/jax_galsim/celestial.py b/jax_galsim/celestial.py index 1e4f346f..1b6e992f 100644 --- a/jax_galsim/celestial.py +++ b/jax_galsim/celestial.py @@ -924,8 +924,15 @@ def _precess(from_epoch, to_epoch, _ra, _dec): @staticmethod def from_galsim(gcoord): + """Create a jax_galsim `CelestialCoord` from a `galsim.CelestialCoord` object.""" return _CelestialCoord(_Angle(gcoord.ra.rad), _Angle(gcoord.dec.rad)) + def to_galsim(self): + """Create a galsim `CelestialCoord` from a `jax_galsim.CelestialCoord` object.""" + return _galsim.celestial.CelestialCoord( + self.ra.to_galsim(), self.dec.to_galsim() + ) + @implements(_coord._CelestialCoord) def _CelestialCoord(ra, dec): diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index f2b32a94..074a762e 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -292,6 +292,7 @@ def tree_unflatten(cls, aux_data, children): @classmethod def from_galsim(cls, galsim_shear): + """Create a jax_galsim `Shear` from a `galsim.Shear` object.""" return cls(g1=galsim_shear.g1, g2=galsim_shear.g2) def to_galsim(self): diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 8957abe1..3a4b7fc4 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -789,6 +789,7 @@ def test_api_angle(): obj = jax_galsim.Angle(jnp.array(0.1) * jax_galsim.degrees) _run_object_checks(obj, obj.__class__, "docs-methods") _run_object_checks(obj, obj.__class__, "pickle-eval-repr") + _run_object_checks(obj, obj.__class__, "to-from-galsim") # JAX tracing should be an identity assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj @@ -826,6 +827,7 @@ def test_api_celestial_coord(): obj = jax_galsim.CelestialCoord(45 * jax_galsim.degrees, -30 * jax_galsim.degrees) _run_object_checks(obj, obj.__class__, "docs-methods") _run_object_checks(obj, obj.__class__, "pickle-eval-repr") + _run_object_checks(obj, obj.__class__, "to-from-galsim") # JAX tracing should be an identity assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj From a378ef5c967f0876129c861752eb0814e243884d Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 8 Feb 2026 12:08:36 -0600 Subject: [PATCH 3/4] feat: add to_galsim method for WCS classes --- jax_galsim/wcs.py | 52 +++++++++++++++++++++++++++++++++++++++++++ tests/jax/test_api.py | 1 + 2 files changed, 53 insertions(+) diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index 3a177d0b..fcddf18e 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -1,5 +1,6 @@ import galsim as _galsim import jax.numpy as jnp +import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.angle import AngleUnit, arcsec, radians @@ -265,6 +266,57 @@ def from_galsim(cls, galsim_wcs): ], ) + def to_galsim(self): + """Create a galsim WCS object from a jax_galsim WCS object.""" + # keep this import here to avoid circular imports + from jax_galsim.fitswcs import GSFitsWCS + + if isinstance(self, PixelScale): + return _galsim.PixelScale(float(self.scale)) + elif isinstance(self, ShearWCS): + return _galsim.ShearWCS(float(self.scale), self.shear.to_galsim()) + elif isinstance(self, JacobianWCS): + return _galsim.JacobianWCS( + float(self.dudx), + float(self.dudy), + float(self.dvdx), + float(self.dvdy), + ) + elif isinstance(self, OffsetWCS): + return _galsim.OffsetWCS( + float(self.scale), + origin=self.origin.to_galsim(), + world_origin=self.world_origin.to_galsim(), + ) + elif isinstance(self, OffsetShearWCS): + return _galsim.OffsetShearWCS( + float(self.scale), + self.shear.to_galsim(), + origin=self.origin.to_galsim(), + world_origin=self.world_origin.to_galsim(), + ) + elif isinstance(self, AffineTransform): + return _galsim.AffineTransform( + float(self.dudx), + float(self.dudy), + float(self.dvdx), + float(self.dvdy), + origin=self.origin.to_galsim(), + world_origin=self.world_origin.to_galsim(), + ) + elif isinstance(self, GSFitsWCS): + return _galsim.GSFitsWCS( + _data=[ + self.wcs_type, + np.asarray(self.crpix), + np.asarray(self.cd), + self.center.to_galsim(), + np.asarray(self.pv) if self.pv is not None else None, + np.asarray(self.ab) if self.ab is not None else None, + np.asarray(self.abp) if self.abp is not None else None, + ], + ) + ######################################################################################### # diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 3a4b7fc4..c9bf7170 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -735,6 +735,7 @@ def test_api_wcs(): tested.add(cls.__name__) _run_object_checks(obj, cls, "docs-methods") _run_object_checks(obj, cls, "pickle-eval-repr-wcs") + _run_object_checks(obj, cls, "to-from-galsim") if isinstance(obj, jax_galsim.wcs.CelestialWCS): _run_object_checks(obj, cls, "vmap-jit-grad-celestialwcs") else: From 8c40ea5a2fd45e3cdee77afc17fe09ce7327b174 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Sun, 8 Feb 2026 12:09:52 -0600 Subject: [PATCH 4/4] Apply suggestion from @beckermr --- tests/jax/test_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index c9bf7170..f7445a6b 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -458,7 +458,7 @@ def test_api_gsobject(kind): def test_api_shear(obj): _run_object_checks(obj, jax_galsim.Shear, "docs-methods") _run_object_checks(obj, jax_galsim.Shear, "pickle-eval-repr") - _run_object_checks(obj, obj.__class__, "to-from-galsim") + _run_object_checks(obj, jax_galsim.Shear, "to-from-galsim") def _reg_sfun(g1): return (jax_galsim.Shear(g1=g1, g2=0.2) + jax_galsim.Shear(g1=g1, g2=-0.1)).eta1