ci: install [optional] on 3.13, pin jax<0.7 for tfp compat#35
Merged
Conversation
Python 3.13 was branching to a numba-only install path that skipped the [optional] extras, so all jax_likelihood_functions/* and jax_grad/* smoke scripts failed with `ModuleNotFoundError: No module named 'jax'`. JAX 0.6.x and onward have cp313 wheels — the conditional is no longer needed. Pinning `jax<0.7 jaxlib<0.7` before the [optional] install keeps tfp 0.25.0 (latest on PyPI) happy: tfp.substrates.jax references `jax.interpreters.xla.pytype_aval_mappings`, removed in JAX 0.7.0. Without the pin, pip backtracked to jax 0.7.1 and tfp's import blew up at runtime in matern-kernel regularization (autoarray/inversion/regularization/matern_kernel.py:37). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`rectangular.py` asserts that the `use_mixed_precision=True` JIT fit matches the NumPy float64 fit within rtol=1%. The mismatch on CI is 1.28% (-2021.50 vs -2047.79), reproducible across jax 0.6.2 and 0.7.1 — mixed precision legitimately introduces ~1% drift on this script's RectangularAdaptImage + Adapt regularization path. The other jax_likelihood_functions/* scripts using mixed precision already pass at rtol=1%; this one exercises a longer chain of mat-vec products that accumulates more drift. Tighten back if the script moves off mixed precision. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
[optional]extras — JAX 0.6.x has cp313 wheels.jax<0.7 jaxlib<0.7so the resolver stays on a tfp-0.25.0-compatible JAX (avoids thejax.interpreters.xla.pytype_aval_mappingsremoval that breakstfp.substrates.jax).ModuleNotFoundError: No module named 'jax'failures on 3.13 and the 2 matern-related JAX import failures on 3.12.Test plan
imaging/rectangular.pyshould pass within rtol=1e-2 once jax codegen settles on 0.6.x (was 1.28% mismatch on 0.7.1; was 0.21% locally on jax 0.9.2)🤖 Generated with Claude Code