Skip to content

perf: batched transform_mapping_matrix in TransformerNUFFT (single nufft2d2 call)#305

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/nufftax-batched-mapping-matrix
May 10, 2026
Merged

perf: batched transform_mapping_matrix in TransformerNUFFT (single nufft2d2 call)#305
Jammy2211 merged 1 commit into
mainfrom
feature/nufftax-batched-mapping-matrix

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

TransformerNUFFT.transform_mapping_matrix previously looped over each source-pixel column in Python, scattering each into a (N_y, N_x) native image and calling _forward_native separately. Under jax.jit this fully unrolled into n_src distinct nufft2d2 invocations in a single trace, ballooning the JIT graph for pixelization-heavy fits.

This was most visible on autolens_workspace_test/scripts/jax_likelihood_functions/interferometer/rectangular_dspl.py: two source planes with mesh_shape=(30, 30) gave 1800 inlined NUFFTs in the JIT trace and caused 11+ minute slow-compile warnings + OOM on the 16 GB CI box.

The replacement scatters all columns into a single (n_src, N_y, N_x) batched array and calls nufft2d2 once with batched f. nufftax natively supports the batched form. The JIT graph drops from O(n_src) NUFFTs to a single batched NUFFT call.

API Changes

None. The method signature, return shape ((M, n_src)), and numerical result are unchanged.

Test plan

  • pytest test_autoarray/operators/test_transformer.py — 8/8 pass
  • pytest test_autoarray/ — 750/750 pass
  • black --check clean
  • Confirmed interferometer/rectangular_dspl.py (with mesh reduced to 8x8) compiles in seconds and produces the same canonical likelihood as the old loop implementation

Companion PR

autolens_workspace_test will follow with the JAX-likelihood cross-check that exercises this fast path end-to-end.

🤖 Generated with Claude Code

…fft2d2 call)

The previous implementation looped over each source-pixel column in
Python, scattering each into a (N_y, N_x) image and calling _forward_native
separately. Under jax.jit this fully unrolled into n_src distinct nufft2d2
invocations in one trace, ballooning the JIT graph for pixelization-heavy
fits — most visibly the rectangular_dspl JAX-likelihood script, where
two source planes with mesh_shape=(30,30) gave 1800 inlined NUFFTs and
caused 11+ minute slow-compile warnings (and OOM on 16 GB hosts).

The replacement scatters all columns into a single (n_src, N_y, N_x)
batched array and calls nufft2d2 once with batched f. nufftax natively
supports the batched form. Numerical results unchanged (existing
test_transformer.py passes); the JIT graph drops from O(n_src) NUFFTs
to a single batched NUFFT call.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit bc00c11 into main May 10, 2026
4 checks passed
@Jammy2211 Jammy2211 deleted the feature/nufftax-batched-mapping-matrix branch May 10, 2026 13:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant