diff --git a/LICENSE b/LICENSE index 344c4a4e..61250865 100644 --- a/LICENSE +++ b/LICENSE @@ -22,6 +22,24 @@ lost profits, business interruption, or indirect special or consequential damages of any kind. +Code in bessel.py (fractional-order Bessel K functions) is derived from +TensorFlow Probability under this license: + +Copyright 2020 The TensorFlow Probability Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + Code in angle.py and celestial.py is based on LSSTDESC/Coord under this license: Copyright (c) 2013-2017 LSST Dark Energy Science Collaboration (DESC) diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index 0471dcc3..87d09291 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -1,7 +1,6 @@ import galsim as _galsim import jax import jax.numpy as jnp -from tensorflow_probability.substrates.jax.math import bessel_kve as _tfp_bessel_kve from jax_galsim.core.utils import implements @@ -105,12 +104,464 @@ def si(x): ) +# ===================================================================== +# Modified Bessel K_v(x) +# +# The functions below (_sqrt1px2, _evaluate_temme_coeffs, +# _temme_series_kve, _continued_fraction_kve, _olver_kve, _temme_kve, +# _kve_core) and the _ASYMPTOTIC_OLVER_COEFFICIENTS constant are +# derived from TensorFlow Probability's bessel.py: +# https://github.com/tensorflow/probability +# +# Original copyright and license: +# Copyright 2020 The TensorFlow Probability Authors. +# Licensed under the Apache License, Version 2.0. +# +# Modifications from the original: +# - Ported from TensorFlow/TFP APIs to pure JAX (jax.numpy, jax.lax) +# - Removed I_v (bessel_ive) computation; only K_v is computed +# - Removed log-space output option +# - Removed negative-v correction for I_v +# - Simplified to scalar-only core (vectorization via jax.vmap) +# ===================================================================== + +# Olver expansion polynomial coefficients (10 terms, up to 31 coefficients each) +# fmt: off +_ASYMPTOTIC_OLVER_COEFFICIENTS = [ + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + -0.20833333333333334, 0., 0.125, 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.3342013888888889, 0., + -0.40104166666666669, 0., 0.0703125, 0., 0.0], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., -1.0258125964506173, 0., 1.8464626736111112, + 0., -0.89121093750000002, 0., 0.0732421875, 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 4.6695844234262474, 0., -11.207002616222995, 0., + 8.78912353515625, 0., -2.3640869140624998, 0., 0.112152099609375, + 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + -28.212072558200244, 0., 84.636217674600744, 0., -91.818241543240035, + 0., 42.534998745388457, 0., -7.3687943594796312, 0., 0.22710800170898438, + 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 212.5701300392171, 0., + -765.25246814118157, 0., 1059.9904525279999, 0., -699.57962737613275, + 0., 218.19051174421159, 0., -26.491430486951554, 0., 0.57250142097473145, + 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., -1919.4576623184068, 0., + 8061.7221817373083, 0., -13586.550006434136, 0., 11655.393336864536, + 0., -5305.6469786134048, 0., 1200.9029132163525, 0., + -108.09091978839464, 0., 1.7277275025844574, 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 20204.291330966149, 0., -96980.598388637503, 0., + 192547.0012325315, 0., -203400.17728041555, 0., 122200.46498301747, + 0., -41192.654968897557, 0., 7109.5143024893641, 0., + -493.915304773088, 0., 6.074042001273483, 0., 0., 0., 0., 0., + 0., 0., 0.], + [0., 0., 0., -242919.18790055133, 0., 1311763.6146629769, 0., + -2998015.9185381061, 0., 3763271.2976564039, 0., -2813563.2265865342, 0., + 1268365.2733216248, 0., -331645.17248456361, 0., 45218.768981362737, 0., + -2499.8304818112092, 0., 24.380529699556064, 0., 0., 0., 0., 0., + 0., 0., 0., 0.0], + [3284469.8530720375, 0., -19706819.11843222, 0., 50952602.492664628, + 0., -74105148.211532637, 0., 66344512.274729028, 0., -37567176.660763353, + 0., 13288767.166421819, 0., -2785618.1280864552, 0., 308186.40461266245, + 0., -13886.089753717039, 0., 110.01714026924674, 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0.] +] +# fmt: on + + +def _sqrt1px2(x): + """Numerically stable computation of sqrt(1 + x^2).""" + eps = jnp.finfo(jnp.float64).eps + return jnp.where( + jnp.abs(x) * jnp.sqrt(eps) <= 1.0, + jnp.exp(0.5 * jnp.log1p(x * x)), + jnp.abs(x), + ) + + +def _evaluate_temme_coeffs(v): + """Numerically stable computation of gamma-related coefficients for Temme's method. + + Computes: + coeff1 = (1/Gamma(1-v) - 1/Gamma(1+v)) / (2v) + coeff2 = (1/Gamma(1-v) + 1/Gamma(1+v)) / 2 + gamma1pv = 1/Gamma(1+v) + gamma1mv = 1/Gamma(1-v) + + Uses Chebyshev expansions for numerical stability (avoids catastrophic cancellation). + """ + coeff1_coeffs = [ + -1.142022680371168e0, + 6.5165112670737e-3, + 3.087090173086e-4, + -3.4706269649e-6, + 6.9437664e-9, + 3.67795e-11, + -1.356e-13, + ] + coeff2_coeffs = [ + 1.843740587300905e0, + -7.68528408447867e-2, + 1.2719271366546e-3, + -4.9717367042e-6, + -3.31261198e-8, + 2.423096e-10, + -1.702e-13, + -1.49e-15, + ] + w = 8.0 * v * v - 1.0 + y = 2.0 * w + + # Clenshaw's recurrence for coeff1 + prev = 0.0 + current = 0.0 + for i in reversed(range(1, len(coeff1_coeffs))): + temp = current + current = y * current - prev + coeff1_coeffs[i] + prev = temp + coeff1 = w * current - prev + 0.5 * coeff1_coeffs[0] + + # Clenshaw's recurrence for coeff2 + prev = 0.0 + current = 0.0 + for i in reversed(range(1, len(coeff2_coeffs))): + temp = current + current = y * current - prev + coeff2_coeffs[i] + prev = temp + coeff2 = w * current - prev + 0.5 * coeff2_coeffs[0] + + gamma1pv = coeff2 - v * coeff1 + gamma1mv = coeff2 + v * coeff1 + return coeff1, coeff2, gamma1pv, gamma1mv + + +def _temme_series_kve(v, z): + """Compute Kve(v, z) and Kve(v+1, z) via Temme power series. + + Assumes |v| < 0.5 and |z| <= 2 for fast convergence. + Returns exponentially scaled values: Kv(v,z)*exp(z). + """ + tol = jnp.finfo(jnp.float64).eps + + coeff1, coeff2, gamma1pv_inv, gamma1mv_inv = _evaluate_temme_coeffs(v) + + z_sq = z * z + logzo2 = jnp.log(z / 2.0) + mu = -v * logzo2 + sinc_v = jnp.where(v == 0.0, 1.0, jnp.sin(jnp.pi * v) / (jnp.pi * v)) + sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu) / mu) + + initial_f = (coeff1 * jnp.cosh(mu) + coeff2 * (-logzo2) * sinhc_mu) / sinc_v + initial_p = 0.5 * jnp.exp(mu) / gamma1pv_inv + initial_q = 0.5 * jnp.exp(-mu) / gamma1mv_inv + + max_iterations = 1000 + + def body_fn(carry): + should_stop, index, f, p, q, coeff, kv_sum, kvp1_sum = carry + f = jnp.where( + should_stop, + f, + (index * f + p + q) / (index * index - v * v), + ) + p = jnp.where(should_stop, p, p / (index - v)) + q = jnp.where(should_stop, q, q / (index + v)) + h = p - index * f + coeff = jnp.where(should_stop, coeff, coeff * z_sq / (4.0 * index)) + kv_sum = jnp.where(should_stop, kv_sum, kv_sum + coeff * f) + kvp1_sum = jnp.where(should_stop, kvp1_sum, kvp1_sum + coeff * h) + index = index + 1.0 + should_stop = (jnp.abs(coeff * f) < jnp.abs(kv_sum) * tol) | ( + index > max_iterations + ) + return (should_stop, index, f, p, q, coeff, kv_sum, kvp1_sum) + + def cond_fn(carry): + should_stop = carry[0] + return ~should_stop + + init = ( + jnp.array(False), + 1.0, + initial_f, + initial_p, + initial_q, + 1.0, + initial_f, + initial_p, + ) + _, _, _, _, _, _, kv_sum, kvp1_sum = jax.lax.while_loop(cond_fn, body_fn, init) + + # Convert to exponentially scaled: kve = kv * exp(z) + kve = kv_sum * jnp.exp(z) + kvep1 = 2.0 * kvp1_sum * jnp.exp(z) / z + + return kve, kvep1 + + +def _continued_fraction_kve(v, z): + """Compute Kve(v, z) and Kve(v+1, z) via Steed's continued fraction. + + Assumes |v| < 0.5 and |z| > 2. + Returns exponentially scaled values: Kv(v,z)*exp(z). + """ + tol = jnp.finfo(jnp.float64).eps + max_iterations = 1000 + + initial_numerator = v * v - 0.25 + initial_denominator = 2.0 * (z + 1.0) + initial_ratio = 1.0 / initial_denominator + initial_seq = -initial_numerator + + def steeds_body(carry): + ( + should_stop, + index, + partial_numerator, + partial_denominator, + denominator_ratio, + convergent_difference, + hypergeometric_ratio, + k_0, + k_1, + c, + q, + hypergeometric_sum, + ) = carry + + partial_numerator = partial_numerator - 2.0 * (index - 1.0) + c = jnp.where(should_stop, c, -c * partial_numerator / index) + next_k = (k_0 - partial_denominator * k_1) / partial_numerator + k_0 = jnp.where(should_stop, k_0, k_1) + k_1 = jnp.where(should_stop, k_1, next_k) + q = jnp.where(should_stop, q, q + c * next_k) + partial_denominator = partial_denominator + 2.0 + denominator_ratio = 1.0 / ( + partial_denominator + partial_numerator * denominator_ratio + ) + convergent_difference = jnp.where( + should_stop, + convergent_difference, + convergent_difference * (partial_denominator * denominator_ratio - 1.0), + ) + hypergeometric_ratio = jnp.where( + should_stop, + hypergeometric_ratio, + hypergeometric_ratio + convergent_difference, + ) + hypergeometric_sum = jnp.where( + should_stop, + hypergeometric_sum, + hypergeometric_sum + q * convergent_difference, + ) + index = index + 1.0 + should_stop = ( + jnp.abs(q * convergent_difference) < jnp.abs(hypergeometric_sum) * tol + ) | (index > max_iterations) + return ( + should_stop, + index, + partial_numerator, + partial_denominator, + denominator_ratio, + convergent_difference, + hypergeometric_ratio, + k_0, + k_1, + c, + q, + hypergeometric_sum, + ) + + def cond_fn(carry): + return ~carry[0] + + init = ( + jnp.array(False), + 2.0, + initial_numerator, + initial_denominator, + initial_ratio, + initial_ratio, + initial_ratio, + 0.0, + 1.0, + initial_seq, + initial_seq, + 1.0 - initial_numerator * initial_ratio, + ) + result = jax.lax.while_loop(cond_fn, steeds_body, init) + hypergeometric_ratio = result[6] + hypergeometric_sum = result[11] + + # log(kve) = 0.5*log(pi/(2z)) - log(hypergeometric_sum) + log_kve = 0.5 * jnp.log(jnp.pi / (2.0 * z)) - jnp.log(hypergeometric_sum) + log_kvp1e = ( + log_kve + + jnp.log1p(2.0 * (v + z + initial_numerator * hypergeometric_ratio)) + - jnp.log(z) + - jnp.log(2.0) + ) + return jnp.exp(log_kve), jnp.exp(log_kvp1e) + + +def _olver_kve(v, z): + """Compute Kve(v, z) using Olver's uniform asymptotic expansion. + + Valid for |v| >= 50. Returns exponentially scaled value: Kv(v,z)*exp(z). + """ + v_abs = jnp.abs(v) + w = z / v_abs + t = 1.0 / _sqrt1px2(w) + + divisor = v_abs + kve_sum = 1.0 + + # Evaluate the Olver polynomial terms using Horner's method + for i in range(len(_ASYMPTOTIC_OLVER_COEFFICIENTS)): + coeff = 0.0 + for c in _ASYMPTOTIC_OLVER_COEFFICIENTS[i]: + coeff = coeff * t + c + term = coeff / divisor + # For K_v, signs alternate: (-1)^i + kve_sum = kve_sum + (term if i % 2 == 1 else -term) + divisor = divisor * v_abs + + # log(kve) = 0.5*log(pi*t/(2*v_abs)) - v_abs*shared_prefactor + shared_prefactor = 1.0 / (_sqrt1px2(w) + w) + jnp.log(w) - jnp.log1p(1.0 / t) + log_k_prefactor = ( + 0.5 * jnp.log(jnp.pi * t / (2.0 * v_abs)) - v_abs * shared_prefactor + ) + + log_kve = log_k_prefactor + jnp.log(kve_sum) + return jnp.exp(log_kve) + + +def _temme_kve(v, x): + """Compute Kve(v, x) using Temme's method for |v| < 50. + + Reduces to fractional order |u| <= 0.5, computes Kve(u, x) and Kve(u+1, x), + then uses forward recurrence to reach order v. + Returns exponentially scaled value: Kv(v,x)*exp(x). + """ + v = jnp.abs(v) + n = jnp.round(v) + u = v - n + + # Branchless: compute both methods with safe inputs, select with jnp.where + small_x = jnp.where(x <= 2.0, x, 0.1) + large_x = jnp.where(x > 2.0, x, 1000.0) + + temme_kue, temme_kuep1 = _temme_series_kve(u, small_x) + cf_kue, cf_kuep1 = _continued_fraction_kve(u, large_x) + + kue = jnp.where(x <= 2.0, temme_kue, cf_kue) + kuep1 = jnp.where(x <= 2.0, temme_kuep1, cf_kuep1) + + # Forward recurrence: K_{v+1}(z) = (2v/z)*K_v(z) + K_{v-1}(z) + # This recurrence is also satisfied by Kv*exp(z) (the exponentially scaled form). + def bessel_recurrence(carry): + index, kve, kvep1 = carry + next_kvep1 = 2.0 * (u + index) * kvep1 / x + kve + kve = jnp.where(index > n, kve, kvep1) + kvep1 = jnp.where(index > n, kvep1, next_kvep1) + return (index + 1.0, kve, kvep1) + + def recurrence_cond(carry): + index = carry[0] + return index <= n + + _, kve, _ = jax.lax.while_loop( + recurrence_cond, bessel_recurrence, (1.0, kue, kuep1) + ) + return kve + + +def _kve_core(nu, x): + """Core dispatcher for Kve(nu, x) = Kv(nu, x) * exp(x). + + Branchless: computes both Olver and Temme with safe dummy inputs, + selects based on |nu| >= 50. + """ + nu = jnp.abs(nu) + + # Safe inputs: avoid invalid regions for each method + small_nu = jnp.where(nu < 50.0, nu, 0.1) + large_nu = jnp.where(nu >= 50.0, nu, 1000.0) + + olver_result = _olver_kve(large_nu, x) + temme_result = _temme_kve(small_nu, x) + + return jnp.where(nu >= 50.0, olver_result, temme_result) + + +def _kv_scalar(nu, x): + """Scalar implementation of K_v(x) using TFP-ported Temme + Olver algorithms.""" + nu = 1.0 * nu + x = 1.0 * x + nu = jnp.abs(nu) # K_{-v} = K_v + + # Compute via exponentially scaled form for numerical stability + # Use a safe x for the core computation (avoid x=0 which causes issues) + safe_x = jnp.where(x > 0.0, x, 1.0) + kve = _kve_core(nu, safe_x) + result = kve * jnp.exp(-safe_x) + + # Edge cases + result = jnp.where(x == 0.0, jnp.inf, result) + result = jnp.where(x < 0.0, jnp.nan, result) + return result + + @implements(_galsim.bessel.kv) @jax.jit +def _kv_impl(nu, x): + """Modified Bessel function of the second kind K_v(x) - internal implementation. + + Uses TFP-ported Temme + Olver algorithms for all orders. + Handles both scalar and array inputs via jax.vmap. + """ + nu = jnp.asarray(1.0 * nu) + x = jnp.asarray(1.0 * x) + out_shape = jnp.broadcast_shapes(jnp.shape(nu), jnp.shape(x)) + if out_shape == (): + return _kv_scalar(nu, x) + nu_bc, x_bc = jnp.broadcast_arrays(nu, x) + flat_nu = nu_bc.ravel() + flat_x = x_bc.ravel() + return jax.vmap(_kv_scalar)(flat_nu, flat_x).reshape(out_shape) + + +@jax.custom_vjp def kv(nu, x): - nu = 1.0 * nu - x = 1.0 * x - return _tfp_bessel_kve(nu, x) / jnp.exp(jnp.abs(x)) + """Modified Bessel function of the second kind K_v(x) with custom gradients. + + Uses TFP-ported Temme + Olver algorithms. Custom gradients via: + dK_v/dx = -1/2 * (K_{v-1}(x) + K_{v+1}(x)) + + Gradient w.r.t. v is not supported (returns zero). + """ + return _kv_impl(nu, x) + + +def _kv_fwd(nu, x): + kv_val = _kv_impl(nu, x) + kv_prev = _kv_impl(nu - 1.0, x) + kv_next = _kv_impl(nu + 1.0, x) + return kv_val, (nu, x, kv_prev, kv_next) + + +def _kv_bwd(residuals, g): + nu, x, kv_prev, kv_next = residuals + grad_x = -0.5 * (kv_prev + kv_next) * g + grad_nu = jnp.zeros_like(nu) + return (grad_nu, grad_x) + + +kv.defvjp(_kv_fwd, _kv_bwd) @jax.jit diff --git a/pyproject.toml b/pyproject.toml index 6780330a..fcac16e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ dependencies = [ "galsim >=2.3.0", "jax >=0.8.0", "astropy >=2.0", - "tfp-nightly", "quadax", ] diff --git a/tests/jax/test_bessel_gradients.py b/tests/jax/test_bessel_gradients.py new file mode 100644 index 00000000..83578d2d --- /dev/null +++ b/tests/jax/test_bessel_gradients.py @@ -0,0 +1,240 @@ +"""Tests for custom gradients of Bessel functions in JAX-GalSim. + +This module tests the custom gradient implementation for the modified Bessel +function K_ν(x), which uses analytical derivative formulas based on Bessel +recurrence relations instead of automatic differentiation. + +To run these tests: + pytest tests/jax/test_bessel_gradients.py \\ + --ignore=tests/jax/test_image_wrapping.py \\ + --ignore=tests/jax/test_interpolant_jax.py -v +""" + +import jax +import jax.numpy as jnp +import pytest + +from jax_galsim.bessel import kv + + +class TestKvGradients: + """Test suite for kv custom gradients.""" + + @pytest.mark.parametrize( + "nu,x", + [ + (0.5, 1.0), + (1.0, 5.0), + (2.5, 5.0), + (10.0, 10.0), + (39.8, 40.0), + (300.9, 500.0), + ], + ) + def test_gradient_wrt_x_compiles(self, nu, x): + """Test that gradient w.r.t. x compiles with jax.jit.""" + grad_fn = jax.jit(jax.grad(lambda x_val: kv(nu, x_val))) + grad_val = grad_fn(x) + # Just verify it's a finite value + assert jnp.isfinite(grad_val), f"Gradient is not finite: {grad_val}" + + @pytest.mark.parametrize( + "nu,x", + [ + (0.5, 1.0), + (1.0, 5.0), + (2.5, 5.0), + (10.0, 10.0), + (39.8, 40.0), + ], + ) + def test_gradient_vs_finite_differences(self, nu, x): + """Test that custom gradients match finite differences.""" + grad_fn = jax.grad(lambda x_val: kv(nu, x_val)) + analytical_grad = grad_fn(x) + + # Compute finite difference gradient + eps = 1e-5 + numerical_grad = (kv(nu, x + eps) - kv(nu, x - eps)) / (2 * eps) + + # Allow slightly looser tolerance for numerical differentiation + relative_error = jnp.abs((analytical_grad - numerical_grad) / numerical_grad) + assert relative_error < 1e-6, ( + f"Gradient error too large: {relative_error} at nu={nu}, x={x}" + ) + + @pytest.mark.parametrize( + "nu,x", + [ + (0.5, 1.0), + (1.0, 5.0), + (2.5, 5.0), + (10.0, 10.0), + (39.8, 40.0), + ], + ) + def test_gradient_analytical_formula(self, nu, x): + """Test that gradients match the analytical formula: ∂K_ν/∂x = -1/2 * (K_{ν-1}(x) + K_{ν+1}(x)).""" + grad_fn = jax.grad(lambda x_val: kv(nu, x_val)) + computed_grad = grad_fn(x) + + # Compute expected gradient using analytical formula + kv_prev = kv(nu - 1.0, x) + kv_next = kv(nu + 1.0, x) + expected_grad = -0.5 * (kv_prev + kv_next) + + # Should match very closely since we use the same formula + assert jnp.allclose(computed_grad, expected_grad, rtol=1e-10), ( + f"Gradient doesn't match analytical formula at nu={nu}, x={x}" + ) + + def test_gradient_vectorization(self): + """Test that gradients work with vmap.""" + grad_fn = jax.vmap(jax.grad(lambda x_val: kv(2.5, x_val))) + x_array = jnp.array([1.0, 5.0, 10.0, 50.0]) + grads = grad_fn(x_array) + + # Verify all gradients are finite + assert jnp.all(jnp.isfinite(grads)), "Some vectorized gradients are not finite" + + # Verify shapes match + assert grads.shape == x_array.shape, "Gradient shape mismatch" + + @pytest.mark.skip( + reason="custom_vjp doesn't support higher-order derivatives by default" + ) + def test_second_derivative(self): + """Test that second derivatives (Hessian) work. + + Note: This test is skipped because custom_vjp doesn't support higher-order + differentiation. This is a known limitation and not critical for first-order + gradient-based optimization. + """ + hessian_fn = jax.grad(jax.grad(lambda x_val: kv(2.5, x_val))) + hess_val = hessian_fn(5.0) + + # Just verify it's finite - we're not testing accuracy of second derivative + assert jnp.isfinite(hess_val), f"Second derivative is not finite: {hess_val}" + + @pytest.mark.parametrize( + "nu,x", + [ + (2.5, 0.1), # Small x + (2.5, 100.0), # Large x + (300.9, 500.0), # Large nu + (0.5, 0.1), # Small x, small nu + ], + ) + def test_gradient_edge_cases(self, nu, x): + """Test gradients at edge cases (small x, large x, large nu).""" + grad_fn = jax.grad(lambda x_val: kv(nu, x_val)) + grad_val = grad_fn(x) + + # Verify gradient is finite and reasonable + assert jnp.isfinite(grad_val), f"Gradient not finite at nu={nu}, x={x}" + + # Verify against analytical formula + kv_prev = kv(nu - 1.0, x) + kv_next = kv(nu + 1.0, x) + expected_grad = -0.5 * (kv_prev + kv_next) + + assert jnp.allclose(grad_val, expected_grad, rtol=1e-8), ( + f"Gradient mismatch at edge case nu={nu}, x={x}" + ) + + def test_gradient_wrt_nu_is_zero(self): + """Test that gradient w.r.t. nu returns zero (not supported).""" + # Gradient w.r.t. first argument (nu) + grad_nu_fn = jax.grad(lambda nu_val: kv(nu_val, 5.0), argnums=0) + grad_nu = grad_nu_fn(2.5) + + # Should be zero since we don't support gradient w.r.t. order + assert grad_nu == 0.0, "Gradient w.r.t. nu should be zero" + + def test_gradient_negative_nu(self): + """Test that gradients work with negative nu (should use abs(nu)).""" + # K_{-nu}(x) = K_nu(x), so gradient should be the same + grad_pos = jax.grad(lambda x_val: kv(2.5, x_val))(5.0) + grad_neg = jax.grad(lambda x_val: kv(-2.5, x_val))(5.0) + + assert jnp.allclose(grad_pos, grad_neg, rtol=1e-10), ( + "Gradient should be same for positive and negative nu" + ) + + def test_gradient_integer_order(self): + """Test gradients for integer orders (special case).""" + # Integer orders use different implementation path + for nu in [0, 1, 2, 5]: + grad_fn = jax.grad(lambda x_val: kv(nu, x_val)) + grad_val = grad_fn(5.0) + + # Verify against analytical formula + kv_prev = kv(nu - 1.0, 5.0) if nu > 0 else kv(1.0, 5.0) # K_{-1} = K_1 + kv_next = kv(nu + 1.0, 5.0) + expected_grad = -0.5 * (kv_prev + kv_next) + + assert jnp.allclose(grad_val, expected_grad, rtol=1e-8), ( + f"Gradient mismatch for integer order nu={nu}" + ) + + def test_gradient_half_integer_order(self): + """Test gradients for half-integer orders (special case).""" + # Half-integer orders use closed-form expressions + for nu in [0.5, 1.5, 2.5, 3.5]: + grad_fn = jax.grad(lambda x_val: kv(nu, x_val)) + grad_val = grad_fn(5.0) + + # Verify against analytical formula + kv_prev = kv(nu - 1.0, 5.0) + kv_next = kv(nu + 1.0, 5.0) + expected_grad = -0.5 * (kv_prev + kv_next) + + assert jnp.allclose(grad_val, expected_grad, rtol=1e-8), ( + f"Gradient mismatch for half-integer order nu={nu}" + ) + + def test_gradient_jit_compilation(self): + """Test that gradient compilation works without errors.""" + # Compile gradient function + grad_fn = jax.jit(jax.grad(lambda x_val: kv(2.5, x_val))) + + # Should compile without errors + _ = grad_fn(5.0) + + # Call again to verify compiled version works + grad_val = grad_fn(10.0) + assert jnp.isfinite(grad_val), "Compiled gradient not finite" + + def test_gradient_multiple_calls(self): + """Test that gradients are consistent across multiple calls.""" + grad_fn = jax.grad(lambda x_val: kv(2.5, x_val)) + + # Call multiple times with same input + grads = [grad_fn(5.0) for _ in range(5)] + + # All should be identical + for grad_val in grads[1:]: + assert jnp.allclose(grad_val, grads[0], rtol=1e-12), ( + "Gradients inconsistent across calls" + ) + + def test_gradient_batch_computation(self): + """Test gradients with batched inputs using vmap.""" + nu_values = jnp.array([0.5, 1.0, 2.5, 10.0]) + x_values = jnp.array([1.0, 5.0, 10.0, 50.0]) + + # Compute gradients for all combinations + grad_fn = jax.vmap( + jax.vmap( + jax.grad(lambda x_val, nu_val: kv(nu_val, x_val)), in_axes=(None, 0) + ), + in_axes=(0, None), + ) + + grads = grad_fn(x_values, nu_values) + + # Verify shape: (len(x_values), len(nu_values)) + assert grads.shape == (len(x_values), len(nu_values)) + + # Verify all are finite + assert jnp.all(jnp.isfinite(grads)), "Some batch gradients are not finite"