Skip to content

feat: 2D regular-grid interpolator with JAX path; wire DatasetInterp #306

@Jammy2211

Description

@Jammy2211

Overview

Add a 2D regular-grid bilinear interpolator helper to PyAutoArray with both NumPy and JAX paths, then wire DatasetInterp in PyAutoGalaxy to use it. Step 4 of 7 in the ellipse_fitting_jax feature decomposition (PyAutoPrompt/z_features/ellipse_fitting_jax.md). DatasetInterp today uses scipy.interpolate.RegularGridInterpolator, which is the first hard JAX blocker for AnalysisEllipse.log_likelihood_function — replacing it with an xp-dispatched helper unblocks prompts 6 and 7.

Plan

  • Add a new autoarray/numerics/ subpackage with interp_2d.py containing two private paths (_interp_2d_numpy, _interp_2d_jax) and one public dispatcher interp_2d(points, x_axis, y_axis, values, fill_value=0.0, xp=np).
  • NumPy path delegates to scipy.interpolate.RegularGridInterpolator(bounds_error=False, fill_value=0.0) — preserves current semantics byte-for-byte.
  • JAX path uses jax.scipy.ndimage.map_coordinates(values, coords, order=1, cval=fill_value) with manual world→pixel-fractional coordinate conversion. Assumes regularly spaced axes (true for mask.derive_grid.all_false-derived grids).
  • Surface the dispatcher via aa.numerics.interp_2d. Mirror the 1D dispatch style of _interp1d_jax / _interp1d_numpy in autoarray/inversion/mesh/interpolator/rectangular_spline.py.
  • Unit tests cover: numpy↔JAX parity on random in-bounds points (rtol=1e-6), out-of-bounds returns fill_value, single-point query has shape (1,), JAX-only tests gated by pytest.importorskip("jax").
  • In PyAutoGalaxy, refactor DatasetInterp.{data,noise_map,mask}_interp from cached properties returning RegularGridInterpolator instances to methods (points, xp=np) that call aa.numerics.interp_2d. All call sites in fit_ellipse.py continue to work — they don't pass xp, defaulting to the numpy path.
  • Behaviour on the numpy path unchanged. The if self.interp.mask_interp is not None: check in fit_ellipse.py stays as-is (always True both before and after; the gate change is for prompt 6).
Detailed implementation plan

Affected Repositories

  • PyAutoArray (primary — adds the interpolator)
  • PyAutoGalaxy (wires DatasetInterp to the new helper)

Work Classification

Library (both repos)

Branch Survey

Repository Current Branch Dirty?
./PyAutoArray main clean
./PyAutoGalaxy main clean

Suggested branch: feature/jax-interp-2d (same name on both repos)
Worktree root: ~/Code/PyAutoLabs-wt/jax-interp-2d/ (created later by /start_library)

Implementation Steps

  1. PyAutoArray: create autoarray/numerics/__init__.py and autoarray/numerics/interp_2d.py with three functions:
    • _interp_2d_numpy(points, x_axis, y_axis, values, fill_value=0.0) — wraps scipy.interpolate.RegularGridInterpolator(points=(x_axis, y_axis), values=values, bounds_error=False, fill_value=fill_value) and evaluates at points (a (N, 2) array).
    • _interp_2d_jax(points, x_axis, y_axis, values, fill_value=0.0) — convert (y, x) world coords to pixel-fractional indices via (point - axis_min) / axis_spacing; call jax.scipy.ndimage.map_coordinates(values, jnp.stack([row_coords, col_coords]), order=1, cval=fill_value, mode='constant'). Import jax locally inside the function per the xp pattern.
    • interp_2d(points, x_axis, y_axis, values, fill_value=0.0, xp=np) — dispatcher.
  2. Update autoarray/__init__.py to surface aa.numerics (so aa.numerics.interp_2d(...) resolves).
  3. PyAutoArray: add tests at test_autoarray/numerics/test_interp_2d.py:
    • Random (N, 2) query points inside the grid → numpy and JAX paths agree to rtol=1e-6.
    • Out-of-bounds query points → both paths return fill_value exactly.
    • Single-point query → result has shape (1,).
    • JAX-only tests gated by pytest.importorskip("jax").
  4. PyAutoGalaxy: refactor autogalaxy/ellipse/dataset_interp.py:
    • Keep points_interp as a cached property — that's still the regular axis tuple.
    • Replace mask_interp, data_interp, noise_map_interp cached properties with methods that take (points, xp=np) and call aa.numerics.interp_2d directly, passing the relevant values array (mask / data / noise_map).
    • The dispatch is internal — call sites pass (points) only, getting the numpy default.
  5. Smoke-check fit_ellipse.py call sites compile by running test_autogalaxy/ellipse/ — no script edits expected.
  6. Run the workspace_test scripts shipped in prompt 2 (scripts/jax_likelihood_functions/ellipse/fit.py, multipoles.py) and confirm the reference numbers are unchanged to rtol=1e-10 on the numpy path.

Key Files

  • PyAutoArray/autoarray/numerics/__init__.py — new.
  • PyAutoArray/autoarray/numerics/interp_2d.py — new.
  • PyAutoArray/autoarray/__init__.py — surface numerics.
  • PyAutoArray/test_autoarray/numerics/__init__.py — new.
  • PyAutoArray/test_autoarray/numerics/test_interp_2d.py — new.
  • PyAutoGalaxy/autogalaxy/ellipse/dataset_interp.py — refactor three properties into methods.

Testing Approach

  • python -m pytest test_autoarray/numerics/test_interp_2d.py -v — new tests pass.
  • python -m pytest test_autogalaxy/ellipse/ -v — all existing tests still pass (including the masked-loop tests from prompt 3, which pin points_from_major_axis reference arrays to rtol=1e-12; if those drift, the numpy path semantics have changed and we have a bug).
  • python scripts/jax_likelihood_functions/ellipse/fit.py and multipoles.py from autogalaxy_workspace_test/ — reference numbers unchanged to rtol=1e-10.

Original Prompt

Click to expand starting prompt

Step 4 of the ellipse-JAX series. DatasetInterp in @PyAutoGalaxy/autogalaxy/ellipse/dataset_interp.py uses scipy.interpolate.RegularGridInterpolator for the data, noise-map, and mask. scipy is numpy-only, so this is the first hard JAX blocker for AnalysisEllipse.log_likelihood_function. There is no JAX-compatible 2D interpolator anywhere in the codebase — the only precedent is the 1D _interp1d_jax in @PyAutoArray/autoarray/inversion/mesh/interpolator/rectangular_spline.py:86-106. We need a 2D analogue.

Please:

  1. Add a 2D regular-grid bilinear interpolator helper to PyAutoArray. Suggested location: @PyAutoArray/autoarray/numerics/interp_2d.py (create the numerics/ subpackage if it doesn't already exist; check @PyAutoArray/autoarray/__init__.py for the right import surface). Two paths:

    • _interp_2d_numpy(points, x_axis, y_axis, values, fill_value=0.0) — matches the current RegularGridInterpolator(bounds_error=False, fill_value=0.0) semantics in dataset_interp.py. A direct call to scipy.interpolate.RegularGridInterpolator is fine here.
    • _interp_2d_jax(points, x_axis, y_axis, values, fill_value=0.0) — uses jax.scipy.ndimage.map_coordinates(values, coords, order=1, cval=fill_value). Translate (y, x) world coordinates to pixel-fractional coordinates using x_axis, y_axis (assume regularly spaced — the existing scipy points_interp is built from mask.derive_grid.all_false, which is regular).
    • Public dispatcher interp_2d(points, x_axis, y_axis, values, fill_value=0.0, xp=np) that picks the path. Mirror the dispatch style in rectangular_spline.py.
  2. Unit tests in @PyAutoArray/test_autoarray/numerics/test_interp_2d.py:

    • Random (N, 2) query points inside the grid: assert numpy and JAX paths agree to rtol=1e-6.
    • Out-of-bounds query points: assert both paths return fill_value for those rows.
    • Single-point query: assert shape is (1,) not ().
    • xp=np is the default — JAX-only tests gated by pytest.importorskip("jax") per @PyAutoArray/CLAUDE.md testing conventions.
  3. Wire DatasetInterp to the new helper. In @PyAutoGalaxy/autogalaxy/ellipse/dataset_interp.py:

    • Drop the cached data_interp, noise_map_interp, mask_interp properties that return RegularGridInterpolator instances.
    • Replace with methods data_interp(points, xp=np), noise_map_interp(points, xp=np), mask_interp(points, xp=np) that call aa.numerics.interp_2d(...) directly. The interp axes (points_interp) can stay cached.
    • Keep the existing call sites in fit_ellipse.py working: self.interp.data_interp(self._points_from_major_axis) continues to take a (N, 2) array.
  4. Do not touch FitEllipse.points_from_major_axis_from's 300-iteration loop in this prompt. That's prompt 6. The mask interpolation calls inside the loop continue to use the numpy path because the surrounding code is still numpy-only — pass xp=np explicitly at those call sites.

  5. Test bar:

    • python -m pytest test_autoarray/numerics/test_interp_2d.py -v passes.
    • python -m pytest test_autogalaxy/ellipse/ -v still passes (no behavioural change on the numpy path — same fill_value=0.0, same regular-grid semantics).
    • The reference numbers from prompt 2's workspace_test scripts are unchanged to rtol=1e-10.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions