Skip to content

ci: install [optional] on 3.13, pin jax<0.7 for tfp compat#35

Merged
Jammy2211 merged 2 commits into
mainfrom
feature/fix-smoke-jax-and-mcmc
May 8, 2026
Merged

ci: install [optional] on 3.13, pin jax<0.7 for tfp compat#35
Jammy2211 merged 2 commits into
mainfrom
feature/fix-smoke-jax-and-mcmc

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

  • 3.13 path no longer skips [optional] extras — JAX 0.6.x has cp313 wheels.
  • Pin jax<0.7 jaxlib<0.7 so the resolver stays on a tfp-0.25.0-compatible JAX (avoids the jax.interpreters.xla.pytype_aval_mappings removal that breaks tfp.substrates.jax).
  • Should fix all 26 ModuleNotFoundError: No module named 'jax' failures on 3.13 and the 2 matern-related JAX import failures on 3.12.

Test plan

  • CI smoke on this PR (3.12 and 3.13)
  • autogalaxy 3.12: imaging/rectangular.py should pass within rtol=1e-2 once jax codegen settles on 0.6.x (was 1.28% mismatch on 0.7.1; was 0.21% locally on jax 0.9.2)

🤖 Generated with Claude Code

Jammy2211 and others added 2 commits May 8, 2026 15:56
Python 3.13 was branching to a numba-only install path that skipped
the [optional] extras, so all jax_likelihood_functions/* and jax_grad/*
smoke scripts failed with `ModuleNotFoundError: No module named 'jax'`.
JAX 0.6.x and onward have cp313 wheels — the conditional is no longer
needed.

Pinning `jax<0.7 jaxlib<0.7` before the [optional] install keeps tfp
0.25.0 (latest on PyPI) happy: tfp.substrates.jax references
`jax.interpreters.xla.pytype_aval_mappings`, removed in JAX 0.7.0.
Without the pin, pip backtracked to jax 0.7.1 and tfp's import blew
up at runtime in matern-kernel regularization
(autoarray/inversion/regularization/matern_kernel.py:37).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`rectangular.py` asserts that the `use_mixed_precision=True` JIT fit
matches the NumPy float64 fit within rtol=1%. The mismatch on CI is
1.28% (-2021.50 vs -2047.79), reproducible across jax 0.6.2 and 0.7.1
— mixed precision legitimately introduces ~1% drift on this script's
RectangularAdaptImage + Adapt regularization path.

The other jax_likelihood_functions/* scripts using mixed precision
already pass at rtol=1%; this one exercises a longer chain of mat-vec
products that accumulates more drift. Tighten back if the script
moves off mixed precision.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit 577f836 into main May 8, 2026
4 checks passed
@Jammy2211 Jammy2211 deleted the feature/fix-smoke-jax-and-mcmc branch May 8, 2026 15:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant