From cfe77646100388187d23a6fbe1306dab48fcd6d1 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 9 Feb 2026 16:48:28 -0600 Subject: [PATCH] feat: add benchmark for gradient w/ moffat --- tests/jax/test_benchmarks.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index 5bf7b1f5..02a1dce2 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -336,3 +336,27 @@ def test_benchmark_moffat_conv(benchmark, kind): benchmark, kind, lambda: _run_moffat_bench_conv_jit().block_until_ready() ) print(f"time: {dt:0.4g} ms", end=" ") + + +def _run_moffat_bench_conv_grad(scale_radius): + obj = jgs.Spergel(nu=-0.6, scale_radius=scale_radius) + psf = jgs.Moffat(beta=2.5, fwhm=0.9) + obj = jgs.Convolve( + [obj, psf], + gsparams=jgs.GSParams(minimum_fft_size=2048, maximum_fft_size=2048), + ) + return jnp.sum(obj.drawImage(nx=50, ny=50, scale=0.2).array ** 2) + + +_run_moffat_bench_conv_grad_jit = jax.jit(jax.grad(_run_moffat_bench_conv_grad)) + + +@pytest.mark.parametrize("kind", ["run"]) +def test_benchmark_moffat_conv_grad(benchmark, kind): + scale_radius = jnp.array(0.5) + dt = _run_benchmarks( + benchmark, + kind, + lambda: _run_moffat_bench_conv_grad_jit(scale_radius).block_until_ready(), + ) + print(f"time: {dt:0.4g} ms", end=" ")