Overview
Add a 2D regular-grid bilinear interpolator helper to PyAutoArray with both NumPy and JAX paths, then wire DatasetInterp in PyAutoGalaxy to use it. Step 4 of 7 in the ellipse_fitting_jax feature decomposition (PyAutoPrompt/z_features/ellipse_fitting_jax.md). DatasetInterp today uses scipy.interpolate.RegularGridInterpolator, which is the first hard JAX blocker for AnalysisEllipse.log_likelihood_function — replacing it with an xp-dispatched helper unblocks prompts 6 and 7.
Plan
- Add a new
autoarray/numerics/ subpackage with interp_2d.py containing two private paths (_interp_2d_numpy, _interp_2d_jax) and one public dispatcher interp_2d(points, x_axis, y_axis, values, fill_value=0.0, xp=np).
- NumPy path delegates to
scipy.interpolate.RegularGridInterpolator(bounds_error=False, fill_value=0.0) — preserves current semantics byte-for-byte.
- JAX path uses
jax.scipy.ndimage.map_coordinates(values, coords, order=1, cval=fill_value) with manual world→pixel-fractional coordinate conversion. Assumes regularly spaced axes (true for mask.derive_grid.all_false-derived grids).
- Surface the dispatcher via
aa.numerics.interp_2d. Mirror the 1D dispatch style of _interp1d_jax / _interp1d_numpy in autoarray/inversion/mesh/interpolator/rectangular_spline.py.
- Unit tests cover: numpy↔JAX parity on random in-bounds points (
rtol=1e-6), out-of-bounds returns fill_value, single-point query has shape (1,), JAX-only tests gated by pytest.importorskip("jax").
- In PyAutoGalaxy, refactor
DatasetInterp.{data,noise_map,mask}_interp from cached properties returning RegularGridInterpolator instances to methods (points, xp=np) that call aa.numerics.interp_2d. All call sites in fit_ellipse.py continue to work — they don't pass xp, defaulting to the numpy path.
- Behaviour on the numpy path unchanged. The
if self.interp.mask_interp is not None: check in fit_ellipse.py stays as-is (always True both before and after; the gate change is for prompt 6).
Detailed implementation plan
Affected Repositories
- PyAutoArray (primary — adds the interpolator)
- PyAutoGalaxy (wires
DatasetInterp to the new helper)
Work Classification
Library (both repos)
Branch Survey
| Repository |
Current Branch |
Dirty? |
| ./PyAutoArray |
main |
clean |
| ./PyAutoGalaxy |
main |
clean |
Suggested branch: feature/jax-interp-2d (same name on both repos)
Worktree root: ~/Code/PyAutoLabs-wt/jax-interp-2d/ (created later by /start_library)
Implementation Steps
- PyAutoArray: create
autoarray/numerics/__init__.py and autoarray/numerics/interp_2d.py with three functions:
_interp_2d_numpy(points, x_axis, y_axis, values, fill_value=0.0) — wraps scipy.interpolate.RegularGridInterpolator(points=(x_axis, y_axis), values=values, bounds_error=False, fill_value=fill_value) and evaluates at points (a (N, 2) array).
_interp_2d_jax(points, x_axis, y_axis, values, fill_value=0.0) — convert (y, x) world coords to pixel-fractional indices via (point - axis_min) / axis_spacing; call jax.scipy.ndimage.map_coordinates(values, jnp.stack([row_coords, col_coords]), order=1, cval=fill_value, mode='constant'). Import jax locally inside the function per the xp pattern.
interp_2d(points, x_axis, y_axis, values, fill_value=0.0, xp=np) — dispatcher.
- Update
autoarray/__init__.py to surface aa.numerics (so aa.numerics.interp_2d(...) resolves).
- PyAutoArray: add tests at
test_autoarray/numerics/test_interp_2d.py:
- Random
(N, 2) query points inside the grid → numpy and JAX paths agree to rtol=1e-6.
- Out-of-bounds query points → both paths return
fill_value exactly.
- Single-point query → result has shape
(1,).
- JAX-only tests gated by
pytest.importorskip("jax").
- PyAutoGalaxy: refactor
autogalaxy/ellipse/dataset_interp.py:
- Keep
points_interp as a cached property — that's still the regular axis tuple.
- Replace
mask_interp, data_interp, noise_map_interp cached properties with methods that take (points, xp=np) and call aa.numerics.interp_2d directly, passing the relevant values array (mask / data / noise_map).
- The dispatch is internal — call sites pass
(points) only, getting the numpy default.
- Smoke-check
fit_ellipse.py call sites compile by running test_autogalaxy/ellipse/ — no script edits expected.
- Run the workspace_test scripts shipped in prompt 2 (
scripts/jax_likelihood_functions/ellipse/fit.py, multipoles.py) and confirm the reference numbers are unchanged to rtol=1e-10 on the numpy path.
Key Files
PyAutoArray/autoarray/numerics/__init__.py — new.
PyAutoArray/autoarray/numerics/interp_2d.py — new.
PyAutoArray/autoarray/__init__.py — surface numerics.
PyAutoArray/test_autoarray/numerics/__init__.py — new.
PyAutoArray/test_autoarray/numerics/test_interp_2d.py — new.
PyAutoGalaxy/autogalaxy/ellipse/dataset_interp.py — refactor three properties into methods.
Testing Approach
python -m pytest test_autoarray/numerics/test_interp_2d.py -v — new tests pass.
python -m pytest test_autogalaxy/ellipse/ -v — all existing tests still pass (including the masked-loop tests from prompt 3, which pin points_from_major_axis reference arrays to rtol=1e-12; if those drift, the numpy path semantics have changed and we have a bug).
python scripts/jax_likelihood_functions/ellipse/fit.py and multipoles.py from autogalaxy_workspace_test/ — reference numbers unchanged to rtol=1e-10.
Original Prompt
Click to expand starting prompt
Step 4 of the ellipse-JAX series. DatasetInterp in @PyAutoGalaxy/autogalaxy/ellipse/dataset_interp.py uses scipy.interpolate.RegularGridInterpolator for the data, noise-map, and mask. scipy is numpy-only, so this is the first hard JAX blocker for AnalysisEllipse.log_likelihood_function. There is no JAX-compatible 2D interpolator anywhere in the codebase — the only precedent is the 1D _interp1d_jax in @PyAutoArray/autoarray/inversion/mesh/interpolator/rectangular_spline.py:86-106. We need a 2D analogue.
Please:
-
Add a 2D regular-grid bilinear interpolator helper to PyAutoArray. Suggested location: @PyAutoArray/autoarray/numerics/interp_2d.py (create the numerics/ subpackage if it doesn't already exist; check @PyAutoArray/autoarray/__init__.py for the right import surface). Two paths:
_interp_2d_numpy(points, x_axis, y_axis, values, fill_value=0.0) — matches the current RegularGridInterpolator(bounds_error=False, fill_value=0.0) semantics in dataset_interp.py. A direct call to scipy.interpolate.RegularGridInterpolator is fine here.
_interp_2d_jax(points, x_axis, y_axis, values, fill_value=0.0) — uses jax.scipy.ndimage.map_coordinates(values, coords, order=1, cval=fill_value). Translate (y, x) world coordinates to pixel-fractional coordinates using x_axis, y_axis (assume regularly spaced — the existing scipy points_interp is built from mask.derive_grid.all_false, which is regular).
- Public dispatcher
interp_2d(points, x_axis, y_axis, values, fill_value=0.0, xp=np) that picks the path. Mirror the dispatch style in rectangular_spline.py.
-
Unit tests in @PyAutoArray/test_autoarray/numerics/test_interp_2d.py:
- Random
(N, 2) query points inside the grid: assert numpy and JAX paths agree to rtol=1e-6.
- Out-of-bounds query points: assert both paths return
fill_value for those rows.
- Single-point query: assert shape is
(1,) not ().
xp=np is the default — JAX-only tests gated by pytest.importorskip("jax") per @PyAutoArray/CLAUDE.md testing conventions.
-
Wire DatasetInterp to the new helper. In @PyAutoGalaxy/autogalaxy/ellipse/dataset_interp.py:
- Drop the cached
data_interp, noise_map_interp, mask_interp properties that return RegularGridInterpolator instances.
- Replace with methods
data_interp(points, xp=np), noise_map_interp(points, xp=np), mask_interp(points, xp=np) that call aa.numerics.interp_2d(...) directly. The interp axes (points_interp) can stay cached.
- Keep the existing call sites in
fit_ellipse.py working: self.interp.data_interp(self._points_from_major_axis) continues to take a (N, 2) array.
-
Do not touch FitEllipse.points_from_major_axis_from's 300-iteration loop in this prompt. That's prompt 6. The mask interpolation calls inside the loop continue to use the numpy path because the surrounding code is still numpy-only — pass xp=np explicitly at those call sites.
-
Test bar:
python -m pytest test_autoarray/numerics/test_interp_2d.py -v passes.
python -m pytest test_autogalaxy/ellipse/ -v still passes (no behavioural change on the numpy path — same fill_value=0.0, same regular-grid semantics).
- The reference numbers from prompt 2's workspace_test scripts are unchanged to
rtol=1e-10.
Overview
Add a 2D regular-grid bilinear interpolator helper to PyAutoArray with both NumPy and JAX paths, then wire
DatasetInterpin PyAutoGalaxy to use it. Step 4 of 7 in theellipse_fitting_jaxfeature decomposition (PyAutoPrompt/z_features/ellipse_fitting_jax.md).DatasetInterptoday usesscipy.interpolate.RegularGridInterpolator, which is the first hard JAX blocker forAnalysisEllipse.log_likelihood_function— replacing it with anxp-dispatched helper unblocks prompts 6 and 7.Plan
autoarray/numerics/subpackage withinterp_2d.pycontaining two private paths (_interp_2d_numpy,_interp_2d_jax) and one public dispatcherinterp_2d(points, x_axis, y_axis, values, fill_value=0.0, xp=np).scipy.interpolate.RegularGridInterpolator(bounds_error=False, fill_value=0.0)— preserves current semantics byte-for-byte.jax.scipy.ndimage.map_coordinates(values, coords, order=1, cval=fill_value)with manual world→pixel-fractional coordinate conversion. Assumes regularly spaced axes (true formask.derive_grid.all_false-derived grids).aa.numerics.interp_2d. Mirror the 1D dispatch style of_interp1d_jax/_interp1d_numpyinautoarray/inversion/mesh/interpolator/rectangular_spline.py.rtol=1e-6), out-of-bounds returnsfill_value, single-point query has shape(1,), JAX-only tests gated bypytest.importorskip("jax").DatasetInterp.{data,noise_map,mask}_interpfrom cached properties returningRegularGridInterpolatorinstances to methods(points, xp=np)that callaa.numerics.interp_2d. All call sites infit_ellipse.pycontinue to work — they don't passxp, defaulting to the numpy path.if self.interp.mask_interp is not None:check infit_ellipse.pystays as-is (always True both before and after; the gate change is for prompt 6).Detailed implementation plan
Affected Repositories
DatasetInterpto the new helper)Work Classification
Library (both repos)
Branch Survey
Suggested branch:
feature/jax-interp-2d(same name on both repos)Worktree root:
~/Code/PyAutoLabs-wt/jax-interp-2d/(created later by/start_library)Implementation Steps
autoarray/numerics/__init__.pyandautoarray/numerics/interp_2d.pywith three functions:_interp_2d_numpy(points, x_axis, y_axis, values, fill_value=0.0)— wrapsscipy.interpolate.RegularGridInterpolator(points=(x_axis, y_axis), values=values, bounds_error=False, fill_value=fill_value)and evaluates atpoints(a(N, 2)array)._interp_2d_jax(points, x_axis, y_axis, values, fill_value=0.0)— convert(y, x)world coords to pixel-fractional indices via(point - axis_min) / axis_spacing; calljax.scipy.ndimage.map_coordinates(values, jnp.stack([row_coords, col_coords]), order=1, cval=fill_value, mode='constant'). Import jax locally inside the function per thexppattern.interp_2d(points, x_axis, y_axis, values, fill_value=0.0, xp=np)— dispatcher.autoarray/__init__.pyto surfaceaa.numerics(soaa.numerics.interp_2d(...)resolves).test_autoarray/numerics/test_interp_2d.py:(N, 2)query points inside the grid → numpy and JAX paths agree tortol=1e-6.fill_valueexactly.(1,).pytest.importorskip("jax").autogalaxy/ellipse/dataset_interp.py:points_interpas a cached property — that's still the regular axis tuple.mask_interp,data_interp,noise_map_interpcached properties with methods that take(points, xp=np)and callaa.numerics.interp_2ddirectly, passing the relevantvaluesarray (mask / data / noise_map).(points)only, getting the numpy default.fit_ellipse.pycall sites compile by runningtest_autogalaxy/ellipse/— no script edits expected.scripts/jax_likelihood_functions/ellipse/fit.py,multipoles.py) and confirm the reference numbers are unchanged tortol=1e-10on the numpy path.Key Files
PyAutoArray/autoarray/numerics/__init__.py— new.PyAutoArray/autoarray/numerics/interp_2d.py— new.PyAutoArray/autoarray/__init__.py— surfacenumerics.PyAutoArray/test_autoarray/numerics/__init__.py— new.PyAutoArray/test_autoarray/numerics/test_interp_2d.py— new.PyAutoGalaxy/autogalaxy/ellipse/dataset_interp.py— refactor three properties into methods.Testing Approach
python -m pytest test_autoarray/numerics/test_interp_2d.py -v— new tests pass.python -m pytest test_autogalaxy/ellipse/ -v— all existing tests still pass (including the masked-loop tests from prompt 3, which pinpoints_from_major_axisreference arrays tortol=1e-12; if those drift, the numpy path semantics have changed and we have a bug).python scripts/jax_likelihood_functions/ellipse/fit.pyandmultipoles.pyfromautogalaxy_workspace_test/— reference numbers unchanged tortol=1e-10.Original Prompt
Click to expand starting prompt
Step 4 of the ellipse-JAX series.
DatasetInterpin@PyAutoGalaxy/autogalaxy/ellipse/dataset_interp.pyusesscipy.interpolate.RegularGridInterpolatorfor the data, noise-map, and mask. scipy is numpy-only, so this is the first hard JAX blocker forAnalysisEllipse.log_likelihood_function. There is no JAX-compatible 2D interpolator anywhere in the codebase — the only precedent is the 1D_interp1d_jaxin@PyAutoArray/autoarray/inversion/mesh/interpolator/rectangular_spline.py:86-106. We need a 2D analogue.Please:
Add a 2D regular-grid bilinear interpolator helper to PyAutoArray. Suggested location:
@PyAutoArray/autoarray/numerics/interp_2d.py(create thenumerics/subpackage if it doesn't already exist; check@PyAutoArray/autoarray/__init__.pyfor the right import surface). Two paths:_interp_2d_numpy(points, x_axis, y_axis, values, fill_value=0.0)— matches the currentRegularGridInterpolator(bounds_error=False, fill_value=0.0)semantics indataset_interp.py. A direct call toscipy.interpolate.RegularGridInterpolatoris fine here._interp_2d_jax(points, x_axis, y_axis, values, fill_value=0.0)— usesjax.scipy.ndimage.map_coordinates(values, coords, order=1, cval=fill_value). Translate(y, x)world coordinates to pixel-fractional coordinates usingx_axis,y_axis(assume regularly spaced — the existing scipypoints_interpis built frommask.derive_grid.all_false, which is regular).interp_2d(points, x_axis, y_axis, values, fill_value=0.0, xp=np)that picks the path. Mirror the dispatch style inrectangular_spline.py.Unit tests in
@PyAutoArray/test_autoarray/numerics/test_interp_2d.py:(N, 2)query points inside the grid: assert numpy and JAX paths agree tortol=1e-6.fill_valuefor those rows.(1,)not().xp=npis the default — JAX-only tests gated bypytest.importorskip("jax")per@PyAutoArray/CLAUDE.mdtesting conventions.Wire
DatasetInterpto the new helper. In@PyAutoGalaxy/autogalaxy/ellipse/dataset_interp.py:data_interp,noise_map_interp,mask_interpproperties that returnRegularGridInterpolatorinstances.data_interp(points, xp=np),noise_map_interp(points, xp=np),mask_interp(points, xp=np)that callaa.numerics.interp_2d(...)directly. The interp axes (points_interp) can stay cached.fit_ellipse.pyworking:self.interp.data_interp(self._points_from_major_axis)continues to take a(N, 2)array.Do not touch
FitEllipse.points_from_major_axis_from's 300-iteration loop in this prompt. That's prompt 6. The mask interpolation calls inside the loop continue to use the numpy path because the surrounding code is still numpy-only — passxp=npexplicitly at those call sites.Test bar:
python -m pytest test_autoarray/numerics/test_interp_2d.py -vpasses.python -m pytest test_autogalaxy/ellipse/ -vstill passes (no behavioural change on the numpy path — samefill_value=0.0, same regular-grid semantics).rtol=1e-10.