diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index a7d1230a..ed513315 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -391,6 +391,15 @@ def tree_unflatten(cls, aux_data, children): ret._rad = children[0] return ret + @staticmethod + def from_galsim(gs_angle): + """Create a jax_galsim `Angle` from a `galsim.Angle` object.""" + return _Angle(gs_angle._rad) + + def to_galsim(self): + """Create a galsim `Angle` from a `jax_galsim.Angle` object.""" + return _galsim.angle._Angle(float(self._rad)) + @implements(_galsim._Angle) def _Angle(theta): diff --git a/jax_galsim/celestial.py b/jax_galsim/celestial.py index 1e4f346f..1b6e992f 100644 --- a/jax_galsim/celestial.py +++ b/jax_galsim/celestial.py @@ -924,8 +924,15 @@ def _precess(from_epoch, to_epoch, _ra, _dec): @staticmethod def from_galsim(gcoord): + """Create a jax_galsim `CelestialCoord` from a `galsim.CelestialCoord` object.""" return _CelestialCoord(_Angle(gcoord.ra.rad), _Angle(gcoord.dec.rad)) + def to_galsim(self): + """Create a galsim `CelestialCoord` from a `jax_galsim.CelestialCoord` object.""" + return _galsim.celestial.CelestialCoord( + self.ra.to_galsim(), self.dec.to_galsim() + ) + @implements(_coord._CelestialCoord) def _CelestialCoord(ra, dec): diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 5af34c41..074a762e 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -292,8 +292,13 @@ def tree_unflatten(cls, aux_data, children): @classmethod def from_galsim(cls, galsim_shear): + """Create a jax_galsim `Shear` from a `galsim.Shear` object.""" return cls(g1=galsim_shear.g1, g2=galsim_shear.g2) + def to_galsim(self): + """Create a galsim `Shear` from a `jax_galsim.Shear` object.""" + return _galsim.Shear(g1=float(self.g1), g2=float(self.g2)) + @implements(_galsim._Shear) def _Shear(shear): diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index 3a177d0b..fcddf18e 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -1,5 +1,6 @@ import galsim as _galsim import jax.numpy as jnp +import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.angle import AngleUnit, arcsec, radians @@ -265,6 +266,57 @@ def from_galsim(cls, galsim_wcs): ], ) + def to_galsim(self): + """Create a galsim WCS object from a jax_galsim WCS object.""" + # keep this import here to avoid circular imports + from jax_galsim.fitswcs import GSFitsWCS + + if isinstance(self, PixelScale): + return _galsim.PixelScale(float(self.scale)) + elif isinstance(self, ShearWCS): + return _galsim.ShearWCS(float(self.scale), self.shear.to_galsim()) + elif isinstance(self, JacobianWCS): + return _galsim.JacobianWCS( + float(self.dudx), + float(self.dudy), + float(self.dvdx), + float(self.dvdy), + ) + elif isinstance(self, OffsetWCS): + return _galsim.OffsetWCS( + float(self.scale), + origin=self.origin.to_galsim(), + world_origin=self.world_origin.to_galsim(), + ) + elif isinstance(self, OffsetShearWCS): + return _galsim.OffsetShearWCS( + float(self.scale), + self.shear.to_galsim(), + origin=self.origin.to_galsim(), + world_origin=self.world_origin.to_galsim(), + ) + elif isinstance(self, AffineTransform): + return _galsim.AffineTransform( + float(self.dudx), + float(self.dudy), + float(self.dvdx), + float(self.dvdy), + origin=self.origin.to_galsim(), + world_origin=self.world_origin.to_galsim(), + ) + elif isinstance(self, GSFitsWCS): + return _galsim.GSFitsWCS( + _data=[ + self.wcs_type, + np.asarray(self.crpix), + np.asarray(self.cd), + self.center.to_galsim(), + np.asarray(self.pv) if self.pv is not None else None, + np.asarray(self.ab) if self.ab is not None else None, + np.asarray(self.abp) if self.abp is not None else None, + ], + ) + ######################################################################################### # diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 6095a81e..f7445a6b 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -458,6 +458,7 @@ def test_api_gsobject(kind): def test_api_shear(obj): _run_object_checks(obj, jax_galsim.Shear, "docs-methods") _run_object_checks(obj, jax_galsim.Shear, "pickle-eval-repr") + _run_object_checks(obj, jax_galsim.Shear, "to-from-galsim") def _reg_sfun(g1): return (jax_galsim.Shear(g1=g1, g2=0.2) + jax_galsim.Shear(g1=g1, g2=-0.1)).eta1 @@ -734,6 +735,7 @@ def test_api_wcs(): tested.add(cls.__name__) _run_object_checks(obj, cls, "docs-methods") _run_object_checks(obj, cls, "pickle-eval-repr-wcs") + _run_object_checks(obj, cls, "to-from-galsim") if isinstance(obj, jax_galsim.wcs.CelestialWCS): _run_object_checks(obj, cls, "vmap-jit-grad-celestialwcs") else: @@ -788,6 +790,7 @@ def test_api_angle(): obj = jax_galsim.Angle(jnp.array(0.1) * jax_galsim.degrees) _run_object_checks(obj, obj.__class__, "docs-methods") _run_object_checks(obj, obj.__class__, "pickle-eval-repr") + _run_object_checks(obj, obj.__class__, "to-from-galsim") # JAX tracing should be an identity assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj @@ -825,6 +828,7 @@ def test_api_celestial_coord(): obj = jax_galsim.CelestialCoord(45 * jax_galsim.degrees, -30 * jax_galsim.degrees) _run_object_checks(obj, obj.__class__, "docs-methods") _run_object_checks(obj, obj.__class__, "pickle-eval-repr") + _run_object_checks(obj, obj.__class__, "to-from-galsim") # JAX tracing should be an identity assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj