From 7cd95cca02865e702622bc928de8023bc914cf68 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Mon, 9 Feb 2026 00:11:01 +0100 Subject: [PATCH 1/2] update --- jax_galsim/bessel.py | 451 ++++++++++++++++++++++--------------------- 1 file changed, 229 insertions(+), 222 deletions(-) diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index 87d09291..23350375 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -1,6 +1,18 @@ import galsim as _galsim import jax import jax.numpy as jnp +import numpy as np +from jax._src import core, dtypes +from jax._src.interpreters import ad, mlir +from jax._src.lax.lax import ( + _const, + _float, + broadcast_in_dim, + broadcast_shapes, + convert_element_type, + standard_naryop, +) +from jax._src.lax.special import evaluate_chebyshev_polynomial from jax_galsim.core.utils import implements @@ -105,26 +117,47 @@ def si(x): # ===================================================================== -# Modified Bessel K_v(x) +# Modified Bessel K_v(x) — JAX primitive implementation +# +# Registered as a JAX primitive (bessel_kv_p) with JVP rules for +# automatic differentiation. Uses fori_loop with fixed iteration counts +# and operates element-wise (no vmap). Broadcasting is handled by +# _up_and_broadcast. # -# The functions below (_sqrt1px2, _evaluate_temme_coeffs, -# _temme_series_kve, _continued_fraction_kve, _olver_kve, _temme_kve, -# _kve_core) and the _ASYMPTOTIC_OLVER_COEFFICIENTS constant are -# derived from TensorFlow Probability's bessel.py: +# The algorithms (Temme series, Steed's continued fraction, Olver +# uniform asymptotic expansion) are derived from TensorFlow +# Probability's bessel.py: # https://github.com/tensorflow/probability # # Original copyright and license: # Copyright 2020 The TensorFlow Probability Authors. # Licensed under the Apache License, Version 2.0. -# -# Modifications from the original: -# - Ported from TensorFlow/TFP APIs to pure JAX (jax.numpy, jax.lax) -# - Removed I_v (bessel_ive) computation; only K_v is computed -# - Removed log-space output option -# - Removed negative-v correction for I_v -# - Simplified to scalar-only core (vectorization via jax.vmap) # ===================================================================== + +def _up_and_broadcast(doit): + """Broadcast args and upcast bf16/f16 to f32 before calling doit.""" + + def up_and_broadcast(*args): + broadcasted_shape = broadcast_shapes(*(a.shape for a in args)) + args = [ + broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args + ] + a_dtype = args[0].dtype + needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16 + if needs_upcast: + args = [convert_element_type(a, np.float32) for a in args] + a_x_type = np.float32 + else: + a_x_type = a_dtype + result = doit(*args, dtype=a_x_type) + if needs_upcast: + result = convert_element_type(result, a_dtype) + return result + + return up_and_broadcast + + # Olver expansion polynomial coefficients (10 terms, up to 31 coefficients each) # fmt: off _ASYMPTOTIC_OLVER_COEFFICIENTS = [ @@ -173,25 +206,20 @@ def si(x): def _sqrt1px2(x): - """Numerically stable computation of sqrt(1 + x^2).""" - eps = jnp.finfo(jnp.float64).eps + """Numerically stable sqrt(1 + x^2).""" + eps = _const(x, jnp.finfo(jnp.float64).eps) return jnp.where( - jnp.abs(x) * jnp.sqrt(eps) <= 1.0, - jnp.exp(0.5 * jnp.log1p(x * x)), + jnp.abs(x) * jnp.sqrt(eps) <= _const(x, 1.0), + jnp.exp(_const(x, 0.5) * jnp.log1p(x * x)), jnp.abs(x), ) def _evaluate_temme_coeffs(v): - """Numerically stable computation of gamma-related coefficients for Temme's method. + """Chebyshev-based gamma coefficients for Temme's method. - Computes: - coeff1 = (1/Gamma(1-v) - 1/Gamma(1+v)) / (2v) - coeff2 = (1/Gamma(1-v) + 1/Gamma(1+v)) / 2 - gamma1pv = 1/Gamma(1+v) - gamma1mv = 1/Gamma(1-v) - - Uses Chebyshev expansions for numerical stability (avoids catastrophic cancellation). + Returns (coeff1, coeff2, gamma1pv_inv, gamma1mv_inv). + Uses JAX's evaluate_chebyshev_polynomial for Clenshaw recurrence. """ coeff1_coeffs = [ -1.142022680371168e0, @@ -212,26 +240,12 @@ def _evaluate_temme_coeffs(v): -1.702e-13, -1.49e-15, ] - w = 8.0 * v * v - 1.0 - y = 2.0 * w - - # Clenshaw's recurrence for coeff1 - prev = 0.0 - current = 0.0 - for i in reversed(range(1, len(coeff1_coeffs))): - temp = current - current = y * current - prev + coeff1_coeffs[i] - prev = temp - coeff1 = w * current - prev + 0.5 * coeff1_coeffs[0] - - # Clenshaw's recurrence for coeff2 - prev = 0.0 - current = 0.0 - for i in reversed(range(1, len(coeff2_coeffs))): - temp = current - current = y * current - prev + coeff2_coeffs[i] - prev = temp - coeff2 = w * current - prev + 0.5 * coeff2_coeffs[0] + y = _const(v, 2.0) * (_const(v, 8.0) * v * v - _const(v, 1.0)) + + # evaluate_chebyshev_polynomial(x, coeffs) evaluates sum c_k T_k(x/2), + # so passing y = 2w gives evaluation at w = 8v^2 - 1. + coeff1 = evaluate_chebyshev_polynomial(y, list(reversed(coeff1_coeffs))) + coeff2 = evaluate_chebyshev_polynomial(y, list(reversed(coeff2_coeffs))) gamma1pv = coeff2 - v * coeff1 gamma1mv = coeff2 + v * coeff1 @@ -239,87 +253,88 @@ def _evaluate_temme_coeffs(v): def _temme_series_kve(v, z): - """Compute Kve(v, z) and Kve(v+1, z) via Temme power series. + """Kve(v, z) and Kve(v+1, z) via Temme power series. - Assumes |v| < 0.5 and |z| <= 2 for fast convergence. - Returns exponentially scaled values: Kv(v,z)*exp(z). + Assumes |v| < 0.5 and |z| <= 2. Returns exponentially scaled values. + Uses fori_loop with fixed 15 iterations (empirically, max needed is 12 + for f64 across the valid domain). """ - tol = jnp.finfo(jnp.float64).eps + tol = _const(z, jnp.finfo(jnp.float64).eps) coeff1, coeff2, gamma1pv_inv, gamma1mv_inv = _evaluate_temme_coeffs(v) z_sq = z * z - logzo2 = jnp.log(z / 2.0) + logzo2 = jnp.log(z / _const(z, 2.0)) mu = -v * logzo2 - sinc_v = jnp.where(v == 0.0, 1.0, jnp.sin(jnp.pi * v) / (jnp.pi * v)) - sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu) / mu) + pi_v = _const(v, jnp.pi) * v + sinc_v = jnp.where(v == _const(v, 0.0), _const(v, 1.0), jnp.sin(pi_v) / pi_v) + sinhc_mu = jnp.where(mu == _const(mu, 0.0), _const(mu, 1.0), jnp.sinh(mu) / mu) initial_f = (coeff1 * jnp.cosh(mu) + coeff2 * (-logzo2) * sinhc_mu) / sinc_v - initial_p = 0.5 * jnp.exp(mu) / gamma1pv_inv - initial_q = 0.5 * jnp.exp(-mu) / gamma1mv_inv + initial_p = _const(v, 0.5) * jnp.exp(mu) / gamma1pv_inv + initial_q = _const(v, 0.5) * jnp.exp(-mu) / gamma1mv_inv - max_iterations = 1000 + max_iters = 15 - def body_fn(carry): - should_stop, index, f, p, q, coeff, kv_sum, kvp1_sum = carry - f = jnp.where( - should_stop, - f, - (index * f + p + q) / (index * index - v * v), - ) - p = jnp.where(should_stop, p, p / (index - v)) - q = jnp.where(should_stop, q, q / (index + v)) - h = p - index * f - coeff = jnp.where(should_stop, coeff, coeff * z_sq / (4.0 * index)) - kv_sum = jnp.where(should_stop, kv_sum, kv_sum + coeff * f) - kvp1_sum = jnp.where(should_stop, kvp1_sum, kvp1_sum + coeff * h) - index = index + 1.0 - should_stop = (jnp.abs(coeff * f) < jnp.abs(kv_sum) * tol) | ( - index > max_iterations + def body_fn(i, carry): + f, p, q, coeff, kv_sum, kvp1_sum, converged = carry + index = i + _const(v, 1.0) + + new_f = (index * f + p + q) / (index * index - v * v) + new_p = p / (index - v) + new_q = q / (index + v) + h = new_p - index * new_f + new_coeff = coeff * z_sq / (_const(z, 4.0) * index) + new_kv_sum = kv_sum + new_coeff * new_f + new_kvp1_sum = kvp1_sum + new_coeff * h + + new_converged = converged | ( + jnp.abs(new_coeff * new_f) < jnp.abs(new_kv_sum) * tol ) - return (should_stop, index, f, p, q, coeff, kv_sum, kvp1_sum) - def cond_fn(carry): - should_stop = carry[0] - return ~should_stop + f = jnp.where(converged, f, new_f) + p = jnp.where(converged, p, new_p) + q = jnp.where(converged, q, new_q) + coeff = jnp.where(converged, coeff, new_coeff) + kv_sum = jnp.where(converged, kv_sum, new_kv_sum) + kvp1_sum = jnp.where(converged, kvp1_sum, new_kvp1_sum) + + return (f, p, q, coeff, kv_sum, kvp1_sum, new_converged) init = ( - jnp.array(False), - 1.0, initial_f, initial_p, initial_q, - 1.0, + jnp.ones_like(z), initial_f, initial_p, + jnp.zeros_like(v, dtype=jnp.bool_), ) - _, _, _, _, _, _, kv_sum, kvp1_sum = jax.lax.while_loop(cond_fn, body_fn, init) + _, _, _, _, kv_sum, kvp1_sum, _ = jax.lax.fori_loop(0, max_iters, body_fn, init) - # Convert to exponentially scaled: kve = kv * exp(z) kve = kv_sum * jnp.exp(z) - kvep1 = 2.0 * kvp1_sum * jnp.exp(z) / z - + kvep1 = _const(z, 2.0) * kvp1_sum * jnp.exp(z) / z return kve, kvep1 def _continued_fraction_kve(v, z): - """Compute Kve(v, z) and Kve(v+1, z) via Steed's continued fraction. + """Kve(v, z) and Kve(v+1, z) via Steed's continued fraction. - Assumes |v| < 0.5 and |z| > 2. - Returns exponentially scaled values: Kv(v,z)*exp(z). + Assumes |v| < 0.5 and |z| > 2. Returns exponentially scaled values. + Uses fori_loop with fixed 80 iterations (empirically, max needed is 77 + for f64 at z~2). """ - tol = jnp.finfo(jnp.float64).eps - max_iterations = 1000 + tol = _const(z, jnp.finfo(jnp.float64).eps) - initial_numerator = v * v - 0.25 - initial_denominator = 2.0 * (z + 1.0) - initial_ratio = 1.0 / initial_denominator + initial_numerator = v * v - _const(v, 0.25) + initial_denominator = _const(z, 2.0) * (z + _const(z, 1.0)) + initial_ratio = _const(z, 1.0) / initial_denominator initial_seq = -initial_numerator - def steeds_body(carry): + max_iters = 80 + + def body_fn(i, carry): ( - should_stop, - index, partial_numerator, partial_denominator, denominator_ratio, @@ -330,40 +345,57 @@ def steeds_body(carry): c, q, hypergeometric_sum, + converged, ) = carry + index = i + _const(z, 2.0) + + new_partial_numerator = partial_numerator - _const(z, 2.0) * ( + index - _const(z, 1.0) + ) + new_c = -c * new_partial_numerator / index + next_k = (k_0 - partial_denominator * k_1) / new_partial_numerator + new_k_0 = k_1 + new_k_1 = next_k + new_q = q + new_c * next_k + new_partial_denominator = partial_denominator + _const(z, 2.0) + new_denominator_ratio = _const(z, 1.0) / ( + new_partial_denominator + new_partial_numerator * denominator_ratio + ) + new_convergent_difference = convergent_difference * ( + new_partial_denominator * new_denominator_ratio - _const(z, 1.0) + ) + new_hypergeometric_ratio = hypergeometric_ratio + new_convergent_difference + new_hypergeometric_sum = hypergeometric_sum + new_q * new_convergent_difference - partial_numerator = partial_numerator - 2.0 * (index - 1.0) - c = jnp.where(should_stop, c, -c * partial_numerator / index) - next_k = (k_0 - partial_denominator * k_1) / partial_numerator - k_0 = jnp.where(should_stop, k_0, k_1) - k_1 = jnp.where(should_stop, k_1, next_k) - q = jnp.where(should_stop, q, q + c * next_k) - partial_denominator = partial_denominator + 2.0 - denominator_ratio = 1.0 / ( - partial_denominator + partial_numerator * denominator_ratio + new_converged = converged | ( + jnp.abs(new_q * new_convergent_difference) + < jnp.abs(new_hypergeometric_sum) * tol + ) + + partial_numerator = jnp.where( + converged, partial_numerator, new_partial_numerator + ) + c = jnp.where(converged, c, new_c) + k_0 = jnp.where(converged, k_0, new_k_0) + k_1 = jnp.where(converged, k_1, new_k_1) + q = jnp.where(converged, q, new_q) + partial_denominator = jnp.where( + converged, partial_denominator, new_partial_denominator + ) + denominator_ratio = jnp.where( + converged, denominator_ratio, new_denominator_ratio ) convergent_difference = jnp.where( - should_stop, - convergent_difference, - convergent_difference * (partial_denominator * denominator_ratio - 1.0), + converged, convergent_difference, new_convergent_difference ) hypergeometric_ratio = jnp.where( - should_stop, - hypergeometric_ratio, - hypergeometric_ratio + convergent_difference, + converged, hypergeometric_ratio, new_hypergeometric_ratio ) hypergeometric_sum = jnp.where( - should_stop, - hypergeometric_sum, - hypergeometric_sum + q * convergent_difference, + converged, hypergeometric_sum, new_hypergeometric_sum ) - index = index + 1.0 - should_stop = ( - jnp.abs(q * convergent_difference) < jnp.abs(hypergeometric_sum) * tol - ) | (index > max_iterations) + return ( - should_stop, - index, partial_numerator, partial_denominator, denominator_ratio, @@ -374,194 +406,169 @@ def steeds_body(carry): c, q, hypergeometric_sum, + new_converged, ) - def cond_fn(carry): - return ~carry[0] - init = ( - jnp.array(False), - 2.0, initial_numerator, initial_denominator, initial_ratio, initial_ratio, initial_ratio, - 0.0, - 1.0, + jnp.zeros_like(z), + jnp.ones_like(z), initial_seq, initial_seq, - 1.0 - initial_numerator * initial_ratio, + jnp.ones_like(z) - initial_numerator * initial_ratio, + jnp.zeros_like(v, dtype=jnp.bool_), ) - result = jax.lax.while_loop(cond_fn, steeds_body, init) - hypergeometric_ratio = result[6] - hypergeometric_sum = result[11] + result = jax.lax.fori_loop(0, max_iters, body_fn, init) + hypergeometric_ratio = result[4] + hypergeometric_sum = result[9] - # log(kve) = 0.5*log(pi/(2z)) - log(hypergeometric_sum) - log_kve = 0.5 * jnp.log(jnp.pi / (2.0 * z)) - jnp.log(hypergeometric_sum) + log_kve = _const(z, 0.5) * jnp.log( + _const(z, jnp.pi) / (_const(z, 2.0) * z) + ) - jnp.log(hypergeometric_sum) log_kvp1e = ( log_kve - + jnp.log1p(2.0 * (v + z + initial_numerator * hypergeometric_ratio)) + + jnp.log1p(_const(z, 2.0) * (v + z + initial_numerator * hypergeometric_ratio)) - jnp.log(z) - - jnp.log(2.0) + - jnp.log(_const(z, 2.0)) ) return jnp.exp(log_kve), jnp.exp(log_kvp1e) def _olver_kve(v, z): - """Compute Kve(v, z) using Olver's uniform asymptotic expansion. + """Kve(v, z) using Olver's uniform asymptotic expansion. - Valid for |v| >= 50. Returns exponentially scaled value: Kv(v,z)*exp(z). + Valid for |v| >= 50. Returns Kv(v,z)*exp(z). """ v_abs = jnp.abs(v) w = z / v_abs - t = 1.0 / _sqrt1px2(w) + t = _const(z, 1.0) / _sqrt1px2(w) divisor = v_abs - kve_sum = 1.0 + kve_sum = _const(z, 1.0) - # Evaluate the Olver polynomial terms using Horner's method for i in range(len(_ASYMPTOTIC_OLVER_COEFFICIENTS)): - coeff = 0.0 + coeff = _const(z, 0.0) for c in _ASYMPTOTIC_OLVER_COEFFICIENTS[i]: - coeff = coeff * t + c + coeff = coeff * t + _const(z, c) term = coeff / divisor - # For K_v, signs alternate: (-1)^i kve_sum = kve_sum + (term if i % 2 == 1 else -term) divisor = divisor * v_abs - # log(kve) = 0.5*log(pi*t/(2*v_abs)) - v_abs*shared_prefactor - shared_prefactor = 1.0 / (_sqrt1px2(w) + w) + jnp.log(w) - jnp.log1p(1.0 / t) + shared_prefactor = ( + _const(z, 1.0) / (_sqrt1px2(w) + w) + jnp.log(w) - jnp.log1p(_const(z, 1.0) / t) + ) log_k_prefactor = ( - 0.5 * jnp.log(jnp.pi * t / (2.0 * v_abs)) - v_abs * shared_prefactor + _const(z, 0.5) * jnp.log(_const(z, jnp.pi) * t / (_const(z, 2.0) * v_abs)) + - v_abs * shared_prefactor ) - log_kve = log_k_prefactor + jnp.log(kve_sum) return jnp.exp(log_kve) def _temme_kve(v, x): - """Compute Kve(v, x) using Temme's method for |v| < 50. + """Kve(v, x) using Temme's method for |v| < 50. - Reduces to fractional order |u| <= 0.5, computes Kve(u, x) and Kve(u+1, x), - then uses forward recurrence to reach order v. - Returns exponentially scaled value: Kv(v,x)*exp(x). + Reduces to fractional order |u| <= 0.5, computes via series or CF, + then forward recurrence to reach order v. + Uses fori_loop with fixed 50 iterations for forward recurrence. """ v = jnp.abs(v) n = jnp.round(v) u = v - n - # Branchless: compute both methods with safe inputs, select with jnp.where - small_x = jnp.where(x <= 2.0, x, 0.1) - large_x = jnp.where(x > 2.0, x, 1000.0) + small_x = jnp.where(x <= _const(x, 2.0), x, _const(x, 0.1)) + large_x = jnp.where(x > _const(x, 2.0), x, _const(x, 1000.0)) temme_kue, temme_kuep1 = _temme_series_kve(u, small_x) cf_kue, cf_kuep1 = _continued_fraction_kve(u, large_x) - kue = jnp.where(x <= 2.0, temme_kue, cf_kue) - kuep1 = jnp.where(x <= 2.0, temme_kuep1, cf_kuep1) + kue = jnp.where(x <= _const(x, 2.0), temme_kue, cf_kue) + kuep1 = jnp.where(x <= _const(x, 2.0), temme_kuep1, cf_kuep1) - # Forward recurrence: K_{v+1}(z) = (2v/z)*K_v(z) + K_{v-1}(z) - # This recurrence is also satisfied by Kv*exp(z) (the exponentially scaled form). - def bessel_recurrence(carry): - index, kve, kvep1 = carry - next_kvep1 = 2.0 * (u + index) * kvep1 / x + kve - kve = jnp.where(index > n, kve, kvep1) - kvep1 = jnp.where(index > n, kvep1, next_kvep1) - return (index + 1.0, kve, kvep1) + max_recurrence = 50 - def recurrence_cond(carry): - index = carry[0] - return index <= n + def recurrence_body(i, carry): + kve, kvep1 = carry + index = i + _const(x, 1.0) + past_n = index > n + next_kvep1 = _const(x, 2.0) * (u + index) * kvep1 / x + kve + new_kve = jnp.where(past_n, kve, kvep1) + new_kvep1 = jnp.where(past_n, kvep1, next_kvep1) + return (new_kve, new_kvep1) - _, kve, _ = jax.lax.while_loop( - recurrence_cond, bessel_recurrence, (1.0, kue, kuep1) - ) + kve, _ = jax.lax.fori_loop(0, max_recurrence, recurrence_body, (kue, kuep1)) return kve def _kve_core(nu, x): - """Core dispatcher for Kve(nu, x) = Kv(nu, x) * exp(x). + """Core dispatcher: computes Kve(nu, x) = Kv(nu, x) * exp(x). - Branchless: computes both Olver and Temme with safe dummy inputs, - selects based on |nu| >= 50. + Branchless: computes both Olver and Temme, selects based on |nu| >= 50. """ nu = jnp.abs(nu) - # Safe inputs: avoid invalid regions for each method - small_nu = jnp.where(nu < 50.0, nu, 0.1) - large_nu = jnp.where(nu >= 50.0, nu, 1000.0) + small_nu = jnp.where(nu < _const(nu, 50.0), nu, _const(nu, 0.1)) + large_nu = jnp.where(nu >= _const(nu, 50.0), nu, _const(nu, 1000.0)) olver_result = _olver_kve(large_nu, x) temme_result = _temme_kve(small_nu, x) - return jnp.where(nu >= 50.0, olver_result, temme_result) + return jnp.where(nu >= _const(nu, 50.0), olver_result, temme_result) -def _kv_scalar(nu, x): - """Scalar implementation of K_v(x) using TFP-ported Temme + Olver algorithms.""" - nu = 1.0 * nu - x = 1.0 * x - nu = jnp.abs(nu) # K_{-v} = K_v +def _bessel_kv_impl(v, x, *, dtype): + """Element-wise implementation of K_v(x). - # Compute via exponentially scaled form for numerical stability - # Use a safe x for the core computation (avoid x=0 which causes issues) - safe_x = jnp.where(x > 0.0, x, 1.0) - kve = _kve_core(nu, safe_x) + The dtype kwarg is required by the _up_and_broadcast pattern. + """ + v = jnp.abs(v) # K_{-v} = K_v + + safe_x = jnp.where(x > _const(x, 0.0), x, _const(x, 1.0)) + kve = _kve_core(v, safe_x) result = kve * jnp.exp(-safe_x) - # Edge cases - result = jnp.where(x == 0.0, jnp.inf, result) - result = jnp.where(x < 0.0, jnp.nan, result) + result = jnp.where(x == _const(x, 0.0), _const(x, jnp.inf), result) + result = jnp.where(x < _const(x, 0.0), _const(x, jnp.nan), result) return result -@implements(_galsim.bessel.kv) -@jax.jit -def _kv_impl(nu, x): - """Modified Bessel function of the second kind K_v(x) - internal implementation. +# --- JAX primitive registration --- - Uses TFP-ported Temme + Olver algorithms for all orders. - Handles both scalar and array inputs via jax.vmap. - """ - nu = jnp.asarray(1.0 * nu) - x = jnp.asarray(1.0 * x) - out_shape = jnp.broadcast_shapes(jnp.shape(nu), jnp.shape(x)) - if out_shape == (): - return _kv_scalar(nu, x) - nu_bc, x_bc = jnp.broadcast_arrays(nu, x) - flat_nu = nu_bc.ravel() - flat_x = x_bc.ravel() - return jax.vmap(_kv_scalar)(flat_nu, flat_x).reshape(out_shape) - - -@jax.custom_vjp -def kv(nu, x): - """Modified Bessel function of the second kind K_v(x) with custom gradients. +_bessel_kv_p = standard_naryop([_float, _float], "bessel_kv") - Uses TFP-ported Temme + Olver algorithms. Custom gradients via: - dK_v/dx = -1/2 * (K_{v-1}(x) + K_{v+1}(x)) +mlir.register_lowering( + _bessel_kv_p, + mlir.lower_fun(_up_and_broadcast(_bessel_kv_impl), multiple_results=False), +) - Gradient w.r.t. v is not supported (returns zero). - """ - return _kv_impl(nu, x) +def _bessel_kv_jvp_v(g, v, x): + return jnp.zeros_like(v) * g -def _kv_fwd(nu, x): - kv_val = _kv_impl(nu, x) - kv_prev = _kv_impl(nu - 1.0, x) - kv_next = _kv_impl(nu + 1.0, x) - return kv_val, (nu, x, kv_prev, kv_next) +def _bessel_kv_jvp_x(g, v, x): + return g * _const(x, -0.5) * (kv(v - _const(v, 1.0), x) + kv(v + _const(v, 1.0), x)) -def _kv_bwd(residuals, g): - nu, x, kv_prev, kv_next = residuals - grad_x = -0.5 * (kv_prev + kv_next) * g - grad_nu = jnp.zeros_like(nu) - return (grad_nu, grad_x) +ad.defjvp(_bessel_kv_p, _bessel_kv_jvp_v, _bessel_kv_jvp_x) -kv.defvjp(_kv_fwd, _kv_bwd) + +@implements(_galsim.bessel.kv) +def kv(nu, x): + """Modified Bessel function of the second kind K_v(x). + + Registered as a JAX primitive with JVP rules. + Supports jit, vmap, and grad (w.r.t. x). + Gradient w.r.t. v is not supported (returns zero). + """ + nu = jnp.asarray(nu, dtype=float) + x = jnp.asarray(x, dtype=float) + nu, x = core.standard_insert_pvary(nu, x) + return _bessel_kv_p.bind(nu, x) @jax.jit From a5fea9d20c98e4c0e595ffd64bccb298c930fd4f Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Mon, 9 Feb 2026 00:29:39 +0100 Subject: [PATCH 2/2] fix nans --- jax_galsim/bessel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index 23350375..fd0eb8b3 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -207,7 +207,7 @@ def up_and_broadcast(*args): def _sqrt1px2(x): """Numerically stable sqrt(1 + x^2).""" - eps = _const(x, jnp.finfo(jnp.float64).eps) + eps = _const(x, jnp.finfo(x.dtype).eps) return jnp.where( jnp.abs(x) * jnp.sqrt(eps) <= _const(x, 1.0), jnp.exp(_const(x, 0.5) * jnp.log1p(x * x)), @@ -259,7 +259,7 @@ def _temme_series_kve(v, z): Uses fori_loop with fixed 15 iterations (empirically, max needed is 12 for f64 across the valid domain). """ - tol = _const(z, jnp.finfo(jnp.float64).eps) + tol = _const(z, jnp.finfo(z.dtype).eps) coeff1, coeff2, gamma1pv_inv, gamma1mv_inv = _evaluate_temme_coeffs(v) @@ -324,7 +324,7 @@ def _continued_fraction_kve(v, z): Uses fori_loop with fixed 80 iterations (empirically, max needed is 77 for f64 at z~2). """ - tol = _const(z, jnp.finfo(jnp.float64).eps) + tol = _const(z, jnp.finfo(z.dtype).eps) initial_numerator = v * v - _const(v, 0.25) initial_denominator = _const(z, 2.0) * (z + _const(z, 1.0))