From ae833210b2ea352e222fce2c51d0c02d3ac07d62 Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Mon, 2 Dec 2024 14:43:52 -0500 Subject: [PATCH 01/20] dev hsm --- jax_galsim/__init__.py | 2 + jax_galsim/hsm.py | 618 +++++++++++++++++++++++++++++++++ tests/galsim_tests_config.yaml | 7 +- 3 files changed, 624 insertions(+), 3 deletions(-) create mode 100644 jax_galsim/hsm.py diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index 5c16c447..ea30a0df 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -99,3 +99,5 @@ # this one is specific to jax_galsim from . import core + +from . import hsm diff --git a/jax_galsim/hsm.py b/jax_galsim/hsm.py new file mode 100644 index 00000000..bc6daa78 --- /dev/null +++ b/jax_galsim/hsm.py @@ -0,0 +1,618 @@ +# Copyright (c) 2012-2023 by the GalSim developers team on GitHub +# https://github.com/GalSim-developers +# +# This file is part of GalSim: The modular galaxy image simulation toolkit. +# https://github.com/GalSim-developers/GalSim +# +# GalSim is free software: redistribution and use in source and binary forms, +# with or without modification, are permitted provided that the following +# conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions, and the disclaimer given in the accompanying LICENSE +# file. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions, and the disclaimer given in the documentation +# and/or other materials provided with the distribution. +# + +from dataclasses import dataclass + +import jax.numpy as jnp + +import galsim as _galsim +from jax_galsim.core.utils import implements +from jax_galsim.position import PositionD +from jax_galsim.bounds import BoundsI +from jax_galsim.shear import Shear +from jax_galsim.image import Image, ImageI, ImageF, ImageD +from jax_galsim.errors import GalSimValueError, GalSimHSMError, GalSimIncompatibleValuesError +from jax_galsim.core.utils import cast_to_float, cast_to_int + +@implements(_galsim.hsm.ShapeData) +class ShapeData: + def __init__(self, image_bounds=BoundsI(), moments_status=-1, + observed_shape=Shear(), moments_sigma=-1.0, moments_amp=-1.0, + moments_centroid=PositionD(), moments_rho4=-1.0, moments_n_iter=0, + correction_status=-1, corrected_e1=-10., corrected_e2=-10., + corrected_g1=-10., corrected_g2=-10., meas_type="None", + corrected_shape_err=-1.0, correction_method="None", + resolution_factor=-1.0, psf_sigma=-1.0, + psf_shape=Shear(), error_message=""): + + # from https://github.com/GalSim-developers/GalSim/blob/releases/2.5/include/galsim/hsm/PSFCorr.h#L281 + # This representation of an object shape contains information about observed shapes and shape + # estimators after PSF correction. It also contains information about what PSF correction was + # used; if no PSF correction was carried out and only the observed moments were measured, the + # PSF correction method will be 'None'. Note that observed shapes are bounded to lie in the + # range |e| < 1 or |g| < 1, so they can be represented using a Shear object. In contrast, + # the PSF-corrected distortions and shears are not bounded at a maximum of 1 since they are + # shear estimators, and placing such a bound would bias the mean. Thus, the corrected results + # are not represented using Shear objects, since it may not be possible to make a meaningful + # per-object conversion from distortion to shear (e.g., if |e|>1). + + # Avoid empty string, which can caus problems in C++ layer. + if error_message == "": error_message = "None" + + if not isinstance(image_bounds, BoundsI): + raise TypeError("image_bounds must be a BoundsI instance") + + # The others will raise an appropriate TypeError from the call to _galsim.ShapeData + # when converting to int, float, etc. + # self._data = _galsim.ShapeData( + # image_bounds._b, int(moments_status), observed_shape.e1, observed_shape.e2, + # float(moments_sigma), float(moments_amp), moments_centroid._p, + # float(moments_rho4), int(moments_n_iter), int(correction_status), + # float(corrected_e1), float(corrected_e2), float(corrected_g1), float(corrected_g2), + # str(meas_type), float(corrected_shape_err), str(correction_method), + # float(resolution_factor), float(psf_sigma), psf_shape.e1, psf_shape.e2, + # str(error_message)) + + self._image_bounds = image_bounds + self._moments_status = cast_to_int(moments_status) + self._observed_e1 = observed_shape.e1 + self._observed_e2 = observed_shape.e2 + self._moments_sigma = cast_to_float(moments_sigma) + self._moments_amp = cast_to_float(moments_amp) + self._moments_centroid = moments_centroid + self._moments_rho4 = cast_to_float(moments_rho4) + self._moments_n_iter = cast_to_int(moments_n_iter) + self._correction_status = cast_to_int(correction_status) + self._corrected_e1 = cast_to_float(corrected_e1) + self._corrected_e2 = cast_to_float(corrected_e2) + self._corrected_g1 = cast_to_float(corrected_g1) + self._corrected_g2 = cast_to_float(corrected_g2) + self._meas_type = meas_type + self._corrected_shape_err = cast_to_float(corrected_shape_err) + self._correction_method = correction_method + self._resolution_factor = cast_to_float(resolution_factor) + self._psf_sigma = cast_to_float(psf_sigma) + self._psf_e1 = psf_shape.e1 + self._psf_e2 = psf_shape.e2 + self._error_message = error_message + + @property + def image_bounds(self): return BoundsI(self._image_bounds) + @property + def moments_status(self): return self._moments_status + + @property + def observed_e1(self): + return self._observed_e1 + + @property + def observed_e2(self): + return self._observed_e2 + + @property + def observed_shape(self): + return Shear(e1=self.observed_e1, e2=self.observed_e2) + + @property + def moments_sigma(self): return self._moments_sigma + @property + def moments_amp(self): return self._moments_amp + @property + def moments_centroid(self): return PositionD(self._moments_centroid) + @property + def moments_rho4(self): return self._moments_rho4 + @property + def moments_n_iter(self): return self._moments_n_iter + @property + def correction_status(self): return self._correction_status + @property + def corrected_e1(self): return self._corrected_e1 + @property + def corrected_e2(self): return self._corrected_e2 + @property + def corrected_g1(self): return self._corrected_g1 + @property + def corrected_g2(self): return self._corrected_g2 + @property + def meas_type(self): return self._meas_type + @property + def corrected_shape_err(self): return self._corrected_shape_err + @property + def correction_method(self): return self._correction_method + @property + def resolution_factor(self): return self._resolution_factor + @property + def psf_sigma(self): return self._psf_sigma + + @property + def psf_shape(self): + return Shear(e1=self._psf_e1, e2=self._psf_e2) + + @property + def error_message(self): + # We use "None" in C++ ShapeData to indicate no error messages to avoid problems on + # (some) Macs using zero-length strings. Here, we revert that back to "". + if self._error_message == "None": + return "" + else: + return self._error_message + + def __repr__(self): + s = 'galsim.hsm.ShapeData(' + if self.image_bounds.isDefined(): s += 'image_bounds=%r, '%self.image_bounds + if self.moments_status != -1: s += 'moments_status=%r, '%self.moments_status + # Always include this one: + s += 'observed_shape=%r'%self.observed_shape + if self.moments_sigma != -1: s += ', moments_sigma=%r'%self.moments_sigma + if self.moments_amp != -1: s += ', moments_amp=%r'%self.moments_amp + if self.moments_centroid != PositionD(): + s += ', moments_centroid=%r'%self.moments_centroid + if self.moments_rho4 != -1: s += ', moments_rho4=%r'%self.moments_rho4 + if self.moments_n_iter != 0: s += ', moments_n_iter=%r'%self.moments_n_iter + if self.correction_status != -1: s += ', correction_status=%r'%self.correction_status + if self.corrected_e1 != -10.: s += ', corrected_e1=%r'%self.corrected_e1 + if self.corrected_e2 != -10.: s += ', corrected_e2=%r'%self.corrected_e2 + if self.corrected_g1 != -10.: s += ', corrected_g1=%r'%self.corrected_g1 + if self.corrected_g2 != -10.: s += ', corrected_g2=%r'%self.corrected_g2 + if self.meas_type != 'None': s += ', meas_type=%r'%self.meas_type + if self.corrected_shape_err != -1.: + s += ', corrected_shape_err=%r'%self.corrected_shape_err + if self.correction_method != 'None': s += ', correction_method=%r'%self.correction_method + if self.resolution_factor != -1.: s += ', resolution_factor=%r'%self.resolution_factor + if self.psf_sigma != -1.: s += ', psf_sigma=%r'%self.psf_sigma + if self.psf_shape != Shear(): s += ', psf_shape=%r'%self.psf_shape + if self.error_message != "": s += ', error_message=%r'%self.error_message + s += ')' + return s + + def __eq__(self, other): + return (self is other or + (isinstance(other,ShapeData) and self._getinitargs() == other._getinitargs())) + def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): return hash(("galsim.hsm.ShapeData", self._getinitargs())) + + def _getinitargs(self): + return (self.image_bounds, self.moments_status, self.observed_shape, + self.moments_sigma, self.moments_amp, self.moments_centroid, self.moments_rho4, + self.moments_n_iter, self.correction_status, self.corrected_e1, self.corrected_e2, + self.corrected_g1, self.corrected_g2, self.meas_type, self.corrected_shape_err, + self.correction_method, self.resolution_factor, self.psf_sigma, + self.psf_shape, self.error_message) + + def __getstate__(self): + return self._getinitargs() + + def __setstate__(self, state): + self.__init__(*state) + + @implements(_galsim.hsm.ShapeData.applyWCS) + def applyWCS(self, wcs, image_pos): + jac = wcs.jacobian(image_pos=image_pos) + scale, shear, theta, flip = jac.getDecomposition() + + # Fix moments_sigma + moments_sigma = self.moments_sigma * scale + + # Fix observed_shape + shape = self.observed_shape + # First the flip, if any. + if flip: + shape = Shear(g1 = -shape.g1, g2 = shape.g2) + # Next the rotation + shape = Shear(g = shape.g, beta = shape.beta + theta) + # Finally the shear + observed_shape = shear + shape + + # Fix moments_centroid + moments_centroid = jac.toWorld(self.moments_centroid) - jac.toWorld(image_pos) + + return ShapeData(image_bounds=self.image_bounds, + moments_status=self.moments_status, + observed_shape=observed_shape, + moments_sigma=moments_sigma, + moments_amp=self.moments_amp, + moments_centroid=moments_centroid, + moments_rho4=self.moments_rho4, + moments_n_iter=self.moments_n_iter, + error_message=self.error_message) + # The other values are reset to the defaults, since they are + # results from EstimateShear. + + +@implements(_galsim.hsm.HSMParams) +# @dataclass(repr=False) +class HSMParams: + nsig_rg: float = 3.0 + nsig_rg2: float = 3.6 + regauss_too_small: int = 1 + adapt_order: int = 2 + convergence_threshold: float = 1.e-6 + max_mom2_iter: int = 400 + num_iter_default: int = -1 + bound_correct_wt: float = 0.25 + max_amoment: float = 8000. + max_ashift: float = 15. + ksb_moments_max: int = 4 + ksb_sig_weight: float = 0.0 + ksb_sig_factor: float = 1.0 + failed_moments: float = -1000 + + def _getinitargs(self): + # TODO: For now, leave 3rd param as unused max_moment_nsig2. + # Remove it at version 3.0 to avoid changing C++ API yet. + return (self.nsig_rg, self.nsig_rg2, 0., self.regauss_too_small, + self.adapt_order, self.convergence_threshold, self.max_mom2_iter, + self.num_iter_default, self.bound_correct_wt, self.max_amoment, self.max_ashift, + self.ksb_moments_max, self.ksb_sig_weight, self.ksb_sig_factor, + self.failed_moments) + + @property + def nsig_rg(self): return self._nsig_rg + @property + def nsig_rg2(self): return self._nsig_rg2 + @property + def regauss_too_small(self): return self._regauss_too_small + @property + def adapt_order(self): return self._adapt_order + @property + def convergence_threshold(self): return self._convergence_threshold + @property + def max_mom2_iter(self): return self._max_mom2_iter + @property + def num_iter_default(self): return self._num_iter_default + @property + def bound_correct_wt(self): return self._bound_correct_wt + @property + def max_amoment(self): return self._max_amoment + @property + def max_ashift(self): return self._max_ashift + @property + def ksb_moments_max(self): return self._ksb_moments_max + @property + def ksb_sig_weight(self): return self._ksb_sig_weight + @property + def ksb_sig_factor(self): return self._ksb_sig_factor + @property + def failed_moments(self): return self._failed_moments + + @staticmethod + def check(hsmparams, default=None): + """Checks that hsmparams is either a valid HSMParams instance or None. + + In the former case, it returns hsmparams, in the latter it returns default + (HSMParams.default if no other default specified). + """ + if hsmparams is None: + return default if default is not None else HSMParams.default + elif not isinstance(hsmparams, HSMParams): + raise TypeError("Invalid HSMParams: %s"%hsmparams) + else: + return hsmparams + + def __repr__(self): + return ('galsim.hsm.HSMParams(' + 14*'%r,' + '%r)')%self._getinitargs() + + def __eq__(self, other): + return (self is other or + (isinstance(other, HSMParams) and self._getinitargs() == other._getinitargs())) + def __ne__(self, other): + return not self.__eq__(other) + def __hash__(self): + return hash(('galsim.hsm.HSMParams', self._getinitargs())) + + def __getstate__(self): + d = self.__dict__.copy() + del d['_hsmp'] + return d + + def __setstate__(self, d): + self.__dict__ = d + self._make_hsmp() + +# We use the default a lot, so make it a class attribute. +HSMParams.default = HSMParams() + + +# A helper function that checks if the weight and the badpix bounds are +# consistent with that of the image, and that the weight is non-negative. +def _checkWeightAndBadpix(image, weight=None, badpix=None): + # Check that the weight and badpix, if given, are sensible and compatible + # with the image. + if weight is not None: + if weight.bounds != image.bounds: + raise GalSimIncompatibleValuesError( + "Weight image does not have same bounds as the input Image.", + weight=weight, image=image) + # also make sure there are no negative values + + if jnp.any(weight.array < 0): + raise GalSimValueError("Weight image cannot contain negative values.", weight) + + if badpix is not None and badpix.bounds != image.bounds: + raise GalSimIncompatibleValuesError( + "Badpix image does not have the same bounds as the input Image.", + badpix=badpix, image=image) + + +# A helper function for taking input weight and badpix Images, and returning a weight Image in the +# format that the C++ functions want +def _convertMask(image, weight=None, badpix=None): + # Convert from input weight and badpix images to a single mask image needed by C++ functions. + # This is used by EstimateShear() and FindAdaptiveMom(). + + # if no weight image was supplied, make an int array (same size as gal image) filled with 1's + if weight is None: + mask = ImageI(bounds=image.bounds, init_value=1) + else: + # if weight is an ImageI, then we can use it as the mask image: + if weight.dtype == jnp.int32: + if not badpix: + mask = weight + else: + # If we need to mask bad pixels, we'll need a copy anyway. + mask = ImageI(weight) + + # otherwise, we need to convert it to the right type + else: + mask = ImageI(bounds=image.bounds, init_value=0) + mask.array[weight.array > 0.] = 1 + + # if badpix image was supplied, identify the nonzero (bad) pixels and set them to zero in weight + # image; also check bounds + if badpix is not None: + mask.array[badpix.array != 0] = 0 + + # if no pixels are used, raise an exception + if not jnp.any(mask.array): + raise GalSimHSMError("No pixels are being used!") + + # finally, return the Image for the weight map + return mask + + +# A simpler helper function to force images to be of type ImageF or ImageD +def _convertImage(image): + # Convert the given image to the correct format needed to pass to the C++ layer. + # This is used by EstimateShear() and FindAdaptiveMom(). + + # if weight is not of type float/double, convert to float/double + if (image.dtype == jnp.int16 or image.dtype == jnp.uint16): + image = ImageF(image) + elif (image.dtype == jnp.int32 or image.dtype == jnp.uint32): + image = ImageD(image) + + return image + +@implements(_galsim.hsm.EstimateShear) +def EstimateShear(gal_image, PSF_image, weight=None, badpix=None, sky_var=0.0, + shear_est="REGAUSS", recompute_flux="FIT", guess_sig_gal=5.0, + guess_sig_PSF=3.0, precision=1.0e-6, guess_centroid=None, + strict=True, check=True, hsmparams=None): + gal_image = _convertImage(gal_image) + PSF_image = _convertImage(PSF_image) + hsmparams = HSMParams.check(hsmparams) + if check: + _checkWeightAndBadpix(gal_image, weight=weight, badpix=badpix) + weight = _convertMask(gal_image, weight=weight, badpix=badpix) + + if guess_centroid is None: + guess_centroid = gal_image.true_center + try: + result = ShapeData() + EstimateShearView(result._data, + gal_image._image, PSF_image._image, weight._image, + float(sky_var), shear_est.upper(), recompute_flux.upper(), + float(guess_sig_gal), float(guess_sig_PSF), float(precision), + guess_centroid._p, hsmparams._hsmp) + return result + except RuntimeError as err: + if (strict == True): + raise GalSimHSMError(str(err)) from None + else: + return ShapeData(error_message = str(err)) + +@implements(_galsim.hsm.FindAdaptiveMom) +def FindAdaptiveMom(object_image, weight=None, badpix=None, guess_sig=5.0, precision=1.0e-6, + guess_centroid=None, strict=True, check=True, round_moments=False, hsmparams=None, + use_sky_coords=False): + """Measure adaptive moments of an object. + + This method estimates the best-fit elliptical Gaussian to the object (see Hirata & Seljak 2003 + for more discussion of adaptive moments). This elliptical Gaussian is computed iteratively + by initially guessing a circular Gaussian that is used as a weight function, computing the + weighted moments, recomputing the moments using the result of the previous step as the weight + function, and so on until the moments that are measured are the same as those used for the + weight function. `FindAdaptiveMom` can be used either as a free function, or as a method of the + `Image` class. + + By default, this routine computes moments in pixel coordinates, which generally use (x,y) + for the coordinate variables, so the underlying second moments are Ixx, Iyy, and Ixy. + If the WCS is (at least approximately) just a `PixelScale`, then this scale can be applied to + convert the moments' units from pixels to arcsec. The derived shapes are unaffected by + the pixel scale. + + However, there is also an option to apply a non-trivial WCS, which may potentially rotate + and/or shear the (x,y) moments to the local sky coordinates, which generally use (u,v) + for the coordinate variables. These coordinates are measured in arcsec and are oriented + such that +v is towards North and +u is towards West. In this case, the returned values are + all in arcsec, and are based instead on Iuu, Ivv, and Iuv. To enable this feature, use + ``use_sky_coords=True``. See also the method `ShapeData.applyWCS` for more details. + + .. note:: + + The application of the WCS implicitly assumes that the WCS is locally uniform across the + size of the object being measured. This is normally a very good approximation for most + applications of interest. + + Like `EstimateShear`, `FindAdaptiveMom` works on `Image` inputs, and fails if the object is + small compared to the pixel scale. For more details, see `EstimateShear`. + + Example:: + + >>> my_gaussian = galsim.Gaussian(flux=1.0, sigma=1.0) + >>> my_gaussian_image = my_gaussian.drawImage(scale=0.2, method='no_pixel') + >>> my_moments = galsim.hsm.FindAdaptiveMom(my_gaussian_image) + + or:: + + >>> my_moments = my_gaussian_image.FindAdaptiveMom() + + Assuming a successful measurement, the most relevant pieces of information are + ``my_moments.moments_sigma``, which is ``|det(M)|^(1/4)`` (= ``sigma`` for a circular Gaussian) + and ``my_moments.observed_shape``, which is a `Shear`. In this case, + ``my_moments.moments_sigma`` is precisely 5.0 (in units of pixels), and + ``my_moments.observed_shape`` is consistent with zero. + + Methods of the `Shear` class can be used to get the distortion ``e``, the shear ``g``, the + conformal shear ``eta``, and so on. + + As an example of how to use the optional ``hsmparams`` argument, consider cases where the input + images have unusual properties, such as being very large. This could occur when measuring the + properties of a very over-sampled image such as that generated using:: + + >>> my_gaussian = galsim.Gaussian(sigma=5.0) + >>> my_gaussian_image = my_gaussian.drawImage(scale=0.01, method='no_pixel') + + If the user attempts to measure the moments of this very large image using the standard syntax, + :: + + >>> my_moments = my_gaussian_image.FindAdaptiveMom() + + then the result will be a ``GalSimHSMError`` due to moment measurement failing because the + object is so large. While the list of all possible settings that can be changed is accessible + in the docstring of the `HSMParams` class, in this case we need to modify ``max_amoment`` which + is the maximum value of the moments in units of pixel^2. The following measurement, using the + default values for every parameter except for ``max_amoment``, will be + successful:: + + >>> new_params = galsim.hsm.HSMParams(max_amoment=5.0e5) + >>> my_moments = my_gaussian_image.FindAdaptiveMom(hsmparams=new_params) + + Parameters: + object_image: The `Image` for the object being measured. + weight: The optional weight image for the object being measured. Can be an int + or a float array. Currently, GalSim does not account for the variation + in non-zero weights, i.e., a weight map is converted to an image with 0 + and 1 for pixels that are not and are used. Full use of spatial + variation in non-zero weights will be included in a future version of + the code. [default: None] + badpix: The optional bad pixel mask for the image being used. Zero should be + used for pixels that are good, and any nonzero value indicates a bad + pixel. [default: None] + guess_sig: Optional argument with an initial guess for the Gaussian sigma of the + object (in pixels). [default: 5.0] + precision: The convergence criterion for the moments. [default: 1e-6] + guess_centroid: An initial guess for the object centroid (useful in case it is not + located at the center, which is used if this keyword is not set). The + convention for centroids is such that the center of the lower-left pixel + is (image.xmin, image.ymin). + [default: object_image.true_center] + strict: Whether to require success. If ``strict=True``, then there will be a + ``GalSimHSMError`` exception if shear estimation fails. If set to + ``False``, then information about failures will be silently stored in + the output ShapeData object. [default: True] + check: Check if the object_image, weight and badpix are in the correct format and valid. + [default: True] + round_moments: Use a circular weight function instead of elliptical. + [default: False] + hsmparams: The hsmparams keyword can be used to change the settings used by + FindAdaptiveMom when estimating moments; see `HSMParams` documentation + for more information. [default: None] + use_sky_coords: Whether to convert the measured moments to sky_coordinates. + Setting this to true is equivalent to running + ``applyWCS(object_image.wcs, image_pos=object_image.true_center)`` + on the result. [default: False] + + Returns: + a `ShapeData` object containing the results of moment measurement. + """ + # prepare inputs to C++ routines: ImageF or ImageD for galaxy, PSF, and ImageI for weight map + object_image = _convertImage(object_image) + hsmparams = HSMParams.check(hsmparams) + if check: + _checkWeightAndBadpix(object_image, weight=weight, badpix=badpix) + + weight = _convertMask(object_image, weight=weight, badpix=badpix) + + if guess_centroid is None: + guess_centroid = object_image.true_center + + try: + result = ShapeData() + FindAdaptiveMomView(result._data, + object_image._image, weight._image, + float(guess_sig), float(precision), guess_centroid._p, + bool(round_moments), hsmparams._hsmp) + + if use_sky_coords: + result = result.applyWCS(object_image.wcs, image_pos=object_image.true_center) + return result + except RuntimeError as err: + if (strict == True): + raise GalSimHSMError(str(err)) from None + else: + return ShapeData(error_message = str(err)) + +# make FindAdaptiveMom a method of Image class +Image.FindAdaptiveMom = FindAdaptiveMom + +def nonZeroBounds(image): + pass + +def MakeMaskedImage(image, mask): + b1 = image.nonZeroBounds() + b2 = mask.noneZeroBounds() + b = b1 & b2 + + masked_image = image[b] * mask[b] + + # return masked_image + return image + +def FindAdaptativeMomView(results, + object_image, + object_mask_image, + guess_sig, + precision, + guess_centroid, + round_moments, + hsmparams + ): + + tc = object_image.getBounds().trueCenter() + results.moments_centroid = jnp.where(guess_centroid!=-1000.0, + guess_centroid, + tc) + + m_xx = guess_sig*guess_sig + m_yy = m_xx + m_xy = 0. + + # Apply the mask + masked_object_image = MakeMaskedImage(object_image, object_mask_image) + + results.image_bounds = object_image.bounds + + # TODO: find_ellipmom_2 + # TODO: find_ellipmom_1 + + # def find_ellipmom_1(data, x0, y0, Mxx, Mxy, Myy, A, Bx, By, Cxx, Cxy, Cyy, rho4w, hsmparams): + # xmin = data.XMin() + # xmax = data.XMax() + # ymin = data.YMin() + # ymax = data.YMax() diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index eddafca4..c005e3bb 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -20,7 +20,8 @@ enabled_tests: - test_random.py - test_noise.py - test_image.py - - test_photon_array.py + - test_photon_array. + - test_hsm.py - "*" # means all tests from galsim coord: - test_angle.py @@ -83,7 +84,7 @@ allowed_failures: - "'Image' object has no attribute 'bin'" - "module 'jax_galsim' has no attribute 'InterpolatedKImage'" - "module 'jax_galsim' has no attribute 'CorrelatedNoise'" - - "'Image' object has no attribute 'FindAdaptiveMom'" + # - "'Image' object has no attribute 'FindAdaptiveMom'" - "CelestialCoord.precess is too slow" # cannot get jax to warmup but once it does it passes - "ValueError not raised by from_xyz" - "ValueError not raised by greatCirclePoint" @@ -107,7 +108,7 @@ allowed_failures: - "GSParams.__init__() got an unexpected keyword argument 'allowed_flux_variation'" - "module 'jax_galsim' has no attribute 'Atmosphere'" - "module 'jax_galsim' has no attribute 'RandomWalk'" - - "module 'jax_galsim' has no attribute 'hsm'" + # - "module 'jax_galsim' has no attribute 'hsm'" - "module 'jax_galsim' has no attribute 'des'" - "'Image' object has no attribute 'applyNonlinearity'" - "'Image' object has no attribute 'addReciprocityFailure'" From 1e56cc3db12cda290c38d5513d4a2bcd7ac57ed7 Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Fri, 6 Feb 2026 14:39:08 -0500 Subject: [PATCH 02/20] ShapeData running --- jax_galsim/hsm.py | 384 +++++++++++++--------------------------------- 1 file changed, 107 insertions(+), 277 deletions(-) diff --git a/jax_galsim/hsm.py b/jax_galsim/hsm.py index bc6daa78..a61e9ab0 100644 --- a/jax_galsim/hsm.py +++ b/jax_galsim/hsm.py @@ -1,24 +1,4 @@ -# Copyright (c) 2012-2023 by the GalSim developers team on GitHub -# https://github.com/GalSim-developers -# -# This file is part of GalSim: The modular galaxy image simulation toolkit. -# https://github.com/GalSim-developers/GalSim -# -# GalSim is free software: redistribution and use in source and binary forms, -# with or without modification, are permitted provided that the following -# conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions, and the disclaimer given in the accompanying LICENSE -# file. -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions, and the disclaimer given in the documentation -# and/or other materials provided with the distribution. -# - -from dataclasses import dataclass - -import jax.numpy as jnp +import numpy as np import galsim as _galsim from jax_galsim.core.utils import implements @@ -27,9 +7,25 @@ from jax_galsim.shear import Shear from jax_galsim.image import Image, ImageI, ImageF, ImageD from jax_galsim.errors import GalSimValueError, GalSimHSMError, GalSimIncompatibleValuesError -from jax_galsim.core.utils import cast_to_float, cast_to_int -@implements(_galsim.hsm.ShapeData) +HSM_LAX_DOCS = """\ +Contrary to most other classes and objects in jax-galsim, the HSM +functionality is not implemented using JAX primitives. + +All HSM-related methods directly rely on the original GalSim +implementation and therefore: + - do not run on GPU or TPU + - are not JIT-compilable + - do not benefit from JAX transformations (vmap, grad, etc.) + +As a result, all computations are performed on the CPU using classical +GalSim code, and HSM should be considered outside the JAX execution model. +""" + +@implements( + _galsim.hsm.ShapeData, + lax_description=HSM_LAX_DOCS + ) class ShapeData: def __init__(self, image_bounds=BoundsI(), moments_status=-1, observed_shape=Shear(), moments_sigma=-1.0, moments_amp=-1.0, @@ -40,17 +36,6 @@ def __init__(self, image_bounds=BoundsI(), moments_status=-1, resolution_factor=-1.0, psf_sigma=-1.0, psf_shape=Shear(), error_message=""): - # from https://github.com/GalSim-developers/GalSim/blob/releases/2.5/include/galsim/hsm/PSFCorr.h#L281 - # This representation of an object shape contains information about observed shapes and shape - # estimators after PSF correction. It also contains information about what PSF correction was - # used; if no PSF correction was carried out and only the observed moments were measured, the - # PSF correction method will be 'None'. Note that observed shapes are bounded to lie in the - # range |e| < 1 or |g| < 1, so they can be represented using a Shear object. In contrast, - # the PSF-corrected distortions and shears are not bounded at a maximum of 1 since they are - # shear estimators, and placing such a bound would bias the mean. Thus, the corrected results - # are not represented using Shear objects, since it may not be possible to make a meaningful - # per-object conversion from distortion to shear (e.g., if |e|>1). - # Avoid empty string, which can caus problems in C++ layer. if error_message == "": error_message = "None" @@ -59,98 +44,77 @@ def __init__(self, image_bounds=BoundsI(), moments_status=-1, # The others will raise an appropriate TypeError from the call to _galsim.ShapeData # when converting to int, float, etc. - # self._data = _galsim.ShapeData( - # image_bounds._b, int(moments_status), observed_shape.e1, observed_shape.e2, - # float(moments_sigma), float(moments_amp), moments_centroid._p, - # float(moments_rho4), int(moments_n_iter), int(correction_status), - # float(corrected_e1), float(corrected_e2), float(corrected_g1), float(corrected_g2), - # str(meas_type), float(corrected_shape_err), str(correction_method), - # float(resolution_factor), float(psf_sigma), psf_shape.e1, psf_shape.e2, - # str(error_message)) - - self._image_bounds = image_bounds - self._moments_status = cast_to_int(moments_status) - self._observed_e1 = observed_shape.e1 - self._observed_e2 = observed_shape.e2 - self._moments_sigma = cast_to_float(moments_sigma) - self._moments_amp = cast_to_float(moments_amp) - self._moments_centroid = moments_centroid - self._moments_rho4 = cast_to_float(moments_rho4) - self._moments_n_iter = cast_to_int(moments_n_iter) - self._correction_status = cast_to_int(correction_status) - self._corrected_e1 = cast_to_float(corrected_e1) - self._corrected_e2 = cast_to_float(corrected_e2) - self._corrected_g1 = cast_to_float(corrected_g1) - self._corrected_g2 = cast_to_float(corrected_g2) - self._meas_type = meas_type - self._corrected_shape_err = cast_to_float(corrected_shape_err) - self._correction_method = correction_method - self._resolution_factor = cast_to_float(resolution_factor) - self._psf_sigma = cast_to_float(psf_sigma) - self._psf_e1 = psf_shape.e1 - self._psf_e2 = psf_shape.e2 - self._error_message = error_message - - @property - def image_bounds(self): return BoundsI(self._image_bounds) - @property - def moments_status(self): return self._moments_status + self._data = _galsim._galsim.ShapeData( + _galsim.BoundsI(image_bounds.xmin, image_bounds.xmax, image_bounds.ymin, image_bounds.ymax)._b, + int(moments_status), float(observed_shape.e1), float(observed_shape.e2), + float(moments_sigma), float(moments_amp), + _galsim.PositionD(moments_centroid.x, moments_centroid.y)._p, + float(moments_rho4), int(moments_n_iter), int(correction_status), + float(corrected_e1), float(corrected_e2), float(corrected_g1), float(corrected_g2), + str(meas_type), float(corrected_shape_err), str(correction_method), + float(resolution_factor), float(psf_sigma), float(psf_shape.e1), float(psf_shape.e2), + str(error_message)) + + @property + def image_bounds(self): return BoundsI(self._data.image_bounds) + @property + def moments_status(self): return self._data.moments_status @property def observed_e1(self): - return self._observed_e1 + return self._data.observed_e1 @property def observed_e2(self): - return self._observed_e2 + return self._data.observed_e2 @property def observed_shape(self): return Shear(e1=self.observed_e1, e2=self.observed_e2) @property - def moments_sigma(self): return self._moments_sigma + def moments_sigma(self): return self._data.moments_sigma @property - def moments_amp(self): return self._moments_amp + def moments_amp(self): return self._data.moments_amp @property - def moments_centroid(self): return PositionD(self._moments_centroid) + def moments_centroid(self): return PositionD(self._data.moments_centroid) @property - def moments_rho4(self): return self._moments_rho4 + def moments_rho4(self): return self._data.moments_rho4 @property - def moments_n_iter(self): return self._moments_n_iter + def moments_n_iter(self): return self._data.moments_n_iter @property - def correction_status(self): return self._correction_status + def correction_status(self): return self._data.correction_status @property - def corrected_e1(self): return self._corrected_e1 + def corrected_e1(self): return self._data.corrected_e1 @property - def corrected_e2(self): return self._corrected_e2 + def corrected_e2(self): return self._data.corrected_e2 @property - def corrected_g1(self): return self._corrected_g1 + def corrected_g1(self): return self._data.corrected_g1 @property - def corrected_g2(self): return self._corrected_g2 + def corrected_g2(self): return self._data.corrected_g2 @property - def meas_type(self): return self._meas_type + def meas_type(self): return self._data.meas_type @property - def corrected_shape_err(self): return self._corrected_shape_err + def corrected_shape_err(self): return self._data.corrected_shape_err @property - def correction_method(self): return self._correction_method + def correction_method(self): return self._data.correction_method @property - def resolution_factor(self): return self._resolution_factor + def resolution_factor(self): return self._data.resolution_factor @property - def psf_sigma(self): return self._psf_sigma + def psf_sigma(self): return self._data.psf_sigma @property def psf_shape(self): - return Shear(e1=self._psf_e1, e2=self._psf_e2) + return Shear(e1=self._data.psf_e1, e2=self._data.psf_e2) @property def error_message(self): # We use "None" in C++ ShapeData to indicate no error messages to avoid problems on # (some) Macs using zero-length strings. Here, we revert that back to "". - if self._error_message == "None": + if self._data.error_message == "None": return "" else: - return self._error_message + return self._data.error_message def __repr__(self): s = 'galsim.hsm.ShapeData(' @@ -233,24 +197,38 @@ def applyWCS(self, wcs, image_pos): # The other values are reset to the defaults, since they are # results from EstimateShear. - -@implements(_galsim.hsm.HSMParams) -# @dataclass(repr=False) +@implements( + _galsim.hsm.HSMParams, + lax_description=HSM_LAX_DOCS + ) class HSMParams: - nsig_rg: float = 3.0 - nsig_rg2: float = 3.6 - regauss_too_small: int = 1 - adapt_order: int = 2 - convergence_threshold: float = 1.e-6 - max_mom2_iter: int = 400 - num_iter_default: int = -1 - bound_correct_wt: float = 0.25 - max_amoment: float = 8000. - max_ashift: float = 15. - ksb_moments_max: int = 4 - ksb_sig_weight: float = 0.0 - ksb_sig_factor: float = 1.0 - failed_moments: float = -1000 + def __init__(self, nsig_rg=3.0, nsig_rg2=3.6, max_moment_nsig2=0, regauss_too_small=1, + adapt_order=2, convergence_threshold=1.e-6, max_mom2_iter=400, + num_iter_default=-1, bound_correct_wt=0.25, max_amoment=8000., max_ashift=15., + ksb_moments_max=4, ksb_sig_weight=0.0, ksb_sig_factor=1.0, failed_moments=-1000.): + + if max_moment_nsig2 != 0: + from .deprecated import depr + depr('max_moment_nsig2', 2.4, '', 'This parameter is no longer used.') + + self._nsig_rg = float(nsig_rg) + self._nsig_rg2 = float(nsig_rg2) + self._regauss_too_small = int(regauss_too_small) + self._adapt_order = int(adapt_order) + self._convergence_threshold = float(convergence_threshold) + self._max_mom2_iter = int(max_mom2_iter) + self._num_iter_default = int(num_iter_default) + self._bound_correct_wt = float(bound_correct_wt) + self._max_amoment = float(max_amoment) + self._max_ashift = float(max_ashift) + self._ksb_moments_max = int(ksb_moments_max) + self._ksb_sig_weight = float(ksb_sig_weight) + self._ksb_sig_factor = float(ksb_sig_factor) + self._failed_moments = float(failed_moments) + self._make_hsmp() + + def _make_hsmp(self): + self._hsmp = _galsim.hsm.HSMParams(*self._getinitargs()) def _getinitargs(self): # TODO: For now, leave 3rd param as unused max_moment_nsig2. @@ -266,6 +244,8 @@ def nsig_rg(self): return self._nsig_rg @property def nsig_rg2(self): return self._nsig_rg2 @property + def max_moment_nsig2(self): return 0. + @property def regauss_too_small(self): return self._regauss_too_small @property def adapt_order(self): return self._adapt_order @@ -340,7 +320,7 @@ def _checkWeightAndBadpix(image, weight=None, badpix=None): weight=weight, image=image) # also make sure there are no negative values - if jnp.any(weight.array < 0): + if np.any(weight.array < 0): raise GalSimValueError("Weight image cannot contain negative values.", weight) if badpix is not None and badpix.bounds != image.bounds: @@ -360,7 +340,7 @@ def _convertMask(image, weight=None, badpix=None): mask = ImageI(bounds=image.bounds, init_value=1) else: # if weight is an ImageI, then we can use it as the mask image: - if weight.dtype == jnp.int32: + if weight.dtype == np.int32: if not badpix: mask = weight else: @@ -378,7 +358,7 @@ def _convertMask(image, weight=None, badpix=None): mask.array[badpix.array != 0] = 0 # if no pixels are used, raise an exception - if not jnp.any(mask.array): + if not np.any(mask.array): raise GalSimHSMError("No pixels are being used!") # finally, return the Image for the weight map @@ -391,18 +371,22 @@ def _convertImage(image): # This is used by EstimateShear() and FindAdaptiveMom(). # if weight is not of type float/double, convert to float/double - if (image.dtype == jnp.int16 or image.dtype == jnp.uint16): + if (image.dtype == np.int16 or image.dtype == np.uint16): image = ImageF(image) - elif (image.dtype == jnp.int32 or image.dtype == jnp.uint32): + elif (image.dtype == np.int32 or image.dtype == np.uint32): image = ImageD(image) return image -@implements(_galsim.hsm.EstimateShear) +@implements( + _galsim.hsm.EstimateShear, + lax_description=HSM_LAX_DOCS + ) def EstimateShear(gal_image, PSF_image, weight=None, badpix=None, sky_var=0.0, shear_est="REGAUSS", recompute_flux="FIT", guess_sig_gal=5.0, guess_sig_PSF=3.0, precision=1.0e-6, guess_centroid=None, strict=True, check=True, hsmparams=None): + # prepare inputs to C++ routines: ImageF or ImageD for galaxy, PSF, and ImageI for weight map gal_image = _convertImage(gal_image) PSF_image = _convertImage(PSF_image) hsmparams = HSMParams.check(hsmparams) @@ -414,7 +398,7 @@ def EstimateShear(gal_image, PSF_image, weight=None, badpix=None, sky_var=0.0, guess_centroid = gal_image.true_center try: result = ShapeData() - EstimateShearView(result._data, + _galsim.EstimateShearView(result._data, gal_image._image, PSF_image._image, weight._image, float(sky_var), shear_est.upper(), recompute_flux.upper(), float(guess_sig_gal), float(guess_sig_PSF), float(precision), @@ -426,121 +410,13 @@ def EstimateShear(gal_image, PSF_image, weight=None, badpix=None, sky_var=0.0, else: return ShapeData(error_message = str(err)) -@implements(_galsim.hsm.FindAdaptiveMom) +@implements( + _galsim.hsm.FindAdaptiveMom, + lax_description=HSM_LAX_DOCS + ) def FindAdaptiveMom(object_image, weight=None, badpix=None, guess_sig=5.0, precision=1.0e-6, guess_centroid=None, strict=True, check=True, round_moments=False, hsmparams=None, use_sky_coords=False): - """Measure adaptive moments of an object. - - This method estimates the best-fit elliptical Gaussian to the object (see Hirata & Seljak 2003 - for more discussion of adaptive moments). This elliptical Gaussian is computed iteratively - by initially guessing a circular Gaussian that is used as a weight function, computing the - weighted moments, recomputing the moments using the result of the previous step as the weight - function, and so on until the moments that are measured are the same as those used for the - weight function. `FindAdaptiveMom` can be used either as a free function, or as a method of the - `Image` class. - - By default, this routine computes moments in pixel coordinates, which generally use (x,y) - for the coordinate variables, so the underlying second moments are Ixx, Iyy, and Ixy. - If the WCS is (at least approximately) just a `PixelScale`, then this scale can be applied to - convert the moments' units from pixels to arcsec. The derived shapes are unaffected by - the pixel scale. - - However, there is also an option to apply a non-trivial WCS, which may potentially rotate - and/or shear the (x,y) moments to the local sky coordinates, which generally use (u,v) - for the coordinate variables. These coordinates are measured in arcsec and are oriented - such that +v is towards North and +u is towards West. In this case, the returned values are - all in arcsec, and are based instead on Iuu, Ivv, and Iuv. To enable this feature, use - ``use_sky_coords=True``. See also the method `ShapeData.applyWCS` for more details. - - .. note:: - - The application of the WCS implicitly assumes that the WCS is locally uniform across the - size of the object being measured. This is normally a very good approximation for most - applications of interest. - - Like `EstimateShear`, `FindAdaptiveMom` works on `Image` inputs, and fails if the object is - small compared to the pixel scale. For more details, see `EstimateShear`. - - Example:: - - >>> my_gaussian = galsim.Gaussian(flux=1.0, sigma=1.0) - >>> my_gaussian_image = my_gaussian.drawImage(scale=0.2, method='no_pixel') - >>> my_moments = galsim.hsm.FindAdaptiveMom(my_gaussian_image) - - or:: - - >>> my_moments = my_gaussian_image.FindAdaptiveMom() - - Assuming a successful measurement, the most relevant pieces of information are - ``my_moments.moments_sigma``, which is ``|det(M)|^(1/4)`` (= ``sigma`` for a circular Gaussian) - and ``my_moments.observed_shape``, which is a `Shear`. In this case, - ``my_moments.moments_sigma`` is precisely 5.0 (in units of pixels), and - ``my_moments.observed_shape`` is consistent with zero. - - Methods of the `Shear` class can be used to get the distortion ``e``, the shear ``g``, the - conformal shear ``eta``, and so on. - - As an example of how to use the optional ``hsmparams`` argument, consider cases where the input - images have unusual properties, such as being very large. This could occur when measuring the - properties of a very over-sampled image such as that generated using:: - - >>> my_gaussian = galsim.Gaussian(sigma=5.0) - >>> my_gaussian_image = my_gaussian.drawImage(scale=0.01, method='no_pixel') - - If the user attempts to measure the moments of this very large image using the standard syntax, - :: - - >>> my_moments = my_gaussian_image.FindAdaptiveMom() - - then the result will be a ``GalSimHSMError`` due to moment measurement failing because the - object is so large. While the list of all possible settings that can be changed is accessible - in the docstring of the `HSMParams` class, in this case we need to modify ``max_amoment`` which - is the maximum value of the moments in units of pixel^2. The following measurement, using the - default values for every parameter except for ``max_amoment``, will be - successful:: - - >>> new_params = galsim.hsm.HSMParams(max_amoment=5.0e5) - >>> my_moments = my_gaussian_image.FindAdaptiveMom(hsmparams=new_params) - - Parameters: - object_image: The `Image` for the object being measured. - weight: The optional weight image for the object being measured. Can be an int - or a float array. Currently, GalSim does not account for the variation - in non-zero weights, i.e., a weight map is converted to an image with 0 - and 1 for pixels that are not and are used. Full use of spatial - variation in non-zero weights will be included in a future version of - the code. [default: None] - badpix: The optional bad pixel mask for the image being used. Zero should be - used for pixels that are good, and any nonzero value indicates a bad - pixel. [default: None] - guess_sig: Optional argument with an initial guess for the Gaussian sigma of the - object (in pixels). [default: 5.0] - precision: The convergence criterion for the moments. [default: 1e-6] - guess_centroid: An initial guess for the object centroid (useful in case it is not - located at the center, which is used if this keyword is not set). The - convention for centroids is such that the center of the lower-left pixel - is (image.xmin, image.ymin). - [default: object_image.true_center] - strict: Whether to require success. If ``strict=True``, then there will be a - ``GalSimHSMError`` exception if shear estimation fails. If set to - ``False``, then information about failures will be silently stored in - the output ShapeData object. [default: True] - check: Check if the object_image, weight and badpix are in the correct format and valid. - [default: True] - round_moments: Use a circular weight function instead of elliptical. - [default: False] - hsmparams: The hsmparams keyword can be used to change the settings used by - FindAdaptiveMom when estimating moments; see `HSMParams` documentation - for more information. [default: None] - use_sky_coords: Whether to convert the measured moments to sky_coordinates. - Setting this to true is equivalent to running - ``applyWCS(object_image.wcs, image_pos=object_image.true_center)`` - on the result. [default: False] - - Returns: - a `ShapeData` object containing the results of moment measurement. - """ # prepare inputs to C++ routines: ImageF or ImageD for galaxy, PSF, and ImageI for weight map object_image = _convertImage(object_image) hsmparams = HSMParams.check(hsmparams) @@ -554,10 +430,10 @@ def FindAdaptiveMom(object_image, weight=None, badpix=None, guess_sig=5.0, preci try: result = ShapeData() - FindAdaptiveMomView(result._data, - object_image._image, weight._image, - float(guess_sig), float(precision), guess_centroid._p, - bool(round_moments), hsmparams._hsmp) + _galsim._galsim.FindAdaptiveMomView(result._data, + object_image._image, weight._image, + float(guess_sig), float(precision), guess_centroid._p, + bool(round_moments), hsmparams._hsmp) if use_sky_coords: result = result.applyWCS(object_image.wcs, image_pos=object_image.true_center) @@ -569,50 +445,4 @@ def FindAdaptiveMom(object_image, weight=None, badpix=None, guess_sig=5.0, preci return ShapeData(error_message = str(err)) # make FindAdaptiveMom a method of Image class -Image.FindAdaptiveMom = FindAdaptiveMom - -def nonZeroBounds(image): - pass - -def MakeMaskedImage(image, mask): - b1 = image.nonZeroBounds() - b2 = mask.noneZeroBounds() - b = b1 & b2 - - masked_image = image[b] * mask[b] - - # return masked_image - return image - -def FindAdaptativeMomView(results, - object_image, - object_mask_image, - guess_sig, - precision, - guess_centroid, - round_moments, - hsmparams - ): - - tc = object_image.getBounds().trueCenter() - results.moments_centroid = jnp.where(guess_centroid!=-1000.0, - guess_centroid, - tc) - - m_xx = guess_sig*guess_sig - m_yy = m_xx - m_xy = 0. - - # Apply the mask - masked_object_image = MakeMaskedImage(object_image, object_mask_image) - - results.image_bounds = object_image.bounds - - # TODO: find_ellipmom_2 - # TODO: find_ellipmom_1 - - # def find_ellipmom_1(data, x0, y0, Mxx, Mxy, Myy, A, Bx, By, Cxx, Cxy, Cyy, rho4w, hsmparams): - # xmin = data.XMin() - # xmax = data.XMax() - # ymin = data.YMin() - # ymax = data.YMax() +Image.FindAdaptiveMom = FindAdaptiveMom \ No newline at end of file From 1e9caebfdb7b6914d4897d9f3ea1a914af8cbde4 Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Fri, 6 Feb 2026 15:01:03 -0500 Subject: [PATCH 03/20] demo1.py runnning --- jax_galsim/bounds.py | 9 +++++++++ jax_galsim/hsm.py | 8 ++++---- jax_galsim/image.py | 20 ++++++++++++++++++++ jax_galsim/position.py | 8 ++++++++ 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index c30de770..c5e4d636 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -282,6 +282,11 @@ def __init__(self, *args, **kwargs): self.ymin = cast_to_float(self.ymin) self.ymax = cast_to_float(self.ymax) + @property + def _b(self): + return _galsim._galsim.BoundsD(cast_to_float(self.xmin), cast_to_float(self.xmax), + cast_to_float(self.ymin), cast_to_float(self.ymax)) + def _check_scalar(self, x, name): try: if ( @@ -331,6 +336,10 @@ def __init__(self, *args, **kwargs): self.ymin = cast_to_int(self.ymin) self.ymax = cast_to_int(self.ymax) + @property + def _b(self): + return _galsim._galsim.BoundsI(self.xmin, self.xmax, self.ymin, self.ymax) + def _check_scalar(self, x, name): try: if ( diff --git a/jax_galsim/hsm.py b/jax_galsim/hsm.py index a61e9ab0..40cd7004 100644 --- a/jax_galsim/hsm.py +++ b/jax_galsim/hsm.py @@ -45,10 +45,10 @@ def __init__(self, image_bounds=BoundsI(), moments_status=-1, # The others will raise an appropriate TypeError from the call to _galsim.ShapeData # when converting to int, float, etc. self._data = _galsim._galsim.ShapeData( - _galsim.BoundsI(image_bounds.xmin, image_bounds.xmax, image_bounds.ymin, image_bounds.ymax)._b, + image_bounds._b, int(moments_status), float(observed_shape.e1), float(observed_shape.e2), float(moments_sigma), float(moments_amp), - _galsim.PositionD(moments_centroid.x, moments_centroid.y)._p, + moments_centroid._p, float(moments_rho4), int(moments_n_iter), int(correction_status), float(corrected_e1), float(corrected_e2), float(corrected_g1), float(corrected_g2), str(meas_type), float(corrected_shape_err), str(correction_method), @@ -228,7 +228,7 @@ def __init__(self, nsig_rg=3.0, nsig_rg2=3.6, max_moment_nsig2=0, regauss_too_sm self._make_hsmp() def _make_hsmp(self): - self._hsmp = _galsim.hsm.HSMParams(*self._getinitargs()) + self._hsmp = _galsim._galsim.HSMParams(*self._getinitargs()) def _getinitargs(self): # TODO: For now, leave 3rd param as unused max_moment_nsig2. @@ -398,7 +398,7 @@ def EstimateShear(gal_image, PSF_image, weight=None, badpix=None, sky_var=0.0, guess_centroid = gal_image.true_center try: result = ShapeData() - _galsim.EstimateShearView(result._data, + _galsim._galsim.EstimateShearView(result._data, gal_image._image, PSF_image._image, weight._image, float(sky_var), shear_est.upper(), recompute_flux.upper(), float(guess_sig_gal), float(guess_sig_PSF), float(precision), diff --git a/jax_galsim/image.py b/jax_galsim/image.py index acd701cf..96fada82 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -28,6 +28,16 @@ ) @register_pytree_node_class class Image(object): + _cpp_type = { np.uint16 : _galsim._galsim.ImageViewUS, + np.uint32 : _galsim._galsim.ImageViewUI, + np.int16 : _galsim._galsim.ImageViewS, + np.int32 : _galsim._galsim.ImageViewI, + np.float32 : _galsim._galsim.ImageViewF, + np.float64 : _galsim._galsim.ImageViewD, + np.complex64 : _galsim._galsim.ImageViewCF, + np.complex128 : _galsim._galsim.ImageViewCD, + } + _alias_dtypes = { int: jnp.int32, # So that user gets what they would expect float: jnp.float64, # if using dtype=int or float or complex @@ -373,6 +383,16 @@ def iscontiguous(self): """ return True # In JAX all arrays are contiguous (almost) + @_galsim._utilities.lazy_property + def _image(self): + cls = self._cpp_type[self.dtype] + _array = np.asarray(self._array) + _data = _array.__array_interface__['data'][0] + return cls(_data, + _array.strides[1]//_array.itemsize, + _array.strides[0]//_array.itemsize, + self._bounds._b) + # Allow scale to work as a PixelScale wcs. @property def scale(self): diff --git a/jax_galsim/position.py b/jax_galsim/position.py index b805a4ab..7736ed84 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -188,6 +188,10 @@ def __init__(self, *args, **kwargs): self.x = cast_to_float(self.x) self.y = cast_to_float(self.y) + @property + def _p(self): + return _galsim._galsim.PositionD(self.x, self.y) + def _check_scalar(self, other, op): try: if ( @@ -213,6 +217,10 @@ def __init__(self, *args, **kwargs): self.x = cast_to_int(self.x) self.y = cast_to_int(self.y) + @property + def _p(self): + return _galsim._galsim.PositionI(self.x, self.y) + def _check_scalar(self, other, op): try: if ( From fb3905383e57020575c14b23cb38c80dd5aee678 Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Fri, 6 Feb 2026 15:10:37 -0500 Subject: [PATCH 04/20] adding demo1.py demo2.py --- examples/demo1.py | 127 +++++++++++++++++++++++++++++++++++++ examples/demo2.py | 158 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 285 insertions(+) create mode 100644 examples/demo1.py create mode 100644 examples/demo2.py diff --git a/examples/demo1.py b/examples/demo1.py new file mode 100644 index 00000000..854c1549 --- /dev/null +++ b/examples/demo1.py @@ -0,0 +1,127 @@ +# Copyright (c) 2012-2026 by the GalSim developers team on GitHub +# https://github.com/GalSim-developers +# +# This file is part of GalSim: The modular galaxy image simulation toolkit. +# https://github.com/GalSim-developers/GalSim +# +# GalSim is free software: redistribution and use in source and binary forms, +# with or without modification, are permitted provided that the following +# conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions, and the disclaimer given in the accompanying LICENSE +# file. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions, and the disclaimer given in the documentation +# and/or other materials provided with the distribution. +# +""" +Demo #1 + +This is the first script in our tutorial about using GalSim in python scripts: examples/demo*.py. +(This file is designed to be viewed in a window 100 characters wide.) + +Each of these demo*.py files are designed to be equivalent to the corresponding demo*.yaml file +(or demo*.json -- found in the json directory). If you are new to python, you should probably +look at those files first as they will probably have a quicker learning curve for you. Then you +can look through these python scripts, which show how to do the same thing. Of course, experienced +pythonistas may prefer to start with these scripts and then look at the corresponding YAML files. + +To run this script, simply write: + + python demo1.py + + +This first script is about as simple as it gets. We draw an image of a single galaxy convolved +with a PSF and write it to disk. We use a circular Gaussian profile for both the PSF and the +galaxy, and add a constant level of Gaussian noise to the image. + +In each demo, we list the new features introduced in that demo file. These will differ somewhat +between the .py and .yaml (or .json) versions, since the two methods implement things in different +ways. (demo*.py are python scripts, while demo*.yaml and demo*.json are configuration files.) + +New features introduced in this demo: + +- obj = galsim.Gaussian(flux, sigma) +- obj = galsim.Convolve([list of objects]) +- image = obj.drawImage(scale) +- image.added_flux (Only present after a drawImage command.) +- noise = galsim.GaussianNoise(sigma) +- image.addNoise(noise) +- image.write(file_name) +- image.FindAdaptiveMom() +""" + +import sys +import os +import math +import logging +import jax_galsim as galsim + +def main(argv): + """ + About as simple as it gets: + - Use a circular Gaussian profile for the galaxy. + - Convolve it by a circular Gaussian PSF. + - Add Gaussian noise to the image. + """ + # In non-script code, use getLogger(__name__) at module scope instead. + logging.basicConfig(format="%(message)s", level=logging.INFO, stream=sys.stdout) + logger = logging.getLogger("demo1") + + gal_flux = 1.e5 # total counts on the image + gal_sigma = 2. # arcsec + psf_sigma = 1. # arcsec + pixel_scale = 0.2 # arcsec / pixel + noise = 30. # standard deviation of the counts in each pixel + + logger.info('Starting demo script 1 using:') + logger.info(' - circular Gaussian galaxy (flux = %.1e, sigma = %.1f),',gal_flux,gal_sigma) + logger.info(' - circular Gaussian PSF (sigma = %.1f),',psf_sigma) + logger.info(' - pixel scale = %.2f,',pixel_scale) + logger.info(' - Gaussian noise (sigma = %.2f).',noise) + + # Define the galaxy profile + gal = galsim.Gaussian(flux=gal_flux, sigma=gal_sigma) + logger.debug('Made galaxy profile') + + # Define the PSF profile + psf = galsim.Gaussian(flux=1., sigma=psf_sigma) # PSF flux should always = 1 + logger.debug('Made PSF profile') + + # Final profile is the convolution of these + # Can include any number of things in the list, all of which are convolved + # together to make the final flux profile. + final = galsim.Convolve([gal, psf]) + logger.debug('Convolved components into final profile') + + # Draw the image with a particular pixel scale, given in arcsec/pixel. + # The returned image has a member, added_flux, which is gives the total flux actually added to + # the image. One could use this value to check if the image is large enough for some desired + # accuracy level. Here, we just ignore it. + image = final.drawImage(scale=pixel_scale) + logger.debug('Made image of the profile: flux = %f, added_flux = %f',gal_flux,image.added_flux) + + # Add Gaussian noise to the image with specified sigma + image.addNoise(galsim.GaussianNoise(sigma=noise)) + logger.debug('Added Gaussian noise') + + # Write the image to a file + if not os.path.isdir('output'): + os.mkdir('output') + file_name = os.path.join('output','demo1.fits') + # Note: if the file already exists, this will overwrite it. + image.write(file_name) + logger.info('Wrote image to %r' % file_name) # using %r adds quotes around filename for us + + results = image.FindAdaptiveMom() + + logger.info('HSM reports that the image has observed shape and size:') + logger.info(' e1 = %.3f, e2 = %.3f, sigma = %.3f (pixels)', results.observed_shape.e1, + results.observed_shape.e2, results.moments_sigma) + logger.info('Expected values in the limit that pixel response and noise are negligible:') + logger.info(' e1 = %.3f, e2 = %.3f, sigma = %.3f', 0.0, 0.0, + math.sqrt(gal_sigma**2 + psf_sigma**2)/pixel_scale) + +if __name__ == "__main__": + main(sys.argv) \ No newline at end of file diff --git a/examples/demo2.py b/examples/demo2.py new file mode 100644 index 00000000..6330e0e8 --- /dev/null +++ b/examples/demo2.py @@ -0,0 +1,158 @@ +# Copyright (c) 2012-2026 by the GalSim developers team on GitHub +# https://github.com/GalSim-developers +# +# This file is part of GalSim: The modular galaxy image simulation toolkit. +# https://github.com/GalSim-developers/GalSim +# +# GalSim is free software: redistribution and use in source and binary forms, +# with or without modification, are permitted provided that the following +# conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions, and the disclaimer given in the accompanying LICENSE +# file. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions, and the disclaimer given in the documentation +# and/or other materials provided with the distribution. +# +""" +Demo #2 + +The second script in our tutorial about using GalSim in python scripts: examples/demo*.py. +(This file is designed to be viewed in a window 100 characters wide.) + +This script is a bit more sophisticated, but still pretty basic. We're still only making +a single image, but now the galaxy has an exponential radial profile and is sheared. +The PSF is a circular Moffat profile. The noise is drawn from a Poisson distribution +using the flux from both the object and a background sky level to determine the +variance in each pixel. + +New features introduced in this demo: + +- obj = galsim.Exponential(flux, scale_radius) +- obj = galsim.Moffat(beta, flux, half_light_radius) +- obj = obj.shear(g1, g2) -- with explanation of other ways to specify shear +- rng = galsim.BaseDeviate(seed) +- noise = galsim.PoissonNoise(rng, sky_level) +- galsim.hsm.EstimateShear(image, image_epsf) +""" + +import sys +import os +import logging +import jax_galsim as galsim + +def main(argv): + """ + A little bit more sophisticated, but still pretty basic: + - Use a sheared, exponential profile for the galaxy. + - Convolve it by a circular Moffat PSF. + - Add Poisson noise to the image. + """ + # In non-script code, use getLogger(__name__) at module scope instead. + logging.basicConfig(format="%(message)s", level=logging.INFO, stream=sys.stdout) + logger = logging.getLogger("demo2") + + gal_flux = 1.e5 # counts + gal_r0 = 2.7 # arcsec + g1 = 0.1 # + g2 = 0.2 # + psf_beta = 5 # + psf_re = 1.0 # arcsec + pixel_scale = 0.2 # arcsec / pixel + sky_level = 2.5e3 # counts / arcsec^2 + + # This time use a particular seed, so the image is deterministic. + # This is the same seed that is used in demo2.yaml, which means the images + # produced by the two methods will be precisely identical. + random_seed = 1534225 + + # The first thing the config layer does with the random seed is to scramble + # it a bit. Specifically, it makes a random number generator (BaseDeviate) + # using that seed and asks for a raw value. This becomes the seed that + # actually gets used. + # The reason for this extra step is that eventually (cf. demo4) the config + # layer will want to increment these seed values when building multiple + # objects or images. If the user is likewise incrementing seed values for + # multiple runs of a given config file, these can interfere leading to + # surprising (and typically bad) results. + random_seed = galsim.BaseDeviate(random_seed).raw() + + logger.info('Starting demo script 2 using:') + logger.info(' - sheared (%.2f,%.2f) exponential galaxy (flux = %.1e, scale radius = %.2f),', + g1, g2, gal_flux, gal_r0) + logger.info(' - circular Moffat PSF (beta = %.1f, re = %.2f),', psf_beta, psf_re) + logger.info(' - pixel scale = %.2f,', pixel_scale) + logger.info(' - Poisson noise (sky level = %.1e).', sky_level) + + # Initialize the (pseudo-)random number generator that we will be using below. + # For a technical reason that will be explained later (demo9.py), we add 1 to the + # given random seed here. + rng = galsim.BaseDeviate(random_seed+1) + + # Define the galaxy profile. + gal = galsim.Exponential(flux=gal_flux, scale_radius=gal_r0) + + # Shear the galaxy by some value. + # There are quite a few ways you can use to specify a shape. + # q, beta Axis ratio and position angle: q = b/a, 0 < q < 1 + # e, beta Ellipticity and position angle: |e| = (1-q^2)/(1+q^2) + # g, beta ("Reduced") Shear and position angle: |g| = (1-q)/(1+q) + # eta, beta Conformal shear and position angle: eta = ln(1/q) + # e1,e2 Ellipticity components: e1 = e cos(2 beta), e2 = e sin(2 beta) + # g1,g2 ("Reduced") shear components: g1 = g cos(2 beta), g2 = g sin(2 beta) + # eta1,eta2 Conformal shear components: eta1 = eta cos(2 beta), eta2 = eta sin(2 beta) + gal = gal.shear(g1=g1, g2=g2) + logger.debug('Made galaxy profile') + + # Define the PSF profile. + psf = galsim.Moffat(beta=psf_beta, flux=1., half_light_radius=psf_re) + logger.debug('Made PSF profile') + + # Final profile is the convolution of these. + final = galsim.Convolve([gal, psf]) + logger.debug('Convolved components into final profile') + + # Draw the image with a particular pixel scale. + image = final.drawImage(scale=pixel_scale) + # The "effective PSF" is the PSF as drawn on an image, which includes the convolution + # by the pixel response. We label it epsf here. + image_epsf = psf.drawImage(scale=pixel_scale) + logger.debug('Made image of the profile') + + # To get Poisson noise on the image, we will use a class called PoissonNoise. + # However, we want the noise to correspond to what you would get with a significant + # flux from tke sky. This is done by telling PoissonNoise to add noise from a + # sky level in addition to the counts currently in the image. + # + # One wrinkle here is that the PoissonNoise class needs the sky level in each pixel, + # while we have a sky_level in counts per arcsec^2. So we need to convert: + sky_level_pixel = sky_level * pixel_scale**2 + noise = galsim.PoissonNoise(rng, sky_level=sky_level_pixel) + image.addNoise(noise) + logger.debug('Added Poisson noise') + + # Write the image to a file. + if not os.path.isdir('output'): + os.mkdir('output') + file_name = os.path.join('output', 'demo2.fits') + file_name_epsf = os.path.join('output','demo2_epsf.fits') + image.write(file_name) + image_epsf.write(file_name_epsf) + logger.info('Wrote image to %r',file_name) + logger.info('Wrote effective PSF image to %r',file_name_epsf) + + results = galsim.hsm.EstimateShear(image, image_epsf) + + logger.info('HSM reports that the image has observed shape and size:') + logger.info(' e1 = %.3f, e2 = %.3f, sigma = %.3f (pixels)', results.observed_shape.e1, + results.observed_shape.e2, results.moments_sigma) + logger.info('When carrying out Regaussianization PSF correction, HSM reports distortions') + logger.info(' e1, e2 = %.3f, %.3f', + results.corrected_e1, results.corrected_e2) + logger.info('Expected values in the limit that noise and non-Gaussianity are negligible:') + exp_shear = galsim.Shear(g1=g1, g2=g2) + logger.info(' g1, g2 = %.3f, %.3f', exp_shear.e1,exp_shear.e2) + +if __name__ == "__main__": + main(sys.argv) \ No newline at end of file From 9f6525e3f49ce773ab9d07d37beeb6fa6b940154 Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Fri, 6 Feb 2026 15:35:02 -0500 Subject: [PATCH 05/20] update tests --- tests/galsim_tests_config.yaml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 4416da23..54e81bbf 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -20,7 +20,7 @@ enabled_tests: - test_random.py - test_noise.py - test_image.py - - test_photon_array. + - test_photon_array.py - test_hsm.py - "*" # means all tests from galsim coord: @@ -87,7 +87,6 @@ allowed_failures: - "'Image' object has no attribute 'bin'" - "module 'jax_galsim' has no attribute 'InterpolatedKImage'" - "module 'jax_galsim' has no attribute 'CorrelatedNoise'" - # - "'Image' object has no attribute 'FindAdaptiveMom'" - "CelestialCoord.precess is too slow" # cannot get jax to warmup but once it does it passes - "ValueError not raised by from_xyz" - "ValueError not raised by greatCirclePoint" @@ -111,7 +110,6 @@ allowed_failures: - "GSParams.__init__() got an unexpected keyword argument 'allowed_flux_variation'" - "module 'jax_galsim' has no attribute 'Atmosphere'" - "module 'jax_galsim' has no attribute 'RandomWalk'" - # - "module 'jax_galsim' has no attribute 'hsm'" - "module 'jax_galsim' has no attribute 'des'" - "'Image' object has no attribute 'applyNonlinearity'" - "'Image' object has no attribute 'addReciprocityFailure'" From f16dc2f63be426b8f3b57c73b15ea67bd971fba9 Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Fri, 6 Feb 2026 15:40:42 -0500 Subject: [PATCH 06/20] apply back & ruff --- examples/demo1.py | 83 ++++--- examples/demo2.py | 87 ++++--- jax_galsim/bounds.py | 8 +- jax_galsim/hsm.py | 528 ++++++++++++++++++++++++++++++------------- jax_galsim/image.py | 35 +-- 5 files changed, 494 insertions(+), 247 deletions(-) diff --git a/examples/demo1.py b/examples/demo1.py index 854c1549..da8c60cd 100644 --- a/examples/demo1.py +++ b/examples/demo1.py @@ -52,12 +52,14 @@ - image.FindAdaptiveMom() """ -import sys -import os -import math import logging +import math +import os +import sys + import jax_galsim as galsim + def main(argv): """ About as simple as it gets: @@ -69,59 +71,80 @@ def main(argv): logging.basicConfig(format="%(message)s", level=logging.INFO, stream=sys.stdout) logger = logging.getLogger("demo1") - gal_flux = 1.e5 # total counts on the image - gal_sigma = 2. # arcsec - psf_sigma = 1. # arcsec + gal_flux = 1.0e5 # total counts on the image + gal_sigma = 2.0 # arcsec + psf_sigma = 1.0 # arcsec pixel_scale = 0.2 # arcsec / pixel - noise = 30. # standard deviation of the counts in each pixel - - logger.info('Starting demo script 1 using:') - logger.info(' - circular Gaussian galaxy (flux = %.1e, sigma = %.1f),',gal_flux,gal_sigma) - logger.info(' - circular Gaussian PSF (sigma = %.1f),',psf_sigma) - logger.info(' - pixel scale = %.2f,',pixel_scale) - logger.info(' - Gaussian noise (sigma = %.2f).',noise) + noise = 30.0 # standard deviation of the counts in each pixel + + logger.info("Starting demo script 1 using:") + logger.info( + " - circular Gaussian galaxy (flux = %.1e, sigma = %.1f),", + gal_flux, + gal_sigma, + ) + logger.info(" - circular Gaussian PSF (sigma = %.1f),", psf_sigma) + logger.info(" - pixel scale = %.2f,", pixel_scale) + logger.info(" - Gaussian noise (sigma = %.2f).", noise) # Define the galaxy profile gal = galsim.Gaussian(flux=gal_flux, sigma=gal_sigma) - logger.debug('Made galaxy profile') + logger.debug("Made galaxy profile") # Define the PSF profile - psf = galsim.Gaussian(flux=1., sigma=psf_sigma) # PSF flux should always = 1 - logger.debug('Made PSF profile') + psf = galsim.Gaussian(flux=1.0, sigma=psf_sigma) # PSF flux should always = 1 + logger.debug("Made PSF profile") # Final profile is the convolution of these # Can include any number of things in the list, all of which are convolved # together to make the final flux profile. final = galsim.Convolve([gal, psf]) - logger.debug('Convolved components into final profile') + logger.debug("Convolved components into final profile") # Draw the image with a particular pixel scale, given in arcsec/pixel. # The returned image has a member, added_flux, which is gives the total flux actually added to # the image. One could use this value to check if the image is large enough for some desired # accuracy level. Here, we just ignore it. image = final.drawImage(scale=pixel_scale) - logger.debug('Made image of the profile: flux = %f, added_flux = %f',gal_flux,image.added_flux) + logger.debug( + "Made image of the profile: flux = %f, added_flux = %f", + gal_flux, + image.added_flux, + ) # Add Gaussian noise to the image with specified sigma image.addNoise(galsim.GaussianNoise(sigma=noise)) - logger.debug('Added Gaussian noise') + logger.debug("Added Gaussian noise") # Write the image to a file - if not os.path.isdir('output'): - os.mkdir('output') - file_name = os.path.join('output','demo1.fits') + if not os.path.isdir("output"): + os.mkdir("output") + file_name = os.path.join("output", "demo1.fits") # Note: if the file already exists, this will overwrite it. image.write(file_name) - logger.info('Wrote image to %r' % file_name) # using %r adds quotes around filename for us + logger.info( + "Wrote image to %r" % file_name + ) # using %r adds quotes around filename for us results = image.FindAdaptiveMom() - logger.info('HSM reports that the image has observed shape and size:') - logger.info(' e1 = %.3f, e2 = %.3f, sigma = %.3f (pixels)', results.observed_shape.e1, - results.observed_shape.e2, results.moments_sigma) - logger.info('Expected values in the limit that pixel response and noise are negligible:') - logger.info(' e1 = %.3f, e2 = %.3f, sigma = %.3f', 0.0, 0.0, - math.sqrt(gal_sigma**2 + psf_sigma**2)/pixel_scale) + logger.info("HSM reports that the image has observed shape and size:") + logger.info( + " e1 = %.3f, e2 = %.3f, sigma = %.3f (pixels)", + results.observed_shape.e1, + results.observed_shape.e2, + results.moments_sigma, + ) + logger.info( + "Expected values in the limit that pixel response and noise are negligible:" + ) + logger.info( + " e1 = %.3f, e2 = %.3f, sigma = %.3f", + 0.0, + 0.0, + math.sqrt(gal_sigma**2 + psf_sigma**2) / pixel_scale, + ) + if __name__ == "__main__": - main(sys.argv) \ No newline at end of file + main(sys.argv) diff --git a/examples/demo2.py b/examples/demo2.py index 6330e0e8..7a5dfcde 100644 --- a/examples/demo2.py +++ b/examples/demo2.py @@ -37,11 +37,13 @@ - galsim.hsm.EstimateShear(image, image_epsf) """ -import sys -import os import logging +import os +import sys + import jax_galsim as galsim + def main(argv): """ A little bit more sophisticated, but still pretty basic: @@ -53,12 +55,12 @@ def main(argv): logging.basicConfig(format="%(message)s", level=logging.INFO, stream=sys.stdout) logger = logging.getLogger("demo2") - gal_flux = 1.e5 # counts - gal_r0 = 2.7 # arcsec - g1 = 0.1 # - g2 = 0.2 # - psf_beta = 5 # - psf_re = 1.0 # arcsec + gal_flux = 1.0e5 # counts + gal_r0 = 2.7 # arcsec + g1 = 0.1 # + g2 = 0.2 # + psf_beta = 5 # + psf_re = 1.0 # arcsec pixel_scale = 0.2 # arcsec / pixel sky_level = 2.5e3 # counts / arcsec^2 @@ -78,17 +80,22 @@ def main(argv): # surprising (and typically bad) results. random_seed = galsim.BaseDeviate(random_seed).raw() - logger.info('Starting demo script 2 using:') - logger.info(' - sheared (%.2f,%.2f) exponential galaxy (flux = %.1e, scale radius = %.2f),', - g1, g2, gal_flux, gal_r0) - logger.info(' - circular Moffat PSF (beta = %.1f, re = %.2f),', psf_beta, psf_re) - logger.info(' - pixel scale = %.2f,', pixel_scale) - logger.info(' - Poisson noise (sky level = %.1e).', sky_level) + logger.info("Starting demo script 2 using:") + logger.info( + " - sheared (%.2f,%.2f) exponential galaxy (flux = %.1e, scale radius = %.2f),", + g1, + g2, + gal_flux, + gal_r0, + ) + logger.info(" - circular Moffat PSF (beta = %.1f, re = %.2f),", psf_beta, psf_re) + logger.info(" - pixel scale = %.2f,", pixel_scale) + logger.info(" - Poisson noise (sky level = %.1e).", sky_level) # Initialize the (pseudo-)random number generator that we will be using below. # For a technical reason that will be explained later (demo9.py), we add 1 to the # given random seed here. - rng = galsim.BaseDeviate(random_seed+1) + rng = galsim.BaseDeviate(random_seed + 1) # Define the galaxy profile. gal = galsim.Exponential(flux=gal_flux, scale_radius=gal_r0) @@ -103,22 +110,22 @@ def main(argv): # g1,g2 ("Reduced") shear components: g1 = g cos(2 beta), g2 = g sin(2 beta) # eta1,eta2 Conformal shear components: eta1 = eta cos(2 beta), eta2 = eta sin(2 beta) gal = gal.shear(g1=g1, g2=g2) - logger.debug('Made galaxy profile') + logger.debug("Made galaxy profile") # Define the PSF profile. - psf = galsim.Moffat(beta=psf_beta, flux=1., half_light_radius=psf_re) - logger.debug('Made PSF profile') + psf = galsim.Moffat(beta=psf_beta, flux=1.0, half_light_radius=psf_re) + logger.debug("Made PSF profile") # Final profile is the convolution of these. final = galsim.Convolve([gal, psf]) - logger.debug('Convolved components into final profile') + logger.debug("Convolved components into final profile") # Draw the image with a particular pixel scale. image = final.drawImage(scale=pixel_scale) # The "effective PSF" is the PSF as drawn on an image, which includes the convolution # by the pixel response. We label it epsf here. image_epsf = psf.drawImage(scale=pixel_scale) - logger.debug('Made image of the profile') + logger.debug("Made image of the profile") # To get Poisson noise on the image, we will use a class called PoissonNoise. # However, we want the noise to correspond to what you would get with a significant @@ -130,29 +137,37 @@ def main(argv): sky_level_pixel = sky_level * pixel_scale**2 noise = galsim.PoissonNoise(rng, sky_level=sky_level_pixel) image.addNoise(noise) - logger.debug('Added Poisson noise') + logger.debug("Added Poisson noise") # Write the image to a file. - if not os.path.isdir('output'): - os.mkdir('output') - file_name = os.path.join('output', 'demo2.fits') - file_name_epsf = os.path.join('output','demo2_epsf.fits') + if not os.path.isdir("output"): + os.mkdir("output") + file_name = os.path.join("output", "demo2.fits") + file_name_epsf = os.path.join("output", "demo2_epsf.fits") image.write(file_name) image_epsf.write(file_name_epsf) - logger.info('Wrote image to %r',file_name) - logger.info('Wrote effective PSF image to %r',file_name_epsf) + logger.info("Wrote image to %r", file_name) + logger.info("Wrote effective PSF image to %r", file_name_epsf) results = galsim.hsm.EstimateShear(image, image_epsf) - logger.info('HSM reports that the image has observed shape and size:') - logger.info(' e1 = %.3f, e2 = %.3f, sigma = %.3f (pixels)', results.observed_shape.e1, - results.observed_shape.e2, results.moments_sigma) - logger.info('When carrying out Regaussianization PSF correction, HSM reports distortions') - logger.info(' e1, e2 = %.3f, %.3f', - results.corrected_e1, results.corrected_e2) - logger.info('Expected values in the limit that noise and non-Gaussianity are negligible:') + logger.info("HSM reports that the image has observed shape and size:") + logger.info( + " e1 = %.3f, e2 = %.3f, sigma = %.3f (pixels)", + results.observed_shape.e1, + results.observed_shape.e2, + results.moments_sigma, + ) + logger.info( + "When carrying out Regaussianization PSF correction, HSM reports distortions" + ) + logger.info(" e1, e2 = %.3f, %.3f", results.corrected_e1, results.corrected_e2) + logger.info( + "Expected values in the limit that noise and non-Gaussianity are negligible:" + ) exp_shear = galsim.Shear(g1=g1, g2=g2) - logger.info(' g1, g2 = %.3f, %.3f', exp_shear.e1,exp_shear.e2) + logger.info(" g1, g2 = %.3f, %.3f", exp_shear.e1, exp_shear.e2) + if __name__ == "__main__": - main(sys.argv) \ No newline at end of file + main(sys.argv) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 333d8538..a1df6a08 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -279,8 +279,12 @@ def __init__(self, *args, **kwargs): @property def _b(self): - return _galsim._galsim.BoundsD(cast_to_float(self.xmin), cast_to_float(self.xmax), - cast_to_float(self.ymin), cast_to_float(self.ymax)) + return _galsim._galsim.BoundsD( + cast_to_float(self.xmin), + cast_to_float(self.xmax), + cast_to_float(self.ymin), + cast_to_float(self.ymax), + ) def _check_scalar(self, x, name): try: diff --git a/jax_galsim/hsm.py b/jax_galsim/hsm.py index 40cd7004..0a63c96d 100644 --- a/jax_galsim/hsm.py +++ b/jax_galsim/hsm.py @@ -1,12 +1,16 @@ +import galsim as _galsim import numpy as np -import galsim as _galsim +from jax_galsim.bounds import BoundsI from jax_galsim.core.utils import implements +from jax_galsim.errors import ( + GalSimHSMError, + GalSimIncompatibleValuesError, + GalSimValueError, +) +from jax_galsim.image import Image, ImageD, ImageF, ImageI from jax_galsim.position import PositionD -from jax_galsim.bounds import BoundsI from jax_galsim.shear import Shear -from jax_galsim.image import Image, ImageI, ImageF, ImageD -from jax_galsim.errors import GalSimValueError, GalSimHSMError, GalSimIncompatibleValuesError HSM_LAX_DOCS = """\ Contrary to most other classes and objects in jax-galsim, the HSM @@ -22,22 +26,35 @@ GalSim code, and HSM should be considered outside the JAX execution model. """ -@implements( - _galsim.hsm.ShapeData, - lax_description=HSM_LAX_DOCS - ) -class ShapeData: - def __init__(self, image_bounds=BoundsI(), moments_status=-1, - observed_shape=Shear(), moments_sigma=-1.0, moments_amp=-1.0, - moments_centroid=PositionD(), moments_rho4=-1.0, moments_n_iter=0, - correction_status=-1, corrected_e1=-10., corrected_e2=-10., - corrected_g1=-10., corrected_g2=-10., meas_type="None", - corrected_shape_err=-1.0, correction_method="None", - resolution_factor=-1.0, psf_sigma=-1.0, - psf_shape=Shear(), error_message=""): +@implements(_galsim.hsm.ShapeData, lax_description=HSM_LAX_DOCS) +class ShapeData: + def __init__( + self, + image_bounds=BoundsI(), + moments_status=-1, + observed_shape=Shear(), + moments_sigma=-1.0, + moments_amp=-1.0, + moments_centroid=PositionD(), + moments_rho4=-1.0, + moments_n_iter=0, + correction_status=-1, + corrected_e1=-10.0, + corrected_e2=-10.0, + corrected_g1=-10.0, + corrected_g2=-10.0, + meas_type="None", + corrected_shape_err=-1.0, + correction_method="None", + resolution_factor=-1.0, + psf_sigma=-1.0, + psf_shape=Shear(), + error_message="", + ): # Avoid empty string, which can caus problems in C++ layer. - if error_message == "": error_message = "None" + if error_message == "": + error_message = "None" if not isinstance(image_bounds, BoundsI): raise TypeError("image_bounds must be a BoundsI instance") @@ -46,19 +63,36 @@ def __init__(self, image_bounds=BoundsI(), moments_status=-1, # when converting to int, float, etc. self._data = _galsim._galsim.ShapeData( image_bounds._b, - int(moments_status), float(observed_shape.e1), float(observed_shape.e2), - float(moments_sigma), float(moments_amp), + int(moments_status), + float(observed_shape.e1), + float(observed_shape.e2), + float(moments_sigma), + float(moments_amp), moments_centroid._p, - float(moments_rho4), int(moments_n_iter), int(correction_status), - float(corrected_e1), float(corrected_e2), float(corrected_g1), float(corrected_g2), - str(meas_type), float(corrected_shape_err), str(correction_method), - float(resolution_factor), float(psf_sigma), float(psf_shape.e1), float(psf_shape.e2), - str(error_message)) + float(moments_rho4), + int(moments_n_iter), + int(correction_status), + float(corrected_e1), + float(corrected_e2), + float(corrected_g1), + float(corrected_g2), + str(meas_type), + float(corrected_shape_err), + str(correction_method), + float(resolution_factor), + float(psf_sigma), + float(psf_shape.e1), + float(psf_shape.e2), + str(error_message), + ) @property - def image_bounds(self): return BoundsI(self._data.image_bounds) + def image_bounds(self): + return BoundsI(self._data.image_bounds) + @property - def moments_status(self): return self._data.moments_status + def moments_status(self): + return self._data.moments_status @property def observed_e1(self): @@ -73,35 +107,64 @@ def observed_shape(self): return Shear(e1=self.observed_e1, e2=self.observed_e2) @property - def moments_sigma(self): return self._data.moments_sigma + def moments_sigma(self): + return self._data.moments_sigma + @property - def moments_amp(self): return self._data.moments_amp + def moments_amp(self): + return self._data.moments_amp + @property - def moments_centroid(self): return PositionD(self._data.moments_centroid) + def moments_centroid(self): + return PositionD(self._data.moments_centroid) + @property - def moments_rho4(self): return self._data.moments_rho4 + def moments_rho4(self): + return self._data.moments_rho4 + @property - def moments_n_iter(self): return self._data.moments_n_iter + def moments_n_iter(self): + return self._data.moments_n_iter + @property - def correction_status(self): return self._data.correction_status + def correction_status(self): + return self._data.correction_status + @property - def corrected_e1(self): return self._data.corrected_e1 + def corrected_e1(self): + return self._data.corrected_e1 + @property - def corrected_e2(self): return self._data.corrected_e2 + def corrected_e2(self): + return self._data.corrected_e2 + @property - def corrected_g1(self): return self._data.corrected_g1 + def corrected_g1(self): + return self._data.corrected_g1 + @property - def corrected_g2(self): return self._data.corrected_g2 + def corrected_g2(self): + return self._data.corrected_g2 + @property - def meas_type(self): return self._data.meas_type + def meas_type(self): + return self._data.meas_type + @property - def corrected_shape_err(self): return self._data.corrected_shape_err + def corrected_shape_err(self): + return self._data.corrected_shape_err + @property - def correction_method(self): return self._data.correction_method + def correction_method(self): + return self._data.correction_method + @property - def resolution_factor(self): return self._data.resolution_factor + def resolution_factor(self): + return self._data.resolution_factor + @property - def psf_sigma(self): return self._data.psf_sigma + def psf_sigma(self): + return self._data.psf_sigma @property def psf_shape(self): @@ -117,46 +180,84 @@ def error_message(self): return self._data.error_message def __repr__(self): - s = 'galsim.hsm.ShapeData(' - if self.image_bounds.isDefined(): s += 'image_bounds=%r, '%self.image_bounds - if self.moments_status != -1: s += 'moments_status=%r, '%self.moments_status + s = "galsim.hsm.ShapeData(" + if self.image_bounds.isDefined(): + s += "image_bounds=%r, " % self.image_bounds + if self.moments_status != -1: + s += "moments_status=%r, " % self.moments_status # Always include this one: - s += 'observed_shape=%r'%self.observed_shape - if self.moments_sigma != -1: s += ', moments_sigma=%r'%self.moments_sigma - if self.moments_amp != -1: s += ', moments_amp=%r'%self.moments_amp + s += "observed_shape=%r" % self.observed_shape + if self.moments_sigma != -1: + s += ", moments_sigma=%r" % self.moments_sigma + if self.moments_amp != -1: + s += ", moments_amp=%r" % self.moments_amp if self.moments_centroid != PositionD(): - s += ', moments_centroid=%r'%self.moments_centroid - if self.moments_rho4 != -1: s += ', moments_rho4=%r'%self.moments_rho4 - if self.moments_n_iter != 0: s += ', moments_n_iter=%r'%self.moments_n_iter - if self.correction_status != -1: s += ', correction_status=%r'%self.correction_status - if self.corrected_e1 != -10.: s += ', corrected_e1=%r'%self.corrected_e1 - if self.corrected_e2 != -10.: s += ', corrected_e2=%r'%self.corrected_e2 - if self.corrected_g1 != -10.: s += ', corrected_g1=%r'%self.corrected_g1 - if self.corrected_g2 != -10.: s += ', corrected_g2=%r'%self.corrected_g2 - if self.meas_type != 'None': s += ', meas_type=%r'%self.meas_type - if self.corrected_shape_err != -1.: - s += ', corrected_shape_err=%r'%self.corrected_shape_err - if self.correction_method != 'None': s += ', correction_method=%r'%self.correction_method - if self.resolution_factor != -1.: s += ', resolution_factor=%r'%self.resolution_factor - if self.psf_sigma != -1.: s += ', psf_sigma=%r'%self.psf_sigma - if self.psf_shape != Shear(): s += ', psf_shape=%r'%self.psf_shape - if self.error_message != "": s += ', error_message=%r'%self.error_message - s += ')' + s += ", moments_centroid=%r" % self.moments_centroid + if self.moments_rho4 != -1: + s += ", moments_rho4=%r" % self.moments_rho4 + if self.moments_n_iter != 0: + s += ", moments_n_iter=%r" % self.moments_n_iter + if self.correction_status != -1: + s += ", correction_status=%r" % self.correction_status + if self.corrected_e1 != -10.0: + s += ", corrected_e1=%r" % self.corrected_e1 + if self.corrected_e2 != -10.0: + s += ", corrected_e2=%r" % self.corrected_e2 + if self.corrected_g1 != -10.0: + s += ", corrected_g1=%r" % self.corrected_g1 + if self.corrected_g2 != -10.0: + s += ", corrected_g2=%r" % self.corrected_g2 + if self.meas_type != "None": + s += ", meas_type=%r" % self.meas_type + if self.corrected_shape_err != -1.0: + s += ", corrected_shape_err=%r" % self.corrected_shape_err + if self.correction_method != "None": + s += ", correction_method=%r" % self.correction_method + if self.resolution_factor != -1.0: + s += ", resolution_factor=%r" % self.resolution_factor + if self.psf_sigma != -1.0: + s += ", psf_sigma=%r" % self.psf_sigma + if self.psf_shape != Shear(): + s += ", psf_shape=%r" % self.psf_shape + if self.error_message != "": + s += ", error_message=%r" % self.error_message + s += ")" return s def __eq__(self, other): - return (self is other or - (isinstance(other,ShapeData) and self._getinitargs() == other._getinitargs())) - def __ne__(self, other): return not self.__eq__(other) - def __hash__(self): return hash(("galsim.hsm.ShapeData", self._getinitargs())) + return self is other or ( + isinstance(other, ShapeData) and self._getinitargs() == other._getinitargs() + ) + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(("galsim.hsm.ShapeData", self._getinitargs())) def _getinitargs(self): - return (self.image_bounds, self.moments_status, self.observed_shape, - self.moments_sigma, self.moments_amp, self.moments_centroid, self.moments_rho4, - self.moments_n_iter, self.correction_status, self.corrected_e1, self.corrected_e2, - self.corrected_g1, self.corrected_g2, self.meas_type, self.corrected_shape_err, - self.correction_method, self.resolution_factor, self.psf_sigma, - self.psf_shape, self.error_message) + return ( + self.image_bounds, + self.moments_status, + self.observed_shape, + self.moments_sigma, + self.moments_amp, + self.moments_centroid, + self.moments_rho4, + self.moments_n_iter, + self.correction_status, + self.corrected_e1, + self.corrected_e2, + self.corrected_g1, + self.corrected_g2, + self.meas_type, + self.corrected_shape_err, + self.correction_method, + self.resolution_factor, + self.psf_sigma, + self.psf_shape, + self.error_message, + ) def __getstate__(self): return self._getinitargs() @@ -176,40 +277,54 @@ def applyWCS(self, wcs, image_pos): shape = self.observed_shape # First the flip, if any. if flip: - shape = Shear(g1 = -shape.g1, g2 = shape.g2) + shape = Shear(g1=-shape.g1, g2=shape.g2) # Next the rotation - shape = Shear(g = shape.g, beta = shape.beta + theta) + shape = Shear(g=shape.g, beta=shape.beta + theta) # Finally the shear observed_shape = shear + shape # Fix moments_centroid moments_centroid = jac.toWorld(self.moments_centroid) - jac.toWorld(image_pos) - return ShapeData(image_bounds=self.image_bounds, - moments_status=self.moments_status, - observed_shape=observed_shape, - moments_sigma=moments_sigma, - moments_amp=self.moments_amp, - moments_centroid=moments_centroid, - moments_rho4=self.moments_rho4, - moments_n_iter=self.moments_n_iter, - error_message=self.error_message) - # The other values are reset to the defaults, since they are - # results from EstimateShear. - -@implements( - _galsim.hsm.HSMParams, - lax_description=HSM_LAX_DOCS - ) -class HSMParams: - def __init__(self, nsig_rg=3.0, nsig_rg2=3.6, max_moment_nsig2=0, regauss_too_small=1, - adapt_order=2, convergence_threshold=1.e-6, max_mom2_iter=400, - num_iter_default=-1, bound_correct_wt=0.25, max_amoment=8000., max_ashift=15., - ksb_moments_max=4, ksb_sig_weight=0.0, ksb_sig_factor=1.0, failed_moments=-1000.): + return ShapeData( + image_bounds=self.image_bounds, + moments_status=self.moments_status, + observed_shape=observed_shape, + moments_sigma=moments_sigma, + moments_amp=self.moments_amp, + moments_centroid=moments_centroid, + moments_rho4=self.moments_rho4, + moments_n_iter=self.moments_n_iter, + error_message=self.error_message, + ) + # The other values are reset to the defaults, since they are + # results from EstimateShear. + +@implements(_galsim.hsm.HSMParams, lax_description=HSM_LAX_DOCS) +class HSMParams: + def __init__( + self, + nsig_rg=3.0, + nsig_rg2=3.6, + max_moment_nsig2=0, + regauss_too_small=1, + adapt_order=2, + convergence_threshold=1.0e-6, + max_mom2_iter=400, + num_iter_default=-1, + bound_correct_wt=0.25, + max_amoment=8000.0, + max_ashift=15.0, + ksb_moments_max=4, + ksb_sig_weight=0.0, + ksb_sig_factor=1.0, + failed_moments=-1000.0, + ): if max_moment_nsig2 != 0: from .deprecated import depr - depr('max_moment_nsig2', 2.4, '', 'This parameter is no longer used.') + + depr("max_moment_nsig2", 2.4, "", "This parameter is no longer used.") self._nsig_rg = float(nsig_rg) self._nsig_rg2 = float(nsig_rg2) @@ -233,42 +348,83 @@ def _make_hsmp(self): def _getinitargs(self): # TODO: For now, leave 3rd param as unused max_moment_nsig2. # Remove it at version 3.0 to avoid changing C++ API yet. - return (self.nsig_rg, self.nsig_rg2, 0., self.regauss_too_small, - self.adapt_order, self.convergence_threshold, self.max_mom2_iter, - self.num_iter_default, self.bound_correct_wt, self.max_amoment, self.max_ashift, - self.ksb_moments_max, self.ksb_sig_weight, self.ksb_sig_factor, - self.failed_moments) + return ( + self.nsig_rg, + self.nsig_rg2, + 0.0, + self.regauss_too_small, + self.adapt_order, + self.convergence_threshold, + self.max_mom2_iter, + self.num_iter_default, + self.bound_correct_wt, + self.max_amoment, + self.max_ashift, + self.ksb_moments_max, + self.ksb_sig_weight, + self.ksb_sig_factor, + self.failed_moments, + ) @property - def nsig_rg(self): return self._nsig_rg + def nsig_rg(self): + return self._nsig_rg + @property - def nsig_rg2(self): return self._nsig_rg2 + def nsig_rg2(self): + return self._nsig_rg2 + @property - def max_moment_nsig2(self): return 0. + def max_moment_nsig2(self): + return 0.0 + @property - def regauss_too_small(self): return self._regauss_too_small + def regauss_too_small(self): + return self._regauss_too_small + @property - def adapt_order(self): return self._adapt_order + def adapt_order(self): + return self._adapt_order + @property - def convergence_threshold(self): return self._convergence_threshold + def convergence_threshold(self): + return self._convergence_threshold + @property - def max_mom2_iter(self): return self._max_mom2_iter + def max_mom2_iter(self): + return self._max_mom2_iter + @property - def num_iter_default(self): return self._num_iter_default + def num_iter_default(self): + return self._num_iter_default + @property - def bound_correct_wt(self): return self._bound_correct_wt + def bound_correct_wt(self): + return self._bound_correct_wt + @property - def max_amoment(self): return self._max_amoment + def max_amoment(self): + return self._max_amoment + @property - def max_ashift(self): return self._max_ashift + def max_ashift(self): + return self._max_ashift + @property - def ksb_moments_max(self): return self._ksb_moments_max + def ksb_moments_max(self): + return self._ksb_moments_max + @property - def ksb_sig_weight(self): return self._ksb_sig_weight + def ksb_sig_weight(self): + return self._ksb_sig_weight + @property - def ksb_sig_factor(self): return self._ksb_sig_factor + def ksb_sig_factor(self): + return self._ksb_sig_factor + @property - def failed_moments(self): return self._failed_moments + def failed_moments(self): + return self._failed_moments @staticmethod def check(hsmparams, default=None): @@ -280,30 +436,34 @@ def check(hsmparams, default=None): if hsmparams is None: return default if default is not None else HSMParams.default elif not isinstance(hsmparams, HSMParams): - raise TypeError("Invalid HSMParams: %s"%hsmparams) + raise TypeError("Invalid HSMParams: %s" % hsmparams) else: return hsmparams def __repr__(self): - return ('galsim.hsm.HSMParams(' + 14*'%r,' + '%r)')%self._getinitargs() + return ("galsim.hsm.HSMParams(" + 14 * "%r," + "%r)") % self._getinitargs() def __eq__(self, other): - return (self is other or - (isinstance(other, HSMParams) and self._getinitargs() == other._getinitargs())) + return self is other or ( + isinstance(other, HSMParams) and self._getinitargs() == other._getinitargs() + ) + def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): - return hash(('galsim.hsm.HSMParams', self._getinitargs())) + return hash(("galsim.hsm.HSMParams", self._getinitargs())) def __getstate__(self): d = self.__dict__.copy() - del d['_hsmp'] + del d["_hsmp"] return d def __setstate__(self, d): self.__dict__ = d self._make_hsmp() + # We use the default a lot, so make it a class attribute. HSMParams.default = HSMParams() @@ -317,16 +477,22 @@ def _checkWeightAndBadpix(image, weight=None, badpix=None): if weight.bounds != image.bounds: raise GalSimIncompatibleValuesError( "Weight image does not have same bounds as the input Image.", - weight=weight, image=image) - # also make sure there are no negative values + weight=weight, + image=image, + ) + # also make sure there are no negative values if np.any(weight.array < 0): - raise GalSimValueError("Weight image cannot contain negative values.", weight) + raise GalSimValueError( + "Weight image cannot contain negative values.", weight + ) if badpix is not None and badpix.bounds != image.bounds: raise GalSimIncompatibleValuesError( "Badpix image does not have the same bounds as the input Image.", - badpix=badpix, image=image) + badpix=badpix, + image=image, + ) # A helper function for taking input weight and badpix Images, and returning a weight Image in the @@ -350,7 +516,7 @@ def _convertMask(image, weight=None, badpix=None): # otherwise, we need to convert it to the right type else: mask = ImageI(bounds=image.bounds, init_value=0) - mask.array[weight.array > 0.] = 1 + mask.array[weight.array > 0.0] = 1 # if badpix image was supplied, identify the nonzero (bad) pixels and set them to zero in weight # image; also check bounds @@ -371,21 +537,31 @@ def _convertImage(image): # This is used by EstimateShear() and FindAdaptiveMom(). # if weight is not of type float/double, convert to float/double - if (image.dtype == np.int16 or image.dtype == np.uint16): + if image.dtype == np.int16 or image.dtype == np.uint16: image = ImageF(image) - elif (image.dtype == np.int32 or image.dtype == np.uint32): + elif image.dtype == np.int32 or image.dtype == np.uint32: image = ImageD(image) return image -@implements( - _galsim.hsm.EstimateShear, - lax_description=HSM_LAX_DOCS - ) -def EstimateShear(gal_image, PSF_image, weight=None, badpix=None, sky_var=0.0, - shear_est="REGAUSS", recompute_flux="FIT", guess_sig_gal=5.0, - guess_sig_PSF=3.0, precision=1.0e-6, guess_centroid=None, - strict=True, check=True, hsmparams=None): + +@implements(_galsim.hsm.EstimateShear, lax_description=HSM_LAX_DOCS) +def EstimateShear( + gal_image, + PSF_image, + weight=None, + badpix=None, + sky_var=0.0, + shear_est="REGAUSS", + recompute_flux="FIT", + guess_sig_gal=5.0, + guess_sig_PSF=3.0, + precision=1.0e-6, + guess_centroid=None, + strict=True, + check=True, + hsmparams=None, +): # prepare inputs to C++ routines: ImageF or ImageD for galaxy, PSF, and ImageI for weight map gal_image = _convertImage(gal_image) PSF_image = _convertImage(PSF_image) @@ -398,25 +574,42 @@ def EstimateShear(gal_image, PSF_image, weight=None, badpix=None, sky_var=0.0, guess_centroid = gal_image.true_center try: result = ShapeData() - _galsim._galsim.EstimateShearView(result._data, - gal_image._image, PSF_image._image, weight._image, - float(sky_var), shear_est.upper(), recompute_flux.upper(), - float(guess_sig_gal), float(guess_sig_PSF), float(precision), - guess_centroid._p, hsmparams._hsmp) + _galsim._galsim.EstimateShearView( + result._data, + gal_image._image, + PSF_image._image, + weight._image, + float(sky_var), + shear_est.upper(), + recompute_flux.upper(), + float(guess_sig_gal), + float(guess_sig_PSF), + float(precision), + guess_centroid._p, + hsmparams._hsmp, + ) return result except RuntimeError as err: - if (strict == True): + if strict: raise GalSimHSMError(str(err)) from None else: - return ShapeData(error_message = str(err)) - -@implements( - _galsim.hsm.FindAdaptiveMom, - lax_description=HSM_LAX_DOCS - ) -def FindAdaptiveMom(object_image, weight=None, badpix=None, guess_sig=5.0, precision=1.0e-6, - guess_centroid=None, strict=True, check=True, round_moments=False, hsmparams=None, - use_sky_coords=False): + return ShapeData(error_message=str(err)) + + +@implements(_galsim.hsm.FindAdaptiveMom, lax_description=HSM_LAX_DOCS) +def FindAdaptiveMom( + object_image, + weight=None, + badpix=None, + guess_sig=5.0, + precision=1.0e-6, + guess_centroid=None, + strict=True, + check=True, + round_moments=False, + hsmparams=None, + use_sky_coords=False, +): # prepare inputs to C++ routines: ImageF or ImageD for galaxy, PSF, and ImageI for weight map object_image = _convertImage(object_image) hsmparams = HSMParams.check(hsmparams) @@ -430,19 +623,28 @@ def FindAdaptiveMom(object_image, weight=None, badpix=None, guess_sig=5.0, preci try: result = ShapeData() - _galsim._galsim.FindAdaptiveMomView(result._data, - object_image._image, weight._image, - float(guess_sig), float(precision), guess_centroid._p, - bool(round_moments), hsmparams._hsmp) + _galsim._galsim.FindAdaptiveMomView( + result._data, + object_image._image, + weight._image, + float(guess_sig), + float(precision), + guess_centroid._p, + bool(round_moments), + hsmparams._hsmp, + ) if use_sky_coords: - result = result.applyWCS(object_image.wcs, image_pos=object_image.true_center) + result = result.applyWCS( + object_image.wcs, image_pos=object_image.true_center + ) return result except RuntimeError as err: - if (strict == True): + if strict: raise GalSimHSMError(str(err)) from None else: - return ShapeData(error_message = str(err)) + return ShapeData(error_message=str(err)) + # make FindAdaptiveMom a method of Image class -Image.FindAdaptiveMom = FindAdaptiveMom \ No newline at end of file +Image.FindAdaptiveMom = FindAdaptiveMom diff --git a/jax_galsim/image.py b/jax_galsim/image.py index fef81288..46d0e2e7 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -28,16 +28,17 @@ ) @register_pytree_node_class class Image(object): - _cpp_type = { np.uint16 : _galsim._galsim.ImageViewUS, - np.uint32 : _galsim._galsim.ImageViewUI, - np.int16 : _galsim._galsim.ImageViewS, - np.int32 : _galsim._galsim.ImageViewI, - np.float32 : _galsim._galsim.ImageViewF, - np.float64 : _galsim._galsim.ImageViewD, - np.complex64 : _galsim._galsim.ImageViewCF, - np.complex128 : _galsim._galsim.ImageViewCD, - } - + _cpp_type = { + np.uint16: _galsim._galsim.ImageViewUS, + np.uint32: _galsim._galsim.ImageViewUI, + np.int16: _galsim._galsim.ImageViewS, + np.int32: _galsim._galsim.ImageViewI, + np.float32: _galsim._galsim.ImageViewF, + np.float64: _galsim._galsim.ImageViewD, + np.complex64: _galsim._galsim.ImageViewCF, + np.complex128: _galsim._galsim.ImageViewCD, + } + _alias_dtypes = { int: jnp.int32, # So that user gets what they would expect float: jnp.float64, # if using dtype=int or float or complex @@ -385,12 +386,14 @@ def iscontiguous(self): def _image(self): cls = self._cpp_type[self.dtype] _array = np.asarray(self._array) - _data = _array.__array_interface__['data'][0] - return cls(_data, - _array.strides[1]//_array.itemsize, - _array.strides[0]//_array.itemsize, - self._bounds._b) - + _data = _array.__array_interface__["data"][0] + return cls( + _data, + _array.strides[1] // _array.itemsize, + _array.strides[0] // _array.itemsize, + self._bounds._b, + ) + # Allow scale to work as a PixelScale wcs. @property @implements(_galsim.Image.scale) From f93fe46baafc77272f082dec5620369688131145 Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Sat, 7 Feb 2026 12:53:26 -0500 Subject: [PATCH 07/20] fix test_hsm.py erros --- jax_galsim/bounds.py | 17 ++++++++++++++--- jax_galsim/hsm.py | 5 +++-- jax_galsim/position.py | 5 ++++- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index a1df6a08..b21662ab 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -31,13 +31,18 @@ def _parse_args(self, *args, **kwargs): self._isdefined = False self.xmin = self.xmax = self.ymin = self.ymax = 0 elif len(args) == 1: - if isinstance(args[0], Bounds): + if isinstance( + args[0], (Bounds, _galsim._galsim.BoundsD, _galsim._galsim.BoundsI) + ): self._isdefined = True self.xmin = args[0].xmin self.xmax = args[0].xmax self.ymin = args[0].ymin self.ymax = args[0].ymax - elif isinstance(args[0], Position): + elif isinstance( + args[0], + (Position, _galsim._galsim.PositionD, _galsim._galsim.PositionI), + ): self._isdefined = True self.xmin = self.xmax = args[0].x self.ymin = self.ymax = args[0].y @@ -48,7 +53,13 @@ def _parse_args(self, *args, **kwargs): ) self._isdefined = True elif len(args) == 2: - if isinstance(args[0], Position) and isinstance(args[1], Position): + if isinstance( + args[0], + (Position, _galsim._galsim.PositionD, _galsim._galsim.PositionI), + ) and isinstance( + args[1], + (Position, _galsim._galsim.PositionD, _galsim._galsim.PositionI), + ): self._isdefined = True self.xmin = min(args[0].x, args[1].x) self.xmax = max(args[0].x, args[1].x) diff --git a/jax_galsim/hsm.py b/jax_galsim/hsm.py index 0a63c96d..91d78527 100644 --- a/jax_galsim/hsm.py +++ b/jax_galsim/hsm.py @@ -1,4 +1,5 @@ import galsim as _galsim +import jax.numpy as jnp import numpy as np from jax_galsim.bounds import BoundsI @@ -516,12 +517,12 @@ def _convertMask(image, weight=None, badpix=None): # otherwise, we need to convert it to the right type else: mask = ImageI(bounds=image.bounds, init_value=0) - mask.array[weight.array > 0.0] = 1 + mask._array = jnp.where(weight.array > 0.0, 1, mask._array) # if badpix image was supplied, identify the nonzero (bad) pixels and set them to zero in weight # image; also check bounds if badpix is not None: - mask.array[badpix.array != 0] = 0 + mask._array = jnp.where(badpix.array != 0, 0, mask._array) # if no pixels are used, raise an exception if not np.any(mask.array): diff --git a/jax_galsim/position.py b/jax_galsim/position.py index b515a550..678dd9d7 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -25,7 +25,10 @@ def _parse_args(self, *args, **kwargs): elif len(args) == 0: self.x = self.y = 0 elif len(args) == 1: - if isinstance(args[0], (Position,)): + if isinstance( + args[0], + (Position, _galsim._galsim.PositionD, _galsim._galsim.PositionI), + ): self.x = args[0].x self.y = args[0].y else: From 0ce58610a6d5459b227184c2b908a0c870286fa9 Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Sat, 7 Feb 2026 12:58:36 -0500 Subject: [PATCH 08/20] import jax_galsim examples/demo*.py --- examples/demo1.py | 18 +++++++++--------- examples/demo2.py | 30 +++++++++++++++--------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/examples/demo1.py b/examples/demo1.py index da8c60cd..987d3e75 100644 --- a/examples/demo1.py +++ b/examples/demo1.py @@ -18,7 +18,7 @@ """ Demo #1 -This is the first script in our tutorial about using GalSim in python scripts: examples/demo*.py. +This is the first script in our tutorial about using JAX-GalSim in python scripts: examples/demo*.py. (This file is designed to be viewed in a window 100 characters wide.) Each of these demo*.py files are designed to be equivalent to the corresponding demo*.yaml file @@ -42,11 +42,11 @@ New features introduced in this demo: -- obj = galsim.Gaussian(flux, sigma) -- obj = galsim.Convolve([list of objects]) +- obj = jax_galsim.Gaussian(flux, sigma) +- obj = jax_galsim.Convolve([list of objects]) - image = obj.drawImage(scale) - image.added_flux (Only present after a drawImage command.) -- noise = galsim.GaussianNoise(sigma) +- noise = jax_galsim.GaussianNoise(sigma) - image.addNoise(noise) - image.write(file_name) - image.FindAdaptiveMom() @@ -57,7 +57,7 @@ import os import sys -import jax_galsim as galsim +import jax_galsim def main(argv): @@ -88,17 +88,17 @@ def main(argv): logger.info(" - Gaussian noise (sigma = %.2f).", noise) # Define the galaxy profile - gal = galsim.Gaussian(flux=gal_flux, sigma=gal_sigma) + gal = jax_galsim.Gaussian(flux=gal_flux, sigma=gal_sigma) logger.debug("Made galaxy profile") # Define the PSF profile - psf = galsim.Gaussian(flux=1.0, sigma=psf_sigma) # PSF flux should always = 1 + psf = jax_galsim.Gaussian(flux=1.0, sigma=psf_sigma) # PSF flux should always = 1 logger.debug("Made PSF profile") # Final profile is the convolution of these # Can include any number of things in the list, all of which are convolved # together to make the final flux profile. - final = galsim.Convolve([gal, psf]) + final = jax_galsim.Convolve([gal, psf]) logger.debug("Convolved components into final profile") # Draw the image with a particular pixel scale, given in arcsec/pixel. @@ -113,7 +113,7 @@ def main(argv): ) # Add Gaussian noise to the image with specified sigma - image.addNoise(galsim.GaussianNoise(sigma=noise)) + image.addNoise(jax_galsim.GaussianNoise(sigma=noise)) logger.debug("Added Gaussian noise") # Write the image to a file diff --git a/examples/demo2.py b/examples/demo2.py index 7a5dfcde..e1606d63 100644 --- a/examples/demo2.py +++ b/examples/demo2.py @@ -18,7 +18,7 @@ """ Demo #2 -The second script in our tutorial about using GalSim in python scripts: examples/demo*.py. +The second script in our tutorial about using JAX-GalSim in python scripts: examples/demo*.py. (This file is designed to be viewed in a window 100 characters wide.) This script is a bit more sophisticated, but still pretty basic. We're still only making @@ -29,19 +29,19 @@ New features introduced in this demo: -- obj = galsim.Exponential(flux, scale_radius) -- obj = galsim.Moffat(beta, flux, half_light_radius) +- obj = jax_galsim.Exponential(flux, scale_radius) +- obj = jax_galsim.Moffat(beta, flux, half_light_radius) - obj = obj.shear(g1, g2) -- with explanation of other ways to specify shear -- rng = galsim.BaseDeviate(seed) -- noise = galsim.PoissonNoise(rng, sky_level) -- galsim.hsm.EstimateShear(image, image_epsf) +- rng = jax_galsim.BaseDeviate(seed) +- noise = jax_galsim.PoissonNoise(rng, sky_level) +- jax_galsim.hsm.EstimateShear(image, image_epsf) """ import logging import os import sys -import jax_galsim as galsim +import jax_galsim def main(argv): @@ -78,7 +78,7 @@ def main(argv): # objects or images. If the user is likewise incrementing seed values for # multiple runs of a given config file, these can interfere leading to # surprising (and typically bad) results. - random_seed = galsim.BaseDeviate(random_seed).raw() + random_seed = jax_galsim.BaseDeviate(random_seed).raw() logger.info("Starting demo script 2 using:") logger.info( @@ -95,10 +95,10 @@ def main(argv): # Initialize the (pseudo-)random number generator that we will be using below. # For a technical reason that will be explained later (demo9.py), we add 1 to the # given random seed here. - rng = galsim.BaseDeviate(random_seed + 1) + rng = jax_galsim.BaseDeviate(random_seed + 1) # Define the galaxy profile. - gal = galsim.Exponential(flux=gal_flux, scale_radius=gal_r0) + gal = jax_galsim.Exponential(flux=gal_flux, scale_radius=gal_r0) # Shear the galaxy by some value. # There are quite a few ways you can use to specify a shape. @@ -113,11 +113,11 @@ def main(argv): logger.debug("Made galaxy profile") # Define the PSF profile. - psf = galsim.Moffat(beta=psf_beta, flux=1.0, half_light_radius=psf_re) + psf = jax_galsim.Moffat(beta=psf_beta, flux=1.0, half_light_radius=psf_re) logger.debug("Made PSF profile") # Final profile is the convolution of these. - final = galsim.Convolve([gal, psf]) + final = jax_galsim.Convolve([gal, psf]) logger.debug("Convolved components into final profile") # Draw the image with a particular pixel scale. @@ -135,7 +135,7 @@ def main(argv): # One wrinkle here is that the PoissonNoise class needs the sky level in each pixel, # while we have a sky_level in counts per arcsec^2. So we need to convert: sky_level_pixel = sky_level * pixel_scale**2 - noise = galsim.PoissonNoise(rng, sky_level=sky_level_pixel) + noise = jax_galsim.PoissonNoise(rng, sky_level=sky_level_pixel) image.addNoise(noise) logger.debug("Added Poisson noise") @@ -149,7 +149,7 @@ def main(argv): logger.info("Wrote image to %r", file_name) logger.info("Wrote effective PSF image to %r", file_name_epsf) - results = galsim.hsm.EstimateShear(image, image_epsf) + results = jax_galsim.hsm.EstimateShear(image, image_epsf) logger.info("HSM reports that the image has observed shape and size:") logger.info( @@ -165,7 +165,7 @@ def main(argv): logger.info( "Expected values in the limit that noise and non-Gaussianity are negligible:" ) - exp_shear = galsim.Shear(g1=g1, g2=g2) + exp_shear = jax_galsim.Shear(g1=g1, g2=g2) logger.info(" g1, g2 = %.3f, %.3f", exp_shear.e1, exp_shear.e2) From 12535daa3ef29b9d4a3bb21ad6643187a0aaa44e Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Sun, 8 Feb 2026 14:12:38 -0500 Subject: [PATCH 09/20] add to-from image.py --- jax_galsim/image.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 46d0e2e7..8be7d854 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1117,15 +1117,27 @@ def tree_unflatten(cls, aux_data, children): @classmethod def from_galsim(cls, galsim_image): """Create a `Image` from a `galsim.Image` instance.""" + wcs = ( + BaseWCS.from_galsim(galsim_image.wcs) + if galsim_image.wcs is not None + else None + ) im = cls( array=galsim_image.array, - wcs=BaseWCS.from_galsim(galsim_image.wcs), + wcs=wcs, bounds=Bounds.from_galsim(galsim_image.bounds), ) if hasattr(galsim_image, "header"): im.header = galsim_image.header return im + def to_galsim(self): + """Create a galsim `Image` from a `jax_galsim.Image` object.""" + wcs = self.wcs.to_galsim() if self.wcs is not None else None + return _galsim.Image( + np.asarray(self.array), bounds=self.bounds.to_galsim(), wcs=wcs + ) + @implements( _galsim._Image, From fcf5265cf1aa315cd2ad313e4453d4b63dd3a3bc Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Sun, 8 Feb 2026 14:55:29 -0500 Subject: [PATCH 10/20] add FindAdaptiveMom to Image & remove hsm.py --- examples/demo2.py | 4 +- jax_galsim/__init__.py | 2 - jax_galsim/hsm.py | 651 ----------------------------------------- jax_galsim/image.py | 13 + 4 files changed, 16 insertions(+), 654 deletions(-) delete mode 100644 jax_galsim/hsm.py diff --git a/examples/demo2.py b/examples/demo2.py index e1606d63..72438f5d 100644 --- a/examples/demo2.py +++ b/examples/demo2.py @@ -41,6 +41,8 @@ import os import sys +import galsim + import jax_galsim @@ -149,7 +151,7 @@ def main(argv): logger.info("Wrote image to %r", file_name) logger.info("Wrote effective PSF image to %r", file_name_epsf) - results = jax_galsim.hsm.EstimateShear(image, image_epsf) + results = galsim.hsm.EstimateShear(image.to_galsim(), image_epsf.to_galsim()) logger.info("HSM reports that the image has observed shape and size:") logger.info( diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index 2a92690a..8e586d32 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -105,5 +105,3 @@ # this one is specific to jax_galsim from . import core - -from . import hsm diff --git a/jax_galsim/hsm.py b/jax_galsim/hsm.py deleted file mode 100644 index 91d78527..00000000 --- a/jax_galsim/hsm.py +++ /dev/null @@ -1,651 +0,0 @@ -import galsim as _galsim -import jax.numpy as jnp -import numpy as np - -from jax_galsim.bounds import BoundsI -from jax_galsim.core.utils import implements -from jax_galsim.errors import ( - GalSimHSMError, - GalSimIncompatibleValuesError, - GalSimValueError, -) -from jax_galsim.image import Image, ImageD, ImageF, ImageI -from jax_galsim.position import PositionD -from jax_galsim.shear import Shear - -HSM_LAX_DOCS = """\ -Contrary to most other classes and objects in jax-galsim, the HSM -functionality is not implemented using JAX primitives. - -All HSM-related methods directly rely on the original GalSim -implementation and therefore: - - do not run on GPU or TPU - - are not JIT-compilable - - do not benefit from JAX transformations (vmap, grad, etc.) - -As a result, all computations are performed on the CPU using classical -GalSim code, and HSM should be considered outside the JAX execution model. -""" - - -@implements(_galsim.hsm.ShapeData, lax_description=HSM_LAX_DOCS) -class ShapeData: - def __init__( - self, - image_bounds=BoundsI(), - moments_status=-1, - observed_shape=Shear(), - moments_sigma=-1.0, - moments_amp=-1.0, - moments_centroid=PositionD(), - moments_rho4=-1.0, - moments_n_iter=0, - correction_status=-1, - corrected_e1=-10.0, - corrected_e2=-10.0, - corrected_g1=-10.0, - corrected_g2=-10.0, - meas_type="None", - corrected_shape_err=-1.0, - correction_method="None", - resolution_factor=-1.0, - psf_sigma=-1.0, - psf_shape=Shear(), - error_message="", - ): - # Avoid empty string, which can caus problems in C++ layer. - if error_message == "": - error_message = "None" - - if not isinstance(image_bounds, BoundsI): - raise TypeError("image_bounds must be a BoundsI instance") - - # The others will raise an appropriate TypeError from the call to _galsim.ShapeData - # when converting to int, float, etc. - self._data = _galsim._galsim.ShapeData( - image_bounds._b, - int(moments_status), - float(observed_shape.e1), - float(observed_shape.e2), - float(moments_sigma), - float(moments_amp), - moments_centroid._p, - float(moments_rho4), - int(moments_n_iter), - int(correction_status), - float(corrected_e1), - float(corrected_e2), - float(corrected_g1), - float(corrected_g2), - str(meas_type), - float(corrected_shape_err), - str(correction_method), - float(resolution_factor), - float(psf_sigma), - float(psf_shape.e1), - float(psf_shape.e2), - str(error_message), - ) - - @property - def image_bounds(self): - return BoundsI(self._data.image_bounds) - - @property - def moments_status(self): - return self._data.moments_status - - @property - def observed_e1(self): - return self._data.observed_e1 - - @property - def observed_e2(self): - return self._data.observed_e2 - - @property - def observed_shape(self): - return Shear(e1=self.observed_e1, e2=self.observed_e2) - - @property - def moments_sigma(self): - return self._data.moments_sigma - - @property - def moments_amp(self): - return self._data.moments_amp - - @property - def moments_centroid(self): - return PositionD(self._data.moments_centroid) - - @property - def moments_rho4(self): - return self._data.moments_rho4 - - @property - def moments_n_iter(self): - return self._data.moments_n_iter - - @property - def correction_status(self): - return self._data.correction_status - - @property - def corrected_e1(self): - return self._data.corrected_e1 - - @property - def corrected_e2(self): - return self._data.corrected_e2 - - @property - def corrected_g1(self): - return self._data.corrected_g1 - - @property - def corrected_g2(self): - return self._data.corrected_g2 - - @property - def meas_type(self): - return self._data.meas_type - - @property - def corrected_shape_err(self): - return self._data.corrected_shape_err - - @property - def correction_method(self): - return self._data.correction_method - - @property - def resolution_factor(self): - return self._data.resolution_factor - - @property - def psf_sigma(self): - return self._data.psf_sigma - - @property - def psf_shape(self): - return Shear(e1=self._data.psf_e1, e2=self._data.psf_e2) - - @property - def error_message(self): - # We use "None" in C++ ShapeData to indicate no error messages to avoid problems on - # (some) Macs using zero-length strings. Here, we revert that back to "". - if self._data.error_message == "None": - return "" - else: - return self._data.error_message - - def __repr__(self): - s = "galsim.hsm.ShapeData(" - if self.image_bounds.isDefined(): - s += "image_bounds=%r, " % self.image_bounds - if self.moments_status != -1: - s += "moments_status=%r, " % self.moments_status - # Always include this one: - s += "observed_shape=%r" % self.observed_shape - if self.moments_sigma != -1: - s += ", moments_sigma=%r" % self.moments_sigma - if self.moments_amp != -1: - s += ", moments_amp=%r" % self.moments_amp - if self.moments_centroid != PositionD(): - s += ", moments_centroid=%r" % self.moments_centroid - if self.moments_rho4 != -1: - s += ", moments_rho4=%r" % self.moments_rho4 - if self.moments_n_iter != 0: - s += ", moments_n_iter=%r" % self.moments_n_iter - if self.correction_status != -1: - s += ", correction_status=%r" % self.correction_status - if self.corrected_e1 != -10.0: - s += ", corrected_e1=%r" % self.corrected_e1 - if self.corrected_e2 != -10.0: - s += ", corrected_e2=%r" % self.corrected_e2 - if self.corrected_g1 != -10.0: - s += ", corrected_g1=%r" % self.corrected_g1 - if self.corrected_g2 != -10.0: - s += ", corrected_g2=%r" % self.corrected_g2 - if self.meas_type != "None": - s += ", meas_type=%r" % self.meas_type - if self.corrected_shape_err != -1.0: - s += ", corrected_shape_err=%r" % self.corrected_shape_err - if self.correction_method != "None": - s += ", correction_method=%r" % self.correction_method - if self.resolution_factor != -1.0: - s += ", resolution_factor=%r" % self.resolution_factor - if self.psf_sigma != -1.0: - s += ", psf_sigma=%r" % self.psf_sigma - if self.psf_shape != Shear(): - s += ", psf_shape=%r" % self.psf_shape - if self.error_message != "": - s += ", error_message=%r" % self.error_message - s += ")" - return s - - def __eq__(self, other): - return self is other or ( - isinstance(other, ShapeData) and self._getinitargs() == other._getinitargs() - ) - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(("galsim.hsm.ShapeData", self._getinitargs())) - - def _getinitargs(self): - return ( - self.image_bounds, - self.moments_status, - self.observed_shape, - self.moments_sigma, - self.moments_amp, - self.moments_centroid, - self.moments_rho4, - self.moments_n_iter, - self.correction_status, - self.corrected_e1, - self.corrected_e2, - self.corrected_g1, - self.corrected_g2, - self.meas_type, - self.corrected_shape_err, - self.correction_method, - self.resolution_factor, - self.psf_sigma, - self.psf_shape, - self.error_message, - ) - - def __getstate__(self): - return self._getinitargs() - - def __setstate__(self, state): - self.__init__(*state) - - @implements(_galsim.hsm.ShapeData.applyWCS) - def applyWCS(self, wcs, image_pos): - jac = wcs.jacobian(image_pos=image_pos) - scale, shear, theta, flip = jac.getDecomposition() - - # Fix moments_sigma - moments_sigma = self.moments_sigma * scale - - # Fix observed_shape - shape = self.observed_shape - # First the flip, if any. - if flip: - shape = Shear(g1=-shape.g1, g2=shape.g2) - # Next the rotation - shape = Shear(g=shape.g, beta=shape.beta + theta) - # Finally the shear - observed_shape = shear + shape - - # Fix moments_centroid - moments_centroid = jac.toWorld(self.moments_centroid) - jac.toWorld(image_pos) - - return ShapeData( - image_bounds=self.image_bounds, - moments_status=self.moments_status, - observed_shape=observed_shape, - moments_sigma=moments_sigma, - moments_amp=self.moments_amp, - moments_centroid=moments_centroid, - moments_rho4=self.moments_rho4, - moments_n_iter=self.moments_n_iter, - error_message=self.error_message, - ) - # The other values are reset to the defaults, since they are - # results from EstimateShear. - - -@implements(_galsim.hsm.HSMParams, lax_description=HSM_LAX_DOCS) -class HSMParams: - def __init__( - self, - nsig_rg=3.0, - nsig_rg2=3.6, - max_moment_nsig2=0, - regauss_too_small=1, - adapt_order=2, - convergence_threshold=1.0e-6, - max_mom2_iter=400, - num_iter_default=-1, - bound_correct_wt=0.25, - max_amoment=8000.0, - max_ashift=15.0, - ksb_moments_max=4, - ksb_sig_weight=0.0, - ksb_sig_factor=1.0, - failed_moments=-1000.0, - ): - if max_moment_nsig2 != 0: - from .deprecated import depr - - depr("max_moment_nsig2", 2.4, "", "This parameter is no longer used.") - - self._nsig_rg = float(nsig_rg) - self._nsig_rg2 = float(nsig_rg2) - self._regauss_too_small = int(regauss_too_small) - self._adapt_order = int(adapt_order) - self._convergence_threshold = float(convergence_threshold) - self._max_mom2_iter = int(max_mom2_iter) - self._num_iter_default = int(num_iter_default) - self._bound_correct_wt = float(bound_correct_wt) - self._max_amoment = float(max_amoment) - self._max_ashift = float(max_ashift) - self._ksb_moments_max = int(ksb_moments_max) - self._ksb_sig_weight = float(ksb_sig_weight) - self._ksb_sig_factor = float(ksb_sig_factor) - self._failed_moments = float(failed_moments) - self._make_hsmp() - - def _make_hsmp(self): - self._hsmp = _galsim._galsim.HSMParams(*self._getinitargs()) - - def _getinitargs(self): - # TODO: For now, leave 3rd param as unused max_moment_nsig2. - # Remove it at version 3.0 to avoid changing C++ API yet. - return ( - self.nsig_rg, - self.nsig_rg2, - 0.0, - self.regauss_too_small, - self.adapt_order, - self.convergence_threshold, - self.max_mom2_iter, - self.num_iter_default, - self.bound_correct_wt, - self.max_amoment, - self.max_ashift, - self.ksb_moments_max, - self.ksb_sig_weight, - self.ksb_sig_factor, - self.failed_moments, - ) - - @property - def nsig_rg(self): - return self._nsig_rg - - @property - def nsig_rg2(self): - return self._nsig_rg2 - - @property - def max_moment_nsig2(self): - return 0.0 - - @property - def regauss_too_small(self): - return self._regauss_too_small - - @property - def adapt_order(self): - return self._adapt_order - - @property - def convergence_threshold(self): - return self._convergence_threshold - - @property - def max_mom2_iter(self): - return self._max_mom2_iter - - @property - def num_iter_default(self): - return self._num_iter_default - - @property - def bound_correct_wt(self): - return self._bound_correct_wt - - @property - def max_amoment(self): - return self._max_amoment - - @property - def max_ashift(self): - return self._max_ashift - - @property - def ksb_moments_max(self): - return self._ksb_moments_max - - @property - def ksb_sig_weight(self): - return self._ksb_sig_weight - - @property - def ksb_sig_factor(self): - return self._ksb_sig_factor - - @property - def failed_moments(self): - return self._failed_moments - - @staticmethod - def check(hsmparams, default=None): - """Checks that hsmparams is either a valid HSMParams instance or None. - - In the former case, it returns hsmparams, in the latter it returns default - (HSMParams.default if no other default specified). - """ - if hsmparams is None: - return default if default is not None else HSMParams.default - elif not isinstance(hsmparams, HSMParams): - raise TypeError("Invalid HSMParams: %s" % hsmparams) - else: - return hsmparams - - def __repr__(self): - return ("galsim.hsm.HSMParams(" + 14 * "%r," + "%r)") % self._getinitargs() - - def __eq__(self, other): - return self is other or ( - isinstance(other, HSMParams) and self._getinitargs() == other._getinitargs() - ) - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(("galsim.hsm.HSMParams", self._getinitargs())) - - def __getstate__(self): - d = self.__dict__.copy() - del d["_hsmp"] - return d - - def __setstate__(self, d): - self.__dict__ = d - self._make_hsmp() - - -# We use the default a lot, so make it a class attribute. -HSMParams.default = HSMParams() - - -# A helper function that checks if the weight and the badpix bounds are -# consistent with that of the image, and that the weight is non-negative. -def _checkWeightAndBadpix(image, weight=None, badpix=None): - # Check that the weight and badpix, if given, are sensible and compatible - # with the image. - if weight is not None: - if weight.bounds != image.bounds: - raise GalSimIncompatibleValuesError( - "Weight image does not have same bounds as the input Image.", - weight=weight, - image=image, - ) - # also make sure there are no negative values - - if np.any(weight.array < 0): - raise GalSimValueError( - "Weight image cannot contain negative values.", weight - ) - - if badpix is not None and badpix.bounds != image.bounds: - raise GalSimIncompatibleValuesError( - "Badpix image does not have the same bounds as the input Image.", - badpix=badpix, - image=image, - ) - - -# A helper function for taking input weight and badpix Images, and returning a weight Image in the -# format that the C++ functions want -def _convertMask(image, weight=None, badpix=None): - # Convert from input weight and badpix images to a single mask image needed by C++ functions. - # This is used by EstimateShear() and FindAdaptiveMom(). - - # if no weight image was supplied, make an int array (same size as gal image) filled with 1's - if weight is None: - mask = ImageI(bounds=image.bounds, init_value=1) - else: - # if weight is an ImageI, then we can use it as the mask image: - if weight.dtype == np.int32: - if not badpix: - mask = weight - else: - # If we need to mask bad pixels, we'll need a copy anyway. - mask = ImageI(weight) - - # otherwise, we need to convert it to the right type - else: - mask = ImageI(bounds=image.bounds, init_value=0) - mask._array = jnp.where(weight.array > 0.0, 1, mask._array) - - # if badpix image was supplied, identify the nonzero (bad) pixels and set them to zero in weight - # image; also check bounds - if badpix is not None: - mask._array = jnp.where(badpix.array != 0, 0, mask._array) - - # if no pixels are used, raise an exception - if not np.any(mask.array): - raise GalSimHSMError("No pixels are being used!") - - # finally, return the Image for the weight map - return mask - - -# A simpler helper function to force images to be of type ImageF or ImageD -def _convertImage(image): - # Convert the given image to the correct format needed to pass to the C++ layer. - # This is used by EstimateShear() and FindAdaptiveMom(). - - # if weight is not of type float/double, convert to float/double - if image.dtype == np.int16 or image.dtype == np.uint16: - image = ImageF(image) - elif image.dtype == np.int32 or image.dtype == np.uint32: - image = ImageD(image) - - return image - - -@implements(_galsim.hsm.EstimateShear, lax_description=HSM_LAX_DOCS) -def EstimateShear( - gal_image, - PSF_image, - weight=None, - badpix=None, - sky_var=0.0, - shear_est="REGAUSS", - recompute_flux="FIT", - guess_sig_gal=5.0, - guess_sig_PSF=3.0, - precision=1.0e-6, - guess_centroid=None, - strict=True, - check=True, - hsmparams=None, -): - # prepare inputs to C++ routines: ImageF or ImageD for galaxy, PSF, and ImageI for weight map - gal_image = _convertImage(gal_image) - PSF_image = _convertImage(PSF_image) - hsmparams = HSMParams.check(hsmparams) - if check: - _checkWeightAndBadpix(gal_image, weight=weight, badpix=badpix) - weight = _convertMask(gal_image, weight=weight, badpix=badpix) - - if guess_centroid is None: - guess_centroid = gal_image.true_center - try: - result = ShapeData() - _galsim._galsim.EstimateShearView( - result._data, - gal_image._image, - PSF_image._image, - weight._image, - float(sky_var), - shear_est.upper(), - recompute_flux.upper(), - float(guess_sig_gal), - float(guess_sig_PSF), - float(precision), - guess_centroid._p, - hsmparams._hsmp, - ) - return result - except RuntimeError as err: - if strict: - raise GalSimHSMError(str(err)) from None - else: - return ShapeData(error_message=str(err)) - - -@implements(_galsim.hsm.FindAdaptiveMom, lax_description=HSM_LAX_DOCS) -def FindAdaptiveMom( - object_image, - weight=None, - badpix=None, - guess_sig=5.0, - precision=1.0e-6, - guess_centroid=None, - strict=True, - check=True, - round_moments=False, - hsmparams=None, - use_sky_coords=False, -): - # prepare inputs to C++ routines: ImageF or ImageD for galaxy, PSF, and ImageI for weight map - object_image = _convertImage(object_image) - hsmparams = HSMParams.check(hsmparams) - if check: - _checkWeightAndBadpix(object_image, weight=weight, badpix=badpix) - - weight = _convertMask(object_image, weight=weight, badpix=badpix) - - if guess_centroid is None: - guess_centroid = object_image.true_center - - try: - result = ShapeData() - _galsim._galsim.FindAdaptiveMomView( - result._data, - object_image._image, - weight._image, - float(guess_sig), - float(precision), - guess_centroid._p, - bool(round_moments), - hsmparams._hsmp, - ) - - if use_sky_coords: - result = result.applyWCS( - object_image.wcs, image_pos=object_image.true_center - ) - return result - except RuntimeError as err: - if strict: - raise GalSimHSMError(str(err)) from None - else: - return ShapeData(error_message=str(err)) - - -# make FindAdaptiveMom a method of Image class -Image.FindAdaptiveMom = FindAdaptiveMom diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 8be7d854..b4d88e0a 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1138,6 +1138,19 @@ def to_galsim(self): np.asarray(self.array), bounds=self.bounds.to_galsim(), wcs=wcs ) + @implements( + _galsim.Image.FindAdaptiveMom, + lax_description=( + "This method converts the current `jax_galsim.Image` to a native " + "`galsim.Image` and delegates the computation to " + "`galsim.hsm.FindAdaptiveMom`. The returned object is GalSim's " + "`ShapeData`." + ), + ) + def FindAdaptiveMom(self, *args, **kwargs): + gs_image = self.to_galsim() + return gs_image.FindAdaptiveMom(*args, **kwargs) + @implements( _galsim._Image, From cf13d6dcc63f88ca8d7f2e6408bad8a330c62cdc Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Sun, 8 Feb 2026 15:05:20 -0500 Subject: [PATCH 11/20] clean branch and add from-to-galsim test --- jax_galsim/bounds.py | 30 +++--------------------------- jax_galsim/image.py | 23 ----------------------- jax_galsim/position.py | 13 +------------ tests/jax/test_api.py | 1 + 4 files changed, 5 insertions(+), 62 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 84a2e520..ed5942af 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -31,18 +31,13 @@ def _parse_args(self, *args, **kwargs): self._isdefined = False self.xmin = self.xmax = self.ymin = self.ymax = 0 elif len(args) == 1: - if isinstance( - args[0], (Bounds, _galsim._galsim.BoundsD, _galsim._galsim.BoundsI) - ): + if isinstance(args[0], Bounds): self._isdefined = True self.xmin = args[0].xmin self.xmax = args[0].xmax self.ymin = args[0].ymin self.ymax = args[0].ymax - elif isinstance( - args[0], - (Position, _galsim._galsim.PositionD, _galsim._galsim.PositionI), - ): + elif isinstance(args[0], Position): self._isdefined = True self.xmin = self.xmax = args[0].x self.ymin = self.ymax = args[0].y @@ -53,13 +48,7 @@ def _parse_args(self, *args, **kwargs): ) self._isdefined = True elif len(args) == 2: - if isinstance( - args[0], - (Position, _galsim._galsim.PositionD, _galsim._galsim.PositionI), - ) and isinstance( - args[1], - (Position, _galsim._galsim.PositionD, _galsim._galsim.PositionI), - ): + if isinstance(args[0], Position) and isinstance(args[1], Position): self._isdefined = True self.xmin = min(args[0].x, args[1].x) self.xmax = max(args[0].x, args[1].x) @@ -307,15 +296,6 @@ def __init__(self, *args, **kwargs): self.ymin = cast_to_float(self.ymin) self.ymax = cast_to_float(self.ymax) - @property - def _b(self): - return _galsim._galsim.BoundsD( - cast_to_float(self.xmin), - cast_to_float(self.xmax), - cast_to_float(self.ymin), - cast_to_float(self.ymax), - ) - def _check_scalar(self, x, name): try: if ( @@ -365,10 +345,6 @@ def __init__(self, *args, **kwargs): self.ymin = cast_to_int(self.ymin) self.ymax = cast_to_int(self.ymax) - @property - def _b(self): - return _galsim._galsim.BoundsI(self.xmin, self.xmax, self.ymin, self.ymax) - def _check_scalar(self, x, name): try: if ( diff --git a/jax_galsim/image.py b/jax_galsim/image.py index b4d88e0a..22e0849c 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -28,17 +28,6 @@ ) @register_pytree_node_class class Image(object): - _cpp_type = { - np.uint16: _galsim._galsim.ImageViewUS, - np.uint32: _galsim._galsim.ImageViewUI, - np.int16: _galsim._galsim.ImageViewS, - np.int32: _galsim._galsim.ImageViewI, - np.float32: _galsim._galsim.ImageViewF, - np.float64: _galsim._galsim.ImageViewD, - np.complex64: _galsim._galsim.ImageViewCF, - np.complex128: _galsim._galsim.ImageViewCD, - } - _alias_dtypes = { int: jnp.int32, # So that user gets what they would expect float: jnp.float64, # if using dtype=int or float or complex @@ -382,18 +371,6 @@ def isinteger(self): def iscontiguous(self): return True # In JAX all arrays are contiguous (almost) - @_galsim._utilities.lazy_property - def _image(self): - cls = self._cpp_type[self.dtype] - _array = np.asarray(self._array) - _data = _array.__array_interface__["data"][0] - return cls( - _data, - _array.strides[1] // _array.itemsize, - _array.strides[0] // _array.itemsize, - self._bounds._b, - ) - # Allow scale to work as a PixelScale wcs. @property @implements(_galsim.Image.scale) diff --git a/jax_galsim/position.py b/jax_galsim/position.py index 1430f516..822797b8 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -25,10 +25,7 @@ def _parse_args(self, *args, **kwargs): elif len(args) == 0: self.x = self.y = 0 elif len(args) == 1: - if isinstance( - args[0], - (Position, _galsim._galsim.PositionD, _galsim._galsim.PositionI), - ): + if isinstance(args[0], (Position,)): self.x = args[0].x self.y = args[0].y else: @@ -196,10 +193,6 @@ def __init__(self, *args, **kwargs): self.x = cast_to_float(self.x) self.y = cast_to_float(self.y) - @property - def _p(self): - return _galsim._galsim.PositionD(self.x, self.y) - def _check_scalar(self, other, op): try: if ( @@ -225,10 +218,6 @@ def __init__(self, *args, **kwargs): self.x = cast_to_int(self.x) self.y = cast_to_int(self.y) - @property - def _p(self): - return _galsim._galsim.PositionI(self.x, self.y) - def _check_scalar(self, other, op): try: if ( diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index f7445a6b..a9a4e287 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -604,6 +604,7 @@ def _reg_sfun(g1): def test_api_image(obj): _run_object_checks(obj, obj.__class__, "docs-methods") _run_object_checks(obj, obj.__class__, "pickle-eval-repr-img") + # _run_object_checks(obj, obj.__class__, "to-from-galsim") # JAX tracing should be an identity assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj From 1b144f865c11f37bf4e0f5931eb1adcc3d0210ec Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Sun, 8 Feb 2026 15:09:23 -0500 Subject: [PATCH 12/20] enable image to-from-galsim test --- tests/jax/test_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index a9a4e287..65aa1b1c 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -604,7 +604,7 @@ def _reg_sfun(g1): def test_api_image(obj): _run_object_checks(obj, obj.__class__, "docs-methods") _run_object_checks(obj, obj.__class__, "pickle-eval-repr-img") - # _run_object_checks(obj, obj.__class__, "to-from-galsim") + _run_object_checks(obj, obj.__class__, "to-from-galsim") # JAX tracing should be an identity assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj From e30bb90a8a00e9c52b37b7bfc12e9c1eb587c428 Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Sun, 8 Feb 2026 15:14:24 -0500 Subject: [PATCH 13/20] remove demo2.py --- examples/demo2.py | 175 ---------------------------------------------- 1 file changed, 175 deletions(-) delete mode 100644 examples/demo2.py diff --git a/examples/demo2.py b/examples/demo2.py deleted file mode 100644 index 72438f5d..00000000 --- a/examples/demo2.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2012-2026 by the GalSim developers team on GitHub -# https://github.com/GalSim-developers -# -# This file is part of GalSim: The modular galaxy image simulation toolkit. -# https://github.com/GalSim-developers/GalSim -# -# GalSim is free software: redistribution and use in source and binary forms, -# with or without modification, are permitted provided that the following -# conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions, and the disclaimer given in the accompanying LICENSE -# file. -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions, and the disclaimer given in the documentation -# and/or other materials provided with the distribution. -# -""" -Demo #2 - -The second script in our tutorial about using JAX-GalSim in python scripts: examples/demo*.py. -(This file is designed to be viewed in a window 100 characters wide.) - -This script is a bit more sophisticated, but still pretty basic. We're still only making -a single image, but now the galaxy has an exponential radial profile and is sheared. -The PSF is a circular Moffat profile. The noise is drawn from a Poisson distribution -using the flux from both the object and a background sky level to determine the -variance in each pixel. - -New features introduced in this demo: - -- obj = jax_galsim.Exponential(flux, scale_radius) -- obj = jax_galsim.Moffat(beta, flux, half_light_radius) -- obj = obj.shear(g1, g2) -- with explanation of other ways to specify shear -- rng = jax_galsim.BaseDeviate(seed) -- noise = jax_galsim.PoissonNoise(rng, sky_level) -- jax_galsim.hsm.EstimateShear(image, image_epsf) -""" - -import logging -import os -import sys - -import galsim - -import jax_galsim - - -def main(argv): - """ - A little bit more sophisticated, but still pretty basic: - - Use a sheared, exponential profile for the galaxy. - - Convolve it by a circular Moffat PSF. - - Add Poisson noise to the image. - """ - # In non-script code, use getLogger(__name__) at module scope instead. - logging.basicConfig(format="%(message)s", level=logging.INFO, stream=sys.stdout) - logger = logging.getLogger("demo2") - - gal_flux = 1.0e5 # counts - gal_r0 = 2.7 # arcsec - g1 = 0.1 # - g2 = 0.2 # - psf_beta = 5 # - psf_re = 1.0 # arcsec - pixel_scale = 0.2 # arcsec / pixel - sky_level = 2.5e3 # counts / arcsec^2 - - # This time use a particular seed, so the image is deterministic. - # This is the same seed that is used in demo2.yaml, which means the images - # produced by the two methods will be precisely identical. - random_seed = 1534225 - - # The first thing the config layer does with the random seed is to scramble - # it a bit. Specifically, it makes a random number generator (BaseDeviate) - # using that seed and asks for a raw value. This becomes the seed that - # actually gets used. - # The reason for this extra step is that eventually (cf. demo4) the config - # layer will want to increment these seed values when building multiple - # objects or images. If the user is likewise incrementing seed values for - # multiple runs of a given config file, these can interfere leading to - # surprising (and typically bad) results. - random_seed = jax_galsim.BaseDeviate(random_seed).raw() - - logger.info("Starting demo script 2 using:") - logger.info( - " - sheared (%.2f,%.2f) exponential galaxy (flux = %.1e, scale radius = %.2f),", - g1, - g2, - gal_flux, - gal_r0, - ) - logger.info(" - circular Moffat PSF (beta = %.1f, re = %.2f),", psf_beta, psf_re) - logger.info(" - pixel scale = %.2f,", pixel_scale) - logger.info(" - Poisson noise (sky level = %.1e).", sky_level) - - # Initialize the (pseudo-)random number generator that we will be using below. - # For a technical reason that will be explained later (demo9.py), we add 1 to the - # given random seed here. - rng = jax_galsim.BaseDeviate(random_seed + 1) - - # Define the galaxy profile. - gal = jax_galsim.Exponential(flux=gal_flux, scale_radius=gal_r0) - - # Shear the galaxy by some value. - # There are quite a few ways you can use to specify a shape. - # q, beta Axis ratio and position angle: q = b/a, 0 < q < 1 - # e, beta Ellipticity and position angle: |e| = (1-q^2)/(1+q^2) - # g, beta ("Reduced") Shear and position angle: |g| = (1-q)/(1+q) - # eta, beta Conformal shear and position angle: eta = ln(1/q) - # e1,e2 Ellipticity components: e1 = e cos(2 beta), e2 = e sin(2 beta) - # g1,g2 ("Reduced") shear components: g1 = g cos(2 beta), g2 = g sin(2 beta) - # eta1,eta2 Conformal shear components: eta1 = eta cos(2 beta), eta2 = eta sin(2 beta) - gal = gal.shear(g1=g1, g2=g2) - logger.debug("Made galaxy profile") - - # Define the PSF profile. - psf = jax_galsim.Moffat(beta=psf_beta, flux=1.0, half_light_radius=psf_re) - logger.debug("Made PSF profile") - - # Final profile is the convolution of these. - final = jax_galsim.Convolve([gal, psf]) - logger.debug("Convolved components into final profile") - - # Draw the image with a particular pixel scale. - image = final.drawImage(scale=pixel_scale) - # The "effective PSF" is the PSF as drawn on an image, which includes the convolution - # by the pixel response. We label it epsf here. - image_epsf = psf.drawImage(scale=pixel_scale) - logger.debug("Made image of the profile") - - # To get Poisson noise on the image, we will use a class called PoissonNoise. - # However, we want the noise to correspond to what you would get with a significant - # flux from tke sky. This is done by telling PoissonNoise to add noise from a - # sky level in addition to the counts currently in the image. - # - # One wrinkle here is that the PoissonNoise class needs the sky level in each pixel, - # while we have a sky_level in counts per arcsec^2. So we need to convert: - sky_level_pixel = sky_level * pixel_scale**2 - noise = jax_galsim.PoissonNoise(rng, sky_level=sky_level_pixel) - image.addNoise(noise) - logger.debug("Added Poisson noise") - - # Write the image to a file. - if not os.path.isdir("output"): - os.mkdir("output") - file_name = os.path.join("output", "demo2.fits") - file_name_epsf = os.path.join("output", "demo2_epsf.fits") - image.write(file_name) - image_epsf.write(file_name_epsf) - logger.info("Wrote image to %r", file_name) - logger.info("Wrote effective PSF image to %r", file_name_epsf) - - results = galsim.hsm.EstimateShear(image.to_galsim(), image_epsf.to_galsim()) - - logger.info("HSM reports that the image has observed shape and size:") - logger.info( - " e1 = %.3f, e2 = %.3f, sigma = %.3f (pixels)", - results.observed_shape.e1, - results.observed_shape.e2, - results.moments_sigma, - ) - logger.info( - "When carrying out Regaussianization PSF correction, HSM reports distortions" - ) - logger.info(" e1, e2 = %.3f, %.3f", results.corrected_e1, results.corrected_e2) - logger.info( - "Expected values in the limit that noise and non-Gaussianity are negligible:" - ) - exp_shear = jax_galsim.Shear(g1=g1, g2=g2) - logger.info(" g1, g2 = %.3f, %.3f", exp_shear.e1, exp_shear.e2) - - -if __name__ == "__main__": - main(sys.argv) From 07d251ef9cac4c6e188577212e9e0babf6c969ac Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Sun, 8 Feb 2026 15:34:03 -0500 Subject: [PATCH 14/20] add again hsm allowed failures --- tests/galsim_tests_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 54e81bbf..2d532570 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -21,7 +21,6 @@ enabled_tests: - test_noise.py - test_image.py - test_photon_array.py - - test_hsm.py - "*" # means all tests from galsim coord: - test_angle.py @@ -110,6 +109,7 @@ allowed_failures: - "GSParams.__init__() got an unexpected keyword argument 'allowed_flux_variation'" - "module 'jax_galsim' has no attribute 'Atmosphere'" - "module 'jax_galsim' has no attribute 'RandomWalk'" + - "module 'jax_galsim' has no attribute 'hsm'" - "module 'jax_galsim' has no attribute 'des'" - "'Image' object has no attribute 'applyNonlinearity'" - "'Image' object has no attribute 'addReciprocityFailure'" From 0e577340af9b6fa7f5a0c3943ca60f50dc68efd6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Feb 2026 14:24:00 -0600 Subject: [PATCH 15/20] fix: update test submodule --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 2ed86695..2cd2f1ff 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 2ed86695df3669c4ff4de4cd3154e6fd76e206da +Subproject commit 2cd2f1ff70da4989d3c88604e7fc88ed47fbb3ea From 6e66bbbd5cf2dae04ff781dee5f74bb1804ea1e5 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Feb 2026 14:33:52 -0600 Subject: [PATCH 16/20] test: add symlink for good measure --- tests/fits_file | 1 + 1 file changed, 1 insertion(+) create mode 120000 tests/fits_file diff --git a/tests/fits_file b/tests/fits_file new file mode 120000 index 00000000..5b03e34d --- /dev/null +++ b/tests/fits_file @@ -0,0 +1 @@ +GalSim/tests/fits_files \ No newline at end of file From ce5e148b565f705f0f83b4617e5865ac30277757 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Feb 2026 14:41:53 -0600 Subject: [PATCH 17/20] fix: more fixes for test suite --- tests/GalSim | 2 +- tests/SBProfile_comparison_images | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 120000 tests/SBProfile_comparison_images diff --git a/tests/GalSim b/tests/GalSim index 2cd2f1ff..c41c4771 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 2cd2f1ff70da4989d3c88604e7fc88ed47fbb3ea +Subproject commit c41c477111574e0203e44ddba54f6888b950623c diff --git a/tests/SBProfile_comparison_images b/tests/SBProfile_comparison_images new file mode 120000 index 00000000..6e72d788 --- /dev/null +++ b/tests/SBProfile_comparison_images @@ -0,0 +1 @@ +GalSim/tests/SBProfile_comparison_images \ No newline at end of file From f561121513c477a68175c5509c7db1fa33c97c99 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Feb 2026 14:48:21 -0600 Subject: [PATCH 18/20] fix: update submodule for more test fixes --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index c41c4771..3251a393 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit c41c477111574e0203e44ddba54f6888b950623c +Subproject commit 3251a393bf7ea94fe9ccda3508bc7db722eca1cf From f48e61812fd8e6888693724675bd949c4b27c90a Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Feb 2026 14:51:48 -0600 Subject: [PATCH 19/20] fix: convert args and kwargs if needed --- jax_galsim/image.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 22e0849c..f6f0f518 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1125,8 +1125,13 @@ def to_galsim(self): ), ) def FindAdaptiveMom(self, *args, **kwargs): + args_ = [arg.to_galsim() if hasattr(arg, "to_galsim") else arg for arg in args] + kwargs_ = { + key: val.to_galsim() if hasattr(val, "to_galsim") else val + for key, val in kwargs.items() + } gs_image = self.to_galsim() - return gs_image.FindAdaptiveMom(*args, **kwargs) + return gs_image.FindAdaptiveMom(*args_, **kwargs_) @implements( From 09094be82acd064a3ac13c08e1cb6659a5cb3f4b Mon Sep 17 00:00:00 2001 From: Benjamin Remy Date: Mon, 9 Feb 2026 16:19:24 -0500 Subject: [PATCH 20/20] add demo2.py --- examples/demo2.py | 175 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 examples/demo2.py diff --git a/examples/demo2.py b/examples/demo2.py new file mode 100644 index 00000000..4f64b436 --- /dev/null +++ b/examples/demo2.py @@ -0,0 +1,175 @@ +# Copyright (c) 2012-2026 by the GalSim developers team on GitHub +# https://github.com/GalSim-developers +# +# This file is part of GalSim: The modular galaxy image simulation toolkit. +# https://github.com/GalSim-developers/GalSim +# +# GalSim is free software: redistribution and use in source and binary forms, +# with or without modification, are permitted provided that the following +# conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions, and the disclaimer given in the accompanying LICENSE +# file. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions, and the disclaimer given in the documentation +# and/or other materials provided with the distribution. +# +""" +Demo #2 + +The second script in our tutorial about using JAX-GalSim in python scripts: examples/demo*.py. +(This file is designed to be viewed in a window 100 characters wide.) + +This script is a bit more sophisticated, but still pretty basic. We're still only making +a single image, but now the galaxy has an exponential radial profile and is sheared. +The PSF is a circular Moffat profile. The noise is drawn from a Poisson distribution +using the flux from both the object and a background sky level to determine the +variance in each pixel. + +New features introduced in this demo: + +- obj = jax_galsim.Exponential(flux, scale_radius) +- obj = jax_galsim.Moffat(beta, flux, half_light_radius) +- obj = obj.shear(g1, g2) -- with explanation of other ways to specify shear +- rng = jax_galsim.BaseDeviate(seed) +- noise = jax_galsim.PoissonNoise(rng, sky_level) +- galsim.hsm.EstimateShear(image, image_epsf) +""" + +import logging +import os +import sys + +import galsim + +import jax_galsim + + +def main(argv): + """ + A little bit more sophisticated, but still pretty basic: + - Use a sheared, exponential profile for the galaxy. + - Convolve it by a circular Moffat PSF. + - Add Poisson noise to the image. + """ + # In non-script code, use getLogger(__name__) at module scope instead. + logging.basicConfig(format="%(message)s", level=logging.INFO, stream=sys.stdout) + logger = logging.getLogger("demo2") + + gal_flux = 1.0e5 # counts + gal_r0 = 2.7 # arcsec + g1 = 0.1 # + g2 = 0.2 # + psf_beta = 5 # + psf_re = 1.0 # arcsec + pixel_scale = 0.2 # arcsec / pixel + sky_level = 2.5e3 # counts / arcsec^2 + + # This time use a particular seed, so the image is deterministic. + # This is the same seed that is used in demo2.yaml, which means the images + # produced by the two methods will be precisely identical. + random_seed = 1534225 + + # The first thing the config layer does with the random seed is to scramble + # it a bit. Specifically, it makes a random number generator (BaseDeviate) + # using that seed and asks for a raw value. This becomes the seed that + # actually gets used. + # The reason for this extra step is that eventually (cf. demo4) the config + # layer will want to increment these seed values when building multiple + # objects or images. If the user is likewise incrementing seed values for + # multiple runs of a given config file, these can interfere leading to + # surprising (and typically bad) results. + random_seed = jax_galsim.BaseDeviate(random_seed).raw() + + logger.info("Starting demo script 2 using:") + logger.info( + " - sheared (%.2f,%.2f) exponential galaxy (flux = %.1e, scale radius = %.2f),", + g1, + g2, + gal_flux, + gal_r0, + ) + logger.info(" - circular Moffat PSF (beta = %.1f, re = %.2f),", psf_beta, psf_re) + logger.info(" - pixel scale = %.2f,", pixel_scale) + logger.info(" - Poisson noise (sky level = %.1e).", sky_level) + + # Initialize the (pseudo-)random number generator that we will be using below. + # For a technical reason that will be explained later (demo9.py), we add 1 to the + # given random seed here. + rng = jax_galsim.BaseDeviate(random_seed + 1) + + # Define the galaxy profile. + gal = jax_galsim.Exponential(flux=gal_flux, scale_radius=gal_r0) + + # Shear the galaxy by some value. + # There are quite a few ways you can use to specify a shape. + # q, beta Axis ratio and position angle: q = b/a, 0 < q < 1 + # e, beta Ellipticity and position angle: |e| = (1-q^2)/(1+q^2) + # g, beta ("Reduced") Shear and position angle: |g| = (1-q)/(1+q) + # eta, beta Conformal shear and position angle: eta = ln(1/q) + # e1,e2 Ellipticity components: e1 = e cos(2 beta), e2 = e sin(2 beta) + # g1,g2 ("Reduced") shear components: g1 = g cos(2 beta), g2 = g sin(2 beta) + # eta1,eta2 Conformal shear components: eta1 = eta cos(2 beta), eta2 = eta sin(2 beta) + gal = gal.shear(g1=g1, g2=g2) + logger.debug("Made galaxy profile") + + # Define the PSF profile. + psf = jax_galsim.Moffat(beta=psf_beta, flux=1.0, half_light_radius=psf_re) + logger.debug("Made PSF profile") + + # Final profile is the convolution of these. + final = jax_galsim.Convolve([gal, psf]) + logger.debug("Convolved components into final profile") + + # Draw the image with a particular pixel scale. + image = final.drawImage(scale=pixel_scale) + # The "effective PSF" is the PSF as drawn on an image, which includes the convolution + # by the pixel response. We label it epsf here. + image_epsf = psf.drawImage(scale=pixel_scale) + logger.debug("Made image of the profile") + + # To get Poisson noise on the image, we will use a class called PoissonNoise. + # However, we want the noise to correspond to what you would get with a significant + # flux from tke sky. This is done by telling PoissonNoise to add noise from a + # sky level in addition to the counts currently in the image. + # + # One wrinkle here is that the PoissonNoise class needs the sky level in each pixel, + # while we have a sky_level in counts per arcsec^2. So we need to convert: + sky_level_pixel = sky_level * pixel_scale**2 + noise = jax_galsim.PoissonNoise(rng, sky_level=sky_level_pixel) + image.addNoise(noise) + logger.debug("Added Poisson noise") + + # Write the image to a file. + if not os.path.isdir("output"): + os.mkdir("output") + file_name = os.path.join("output", "demo2.fits") + file_name_epsf = os.path.join("output", "demo2_epsf.fits") + image.write(file_name) + image_epsf.write(file_name_epsf) + logger.info("Wrote image to %r", file_name) + logger.info("Wrote effective PSF image to %r", file_name_epsf) + + results = galsim.hsm.EstimateShear(image.to_galsim(), image_epsf.to_galsim()) + + logger.info("HSM reports that the image has observed shape and size:") + logger.info( + " e1 = %.3f, e2 = %.3f, sigma = %.3f (pixels)", + results.observed_shape.e1, + results.observed_shape.e2, + results.moments_sigma, + ) + logger.info( + "When carrying out Regaussianization PSF correction, HSM reports distortions" + ) + logger.info(" e1, e2 = %.3f, %.3f", results.corrected_e1, results.corrected_e2) + logger.info( + "Expected values in the limit that noise and non-Gaussianity are negligible:" + ) + exp_shear = galsim.Shear(g1=g1, g2=g2) + logger.info(" g1, g2 = %.3f, %.3f", exp_shear.e1, exp_shear.e2) + + +if __name__ == "__main__": + main(sys.argv)