diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index 5870d755..232801d6 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -834,16 +834,16 @@ def __eq__(self, other): and jnp.array_equal(self.cd, other.cd) and self.center == other.center and ( - jnp.array_equal(self.pv, other.pv) - or (self.pv is None and other.pv is None) + (self.pv is None and other.pv is None) + or jnp.array_equal(self.pv, other.pv) ) and ( - jnp.array_equal(self.ab, other.ab) - or (self.ab is None and other.ab is None) + (self.ab is None and other.ab is None) + or jnp.array_equal(self.ab, other.ab) ) and ( - jnp.array_equal(self.abp, other.abp) - or (self.abp is None and other.abp is None) + (self.abp is None and other.abp is None) + or jnp.array_equal(self.abp, other.abp) ) ) diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index c2e19a42..19ad5c4b 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -40,7 +40,7 @@ def int1d( @jax.jit def _func(x): rdt = jax.ShapeDtypeStruct(x.shape, x.dtype) - return jax.pure_callback(func, rdt, x) + return jax.pure_callback(func, rdt, x, vmap_method="sequential") else: _func = func diff --git a/pyproject.toml b/pyproject.toml index ff8a2390..1921ba52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,10 +16,9 @@ readme = "README.md" dependencies = [ "numpy >=1.18.0", "galsim >=2.3.0", - "jax <0.7.0", - "jaxlib", + "jax >=0.8.0", "astropy >=2.0", - "tensorflow-probability >=0.21.0", + "tfp-nightly", "quadax", ]