From 8e732768ae0c1875644ddcf1f11057ba7031545e Mon Sep 17 00:00:00 2001 From: EiffL Date: Fri, 6 Feb 2026 19:23:54 +0100 Subject: [PATCH 01/11] Adding missing bessel functions --- CLAUDE.md | 295 ++++++++++++++++++++++ jax_galsim/bessel.py | 431 ++++++++++++++++++++++++++++++++- pyproject.toml | 4 +- tests/galsim_tests_config.yaml | 1 - 4 files changed, 725 insertions(+), 6 deletions(-) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..d8f2386c --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,295 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +JAX-GalSim is a JAX port of GalSim (Galaxy Image Simulation toolkit) that enables parallelized, GPU-accelerated, and differentiable galaxy image simulations. This is an early-stage project aiming to reimplement GalSim functionalities in pure JAX. + +**Key Design Principles:** +- Drop-in replacement for GalSim with a close API match +- Each function/feature is tested against the reference GalSim implementation +- This is a **subset** of GalSim (only includes functions with a reference implementation) +- Code should be readable and pip-installable without compilation + +**Current Status:** Early development phase (v0.0.1rc1). Not for scientific applications yet - use the reference GalSim implementation for production work. + +## Installation and Setup + +**Recommended:** Use a virtual environment to isolate dependencies: + +```bash +# Clone with submodules (required for tests) +git clone --recurse-submodules https://github.com/YOUR_USERNAME/JAX-GalSim +cd JAX-GalSim + +# Create and activate a virtual environment (recommended) +python -m venv venv +source venv/bin/activate # On Windows: venv\Scripts\activate + +# Install in editable mode +pip install -e . + +# Install development tools +pip install black pre-commit pytest +pre-commit install +``` + +## Testing + +```bash +# Run all tests (includes both GalSim reference tests and JAX-specific tests) +pytest + +# Test paths are configured in pytest.ini: +# - tests/GalSim/tests/ (reference GalSim tests) +# - tests/jax (JAX-specific tests like test_jitting.py) + +# Run specific test file +pytest tests/jax/test_jitting.py +``` + +## Code Formatting + +This project uses Black for code formatting: + +```bash +# Format all code +black . + +# Black excludes tests/GalSim/ directory (configured in .pre-commit-config.yaml) +``` + +**Important:** CI will fail if code is not formatted with Black. Use pre-commit hooks to automate this. + +## Architecture + +### Core Structure + +- `jax_galsim/` - Main package implementing JAX versions of GalSim objects + - `gsobject.py` - Base `GSObject` class that all galaxy profile objects inherit from + - `gsparams.py` - `GSParams` class for speed/accuracy trade-off parameters + - `gaussian.py`, `exponential.py` - Specific galaxy profile implementations + - `sum.py` - Composite objects (e.g., `Add`, `Sum`) + - `core/` - Core utilities (currently empty/minimal) + +- `tests/` - Test suite + - `GalSim/` - Git submodule containing the reference GalSim implementation for testing + - `jax/` - JAX-specific tests (e.g., JIT compilation tests) + +### JAX Pytree Registration + +All GSObject classes must be registered as JAX pytrees to support JIT compilation and automatic differentiation: + +```python +from jax.tree_util import register_pytree_node_class + +@register_pytree_node_class +class MyGSObject(GSObject): + def tree_flatten(self): + # Return (children, aux_data) where children are JAX arrays + # and aux_data contains static information + ... + + @classmethod + def tree_unflatten(cls, aux_data, children): + # Reconstruct object from aux_data and children + ... +``` + +### Documentation Pattern + +Avoid duplicating documentation. Use JAX's `_wraps` utility to inherit docs from GalSim: + +```python +from jax._src.numpy.util import _wraps +import galsim as _galsim + +@_wraps(_galsim.Gaussian) +@register_pytree_node_class +class Gaussian(GSObject): + ... + +# Or for functions with differences: +@_wraps(_galsim.Add, lax_description="Does not support `ChromaticObject` at this point.") +def Add(*args, **kwargs): + return Sum(*args, **kwargs) +``` + +The `lax_description` parameter documents any differences or limitations compared to GalSim. + +### GSObject Parameter Management + +GSObjects use a dual-parameter system: +- `_params` dict: Traced parameters (JAX arrays) that can be differentiated +- `_gsparams`: Static parameters (`GSParams` object) for numerical configurations + +Properties like `flux`, `sigma`, `half_light_radius` etc. are accessed via `self.params` dictionary. + +### Testing Against Reference GalSim + +Tests in `tests/GalSim/tests/` are from the reference GalSim implementation. JAX-GalSim objects are tested against these to ensure API compatibility and numerical accuracy. + +JAX-specific tests in `tests/jax/` verify JAX functionality like JIT compilation, differentiation, and pytree behavior. + +### How the Testing Infrastructure Works + +JAX-GalSim uses a **pytest hook system** to automatically run GalSim's test suite against JAX-GalSim implementations. This means you can reuse all of GalSim's existing tests without modification! + +#### The Mechanism + +**1. Import Replacement (`tests/conftest.py`)** + - The `pytest_pycollect_makemodule` hook intercepts test file loading + - Automatically replaces `import galsim` with `import jax_galsim` in all GalSim test files + - This happens transparently - no modification to GalSim test files needed! + +**2. Test Configuration (`tests/galsim_tests_config.yaml`)** + ```yaml + enabled_tests: + galsim: + - test_gaussian.py + - test_exponential.py + - "*" # Enable all GalSim tests + + allowed_failures: + - "module 'jax_galsim' has no attribute 'Airy'" + - "module 'jax_galsim.bessel' has no attribute 'j1'" + # ... list of expected failures for unimplemented features + ``` + + - `enabled_tests`: Lists which GalSim test files to run (`"*"` means all) + - `allowed_failures`: Error messages that won't fail the test suite (for tracking unimplemented features) + +**3. Test Execution Flow** + ``` + pytest tests/GalSim/tests/test_bessel.py + ↓ + pytest hook replaces: import galsim → import jax_galsim + ↓ + GalSim tests run against JAX-GalSim implementation + ↓ + Results compared with scipy.special / reference values + ↓ + PASS / FAIL / ALLOWED FAILURE + ``` + +#### Enabling Tests for New Functions + +When you implement a new function in JAX-GalSim, follow these steps to enable its tests: + +**Example: Adding `bessel.kn` function** + +1. **Implement the function** in `jax_galsim/bessel.py`: + ```python + from jax_galsim.core.utils import implements + import galsim as _galsim + + @implements(_galsim.bessel.kn) + @jax.jit + def kn(n, x): + """Modified Bessel function K_n for integer n""" + # ... implementation ... + return result + ``` + +2. **Remove from allowed_failures** in `tests/galsim_tests_config.yaml`: + ```yaml + allowed_failures: + # DELETE or comment out this line: + # - "module 'jax_galsim.bessel' has no attribute 'kn'" + ``` + +3. **Run the tests**: + ```bash + pytest tests/GalSim/tests/test_bessel.py::test_kn -v + ``` + +4. **Test outcomes**: + - **PASS**: Your implementation matches GalSim's accuracy ✅ + - **FAIL**: Numerical accuracy issues - fix your implementation + - **ERROR**: API mismatch - check function signature and behavior + +#### Finding Which Tests Will Run + +To see what GalSim tests exist for a module: + +```bash +# List all bessel tests +grep "^def test_" tests/GalSim/tests/test_bessel.py + +# Example output: +# def test_j0(): +# def test_j1(): +# def test_kn(): +# def test_kv(): +# ... etc +``` + +Each `test_*` function will automatically run against your JAX-GalSim implementation when enabled! + +#### Tracking Progress + +```bash +# Run all tests and see summary +pytest tests/GalSim/tests/ -v + +# Common output: +# ✅ 25 passed - Implementations working correctly +# ❌ 3 failed - Implementations with accuracy issues +# ⚠️ 100 allowed - Features not yet implemented +``` + +This gives you clear visibility into: +- What's working (passing tests) +- What needs fixing (failing tests) +- What's not implemented yet (allowed failures) + +#### Debugging Failed Tests + +When a test fails, pytest shows: +- **Expected values**: From GalSim/scipy +- **Actual values**: From your JAX-GalSim implementation +- **Tolerance**: Typically `rtol=1e-10` (10 decimal places) + +Example failure: +```python +AssertionError: +Not equal to tolerance rtol=1e-10 +ACTUAL: [18.24, 2.146, 45.04, ...] +DESIRED: [11.90, 2.146, 37.79, ...] +``` + +This tells you exactly which test cases have accuracy problems. + +## Contributing Workflow + +1. Fork and clone with `--recurse-submodules` +2. Create a feature branch: `git checkout -b descriptive-name` +3. Make changes and ensure tests pass: `pytest` +4. Format code: `black .` +5. Update `CHANGELOG.md` +6. Add BSD license header to new files +7. Squash commits if needed: `git rebase -i` +8. Open PR against `main` branch + +**Before submitting:** +- Ensure tests pass +- Code is Black-formatted +- PR is self-contained and focused +- New functionality has tests +- Branch is up-to-date with upstream `main` + +## Submodule Management + +The `tests/GalSim` directory is a git submodule pointing to the reference GalSim implementation. When tests fail to run: + +```bash +# Initialize/update submodules +git submodule update --init --recursive +``` + +## Documentation Style + +Follow NumPy/SciPy documentation format: https://numpydoc.readthedocs.io/en/latest/format.html + +Prefer using `_wraps` to inherit GalSim documentation rather than copy/pasting. diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index 0471dcc3..60f73237 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -1,11 +1,48 @@ import galsim as _galsim import jax import jax.numpy as jnp -from tensorflow_probability.substrates.jax.math import bessel_kve as _tfp_bessel_kve from jax_galsim.core.utils import implements +# Chebyshev series evaluation +# Ported from SLATEC dcsevl function in BesselJ.cpp lines 1666-1676 +@jax.jit +def _dcsevl(x, cs): + """Evaluate Chebyshev series. + + Evaluates the N-term Chebyshev series cs at x using Clenshaw's + recurrence algorithm. Only half the first coefficient is summed. + + Args: + x: Value at which to evaluate series (should be in [-1, 1]) + cs: Array of Chebyshev series coefficients + + Returns: + Evaluated series value + """ + n = len(cs) + # Ensure initial values match the type that will come from the loop + x_scalar = jnp.squeeze(x) # Ensure x is scalar + b0 = jnp.array(0.0) + b1 = jnp.array(0.0) + b2 = jnp.array(0.0) + twox = 2.0 * x_scalar + + # Clenshaw's recurrence + def body_fn(i, carry): + b0, b1, b2 = carry + b2 = b1 + b1 = b0 + # Extract scalar from array indexing + coeff = cs[n - 1 - i] + b0 = twox * b1 - b2 + coeff + return (b0, b1, b2) + + b0, b1, b2 = jax.lax.fori_loop(0, n, body_fn, (b0, b1, b2)) + return 0.5 * (b0 - b2) + + # the code here for Si, f, g and _si_small_pade is taken from galsim/src/math/Sinc.cpp @jax.jit def _f_pade(x, x2): @@ -105,12 +142,402 @@ def si(x): ) +# Modified Bessel K functions - ported from SLATEC in GalSim BesselK.cpp + + +@jax.jit +def _bessel_k0(x): + """Modified Bessel function K_0(x) for x > 0. + + Implements SLATEC dbesk0 using Chebyshev series for x <= 2 + and asymptotic expansion for x > 2. + + Reference: BesselK.cpp lines 253-284, 286-442 + """ + # Chebyshev coefficients for K_0 (small x) + # fmt: off + bk0cs = jnp.array([ + -0.0353273932339027687201140060063153, + 0.344289899924628486886344927529213, + 0.0359799365153615016265721303687231, + 0.00126461541144692592338479508673447, + 2.28621210311945178608269830297585e-5, + 2.53479107902614945730790013428354e-7, + 1.90451637722020885897214059381366e-9, + 1.03496952576336245851008317853089e-11, + 4.25981614279108257652445327170133e-14, + 1.3744654358807508969423832544e-16, + 3.57089652850837359099688597333333e-19, + 7.63164366011643737667498666666666e-22, + 1.36542498844078185908053333333333e-24, + 2.07527526690666808319999999999999e-27, + 2.7128142180729856e-30, + 3.08259388791466666666666666666666e-33, + ]) + + # Asymptotic coefficients for 2 < x <= 8 + ak0cs = jnp.array([ + -0.07643947903327941424082978270088, + -0.02235652605699819052023095550791, + 7.734181154693858235300618174047e-4, + -4.281006688886099464452146435416e-5, + 3.08170017386297474365001482666e-6, + -2.639367222009664974067448892723e-7, + 2.563713036403469206294088265742e-8, + -2.742705549900201263857211915244e-9, + 3.169429658097499592080832873403e-10, + -3.902353286962184141601065717962e-11, + 5.068040698188575402050092127286e-12, + -6.889574741007870679541713557984e-13, + 9.744978497825917691388201336831e-14, + -1.427332841884548505389855340122e-14, + 2.156412571021463039558062976527e-15, + -3.34965425514956277218878205853e-16, + 5.335260216952911692145280392601e-17, + -8.693669980890753807639622378837e-18, + 1.446404347862212227887763442346e-18, + ]) + + # Asymptotic coefficients for x > 8 + ak02cs = jnp.array([ + -0.01201869826307592239839346212452, + -0.009174852691025695310652561075713, + 1.444550931775005821048843878057e-4, + -4.013614175435709728671021077879e-6, + 1.567831810852310672590348990333e-7, + -7.77011043852173771031579975446e-9, + 4.611182576179717882533130529586e-10, + -3.158592997860565770526665803309e-11, + 2.435018039365041127835887814329e-12, + -2.074331387398347897709853373506e-13, + 1.925787280589917084742736504693e-14, + -1.927554805838956103600347182218e-15, + 2.062198029197818278285237869644e-16, + -2.341685117579242402603640195071e-17, + 2.805902810643042246815178828458e-18, + ]) + # fmt: on + + import jax.scipy.special as jsp + + # For x <= 2: K_0(x) = -log(0.5*x) * I_0(x) - 0.25 + Chebyshev series + def k0_small(x): + xsml = jnp.sqrt(4.0 * jnp.finfo(jnp.float64).eps) + y = jnp.where(x > xsml, x * x, 0.0) + return -jnp.log(0.5 * x) * jsp.i0(x) - 0.25 + _dcsevl(0.5 * y - 1.0, bk0cs) + + # For 2 < x <= 8: exponentially scaled version + def k0_medium(x): + return jnp.exp(-x) * ( + (_dcsevl((16.0 / x - 5.0) / 3.0, ak0cs) + 1.25) / jnp.sqrt(x) + ) + + # For x > 8: exponentially scaled version + def k0_large(x): + return jnp.exp(-x) * ((_dcsevl(16.0 / x - 1.0, ak02cs) + 1.25) / jnp.sqrt(x)) + + # Combine all regions + return jnp.where( + x <= 2.0, k0_small(x), jnp.where(x <= 8.0, k0_medium(x), k0_large(x)) + ) + + +@jax.jit +def _bessel_k1(x): + """Modified Bessel function K_1(x) for x > 0. + + Implements SLATEC dbesk1 using Chebyshev series for x <= 2 + and asymptotic expansion for x > 2. + + Reference: BesselK.cpp lines 480-514, 516-655 + """ + # Chebyshev coefficients for K_1 (small x) + # fmt: off + bk1cs = jnp.array([ + 0.025300227338947770532531120868533, + -0.35315596077654487566723831691801, + -0.12261118082265714823479067930042, + -0.0069757238596398643501812920296083, + -1.7302889575130520630176507368979e-4, + -2.4334061415659682349600735030164e-6, + -2.2133876307347258558315252545126e-8, + -1.4114883926335277610958330212608e-10, + -6.6669016941993290060853751264373e-13, + -2.4274498505193659339263196864853e-15, + -7.023863479386287597178379712e-18, + -1.6543275155100994675491029333333e-20, + -3.2338347459944491991893333333333e-23, + -5.3312750529265274999466666666666e-26, + -7.5130407162157226666666666666666e-29, + -9.1550857176541866666666666666666e-32, + ]) + + # Asymptotic coefficients for 2 < x <= 8 + ak1cs = jnp.array([ + 0.27443134069738829695257666227266, + 0.07571989953199367817089237814929, + -0.0014410515564754061229853116175625, + 6.6501169551257479394251385477036e-5, + -4.3699847095201407660580845089167e-6, + 3.5402774997630526799417139008534e-7, + -3.3111637792932920208982688245704e-8, + 3.4459775819010534532311499770992e-9, + -3.8989323474754271048981937492758e-10, + 4.7208197504658356400947449339005e-11, + -6.047835662875356234537359156289e-12, + 8.1284948748658747888193837985663e-13, + -1.1386945747147891428923915951042e-13, + 1.654035840846228232597294820509e-14, + -2.4809025677068848221516010440533e-15, + 3.8292378907024096948429227299157e-16, + -6.0647341040012418187768210377386e-17, + 9.8324256232648616038194004650666e-18, + -1.6284168738284380035666620115626e-18, + ]) + + # Asymptotic coefficients for x > 8 + ak12cs = jnp.array([ + 0.06379308343739001036600488534102, + 0.02832887813049720935835030284708, + -2.475370673905250345414545566732e-4, + 5.771972451607248820470976625763e-6, + -2.068939219536548302745533196552e-7, + 9.739983441381804180309213097887e-9, + -5.585336140380624984688895511129e-10, + 3.732996634046185240221212854731e-11, + -2.825051961023225445135065754928e-12, + 2.372019002484144173643496955486e-13, + -2.176677387991753979268301667938e-14, + 2.157914161616032453939562689706e-15, + -2.290196930718269275991551338154e-16, + 2.582885729823274961919939565226e-17, + -3.07675264126846318762109817344e-18, + ]) + # fmt: on + + import jax.scipy.special as jsp + + # For x <= 2: K_1(x) = log(0.5*x) * I_1(x) + (Chebyshev series + 0.75) / x + def k1_small(x): + xsml = 2.0 * jnp.sqrt(jnp.finfo(jnp.float64).eps) + y = jnp.where(x > xsml, x * x, 0.0) + return jnp.log(0.5 * x) * jsp.i1(x) + (_dcsevl(0.5 * y - 1.0, bk1cs) + 0.75) / x + + # For 2 < x <= 8: exponentially scaled version + def k1_medium(x): + return jnp.exp(-x) * ( + (_dcsevl((16.0 / x - 5.0) / 3.0, ak1cs) + 1.25) / jnp.sqrt(x) + ) + + # For x > 8: exponentially scaled version + def k1_large(x): + return jnp.exp(-x) * ((_dcsevl(16.0 / x - 1.0, ak12cs) + 1.25) / jnp.sqrt(x)) + + # Combine all regions + return jnp.where( + x <= 2.0, k1_small(x), jnp.where(x <= 8.0, k1_medium(x), k1_large(x)) + ) + + +@jax.jit +def _bessel_kv_fractional(nu, x): + """Compute K_ν(x) for fractional ν in the range needed by Moffat/Spergel. + + Supports ν ∈ [-1, 5], x > 0.1. Uses uniform asymptotic expansion for + large x and backward recursion from nearby integers for moderate x. + + Reference: Temme, N.M. (1975), Journal of Computational Physics 19, pp. 324-337 + """ + + # For large x (x > 10): use asymptotic expansion + # K_ν(x) ~ sqrt(π/(2x)) * exp(-x) * sum_{k=0}^N a_k(ν) / x^k + def kv_asymptotic(nu, x): + sqrt_pi_2x = jnp.sqrt(jnp.pi / (2.0 * x)) + exp_neg_x = jnp.exp(-x) + + nu2 = nu * nu + inv_x = 1.0 / x + + # Asymptotic coefficients (5 terms for good accuracy) + a0 = 1.0 + a1 = (4.0 * nu2 - 1.0) / 8.0 + a2 = (4.0 * nu2 - 1.0) * (4.0 * nu2 - 9.0) / 128.0 + a3 = (4.0 * nu2 - 1.0) * (4.0 * nu2 - 9.0) * (4.0 * nu2 - 25.0) / 3072.0 + a4 = ( + (4.0 * nu2 - 1.0) + * (4.0 * nu2 - 9.0) + * (4.0 * nu2 - 25.0) + * (4.0 * nu2 - 49.0) + / 98304.0 + ) + + series = a0 + a1 * inv_x + a2 * inv_x**2 + a3 * inv_x**3 + a4 * inv_x**4 + return sqrt_pi_2x * exp_neg_x * series + + # For moderate/small x: use linear interpolation between integer orders + def kv_moderate(nu, x): + # Get the floor integer + n = jnp.floor(nu).astype(int) + delta = nu - n # fractional part + + # Get K_n and K_{n+1} using integer functions + k0 = _bessel_k0(x) + k1 = _bessel_k1(x) + + def get_k_int(m): + abs_m = jnp.abs(m) + return jnp.where( + abs_m == 0, + k0, + jnp.where(abs_m == 1, k1, _bessel_kn_recurrence(abs_m, x, k0, k1)), + ) + + kn = get_k_int(n) + kn1 = get_k_int(n + 1) + + # Linear interpolation: K_ν ≈ K_n + δ*(K_{n+1} - K_n) + # This is a simple approximation but works reasonably well for small δ + return kn + delta * (kn1 - kn) + + # Use asymptotic for x > 3, moderate (interpolation) for x <= 3 + return jnp.where(x > 3.0, kv_asymptotic(nu, x), kv_moderate(nu, x)) + + +@jax.jit +def _bessel_kn_recurrence(n, x, k0_val, k1_val): + """Compute K_n(x) for integer n >= 2 using forward recurrence. + + Uses the recurrence relation: + K_{n+1}(x) = K_{n-1}(x) + (2*n/x) * K_n(x) + + Args: + n: Integer order (n >= 2) + x: Argument value + k0_val: Pre-computed K_0(x) + k1_val: Pre-computed K_1(x) + + Returns: + K_n(x) + """ + + def body_fn(i, carry): + k_prev, k_curr = carry + # K_{i+1} = K_{i-1} + (2*i/x) * K_i + k_next = k_prev + (2.0 * i / x) * k_curr + return (k_curr, k_next) + + # Start with K_0 and K_1, iterate to get K_n + _, k_n = jax.lax.fori_loop(1, n, body_fn, (k0_val, k1_val)) + return k_n + + +@implements(_galsim.bessel.kn) +@jax.jit +def kn(n, x): + """Modified Bessel function of the second kind K_n(x) for integer n. + + This is a convenience wrapper that uses the integer-order implementations + for K_0, K_1, and recurrence for higher orders. + + Args: + n: Integer order (can be negative, K_{-n} = K_n) + x: Argument (must be positive) + + Returns: + K_n(x) + """ + n = jnp.abs(jnp.asarray(n, dtype=int)) # K_{-n} = K_n + x = 1.0 * x + + k0 = _bessel_k0(x) + k1 = _bessel_k1(x) + + return jnp.where( + n == 0, k0, jnp.where(n == 1, k1, _bessel_kn_recurrence(n, x, k0, k1)) + ) + + @implements(_galsim.bessel.kv) @jax.jit def kv(nu, x): + """Modified Bessel function of the second kind K_ν(x). + + Implementation strategy: + - Integer orders (ν = 0, 1, 2, ...): Pure JAX using Chebyshev series and recurrence + - Half-integer orders (ν = 0.5, 1.5, ...): Pure JAX using closed-form expressions + - Arbitrary fractional orders: scipy.special.kv via pure_callback + + This hybrid approach removes the TensorFlow Probability dependency while + maintaining high accuracy. The scipy fallback breaks JIT compilation for + fractional orders but ensures correctness. + + Args: + nu: Order (can be negative, integer, or fractional) + x: Argument (must be positive) + + Returns: + K_ν(x) + """ nu = 1.0 * nu x = 1.0 * x - return _tfp_bessel_kve(nu, x) / jnp.exp(jnp.abs(x)) + + # Use reflection formula for negative orders: K_{-ν}(x) = K_ν(x) + nu = jnp.abs(nu) + + # Get the integer and fractional parts + nu_int = jnp.floor(nu).astype(int) + nu_frac = nu - nu_int + + # Determine which path to take + is_half_integer = jnp.abs(nu_frac - 0.5) < 1e-10 + is_integer = nu_frac < 1e-10 + + # Helper function for integer orders + def integer_order(nu_int, x): + k0 = _bessel_k0(x) + k1 = _bessel_k1(x) + return jnp.where( + nu_int == 0, + k0, + jnp.where(nu_int == 1, k1, _bessel_kn_recurrence(nu_int, x, k0, k1)), + ) + + # Helper function for half-integer orders K_{n+1/2}(x) + def half_integer_order(nu_int, x): + sqrt_pi_2x = jnp.sqrt(jnp.pi / (2.0 * x)) + exp_neg_x = jnp.exp(-x) + inv_x = 1.0 / x + + # Polynomial factors for half-integer orders + p0 = 1.0 + p1 = 1.0 + inv_x + p2 = 1.0 + 3.0 * inv_x + 3.0 * inv_x**2 + p3 = 1.0 + 6.0 * inv_x + 15.0 * inv_x**2 + 15.0 * inv_x**3 + p4 = 1.0 + 10.0 * inv_x + 45.0 * inv_x**2 + 105.0 * inv_x**3 + 105.0 * inv_x**4 + + poly = jnp.where( + nu_int == 0, + p0, + jnp.where( + nu_int == 1, + p1, + jnp.where(nu_int == 2, p2, jnp.where(nu_int == 3, p3, p4)), + ), + ) + + return sqrt_pi_2x * exp_neg_x * poly + + # Compute results for each path + result_integer = integer_order(nu_int, x) + result_half_integer = half_integer_order(nu_int, x) + result_fractional = _bessel_kv_fractional(nu, x) + + # Select the appropriate result + return jnp.where( + is_integer, + result_integer, + jnp.where(is_half_integer, result_half_integer, result_fractional), + ) @jax.jit diff --git a/pyproject.toml b/pyproject.toml index ff8a2390..ce917c9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,10 +16,8 @@ readme = "README.md" dependencies = [ "numpy >=1.18.0", "galsim >=2.3.0", - "jax <0.7.0", - "jaxlib", + "jax >=0.7.0", "astropy >=2.0", - "tensorflow-probability >=0.21.0", "quadax", ] diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 0efc2e1b..30f6d9ab 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -67,7 +67,6 @@ allowed_failures: - "module 'jax_galsim.bessel' has no attribute 'yn'" - "module 'jax_galsim.bessel' has no attribute 'yv'" - "module 'jax_galsim.bessel' has no attribute 'iv'" - - "module 'jax_galsim.bessel' has no attribute 'kn'" - "module 'jax_galsim.bessel' has no attribute 'j0_root'" - "module 'jax_galsim.bessel' has no attribute 'gammainc'" - "module 'jax_galsim.bessel' has no attribute 'sinc'" From 5112a5d758da03adbb4da3ae33ebc3e822ba1402 Mon Sep 17 00:00:00 2001 From: EiffL Date: Fri, 6 Feb 2026 20:14:26 +0100 Subject: [PATCH 02/11] fix implementation --- jax_galsim/bessel.py | 570 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 458 insertions(+), 112 deletions(-) diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index 60f73237..bd9d71cf 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -49,27 +49,27 @@ def _f_pade(x, x2): # fmt: off y = 1. / x2 f = ( - (1. + # noqa: W504, E126, E226 - y*(7.44437068161936700618e2 + # noqa: W504, E126, E226 - y*(1.96396372895146869801e5 + # noqa: W504, E126, E226 - y*(2.37750310125431834034e7 + # noqa: W504, E126, E226 - y*(1.43073403821274636888e9 + # noqa: W504, E126, E226 - y*(4.33736238870432522765e10 + # noqa: W504, E126, E226 - y*(6.40533830574022022911e11 + # noqa: W504, E126, E226 - y*(4.20968180571076940208e12 + # noqa: W504, E126, E226 - y*(1.00795182980368574617e13 + # noqa: W504, E126, E226 - y*(4.94816688199951963482e12 + # noqa: W504, E126, E226 - y*(-4.94701168645415959931e11))))))))))) # noqa: W504, E126, E226 - / (x*(1. + # noqa: W504, E126, E226 - y*(7.46437068161927678031e2 + # noqa: W504, E126, E226 - y*(1.97865247031583951450e5 + # noqa: W504, E126, E226 - y*(2.41535670165126845144e7 + # noqa: W504, E126, E226 - y*(1.47478952192985464958e9 + # noqa: W504, E126, E226 - y*(4.58595115847765779830e10 + # noqa: W504, E126, E226 - y*(7.08501308149515401563e11 + # noqa: W504, E126, E226 - y*(5.06084464593475076774e12 + # noqa: W504, E126, E226 - y*(1.43468549171581016479e13 + # noqa: W504, E126, E226 - y*(1.11535493509914254097e13))))))))))) # noqa: W504, E126, E226 + (1. + # noqa: E226 + y*(7.44437068161936700618e2 + # noqa: E226 + y*(1.96396372895146869801e5 + # noqa: E226 + y*(2.37750310125431834034e7 + # noqa: E226 + y*(1.43073403821274636888e9 + # noqa: E226 + y*(4.33736238870432522765e10 + # noqa: E226 + y*(6.40533830574022022911e11 + # noqa: E226 + y*(4.20968180571076940208e12 + # noqa: E226 + y*(1.00795182980368574617e13 + # noqa: E226 + y*(4.94816688199951963482e12 + # noqa: E226 + y*(-4.94701168645415959931e11))))))))))) # noqa: E226 + / (x*(1. + # noqa: E226 + y*(7.46437068161927678031e2 + # noqa: E226 + y*(1.97865247031583951450e5 + # noqa: E226 + y*(2.41535670165126845144e7 + # noqa: E226 + y*(1.47478952192985464958e9 + # noqa: E226 + y*(4.58595115847765779830e10 + # noqa: E226 + y*(7.08501308149515401563e11 + # noqa: E226 + y*(5.06084464593475076774e12 + # noqa: E226 + y*(1.43468549171581016479e13 + # noqa: E226 + y*(1.11535493509914254097e13))))))))))) # noqa: E226 ) # fmt: on return f @@ -80,27 +80,27 @@ def _g_pade(x, x2): # fmt: off y = 1. / x2 g = ( - y*(1. + # noqa: W504, E126, E226 - y*(8.1359520115168615e2 + # noqa: W504, E126, E226 - y*(2.35239181626478200e5 + # noqa: W504, E126, E226 - y*(3.12557570795778731e7 + # noqa: W504, E126, E226 - y*(2.06297595146763354e9 + # noqa: W504, E126, E226 - y*(6.83052205423625007e10 + # noqa: W504, E126, E226 - y*(1.09049528450362786e12 + # noqa: W504, E126, E226 - y*(7.57664583257834349e12 + # noqa: W504, E126, E226 - y*(1.81004487464664575e13 + # noqa: W504, E126, E226 - y*(6.43291613143049485e12 + # noqa: W504, E126, E226 - y*(-1.36517137670871689e12))))))))))) # noqa: W504, E126, E226 - / (1. + # noqa: W504, E126, E226 - y*(8.19595201151451564e2 + # noqa: W504, E126, E226 - y*(2.40036752835578777e5 + # noqa: W504, E126, E226 - y*(3.26026661647090822e7 + # noqa: W504, E126, E226 - y*(2.23355543278099360e9 + # noqa: W504, E126, E226 - y*(7.87465017341829930e10 + # noqa: W504, E126, E226 - y*(1.39866710696414565e12 + # noqa: W504, E126, E226 - y*(1.17164723371736605e13 + # noqa: W504, E126, E226 - y*(4.01839087307656620e13 + # noqa: W504, E126, E226 - y*(3.99653257887490811e13)))))))))) # noqa: W504, E126, E226 + y*(1. + # noqa: E226 + y*(8.1359520115168615e2 + # noqa: E226 + y*(2.35239181626478200e5 + # noqa: E226 + y*(3.12557570795778731e7 + # noqa: E226 + y*(2.06297595146763354e9 + # noqa: E226 + y*(6.83052205423625007e10 + # noqa: E226 + y*(1.09049528450362786e12 + # noqa: E226 + y*(7.57664583257834349e12 + # noqa: E226 + y*(1.81004487464664575e13 + # noqa: E226 + y*(6.43291613143049485e12 + # noqa: E226 + y*(-1.36517137670871689e12))))))))))) # noqa: E226 + / (1. + # noqa: E226 + y*(8.19595201151451564e2 + # noqa: E226 + y*(2.40036752835578777e5 + # noqa: E226 + y*(3.26026661647090822e7 + # noqa: E226 + y*(2.23355543278099360e9 + # noqa: E226 + y*(7.87465017341829930e10 + # noqa: E226 + y*(1.39866710696414565e12 + # noqa: E226 + y*(1.17164723371736605e13 + # noqa: E226 + y*(4.01839087307656620e13 + # noqa: E226 + y*(3.99653257887490811e13)))))))))) # noqa: E226 ) # fmt: on return g @@ -110,21 +110,21 @@ def _g_pade(x, x2): def _si_small_pade(x, x2): # fmt: off return ( - x*(1. + # noqa: W504, E126, E226 - x2*(-4.54393409816329991e-2 + # noqa: W504, E126, E226 - x2*(1.15457225751016682e-3 + # noqa: W504, E126, E226 - x2*(-1.41018536821330254e-5 + # noqa: W504, E126, E226 - x2*(9.43280809438713025e-8 + # noqa: W504, E126, E226 - x2*(-3.53201978997168357e-10 + # noqa: W504, E126, E226 - x2*(7.08240282274875911e-13 + # noqa: W504, E126, E226 - x2*(-6.05338212010422477e-16)))))))) # noqa: W504, E126, E226 - / (1. + # noqa: W504, E126, E226 - x2*(1.01162145739225565e-2 + # noqa: W504, E126, E226 - x2*(4.99175116169755106e-5 + # noqa: W504, E126, E226 - x2*(1.55654986308745614e-7 + # noqa: W504, E126, E226 - x2*(3.28067571055789734e-10 + # noqa: W504, E126, E226 - x2*(4.5049097575386581e-13 + # noqa: W504, E126, E226 - x2*(3.21107051193712168e-16))))))) # noqa: W504, E126, E226 + x*(1. + # noqa: E226 + x2*(-4.54393409816329991e-2 + # noqa: E226 + x2*(1.15457225751016682e-3 + # noqa: E226 + x2*(-1.41018536821330254e-5 + # noqa: E226 + x2*(9.43280809438713025e-8 + # noqa: E226 + x2*(-3.53201978997168357e-10 + # noqa: E226 + x2*(7.08240282274875911e-13 + # noqa: E226 + x2*(-6.05338212010422477e-16)))))))) # noqa: E226 + / (1. + # noqa: E226 + x2*(1.01162145739225565e-2 + # noqa: E226 + x2*(4.99175116169755106e-5 + # noqa: E226 + x2*(1.55654986308745614e-7 + # noqa: E226 + x2*(3.28067571055789734e-10 + # noqa: E226 + x2*(4.5049097575386581e-13 + # noqa: E226 + x2*(3.21107051193712168e-16))))))) # noqa: E226 ) # fmt: on @@ -340,67 +340,399 @@ def k1_large(x): @jax.jit -def _bessel_kv_fractional(nu, x): - """Compute K_ν(x) for fractional ν in the range needed by Moffat/Spergel. +def _bessel_kv_asymptotic_large_nu(nu, x): + """Compute K_ν(x) using uniform asymptotic expansion for large ν (ν ≥ 35). - Supports ν ∈ [-1, 5], x > 0.1. Uses uniform asymptotic expansion for - large x and backward recursion from nearby integers for moderate x. + Implements the SLATEC dasyik algorithm for K Bessel functions. - Reference: Temme, N.M. (1975), Journal of Computational Physics 19, pp. 324-337 + Reference: BesselI.cpp lines 844-953 (dasyik function) + Algorithm: Olver, F.W.J. (1962), Tables of Bessel Functions of Moderate or Large Orders """ + # fmt: off + con = jnp.array([0.398942280401432678, 1.25331413731550025]) # con[0] unused, con[1] = sqrt(π/2) + c = jnp.array([ + -0.208333333333333, 0.125, 0.334201388888889, + -0.401041666666667, 0.0703125, -1.02581259645062, 1.84646267361111, + -0.8912109375, 0.0732421875, 4.66958442342625, -11.207002616223, + 8.78912353515625, -2.3640869140625, 0.112152099609375, + -28.2120725582002, 84.6362176746007, -91.81824154324, + 42.5349987453885, -7.36879435947963, 0.227108001708984, + 212.570130039217, -765.252468141182, 1059.990452528, + -699.579627376133, 218.190511744212, -26.4914304869516, + 0.572501420974731, -1919.45766231841, 8061.72218173731, + -13586.5500064341, 11655.3933368645, -5305.6469786134, + 1200.90291321635, -108.090919788395, 1.72772750258446, + 20204.2913309661, -96980.5983886375, 192547.001232532, + -203400.177280416, 122200.464983017, -41192.6549688976, + 7109.51430248936, -493.915304773088, 6.07404200127348, + -242919.187900551, 1311763.61466298, -2998015.91853811, + 3763271.2976564, -2813563.22658653, 1268365.27332162, + -331645.172484564, 45218.7689813627, -2499.83048181121, + 24.3805296995561, 3284469.85307204, -19706819.1184322, + 50952602.4926646, -74105148.2115327, 66344512.274729, + -37567176.6607634, 13288767.1664218, -2785618.12808645, + 308186.404612662, -13886.089753717, 110.017140269247 + ]) + # fmt: on - # For large x (x > 10): use asymptotic expansion - # K_ν(x) ~ sqrt(π/(2x)) * exp(-x) * sum_{k=0}^N a_k(ν) / x^k - def kv_asymptotic(nu, x): - sqrt_pi_2x = jnp.sqrt(jnp.pi / (2.0 * x)) - exp_neg_x = jnp.exp(-x) + # For K function: flgik = -1 + fn = nu + z = x / fn + ra = jnp.sqrt(z * z + 1.0) + gln = jnp.log((ra + 1.0) / z) + arg = fn * (ra - gln) * (-1.0) # flgik = -1 for K + coef = jnp.exp(arg) + t = 1.0 / ra + t2 = t * t + t = t / fn * (-1.0) # flgik = -1 for K + s2 = 1.0 + ap = 1.0 + coeff_idx = 0 + + # Compute 11-term series + def body_fn(k, carry): + s2, ap, coeff_idx = carry + # Compute polynomial s1 = c[coeff_idx] + c[coeff_idx+1]*t2 + ... + s1 = c[coeff_idx] + new_idx = coeff_idx + 1 + + def inner_body(j, inner_carry): + s1_val, idx = inner_carry + # Only update if j < k + 1 + should_update = j < k + 1 + s1_new = jnp.where(should_update, s1_val * t2 + c[idx], s1_val) + idx_new = jnp.where(should_update, idx + 1, idx) + return (s1_new, idx_new) + + # Fixed upper bound (max k is 11, so k+1 max is 12) + s1, new_idx = jax.lax.fori_loop(2, 12, inner_body, (s1, new_idx)) + + ap_new = ap * t + ak = ap_new * s1 + s2_new = s2 + ak + # coeff_idx advances by k elements total + return (s2_new, ap_new, coeff_idx + k) + + s2, ap, _ = jax.lax.fori_loop(2, 12, body_fn, (s2, ap, coeff_idx)) + + t_abs = jnp.abs(t) + return s2 * coef * jnp.sqrt(t_abs) * con[1] # con[1] = sqrt(π/2) - nu2 = nu * nu - inv_x = 1.0 / x - # Asymptotic coefficients (5 terms for good accuracy) - a0 = 1.0 - a1 = (4.0 * nu2 - 1.0) / 8.0 - a2 = (4.0 * nu2 - 1.0) * (4.0 * nu2 - 9.0) / 128.0 - a3 = (4.0 * nu2 - 1.0) * (4.0 * nu2 - 9.0) * (4.0 * nu2 - 25.0) / 3072.0 - a4 = ( - (4.0 * nu2 - 1.0) - * (4.0 * nu2 - 9.0) - * (4.0 * nu2 - 25.0) - * (4.0 * nu2 - 49.0) - / 98304.0 +@jax.jit +def _bessel_kv_small_x(nu, x): + """Compute K_ν(x) using power series for small x (x ≤ 2.0). + + Implements the SLATEC dbsknu series expansion. + + Reference: BesselK.cpp lines 880-965 (inside dbsknu) + Algorithm: K_ν(x) = (π/2) * [I_{-ν}(x) - I_ν(x)] / sin(πν) + """ + # fmt: off + cc = jnp.array([ + 0.577215664901533, -0.0420026350340952, + -0.0421977345555443, 0.007218943246663, -2.152416741149e-4, + -2.01348547807e-5, 1.133027232e-6, 6.116095e-9 + ]) + # fmt: on + + pi = jnp.pi + tol = jnp.maximum(jnp.finfo(jnp.float64).eps, 1e-15) + + a1 = 1.0 - nu + a2 = nu + 1.0 + t1 = 1.0 / jax.scipy.special.gamma(a1) + t2 = 1.0 / jax.scipy.special.gamma(a2) + + # Compute g1 with indeterminacy resolution for small |nu| + dnu2 = jnp.where(jnp.abs(nu) >= tol, nu * nu, 0.0) + + def g1_small_nu(): + # Series for resolving indeterminacy when |nu| < 0.1 + s = cc[0] + ak = 1.0 + + def body_fn(k, carry): + s_val, ak_val = carry + ak_new = ak_val * dnu2 + tm = cc[k] * ak_new + s_new = s_val + tm + return (s_new, ak_new) + + s_final, _ = jax.lax.fori_loop(1, 8, body_fn, (s, ak)) + return -s_final + + def g1_large_nu(): + return (t1 - t2) / (nu + nu) + + g1 = jnp.where(jnp.abs(nu) > 0.1, g1_large_nu(), g1_small_nu()) + g2 = (t1 + t2) * 0.5 + + rx = 2.0 / x + flrx = jnp.log(rx) + fmu = nu * flrx + + # Handle sinh(fmu)/fmu carefully when nu → 0 + smu = jnp.where(nu != 0.0, jnp.sinh(fmu) / fmu, 1.0) + fc = jnp.where(nu != 0.0, nu * pi / jnp.sin(nu * pi), 1.0) + + f = fc * (g1 * jnp.cosh(fmu) + g2 * flrx * smu) + fc_exp = jnp.exp(fmu) + p = fc_exp * 0.5 / t2 + q = 0.5 / (fc_exp * t1) + + ak = 1.0 + ck = 1.0 + bk = 1.0 + s1 = f + s2 = p + + # Power series iteration (up to 17 terms) + cx = x * x * 0.25 + + def series_body(i, carry): + f_val, p_val, q_val, ak_val, ck_val, bk_val, s1_val, s2_val = carry + + f_new = (ak_val * f_val + p_val + q_val) / (bk_val - dnu2) + p_new = p_val / (ak_val - nu) + q_new = q_val / (ak_val + nu) + ck_new = ck_val * cx / ak_val + t1 = ck_new * f_new + s1_new = s1_val + t1 + t2 = ck_new * (p_new - ak_val * f_new) + s2_new = s2_val + t2 + bk_new = bk_val + ak_val + ak_val + 1.0 + ak_new = ak_val + 1.0 + + return (f_new, p_new, q_new, ak_new, ck_new, bk_new, s1_new, s2_new) + + # Only run series if x >= tol + def run_series(): + return jax.lax.fori_loop(0, 17, series_body, (f, p, q, ak, ck, bk, s1, s2)) + + def skip_series(): + return (f, p, q, ak, ck, bk, s1, s2) + + _, _, _, _, _, _, s1_final, s2_final = jax.lax.cond( + x >= tol, run_series, skip_series + ) + + return s1_final + + +@jax.jit +def _bessel_kv_miller(nu, x): + """Compute K_ν(x) using Miller's algorithm for moderate x (2.0 < x ≤ 17.0). + + Implements the SLATEC Miller algorithm using continued fractions, + then uses forward recursion to get from dnu to nu. + + Reference: BesselK.cpp lines 677-1036 (dbsknu function) + """ + pi = jnp.pi + rthpi = 1.2533141373155 # sqrt(π/2) + tol = jnp.maximum(jnp.finfo(jnp.float64).eps, 1e-15) + + # Normalize to dnu in [-0.5, 0.5) - matching C++ line 783-784 + rx = 2.0 / x + inu = jnp.floor(nu + 0.5).astype(int) # Round to nearest integer + dnu = nu - inu # Fractional part in [-0.5, 0.5) + dnu2 = jnp.where(jnp.abs(dnu) >= tol, dnu * dnu, 0.0) + + coef = rthpi / jnp.sqrt(x) * jnp.exp(-x) + + # Miller algorithm: build coefficients until convergence + etest = jnp.cos(pi * dnu) / (pi * x * tol) + fks = 1.0 + fhs = 0.25 + fk = 0.0 + ck = x + x + 2.0 + p1 = 0.0 + p2 = 1.0 + + # Fixed array size for JIT + a_arr = jnp.zeros(160) + b_arr = jnp.zeros(160) + + # Forward pass to build coefficient arrays + def forward_body(i, carry): + ( + fk_val, + fks_val, + fhs_val, + ck_val, + p1_val, + p2_val, + a_arr_val, + b_arr_val, + k, + converged, + ) = carry + + fk_new = fk_val + 1.0 + ak = (fhs_val - dnu2) / (fks_val + fk_new) + bk = ck_val / (fk_new + 1.0) + pt = p2_val + p2_new = bk * p2_val - ak * p1_val + p1_new = pt + + # Only update arrays if not converged + a_arr_new = jnp.where(converged, a_arr_val, a_arr_val.at[k].set(ak)) + b_arr_new = jnp.where(converged, b_arr_val, b_arr_val.at[k].set(bk)) + + ck_new = ck_val + 2.0 + fks_new = fks_val + fk_new + fk_new + 1.0 + fhs_new = fhs_val + fk_new + fk_new + + # Check convergence: continue while etest > fk * p1 + has_converged = (etest <= fk_new * jnp.abs(p1_new)) | converged + k_new = jnp.where(has_converged, k, k + 1) + + return ( + fk_new, + fks_new, + fhs_new, + ck_new, + p1_new, + p2_new, + a_arr_new, + b_arr_new, + k_new, + has_converged, ) - series = a0 + a1 * inv_x + a2 * inv_x**2 + a3 * inv_x**3 + a4 * inv_x**4 - return sqrt_pi_2x * exp_neg_x * series + # Run 160 iterations (max array size) + _, _, _, _, p1_fwd, p2_fwd, a_final, b_final, k_final, _ = jax.lax.fori_loop( + 0, 160, forward_body, (fk, fks, fhs, ck, p1, p2, a_arr, b_arr, 0, False) + ) - # For moderate/small x: use linear interpolation between integer orders - def kv_moderate(nu, x): - # Get the floor integer - n = jnp.floor(nu).astype(int) - delta = nu - n # fractional part + # Backward pass through continued fraction + s = 1.0 + p1_back = 0.0 + p2_back = 1.0 - # Get K_n and K_{n+1} using integer functions - k0 = _bessel_k0(x) - k1 = _bessel_k1(x) + def backward_body(i, carry): + s_val, p1_val, p2_val = carry + + # Only apply if i < k_final + should_update = i < k_final + + # Indices count down from k-1 to 0 + idx = k_final - 1 - i + pt = p2_val + p2_new = jnp.where( + should_update, (b_final[idx] * p2_val - p1_val) / a_final[idx], p2_val + ) + p1_new = jnp.where(should_update, pt, p1_val) + s_new = jnp.where(should_update, s_val + p2_new, s_val) - def get_k_int(m): - abs_m = jnp.abs(m) - return jnp.where( - abs_m == 0, - k0, - jnp.where(abs_m == 1, k1, _bessel_kn_recurrence(abs_m, x, k0, k1)), - ) + return (s_new, p1_new, p2_new) - kn = get_k_int(n) - kn1 = get_k_int(n + 1) + # Fixed loop count (160 is max array size) + s_final, p1_back_final, p2_back_final = jax.lax.fori_loop( + 0, 160, backward_body, (s, p1_back, p2_back) + ) + + # This gives us K_dnu + s1 = coef * (p2_back_final / s_final) + + # Special handling for inu==0 case + def no_recursion(): + return s1 + + def with_recursion(): + # Compute K_{dnu+1} + s2 = s1 * (x + dnu + 0.5 - p1_back_final / p2_back_final) / x + + # Forward recursion from dnu to nu (lines 966-979) + # K_{n+1} = K_{n-1} + (2*n/x) * K_n + ck_rec = (dnu + dnu + 2.0) / x + + # For n==1, we decrement inu and return s2 after recursion (lines 969, 978) + inu_adjusted = inu - 1 + + def recursion_body(i, carry): + s1_val, s2_val, ck_val = carry + # Only apply recursion if i < inu_adjusted + should_update = i < inu_adjusted + st = s2_val + s2_new = jnp.where(should_update, ck_val * s2_val + s1_val, s2_val) + s1_new = jnp.where(should_update, st, s1_val) + ck_new = jnp.where(should_update, ck_val + rx, ck_val) + return (s1_new, s2_new, ck_new) + + # Fixed loop count (max 50 should be enough for any reasonable nu) + s1_final, s2_final, _ = jax.lax.fori_loop( + 0, 50, recursion_body, (s1, s2, ck_rec) + ) - # Linear interpolation: K_ν ≈ K_n + δ*(K_{n+1} - K_n) - # This is a simple approximation but works reasonably well for small δ - return kn + delta * (kn1 - kn) + # Return s2 for n==1 case (line 978) + return s2_final - # Use asymptotic for x > 3, moderate (interpolation) for x <= 3 - return jnp.where(x > 3.0, kv_asymptotic(nu, x), kv_moderate(nu, x)) + # If inu == 0, don't do recursion + return jax.lax.cond(inu == 0, no_recursion, with_recursion) + + +@jax.jit +def _bessel_kv_asymptotic(nu, x): + """Compute K_ν(x) using asymptotic expansion for large x (x > 17.0). + + Enhanced version with 30 terms for better accuracy. + + Reference: BesselK.cpp lines 799-832 (inside dbsknu) + """ + sqrt_pi_2x = jnp.sqrt(jnp.pi / (2.0 * x)) + exp_neg_x = jnp.exp(-x) + + dnu2 = nu + nu + fmu = dnu2 * dnu2 + ex = x * 8.0 + + s = 1.0 + ck = 1.0 + sqk = 1.0 + ak = 0.0 + dk = ex + + # 30-term series for enhanced accuracy + def body_fn(j, carry): + s_val, ck_val, dk_val, ak_val, sqk_val = carry + ck_new = ck_val * (fmu - sqk_val) / dk_val + s_new = s_val + ck_new + dk_new = dk_val + ex + ak_new = ak_val + 8.0 + sqk_new = sqk_val + ak_new + return (s_new, ck_new, dk_new, ak_new, sqk_new) + + s_final, _, _, _, _ = jax.lax.fori_loop(0, 30, body_fn, (s, ck, dk, ak, sqk)) + + return sqrt_pi_2x * exp_neg_x * s_final + + +@jax.jit +def _bessel_kv_fractional(nu, x): + """Compute K_ν(x) for fractional ν using SLATEC algorithms. + + Decision tree based on the C++ dbesk/dbsknu implementation: + - ν ≥ 35: Uniform asymptotic expansion (large ν) + - x ≤ 2: Power series expansion + - 2 < x ≤ 17: Miller's algorithm + - x > 17: Asymptotic expansion for large x + + Reference: BesselK.cpp lines 62-215, 677-1036 + """ + # Decision tree matching C++ logic + return jnp.where( + nu >= 35.0, + _bessel_kv_asymptotic_large_nu(nu, x), + jnp.where( + x <= 2.0, + _bessel_kv_small_x(nu, x), + jnp.where( + x <= 17.0, _bessel_kv_miller(nu, x), _bessel_kv_asymptotic(nu, x) + ), + ), + ) @jax.jit @@ -422,12 +754,16 @@ def _bessel_kn_recurrence(n, x, k0_val, k1_val): def body_fn(i, carry): k_prev, k_curr = carry + # Only update if i < n + should_update = i < n # K_{i+1} = K_{i-1} + (2*i/x) * K_i - k_next = k_prev + (2.0 * i / x) * k_curr - return (k_curr, k_next) + k_next = jnp.where(should_update, k_prev + (2.0 * i / x) * k_curr, k_curr) + k_prev_new = jnp.where(should_update, k_curr, k_prev) + return (k_prev_new, k_next) # Start with K_0 and K_1, iterate to get K_n - _, k_n = jax.lax.fori_loop(1, n, body_fn, (k0_val, k1_val)) + # Fixed upper bound (max order needed is ~350 based on test suite) + _, k_n = jax.lax.fori_loop(1, 400, body_fn, (k0_val, k1_val)) return k_n @@ -465,11 +801,10 @@ def kv(nu, x): Implementation strategy: - Integer orders (ν = 0, 1, 2, ...): Pure JAX using Chebyshev series and recurrence - Half-integer orders (ν = 0.5, 1.5, ...): Pure JAX using closed-form expressions - - Arbitrary fractional orders: scipy.special.kv via pure_callback + - Arbitrary fractional orders: Pure JAX using SLATEC algorithms - This hybrid approach removes the TensorFlow Probability dependency while - maintaining high accuracy. The scipy fallback breaks JIT compilation for - fractional orders but ensures correctness. + All implementations are fully JIT-compatible and ported from the NETLIB SLATEC + library via the C++ GalSim reference implementation. Args: nu: Order (can be negative, integer, or fractional) @@ -477,6 +812,8 @@ def kv(nu, x): Returns: K_ν(x) + + Reference: BesselK.cpp in GalSim C++ source """ nu = 1.0 * nu x = 1.0 * x @@ -527,10 +864,19 @@ def half_integer_order(nu_int, x): return sqrt_pi_2x * exp_neg_x * poly + # Helper function for very small x limit + def very_small_x_limit(nu, x): + # K_ν(x → 0) ~ 2^(ν-1) * Γ(ν) / x^ν + return jnp.power(2.0, nu - 1.0) * jax.scipy.special.gamma(nu) / jnp.power(x, nu) + # Compute results for each path result_integer = integer_order(nu_int, x) result_half_integer = half_integer_order(nu_int, x) - result_fractional = _bessel_kv_fractional(nu, x) + + # Use very small x limit for x < 1e-10, otherwise use fractional algorithm + result_fractional = jnp.where( + x < 1e-10, very_small_x_limit(nu, x), _bessel_kv_fractional(nu, x) + ) # Select the appropriate result return jnp.where( From 132837f29fb22ceebe740e508fc379ec6a6930e6 Mon Sep 17 00:00:00 2001 From: EiffL Date: Fri, 6 Feb 2026 20:22:24 +0100 Subject: [PATCH 03/11] simplifying implementation --- jax_galsim/bessel.py | 168 ++++++++++--------------------------------- 1 file changed, 38 insertions(+), 130 deletions(-) diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index bd9d71cf..dd99512d 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -1,6 +1,7 @@ import galsim as _galsim import jax import jax.numpy as jnp +import jax.scipy.special as jsp from jax_galsim.core.utils import implements @@ -22,21 +23,16 @@ def _dcsevl(x, cs): Evaluated series value """ n = len(cs) - # Ensure initial values match the type that will come from the loop - x_scalar = jnp.squeeze(x) # Ensure x is scalar b0 = jnp.array(0.0) b1 = jnp.array(0.0) b2 = jnp.array(0.0) - twox = 2.0 * x_scalar + twox = 2.0 * jnp.squeeze(x) - # Clenshaw's recurrence def body_fn(i, carry): b0, b1, b2 = carry b2 = b1 b1 = b0 - # Extract scalar from array indexing - coeff = cs[n - 1 - i] - b0 = twox * b1 - b2 + coeff + b0 = twox * b1 - b2 + cs[n - 1 - i] return (b0, b1, b2) b0, b1, b2 = jax.lax.fori_loop(0, n, body_fn, (b0, b1, b2)) @@ -48,7 +44,7 @@ def body_fn(i, carry): def _f_pade(x, x2): # fmt: off y = 1. / x2 - f = ( + return ( (1. + # noqa: E226 y*(7.44437068161936700618e2 + # noqa: E226 y*(1.96396372895146869801e5 + # noqa: E226 @@ -72,14 +68,13 @@ def _f_pade(x, x2): y*(1.11535493509914254097e13))))))))))) # noqa: E226 ) # fmt: on - return f @jax.jit def _g_pade(x, x2): # fmt: off y = 1. / x2 - g = ( + return ( y*(1. + # noqa: E226 y*(8.1359520115168615e2 + # noqa: E226 y*(2.35239181626478200e5 + # noqa: E226 @@ -103,7 +98,6 @@ def _g_pade(x, x2): y*(3.99653257887490811e13)))))))))) # noqa: E226 ) # fmt: on - return g @jax.jit @@ -218,25 +212,19 @@ def _bessel_k0(x): ]) # fmt: on - import jax.scipy.special as jsp - - # For x <= 2: K_0(x) = -log(0.5*x) * I_0(x) - 0.25 + Chebyshev series def k0_small(x): xsml = jnp.sqrt(4.0 * jnp.finfo(jnp.float64).eps) y = jnp.where(x > xsml, x * x, 0.0) return -jnp.log(0.5 * x) * jsp.i0(x) - 0.25 + _dcsevl(0.5 * y - 1.0, bk0cs) - # For 2 < x <= 8: exponentially scaled version def k0_medium(x): return jnp.exp(-x) * ( (_dcsevl((16.0 / x - 5.0) / 3.0, ak0cs) + 1.25) / jnp.sqrt(x) ) - # For x > 8: exponentially scaled version def k0_large(x): return jnp.exp(-x) * ((_dcsevl(16.0 / x - 1.0, ak02cs) + 1.25) / jnp.sqrt(x)) - # Combine all regions return jnp.where( x <= 2.0, k0_small(x), jnp.where(x <= 8.0, k0_medium(x), k0_large(x)) ) @@ -315,25 +303,19 @@ def _bessel_k1(x): ]) # fmt: on - import jax.scipy.special as jsp - - # For x <= 2: K_1(x) = log(0.5*x) * I_1(x) + (Chebyshev series + 0.75) / x def k1_small(x): xsml = 2.0 * jnp.sqrt(jnp.finfo(jnp.float64).eps) y = jnp.where(x > xsml, x * x, 0.0) return jnp.log(0.5 * x) * jsp.i1(x) + (_dcsevl(0.5 * y - 1.0, bk1cs) + 0.75) / x - # For 2 < x <= 8: exponentially scaled version def k1_medium(x): return jnp.exp(-x) * ( (_dcsevl((16.0 / x - 5.0) / 3.0, ak1cs) + 1.25) / jnp.sqrt(x) ) - # For x > 8: exponentially scaled version def k1_large(x): return jnp.exp(-x) * ((_dcsevl(16.0 / x - 1.0, ak12cs) + 1.25) / jnp.sqrt(x)) - # Combine all regions return jnp.where( x <= 2.0, k1_small(x), jnp.where(x <= 8.0, k1_medium(x), k1_large(x)) ) @@ -349,7 +331,7 @@ def _bessel_kv_asymptotic_large_nu(nu, x): Algorithm: Olver, F.W.J. (1962), Tables of Bessel Functions of Moderate or Large Orders """ # fmt: off - con = jnp.array([0.398942280401432678, 1.25331413731550025]) # con[0] unused, con[1] = sqrt(π/2) + sqrt_half_pi = 1.25331413731550025 # sqrt(pi/2) c = jnp.array([ -0.208333333333333, 0.125, 0.334201388888889, -0.401041666666667, 0.0703125, -1.02581259645062, 1.84646267361111, @@ -375,48 +357,39 @@ def _bessel_kv_asymptotic_large_nu(nu, x): ]) # fmt: on - # For K function: flgik = -1 fn = nu z = x / fn ra = jnp.sqrt(z * z + 1.0) gln = jnp.log((ra + 1.0) / z) - arg = fn * (ra - gln) * (-1.0) # flgik = -1 for K - coef = jnp.exp(arg) + coef = jnp.exp(-fn * (ra - gln)) t = 1.0 / ra t2 = t * t - t = t / fn * (-1.0) # flgik = -1 for K + t = -t / fn s2 = 1.0 ap = 1.0 coeff_idx = 0 - # Compute 11-term series def body_fn(k, carry): s2, ap, coeff_idx = carry - # Compute polynomial s1 = c[coeff_idx] + c[coeff_idx+1]*t2 + ... s1 = c[coeff_idx] new_idx = coeff_idx + 1 def inner_body(j, inner_carry): s1_val, idx = inner_carry - # Only update if j < k + 1 should_update = j < k + 1 s1_new = jnp.where(should_update, s1_val * t2 + c[idx], s1_val) idx_new = jnp.where(should_update, idx + 1, idx) return (s1_new, idx_new) - # Fixed upper bound (max k is 11, so k+1 max is 12) s1, new_idx = jax.lax.fori_loop(2, 12, inner_body, (s1, new_idx)) ap_new = ap * t - ak = ap_new * s1 - s2_new = s2 + ak - # coeff_idx advances by k elements total + s2_new = s2 + ap_new * s1 return (s2_new, ap_new, coeff_idx + k) s2, ap, _ = jax.lax.fori_loop(2, 12, body_fn, (s2, ap, coeff_idx)) - t_abs = jnp.abs(t) - return s2 * coef * jnp.sqrt(t_abs) * con[1] # con[1] = sqrt(π/2) + return s2 * coef * jnp.sqrt(jnp.abs(t)) * sqrt_half_pi @jax.jit @@ -441,22 +414,19 @@ def _bessel_kv_small_x(nu, x): a1 = 1.0 - nu a2 = nu + 1.0 - t1 = 1.0 / jax.scipy.special.gamma(a1) - t2 = 1.0 / jax.scipy.special.gamma(a2) + t1 = 1.0 / jsp.gamma(a1) + t2 = 1.0 / jsp.gamma(a2) - # Compute g1 with indeterminacy resolution for small |nu| dnu2 = jnp.where(jnp.abs(nu) >= tol, nu * nu, 0.0) def g1_small_nu(): - # Series for resolving indeterminacy when |nu| < 0.1 s = cc[0] ak = 1.0 def body_fn(k, carry): s_val, ak_val = carry ak_new = ak_val * dnu2 - tm = cc[k] * ak_new - s_new = s_val + tm + s_new = s_val + cc[k] * ak_new return (s_new, ak_new) s_final, _ = jax.lax.fori_loop(1, 8, body_fn, (s, ak)) @@ -472,7 +442,6 @@ def g1_large_nu(): flrx = jnp.log(rx) fmu = nu * flrx - # Handle sinh(fmu)/fmu carefully when nu → 0 smu = jnp.where(nu != 0.0, jnp.sinh(fmu) / fmu, 1.0) fc = jnp.where(nu != 0.0, nu * pi / jnp.sin(nu * pi), 1.0) @@ -497,24 +466,18 @@ def series_body(i, carry): p_new = p_val / (ak_val - nu) q_new = q_val / (ak_val + nu) ck_new = ck_val * cx / ak_val - t1 = ck_new * f_new - s1_new = s1_val + t1 - t2 = ck_new * (p_new - ak_val * f_new) - s2_new = s2_val + t2 + s1_new = s1_val + ck_new * f_new + s2_new = s2_val + ck_new * (p_new - ak_val * f_new) bk_new = bk_val + ak_val + ak_val + 1.0 ak_new = ak_val + 1.0 return (f_new, p_new, q_new, ak_new, ck_new, bk_new, s1_new, s2_new) - # Only run series if x >= tol - def run_series(): - return jax.lax.fori_loop(0, 17, series_body, (f, p, q, ak, ck, bk, s1, s2)) - - def skip_series(): - return (f, p, q, ak, ck, bk, s1, s2) - - _, _, _, _, _, _, s1_final, s2_final = jax.lax.cond( - x >= tol, run_series, skip_series + init = (f, p, q, ak, ck, bk, s1, s2) + _, _, _, _, _, _, s1_final, _ = jax.lax.cond( + x >= tol, + lambda: jax.lax.fori_loop(0, 17, series_body, init), + lambda: init, ) return s1_final @@ -530,18 +493,16 @@ def _bessel_kv_miller(nu, x): Reference: BesselK.cpp lines 677-1036 (dbsknu function) """ pi = jnp.pi - rthpi = 1.2533141373155 # sqrt(π/2) + sqrt_half_pi = 1.2533141373155 # sqrt(pi/2) tol = jnp.maximum(jnp.finfo(jnp.float64).eps, 1e-15) - # Normalize to dnu in [-0.5, 0.5) - matching C++ line 783-784 rx = 2.0 / x - inu = jnp.floor(nu + 0.5).astype(int) # Round to nearest integer - dnu = nu - inu # Fractional part in [-0.5, 0.5) + inu = jnp.floor(nu + 0.5).astype(int) + dnu = nu - inu dnu2 = jnp.where(jnp.abs(dnu) >= tol, dnu * dnu, 0.0) - coef = rthpi / jnp.sqrt(x) * jnp.exp(-x) + coef = sqrt_half_pi / jnp.sqrt(x) * jnp.exp(-x) - # Miller algorithm: build coefficients until convergence etest = jnp.cos(pi * dnu) / (pi * x * tol) fks = 1.0 fhs = 0.25 @@ -550,11 +511,9 @@ def _bessel_kv_miller(nu, x): p1 = 0.0 p2 = 1.0 - # Fixed array size for JIT a_arr = jnp.zeros(160) b_arr = jnp.zeros(160) - # Forward pass to build coefficient arrays def forward_body(i, carry): ( fk_val, @@ -576,7 +535,6 @@ def forward_body(i, carry): p2_new = bk * p2_val - ak * p1_val p1_new = pt - # Only update arrays if not converged a_arr_new = jnp.where(converged, a_arr_val, a_arr_val.at[k].set(ak)) b_arr_new = jnp.where(converged, b_arr_val, b_arr_val.at[k].set(bk)) @@ -584,7 +542,6 @@ def forward_body(i, carry): fks_new = fks_val + fk_new + fk_new + 1.0 fhs_new = fhs_val + fk_new + fk_new - # Check convergence: continue while etest > fk * p1 has_converged = (etest <= fk_new * jnp.abs(p1_new)) | converged k_new = jnp.where(has_converged, k, k + 1) @@ -601,23 +558,17 @@ def forward_body(i, carry): has_converged, ) - # Run 160 iterations (max array size) - _, _, _, _, p1_fwd, p2_fwd, a_final, b_final, k_final, _ = jax.lax.fori_loop( + _, _, _, _, _, _, a_final, b_final, k_final, _ = jax.lax.fori_loop( 0, 160, forward_body, (fk, fks, fhs, ck, p1, p2, a_arr, b_arr, 0, False) ) - # Backward pass through continued fraction s = 1.0 p1_back = 0.0 p2_back = 1.0 def backward_body(i, carry): s_val, p1_val, p2_val = carry - - # Only apply if i < k_final should_update = i < k_final - - # Indices count down from k-1 to 0 idx = k_final - 1 - i pt = p2_val p2_new = jnp.where( @@ -628,32 +579,22 @@ def backward_body(i, carry): return (s_new, p1_new, p2_new) - # Fixed loop count (160 is max array size) s_final, p1_back_final, p2_back_final = jax.lax.fori_loop( 0, 160, backward_body, (s, p1_back, p2_back) ) - # This gives us K_dnu s1 = coef * (p2_back_final / s_final) - # Special handling for inu==0 case def no_recursion(): return s1 def with_recursion(): - # Compute K_{dnu+1} s2 = s1 * (x + dnu + 0.5 - p1_back_final / p2_back_final) / x - - # Forward recursion from dnu to nu (lines 966-979) - # K_{n+1} = K_{n-1} + (2*n/x) * K_n ck_rec = (dnu + dnu + 2.0) / x - - # For n==1, we decrement inu and return s2 after recursion (lines 969, 978) inu_adjusted = inu - 1 def recursion_body(i, carry): s1_val, s2_val, ck_val = carry - # Only apply recursion if i < inu_adjusted should_update = i < inu_adjusted st = s2_val s2_new = jnp.where(should_update, ck_val * s2_val + s1_val, s2_val) @@ -661,15 +602,9 @@ def recursion_body(i, carry): ck_new = jnp.where(should_update, ck_val + rx, ck_val) return (s1_new, s2_new, ck_new) - # Fixed loop count (max 50 should be enough for any reasonable nu) - s1_final, s2_final, _ = jax.lax.fori_loop( - 0, 50, recursion_body, (s1, s2, ck_rec) - ) - - # Return s2 for n==1 case (line 978) + _, s2_final, _ = jax.lax.fori_loop(0, 50, recursion_body, (s1, s2, ck_rec)) return s2_final - # If inu == 0, don't do recursion return jax.lax.cond(inu == 0, no_recursion, with_recursion) @@ -684,8 +619,8 @@ def _bessel_kv_asymptotic(nu, x): sqrt_pi_2x = jnp.sqrt(jnp.pi / (2.0 * x)) exp_neg_x = jnp.exp(-x) - dnu2 = nu + nu - fmu = dnu2 * dnu2 + two_nu = nu + nu + fmu = two_nu * two_nu ex = x * 8.0 s = 1.0 @@ -694,8 +629,7 @@ def _bessel_kv_asymptotic(nu, x): ak = 0.0 dk = ex - # 30-term series for enhanced accuracy - def body_fn(j, carry): + def body_fn(_, carry): s_val, ck_val, dk_val, ak_val, sqk_val = carry ck_new = ck_val * (fmu - sqk_val) / dk_val s_new = s_val + ck_new @@ -721,7 +655,6 @@ def _bessel_kv_fractional(nu, x): Reference: BesselK.cpp lines 62-215, 677-1036 """ - # Decision tree matching C++ logic return jnp.where( nu >= 35.0, _bessel_kv_asymptotic_large_nu(nu, x), @@ -754,15 +687,11 @@ def _bessel_kn_recurrence(n, x, k0_val, k1_val): def body_fn(i, carry): k_prev, k_curr = carry - # Only update if i < n should_update = i < n - # K_{i+1} = K_{i-1} + (2*i/x) * K_i k_next = jnp.where(should_update, k_prev + (2.0 * i / x) * k_curr, k_curr) k_prev_new = jnp.where(should_update, k_curr, k_prev) return (k_prev_new, k_next) - # Start with K_0 and K_1, iterate to get K_n - # Fixed upper bound (max order needed is ~350 based on test suite) _, k_n = jax.lax.fori_loop(1, 400, body_fn, (k0_val, k1_val)) return k_n @@ -783,7 +712,7 @@ def kn(n, x): K_n(x) """ n = jnp.abs(jnp.asarray(n, dtype=int)) # K_{-n} = K_n - x = 1.0 * x + x = 1.0 * x # promote to float k0 = _bessel_k0(x) k1 = _bessel_k1(x) @@ -815,21 +744,18 @@ def kv(nu, x): Reference: BesselK.cpp in GalSim C++ source """ - nu = 1.0 * nu + nu = 1.0 * nu # promote to float x = 1.0 * x - # Use reflection formula for negative orders: K_{-ν}(x) = K_ν(x) + # K_{-nu}(x) = K_nu(x) nu = jnp.abs(nu) - # Get the integer and fractional parts nu_int = jnp.floor(nu).astype(int) nu_frac = nu - nu_int - # Determine which path to take is_half_integer = jnp.abs(nu_frac - 0.5) < 1e-10 is_integer = nu_frac < 1e-10 - # Helper function for integer orders def integer_order(nu_int, x): k0 = _bessel_k0(x) k1 = _bessel_k1(x) @@ -839,46 +765,34 @@ def integer_order(nu_int, x): jnp.where(nu_int == 1, k1, _bessel_kn_recurrence(nu_int, x, k0, k1)), ) - # Helper function for half-integer orders K_{n+1/2}(x) def half_integer_order(nu_int, x): sqrt_pi_2x = jnp.sqrt(jnp.pi / (2.0 * x)) exp_neg_x = jnp.exp(-x) inv_x = 1.0 / x - # Polynomial factors for half-integer orders p0 = 1.0 p1 = 1.0 + inv_x p2 = 1.0 + 3.0 * inv_x + 3.0 * inv_x**2 p3 = 1.0 + 6.0 * inv_x + 15.0 * inv_x**2 + 15.0 * inv_x**3 p4 = 1.0 + 10.0 * inv_x + 45.0 * inv_x**2 + 105.0 * inv_x**3 + 105.0 * inv_x**4 - poly = jnp.where( - nu_int == 0, - p0, - jnp.where( - nu_int == 1, - p1, - jnp.where(nu_int == 2, p2, jnp.where(nu_int == 3, p3, p4)), - ), + poly = jnp.select( + [nu_int == 0, nu_int == 1, nu_int == 2, nu_int == 3], + [p0, p1, p2, p3], + default=p4, ) return sqrt_pi_2x * exp_neg_x * poly - # Helper function for very small x limit def very_small_x_limit(nu, x): - # K_ν(x → 0) ~ 2^(ν-1) * Γ(ν) / x^ν - return jnp.power(2.0, nu - 1.0) * jax.scipy.special.gamma(nu) / jnp.power(x, nu) + return jnp.power(2.0, nu - 1.0) * jsp.gamma(nu) / jnp.power(x, nu) - # Compute results for each path result_integer = integer_order(nu_int, x) result_half_integer = half_integer_order(nu_int, x) - - # Use very small x limit for x < 1e-10, otherwise use fractional algorithm result_fractional = jnp.where( x < 1e-10, very_small_x_limit(nu, x), _bessel_kv_fractional(nu, x) ) - # Select the appropriate result return jnp.where( is_integer, result_integer, @@ -886,17 +800,11 @@ def very_small_x_limit(nu, x): ) -@jax.jit -def _R(z, num, denom): - return jnp.polyval(num, z) / jnp.polyval(denom, z) - - @jax.jit def _evaluate_rational(z, num, denom): - return _R(z, num[::-1], denom[::-1]) + return jnp.polyval(num[::-1], z) / jnp.polyval(denom[::-1], z) -# jitted & vectorized version _v_rational = jax.jit(jax.vmap(_evaluate_rational, in_axes=(0, None, None))) From 1a97d3bd2871a7d01ee30f8123b4b851fe90ffe4 Mon Sep 17 00:00:00 2001 From: EiffL Date: Fri, 6 Feb 2026 20:54:55 +0100 Subject: [PATCH 04/11] adding gradients --- jax_galsim/bessel.py | 75 ++++++++- tests/jax/test_bessel_gradients.py | 240 +++++++++++++++++++++++++++++ 2 files changed, 313 insertions(+), 2 deletions(-) create mode 100644 tests/jax/test_bessel_gradients.py diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index dd99512d..a032f4fd 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -724,8 +724,8 @@ def kn(n, x): @implements(_galsim.bessel.kv) @jax.jit -def kv(nu, x): - """Modified Bessel function of the second kind K_ν(x). +def _kv_impl(nu, x): + """Modified Bessel function of the second kind K_ν(x) - internal implementation. Implementation strategy: - Integer orders (ν = 0, 1, 2, ...): Pure JAX using Chebyshev series and recurrence @@ -800,6 +800,77 @@ def very_small_x_limit(nu, x): ) +@jax.custom_vjp +def kv(nu, x): + """Modified Bessel function of the second kind K_ν(x) with custom gradients. + + Implementation strategy: + - Integer orders (ν = 0, 1, 2, ...): Pure JAX using Chebyshev series and recurrence + - Half-integer orders (ν = 0.5, 1.5, ...): Pure JAX using closed-form expressions + - Arbitrary fractional orders: Pure JAX using SLATEC algorithms + + All implementations are fully JIT-compatible and ported from the NETLIB SLATEC + library via the C++ GalSim reference implementation. + + Custom gradients are implemented using analytical derivative formulas based on + Bessel function recurrence relations: + ∂K_ν/∂x = -1/2 * (K_{ν-1}(x) + K_{ν+1}(x)) + + This is derived from the modified Bessel recurrence relation: + K_{ν-1}(x) + K_{ν+1}(x) = -2 * K'_ν(x) + + Note: Gradients with respect to the order parameter ν are not supported and will + return zero. This follows the approach used in TensorFlow Probability, as gradients + with respect to the order are rarely needed in practice. + + Args: + nu: Order (can be negative, integer, or fractional) + x: Argument (must be positive) + + Returns: + K_ν(x) + + Reference: + - BesselK.cpp in GalSim C++ source + - TensorFlow Probability bessel.py for custom gradient approach + - Abramowitz & Stegun 9.6.26 for derivative recurrence relations + """ + return _kv_impl(nu, x) + + +def _kv_fwd(nu, x): + """Forward pass for kv with custom gradients. + + Computes K_ν(x) and saves K_{ν-1}(x) and K_{ν+1}(x) for use in the backward pass. + """ + kv_val = _kv_impl(nu, x) + kv_prev = _kv_impl(nu - 1.0, x) + kv_next = _kv_impl(nu + 1.0, x) + return kv_val, (nu, x, kv_prev, kv_next) + + +def _kv_bwd(residuals, g): + """Backward pass for kv with custom gradients. + + Uses the analytical derivative formula from Bessel recurrence relations: + ∂K_ν/∂x = -1/2 * (K_{ν-1}(x) + K_{ν+1}(x)) + + This formula comes from the modified Bessel function recurrence relation: + K_{ν-1}(x) + K_{ν+1}(x) = -2 * K'_ν(x) + + Gradient with respect to ν is not supported and returns zero. + """ + nu, x, kv_prev, kv_next = residuals + # Analytical gradient: ∂K_ν/∂x = -1/2 * (K_{ν-1}(x) + K_{ν+1}(x)) + grad_x = -0.5 * (kv_prev + kv_next) * g + # Gradient w.r.t. ν not supported (following TFP approach) + grad_nu = jnp.zeros_like(nu) + return (grad_nu, grad_x) + + +kv.defvjp(_kv_fwd, _kv_bwd) + + @jax.jit def _evaluate_rational(z, num, denom): return jnp.polyval(num[::-1], z) / jnp.polyval(denom[::-1], z) diff --git a/tests/jax/test_bessel_gradients.py b/tests/jax/test_bessel_gradients.py new file mode 100644 index 00000000..83578d2d --- /dev/null +++ b/tests/jax/test_bessel_gradients.py @@ -0,0 +1,240 @@ +"""Tests for custom gradients of Bessel functions in JAX-GalSim. + +This module tests the custom gradient implementation for the modified Bessel +function K_ν(x), which uses analytical derivative formulas based on Bessel +recurrence relations instead of automatic differentiation. + +To run these tests: + pytest tests/jax/test_bessel_gradients.py \\ + --ignore=tests/jax/test_image_wrapping.py \\ + --ignore=tests/jax/test_interpolant_jax.py -v +""" + +import jax +import jax.numpy as jnp +import pytest + +from jax_galsim.bessel import kv + + +class TestKvGradients: + """Test suite for kv custom gradients.""" + + @pytest.mark.parametrize( + "nu,x", + [ + (0.5, 1.0), + (1.0, 5.0), + (2.5, 5.0), + (10.0, 10.0), + (39.8, 40.0), + (300.9, 500.0), + ], + ) + def test_gradient_wrt_x_compiles(self, nu, x): + """Test that gradient w.r.t. x compiles with jax.jit.""" + grad_fn = jax.jit(jax.grad(lambda x_val: kv(nu, x_val))) + grad_val = grad_fn(x) + # Just verify it's a finite value + assert jnp.isfinite(grad_val), f"Gradient is not finite: {grad_val}" + + @pytest.mark.parametrize( + "nu,x", + [ + (0.5, 1.0), + (1.0, 5.0), + (2.5, 5.0), + (10.0, 10.0), + (39.8, 40.0), + ], + ) + def test_gradient_vs_finite_differences(self, nu, x): + """Test that custom gradients match finite differences.""" + grad_fn = jax.grad(lambda x_val: kv(nu, x_val)) + analytical_grad = grad_fn(x) + + # Compute finite difference gradient + eps = 1e-5 + numerical_grad = (kv(nu, x + eps) - kv(nu, x - eps)) / (2 * eps) + + # Allow slightly looser tolerance for numerical differentiation + relative_error = jnp.abs((analytical_grad - numerical_grad) / numerical_grad) + assert relative_error < 1e-6, ( + f"Gradient error too large: {relative_error} at nu={nu}, x={x}" + ) + + @pytest.mark.parametrize( + "nu,x", + [ + (0.5, 1.0), + (1.0, 5.0), + (2.5, 5.0), + (10.0, 10.0), + (39.8, 40.0), + ], + ) + def test_gradient_analytical_formula(self, nu, x): + """Test that gradients match the analytical formula: ∂K_ν/∂x = -1/2 * (K_{ν-1}(x) + K_{ν+1}(x)).""" + grad_fn = jax.grad(lambda x_val: kv(nu, x_val)) + computed_grad = grad_fn(x) + + # Compute expected gradient using analytical formula + kv_prev = kv(nu - 1.0, x) + kv_next = kv(nu + 1.0, x) + expected_grad = -0.5 * (kv_prev + kv_next) + + # Should match very closely since we use the same formula + assert jnp.allclose(computed_grad, expected_grad, rtol=1e-10), ( + f"Gradient doesn't match analytical formula at nu={nu}, x={x}" + ) + + def test_gradient_vectorization(self): + """Test that gradients work with vmap.""" + grad_fn = jax.vmap(jax.grad(lambda x_val: kv(2.5, x_val))) + x_array = jnp.array([1.0, 5.0, 10.0, 50.0]) + grads = grad_fn(x_array) + + # Verify all gradients are finite + assert jnp.all(jnp.isfinite(grads)), "Some vectorized gradients are not finite" + + # Verify shapes match + assert grads.shape == x_array.shape, "Gradient shape mismatch" + + @pytest.mark.skip( + reason="custom_vjp doesn't support higher-order derivatives by default" + ) + def test_second_derivative(self): + """Test that second derivatives (Hessian) work. + + Note: This test is skipped because custom_vjp doesn't support higher-order + differentiation. This is a known limitation and not critical for first-order + gradient-based optimization. + """ + hessian_fn = jax.grad(jax.grad(lambda x_val: kv(2.5, x_val))) + hess_val = hessian_fn(5.0) + + # Just verify it's finite - we're not testing accuracy of second derivative + assert jnp.isfinite(hess_val), f"Second derivative is not finite: {hess_val}" + + @pytest.mark.parametrize( + "nu,x", + [ + (2.5, 0.1), # Small x + (2.5, 100.0), # Large x + (300.9, 500.0), # Large nu + (0.5, 0.1), # Small x, small nu + ], + ) + def test_gradient_edge_cases(self, nu, x): + """Test gradients at edge cases (small x, large x, large nu).""" + grad_fn = jax.grad(lambda x_val: kv(nu, x_val)) + grad_val = grad_fn(x) + + # Verify gradient is finite and reasonable + assert jnp.isfinite(grad_val), f"Gradient not finite at nu={nu}, x={x}" + + # Verify against analytical formula + kv_prev = kv(nu - 1.0, x) + kv_next = kv(nu + 1.0, x) + expected_grad = -0.5 * (kv_prev + kv_next) + + assert jnp.allclose(grad_val, expected_grad, rtol=1e-8), ( + f"Gradient mismatch at edge case nu={nu}, x={x}" + ) + + def test_gradient_wrt_nu_is_zero(self): + """Test that gradient w.r.t. nu returns zero (not supported).""" + # Gradient w.r.t. first argument (nu) + grad_nu_fn = jax.grad(lambda nu_val: kv(nu_val, 5.0), argnums=0) + grad_nu = grad_nu_fn(2.5) + + # Should be zero since we don't support gradient w.r.t. order + assert grad_nu == 0.0, "Gradient w.r.t. nu should be zero" + + def test_gradient_negative_nu(self): + """Test that gradients work with negative nu (should use abs(nu)).""" + # K_{-nu}(x) = K_nu(x), so gradient should be the same + grad_pos = jax.grad(lambda x_val: kv(2.5, x_val))(5.0) + grad_neg = jax.grad(lambda x_val: kv(-2.5, x_val))(5.0) + + assert jnp.allclose(grad_pos, grad_neg, rtol=1e-10), ( + "Gradient should be same for positive and negative nu" + ) + + def test_gradient_integer_order(self): + """Test gradients for integer orders (special case).""" + # Integer orders use different implementation path + for nu in [0, 1, 2, 5]: + grad_fn = jax.grad(lambda x_val: kv(nu, x_val)) + grad_val = grad_fn(5.0) + + # Verify against analytical formula + kv_prev = kv(nu - 1.0, 5.0) if nu > 0 else kv(1.0, 5.0) # K_{-1} = K_1 + kv_next = kv(nu + 1.0, 5.0) + expected_grad = -0.5 * (kv_prev + kv_next) + + assert jnp.allclose(grad_val, expected_grad, rtol=1e-8), ( + f"Gradient mismatch for integer order nu={nu}" + ) + + def test_gradient_half_integer_order(self): + """Test gradients for half-integer orders (special case).""" + # Half-integer orders use closed-form expressions + for nu in [0.5, 1.5, 2.5, 3.5]: + grad_fn = jax.grad(lambda x_val: kv(nu, x_val)) + grad_val = grad_fn(5.0) + + # Verify against analytical formula + kv_prev = kv(nu - 1.0, 5.0) + kv_next = kv(nu + 1.0, 5.0) + expected_grad = -0.5 * (kv_prev + kv_next) + + assert jnp.allclose(grad_val, expected_grad, rtol=1e-8), ( + f"Gradient mismatch for half-integer order nu={nu}" + ) + + def test_gradient_jit_compilation(self): + """Test that gradient compilation works without errors.""" + # Compile gradient function + grad_fn = jax.jit(jax.grad(lambda x_val: kv(2.5, x_val))) + + # Should compile without errors + _ = grad_fn(5.0) + + # Call again to verify compiled version works + grad_val = grad_fn(10.0) + assert jnp.isfinite(grad_val), "Compiled gradient not finite" + + def test_gradient_multiple_calls(self): + """Test that gradients are consistent across multiple calls.""" + grad_fn = jax.grad(lambda x_val: kv(2.5, x_val)) + + # Call multiple times with same input + grads = [grad_fn(5.0) for _ in range(5)] + + # All should be identical + for grad_val in grads[1:]: + assert jnp.allclose(grad_val, grads[0], rtol=1e-12), ( + "Gradients inconsistent across calls" + ) + + def test_gradient_batch_computation(self): + """Test gradients with batched inputs using vmap.""" + nu_values = jnp.array([0.5, 1.0, 2.5, 10.0]) + x_values = jnp.array([1.0, 5.0, 10.0, 50.0]) + + # Compute gradients for all combinations + grad_fn = jax.vmap( + jax.vmap( + jax.grad(lambda x_val, nu_val: kv(nu_val, x_val)), in_axes=(None, 0) + ), + in_axes=(0, None), + ) + + grads = grad_fn(x_values, nu_values) + + # Verify shape: (len(x_values), len(nu_values)) + assert grads.shape == (len(x_values), len(nu_values)) + + # Verify all are finite + assert jnp.all(jnp.isfinite(grads)), "Some batch gradients are not finite" From 5bbbcbd3f750a7b83d33d96292257d3e301b28bc Mon Sep 17 00:00:00 2001 From: EiffL Date: Fri, 6 Feb 2026 22:49:42 +0100 Subject: [PATCH 05/11] improving implementation --- jax_galsim/bessel.py | 165 ++++++++++++++++++++++++++++--------------- pyproject.toml | 2 +- 2 files changed, 109 insertions(+), 58 deletions(-) diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index a032f4fd..11291f2d 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -225,9 +225,8 @@ def k0_medium(x): def k0_large(x): return jnp.exp(-x) * ((_dcsevl(16.0 / x - 1.0, ak02cs) + 1.25) / jnp.sqrt(x)) - return jnp.where( - x <= 2.0, k0_small(x), jnp.where(x <= 8.0, k0_medium(x), k0_large(x)) - ) + idx = jnp.where(x <= 2.0, 0, jnp.where(x <= 8.0, 1, 2)) + return jax.lax.switch(idx, [k0_small, k0_medium, k0_large], x) @jax.jit @@ -316,9 +315,8 @@ def k1_medium(x): def k1_large(x): return jnp.exp(-x) * ((_dcsevl(16.0 / x - 1.0, ak12cs) + 1.25) / jnp.sqrt(x)) - return jnp.where( - x <= 2.0, k1_small(x), jnp.where(x <= 8.0, k1_medium(x), k1_large(x)) - ) + idx = jnp.where(x <= 2.0, 0, jnp.where(x <= 8.0, 1, 2)) + return jax.lax.switch(idx, [k1_small, k1_medium, k1_large], x) @jax.jit @@ -655,16 +653,19 @@ def _bessel_kv_fractional(nu, x): Reference: BesselK.cpp lines 62-215, 677-1036 """ - return jnp.where( - nu >= 35.0, - _bessel_kv_asymptotic_large_nu(nu, x), - jnp.where( - x <= 2.0, - _bessel_kv_small_x(nu, x), - jnp.where( - x <= 17.0, _bessel_kv_miller(nu, x), _bessel_kv_asymptotic(nu, x) - ), - ), + branch_index = jnp.where( + nu >= 35.0, 0, jnp.where(x <= 2.0, 1, jnp.where(x <= 17.0, 2, 3)) + ) + return jax.lax.switch( + branch_index, + [ + lambda nu, x: _bessel_kv_asymptotic_large_nu(nu, x), + lambda nu, x: _bessel_kv_small_x(nu, x), + lambda nu, x: _bessel_kv_miller(nu, x), + lambda nu, x: _bessel_kv_asymptotic(nu, x), + ], + nu, + x, ) @@ -675,6 +676,9 @@ def _bessel_kn_recurrence(n, x, k0_val, k1_val): Uses the recurrence relation: K_{n+1}(x) = K_{n-1}(x) + (2*n/x) * K_n(x) + For n <= 5, uses direct computation without a loop. + For n > 5, uses fori_loop with 399 iterations. + Args: n: Integer order (n >= 2) x: Argument value @@ -685,15 +689,27 @@ def _bessel_kn_recurrence(n, x, k0_val, k1_val): K_n(x) """ - def body_fn(i, carry): - k_prev, k_curr = carry - should_update = i < n - k_next = jnp.where(should_update, k_prev + (2.0 * i / x) * k_curr, k_curr) - k_prev_new = jnp.where(should_update, k_curr, k_prev) - return (k_prev_new, k_next) + def small_n(): + k2 = k0_val + (2.0 / x) * k1_val + k3 = k1_val + (4.0 / x) * k2 + k4 = k2 + (6.0 / x) * k3 + k5 = k3 + (8.0 / x) * k4 + return jnp.select( + [n == 2, n == 3, n == 4, n == 5], [k2, k3, k4, k5], default=k5 + ) + + def large_n(): + def body_fn(i, carry): + k_prev, k_curr = carry + should_update = i < n + k_next = jnp.where(should_update, k_prev + (2.0 * i / x) * k_curr, k_curr) + k_prev_new = jnp.where(should_update, k_curr, k_prev) + return (k_prev_new, k_next) - _, k_n = jax.lax.fori_loop(1, 400, body_fn, (k0_val, k1_val)) - return k_n + _, k_n = jax.lax.fori_loop(1, 400, body_fn, (k0_val, k1_val)) + return k_n + + return jax.lax.cond(n <= 5, small_n, large_n) @implements(_galsim.bessel.kn) @@ -717,33 +733,23 @@ def kn(n, x): k0 = _bessel_k0(x) k1 = _bessel_k1(x) - return jnp.where( - n == 0, k0, jnp.where(n == 1, k1, _bessel_kn_recurrence(n, x, k0, k1)) + idx = jnp.where(n == 0, 0, jnp.where(n == 1, 1, 2)) + return jax.lax.switch( + idx, + [ + lambda n, x, k0, k1: k0, + lambda n, x, k0, k1: k1, + lambda n, x, k0, k1: _bessel_kn_recurrence(n, x, k0, k1), + ], + n, + x, + k0, + k1, ) -@implements(_galsim.bessel.kv) -@jax.jit -def _kv_impl(nu, x): - """Modified Bessel function of the second kind K_ν(x) - internal implementation. - - Implementation strategy: - - Integer orders (ν = 0, 1, 2, ...): Pure JAX using Chebyshev series and recurrence - - Half-integer orders (ν = 0.5, 1.5, ...): Pure JAX using closed-form expressions - - Arbitrary fractional orders: Pure JAX using SLATEC algorithms - - All implementations are fully JIT-compatible and ported from the NETLIB SLATEC - library via the C++ GalSim reference implementation. - - Args: - nu: Order (can be negative, integer, or fractional) - x: Argument (must be positive) - - Returns: - K_ν(x) - - Reference: BesselK.cpp in GalSim C++ source - """ +def _kv_scalar(nu, x): + """Scalar implementation of K_ν(x). Inputs must be scalar (0-d arrays).""" nu = 1.0 * nu # promote to float x = 1.0 * x @@ -787,19 +793,64 @@ def half_integer_order(nu_int, x): def very_small_x_limit(nu, x): return jnp.power(2.0, nu - 1.0) * jsp.gamma(nu) / jnp.power(x, nu) - result_integer = integer_order(nu_int, x) - result_half_integer = half_integer_order(nu_int, x) - result_fractional = jnp.where( - x < 1e-10, very_small_x_limit(nu, x), _bessel_kv_fractional(nu, x) - ) + branch_index = jnp.where(is_integer, 0, jnp.where(is_half_integer, 1, 2)) - return jnp.where( - is_integer, - result_integer, - jnp.where(is_half_integer, result_half_integer, result_fractional), + def _branch_integer(nu, nu_int, x): + return integer_order(nu_int, x) + + def _branch_half_integer(nu, nu_int, x): + return half_integer_order(nu_int, x) + + def _branch_fractional(nu, nu_int, x): + return jnp.where( + x < 1e-10, very_small_x_limit(nu, x), _bessel_kv_fractional(nu, x) + ) + + return jax.lax.switch( + branch_index, + [_branch_integer, _branch_half_integer, _branch_fractional], + nu, + nu_int, + x, ) +@implements(_galsim.bessel.kv) +@jax.jit +def _kv_impl(nu, x): + """Modified Bessel function of the second kind K_ν(x) - internal implementation. + + Implementation strategy: + - Integer orders (ν = 0, 1, 2, ...): Pure JAX using Chebyshev series and recurrence + - Half-integer orders (ν = 0.5, 1.5, ...): Pure JAX using closed-form expressions + - Arbitrary fractional orders: Pure JAX using SLATEC algorithms + + All implementations are fully JIT-compatible and ported from the NETLIB SLATEC + library via the C++ GalSim reference implementation. + + Handles both scalar and array inputs. Internal algorithms are scalar; + array inputs are handled via jax.vmap. + + Args: + nu: Order (can be negative, integer, or fractional) + x: Argument (must be positive) + + Returns: + K_ν(x) + + Reference: BesselK.cpp in GalSim C++ source + """ + nu = jnp.asarray(1.0 * nu) + x = jnp.asarray(1.0 * x) + out_shape = jnp.broadcast_shapes(jnp.shape(nu), jnp.shape(x)) + if out_shape == (): + return _kv_scalar(nu, x) + nu_bc, x_bc = jnp.broadcast_arrays(nu, x) + flat_nu = nu_bc.ravel() + flat_x = x_bc.ravel() + return jax.vmap(_kv_scalar)(flat_nu, flat_x).reshape(out_shape) + + @jax.custom_vjp def kv(nu, x): """Modified Bessel function of the second kind K_ν(x) with custom gradients. diff --git a/pyproject.toml b/pyproject.toml index ce917c9c..77fa3762 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ readme = "README.md" dependencies = [ "numpy >=1.18.0", "galsim >=2.3.0", - "jax >=0.7.0", + "jax >=0.5.0", "astropy >=2.0", "quadax", ] From bb604f385a76514ff1aaabcd1985f114b93266d2 Mon Sep 17 00:00:00 2001 From: EiffL Date: Fri, 6 Feb 2026 22:51:53 +0100 Subject: [PATCH 06/11] fix wcs --- jax_galsim/fitswcs.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index 5870d755..6fcd8953 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -834,16 +834,28 @@ def __eq__(self, other): and jnp.array_equal(self.cd, other.cd) and self.center == other.center and ( - jnp.array_equal(self.pv, other.pv) - or (self.pv is None and other.pv is None) + (self.pv is None and other.pv is None) + or ( + self.pv is not None + and other.pv is not None + and jnp.array_equal(self.pv, other.pv) + ) ) and ( - jnp.array_equal(self.ab, other.ab) - or (self.ab is None and other.ab is None) + (self.ab is None and other.ab is None) + or ( + self.ab is not None + and other.ab is not None + and jnp.array_equal(self.ab, other.ab) + ) ) and ( - jnp.array_equal(self.abp, other.abp) - or (self.abp is None and other.abp is None) + (self.abp is None and other.abp is None) + or ( + self.abp is not None + and other.abp is not None + and jnp.array_equal(self.abp, other.abp) + ) ) ) From a7fc859c4299c4850eebebc245f9fe757c70d96c Mon Sep 17 00:00:00 2001 From: EiffL Date: Fri, 6 Feb 2026 23:04:01 +0100 Subject: [PATCH 07/11] upddsting exponemtial --- jax_galsim/integ.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index c2e19a42..19ad5c4b 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -40,7 +40,7 @@ def int1d( @jax.jit def _func(x): rdt = jax.ShapeDtypeStruct(x.shape, x.dtype) - return jax.pure_callback(func, rdt, x) + return jax.pure_callback(func, rdt, x, vmap_method="sequential") else: _func = func From 2835a6c5a574c520e041b94515aa50df4e3cdcbb Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Sat, 7 Feb 2026 12:05:03 +0100 Subject: [PATCH 08/11] ported implementation from tfp --- jax_galsim/bessel.py | 783 +++++++++++++++++++------------------------ 1 file changed, 352 insertions(+), 431 deletions(-) diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index 11291f2d..7ccee187 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -319,354 +319,383 @@ def k1_large(x): return jax.lax.switch(idx, [k1_small, k1_medium, k1_large], x) -@jax.jit -def _bessel_kv_asymptotic_large_nu(nu, x): - """Compute K_ν(x) using uniform asymptotic expansion for large ν (ν ≥ 35). - - Implements the SLATEC dasyik algorithm for K Bessel functions. - - Reference: BesselI.cpp lines 844-953 (dasyik function) - Algorithm: Olver, F.W.J. (1962), Tables of Bessel Functions of Moderate or Large Orders - """ - # fmt: off - sqrt_half_pi = 1.25331413731550025 # sqrt(pi/2) - c = jnp.array([ - -0.208333333333333, 0.125, 0.334201388888889, - -0.401041666666667, 0.0703125, -1.02581259645062, 1.84646267361111, - -0.8912109375, 0.0732421875, 4.66958442342625, -11.207002616223, - 8.78912353515625, -2.3640869140625, 0.112152099609375, - -28.2120725582002, 84.6362176746007, -91.81824154324, - 42.5349987453885, -7.36879435947963, 0.227108001708984, - 212.570130039217, -765.252468141182, 1059.990452528, - -699.579627376133, 218.190511744212, -26.4914304869516, - 0.572501420974731, -1919.45766231841, 8061.72218173731, - -13586.5500064341, 11655.3933368645, -5305.6469786134, - 1200.90291321635, -108.090919788395, 1.72772750258446, - 20204.2913309661, -96980.5983886375, 192547.001232532, - -203400.177280416, 122200.464983017, -41192.6549688976, - 7109.51430248936, -493.915304773088, 6.07404200127348, - -242919.187900551, 1311763.61466298, -2998015.91853811, - 3763271.2976564, -2813563.22658653, 1268365.27332162, - -331645.172484564, 45218.7689813627, -2499.83048181121, - 24.3805296995561, 3284469.85307204, -19706819.1184322, - 50952602.4926646, -74105148.2115327, 66344512.274729, - -37567176.6607634, 13288767.1664218, -2785618.12808645, - 308186.404612662, -13886.089753717, 110.017140269247 - ]) - # fmt: on - - fn = nu - z = x / fn - ra = jnp.sqrt(z * z + 1.0) - gln = jnp.log((ra + 1.0) / z) - coef = jnp.exp(-fn * (ra - gln)) - t = 1.0 / ra - t2 = t * t - t = -t / fn - s2 = 1.0 - ap = 1.0 - coeff_idx = 0 - - def body_fn(k, carry): - s2, ap, coeff_idx = carry - s1 = c[coeff_idx] - new_idx = coeff_idx + 1 - - def inner_body(j, inner_carry): - s1_val, idx = inner_carry - should_update = j < k + 1 - s1_new = jnp.where(should_update, s1_val * t2 + c[idx], s1_val) - idx_new = jnp.where(should_update, idx + 1, idx) - return (s1_new, idx_new) - - s1, new_idx = jax.lax.fori_loop(2, 12, inner_body, (s1, new_idx)) - - ap_new = ap * t - s2_new = s2 + ap_new * s1 - return (s2_new, ap_new, coeff_idx + k) - - s2, ap, _ = jax.lax.fori_loop(2, 12, body_fn, (s2, ap, coeff_idx)) - - return s2 * coef * jnp.sqrt(jnp.abs(t)) * sqrt_half_pi +# ===================================================================== +# Modified Bessel K_v(x) for fractional order +# Ported from TensorFlow Probability's bessel.py (Apache 2.0 License) +# Uses Temme's method (|v| < 50) + Olver's uniform asymptotic (|v| >= 50) +# ===================================================================== + +# Olver expansion polynomial coefficients (10 terms, up to 31 coefficients each) +# fmt: off +_ASYMPTOTIC_OLVER_COEFFICIENTS = [ + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + -0.20833333333333334, 0., 0.125, 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.3342013888888889, 0., + -0.40104166666666669, 0., 0.0703125, 0., 0.0], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., -1.0258125964506173, 0., 1.8464626736111112, + 0., -0.89121093750000002, 0., 0.0732421875, 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 4.6695844234262474, 0., -11.207002616222995, 0., + 8.78912353515625, 0., -2.3640869140624998, 0., 0.112152099609375, + 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + -28.212072558200244, 0., 84.636217674600744, 0., -91.818241543240035, + 0., 42.534998745388457, 0., -7.3687943594796312, 0., 0.22710800170898438, + 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 212.5701300392171, 0., + -765.25246814118157, 0., 1059.9904525279999, 0., -699.57962737613275, + 0., 218.19051174421159, 0., -26.491430486951554, 0., 0.57250142097473145, + 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., -1919.4576623184068, 0., + 8061.7221817373083, 0., -13586.550006434136, 0., 11655.393336864536, + 0., -5305.6469786134048, 0., 1200.9029132163525, 0., + -108.09091978839464, 0., 1.7277275025844574, 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 20204.291330966149, 0., -96980.598388637503, 0., + 192547.0012325315, 0., -203400.17728041555, 0., 122200.46498301747, + 0., -41192.654968897557, 0., 7109.5143024893641, 0., + -493.915304773088, 0., 6.074042001273483, 0., 0., 0., 0., 0., + 0., 0., 0.], + [0., 0., 0., -242919.18790055133, 0., 1311763.6146629769, 0., + -2998015.9185381061, 0., 3763271.2976564039, 0., -2813563.2265865342, 0., + 1268365.2733216248, 0., -331645.17248456361, 0., 45218.768981362737, 0., + -2499.8304818112092, 0., 24.380529699556064, 0., 0., 0., 0., 0., + 0., 0., 0., 0.0], + [3284469.8530720375, 0., -19706819.11843222, 0., 50952602.492664628, + 0., -74105148.211532637, 0., 66344512.274729028, 0., -37567176.660763353, + 0., 13288767.166421819, 0., -2785618.1280864552, 0., 308186.40461266245, + 0., -13886.089753717039, 0., 110.01714026924674, 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0.] +] +# fmt: on + + +def _sqrt1px2(x): + """Numerically stable computation of sqrt(1 + x^2).""" + eps = jnp.finfo(jnp.float64).eps + return jnp.where( + jnp.abs(x) * jnp.sqrt(eps) <= 1.0, + jnp.exp(0.5 * jnp.log1p(x * x)), + jnp.abs(x), + ) -@jax.jit -def _bessel_kv_small_x(nu, x): - """Compute K_ν(x) using power series for small x (x ≤ 2.0). +def _evaluate_temme_coeffs(v): + """Numerically stable computation of gamma-related coefficients for Temme's method. - Implements the SLATEC dbsknu series expansion. + Computes: + coeff1 = (1/Gamma(1-v) - 1/Gamma(1+v)) / (2v) + coeff2 = (1/Gamma(1-v) + 1/Gamma(1+v)) / 2 + gamma1pv = 1/Gamma(1+v) + gamma1mv = 1/Gamma(1-v) - Reference: BesselK.cpp lines 880-965 (inside dbsknu) - Algorithm: K_ν(x) = (π/2) * [I_{-ν}(x) - I_ν(x)] / sin(πν) + Uses Chebyshev expansions for numerical stability (avoids catastrophic cancellation). """ - # fmt: off - cc = jnp.array([ - 0.577215664901533, -0.0420026350340952, - -0.0421977345555443, 0.007218943246663, -2.152416741149e-4, - -2.01348547807e-5, 1.133027232e-6, 6.116095e-9 - ]) - # fmt: on - - pi = jnp.pi - tol = jnp.maximum(jnp.finfo(jnp.float64).eps, 1e-15) - - a1 = 1.0 - nu - a2 = nu + 1.0 - t1 = 1.0 / jsp.gamma(a1) - t2 = 1.0 / jsp.gamma(a2) - - dnu2 = jnp.where(jnp.abs(nu) >= tol, nu * nu, 0.0) - - def g1_small_nu(): - s = cc[0] - ak = 1.0 - - def body_fn(k, carry): - s_val, ak_val = carry - ak_new = ak_val * dnu2 - s_new = s_val + cc[k] * ak_new - return (s_new, ak_new) - - s_final, _ = jax.lax.fori_loop(1, 8, body_fn, (s, ak)) - return -s_final - - def g1_large_nu(): - return (t1 - t2) / (nu + nu) - - g1 = jnp.where(jnp.abs(nu) > 0.1, g1_large_nu(), g1_small_nu()) - g2 = (t1 + t2) * 0.5 - - rx = 2.0 / x - flrx = jnp.log(rx) - fmu = nu * flrx - - smu = jnp.where(nu != 0.0, jnp.sinh(fmu) / fmu, 1.0) - fc = jnp.where(nu != 0.0, nu * pi / jnp.sin(nu * pi), 1.0) - - f = fc * (g1 * jnp.cosh(fmu) + g2 * flrx * smu) - fc_exp = jnp.exp(fmu) - p = fc_exp * 0.5 / t2 - q = 0.5 / (fc_exp * t1) - - ak = 1.0 - ck = 1.0 - bk = 1.0 - s1 = f - s2 = p + coeff1_coeffs = [ + -1.142022680371168e0, + 6.5165112670737e-3, + 3.087090173086e-4, + -3.4706269649e-6, + 6.9437664e-9, + 3.67795e-11, + -1.356e-13, + ] + coeff2_coeffs = [ + 1.843740587300905e0, + -7.68528408447867e-2, + 1.2719271366546e-3, + -4.9717367042e-6, + -3.31261198e-8, + 2.423096e-10, + -1.702e-13, + -1.49e-15, + ] + w = 8.0 * v * v - 1.0 + y = 2.0 * w + + # Clenshaw's recurrence for coeff1 + prev = 0.0 + current = 0.0 + for i in reversed(range(1, len(coeff1_coeffs))): + temp = current + current = y * current - prev + coeff1_coeffs[i] + prev = temp + coeff1 = w * current - prev + 0.5 * coeff1_coeffs[0] + + # Clenshaw's recurrence for coeff2 + prev = 0.0 + current = 0.0 + for i in reversed(range(1, len(coeff2_coeffs))): + temp = current + current = y * current - prev + coeff2_coeffs[i] + prev = temp + coeff2 = w * current - prev + 0.5 * coeff2_coeffs[0] + + gamma1pv = coeff2 - v * coeff1 + gamma1mv = coeff2 + v * coeff1 + return coeff1, coeff2, gamma1pv, gamma1mv + + +def _temme_series_kve(v, z): + """Compute Kve(v, z) and Kve(v+1, z) via Temme power series. + + Assumes |v| < 0.5 and |z| <= 2 for fast convergence. + Returns exponentially scaled values: Kv(v,z)*exp(z). + """ + tol = jnp.finfo(jnp.float64).eps - # Power series iteration (up to 17 terms) - cx = x * x * 0.25 + coeff1, coeff2, gamma1pv_inv, gamma1mv_inv = _evaluate_temme_coeffs(v) - def series_body(i, carry): - f_val, p_val, q_val, ak_val, ck_val, bk_val, s1_val, s2_val = carry + z_sq = z * z + logzo2 = jnp.log(z / 2.0) + mu = -v * logzo2 + sinc_v = jnp.where(v == 0.0, 1.0, jnp.sin(jnp.pi * v) / (jnp.pi * v)) + sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu) / mu) - f_new = (ak_val * f_val + p_val + q_val) / (bk_val - dnu2) - p_new = p_val / (ak_val - nu) - q_new = q_val / (ak_val + nu) - ck_new = ck_val * cx / ak_val - s1_new = s1_val + ck_new * f_new - s2_new = s2_val + ck_new * (p_new - ak_val * f_new) - bk_new = bk_val + ak_val + ak_val + 1.0 - ak_new = ak_val + 1.0 + initial_f = (coeff1 * jnp.cosh(mu) + coeff2 * (-logzo2) * sinhc_mu) / sinc_v + initial_p = 0.5 * jnp.exp(mu) / gamma1pv_inv + initial_q = 0.5 * jnp.exp(-mu) / gamma1mv_inv - return (f_new, p_new, q_new, ak_new, ck_new, bk_new, s1_new, s2_new) + max_iterations = 1000 - init = (f, p, q, ak, ck, bk, s1, s2) - _, _, _, _, _, _, s1_final, _ = jax.lax.cond( - x >= tol, - lambda: jax.lax.fori_loop(0, 17, series_body, init), - lambda: init, + def body_fn(carry): + should_stop, index, f, p, q, coeff, kv_sum, kvp1_sum = carry + f = jnp.where( + should_stop, + f, + (index * f + p + q) / (index * index - v * v), + ) + p = jnp.where(should_stop, p, p / (index - v)) + q = jnp.where(should_stop, q, q / (index + v)) + h = p - index * f + coeff = jnp.where(should_stop, coeff, coeff * z_sq / (4.0 * index)) + kv_sum = jnp.where(should_stop, kv_sum, kv_sum + coeff * f) + kvp1_sum = jnp.where(should_stop, kvp1_sum, kvp1_sum + coeff * h) + index = index + 1.0 + should_stop = (jnp.abs(coeff * f) < jnp.abs(kv_sum) * tol) | ( + index > max_iterations + ) + return (should_stop, index, f, p, q, coeff, kv_sum, kvp1_sum) + + def cond_fn(carry): + should_stop = carry[0] + return ~should_stop + + init = ( + jnp.array(False), + 1.0, + initial_f, + initial_p, + initial_q, + 1.0, + initial_f, + initial_p, ) + _, _, _, _, _, _, kv_sum, kvp1_sum = jax.lax.while_loop(cond_fn, body_fn, init) - return s1_final + # Convert to exponentially scaled: kve = kv * exp(z) + kve = kv_sum * jnp.exp(z) + kvep1 = 2.0 * kvp1_sum * jnp.exp(z) / z + return kve, kvep1 -@jax.jit -def _bessel_kv_miller(nu, x): - """Compute K_ν(x) using Miller's algorithm for moderate x (2.0 < x ≤ 17.0). - Implements the SLATEC Miller algorithm using continued fractions, - then uses forward recursion to get from dnu to nu. +def _continued_fraction_kve(v, z): + """Compute Kve(v, z) and Kve(v+1, z) via Steed's continued fraction. - Reference: BesselK.cpp lines 677-1036 (dbsknu function) + Assumes |v| < 0.5 and |z| > 2. + Returns exponentially scaled values: Kv(v,z)*exp(z). """ - pi = jnp.pi - sqrt_half_pi = 1.2533141373155 # sqrt(pi/2) - tol = jnp.maximum(jnp.finfo(jnp.float64).eps, 1e-15) - - rx = 2.0 / x - inu = jnp.floor(nu + 0.5).astype(int) - dnu = nu - inu - dnu2 = jnp.where(jnp.abs(dnu) >= tol, dnu * dnu, 0.0) - - coef = sqrt_half_pi / jnp.sqrt(x) * jnp.exp(-x) - - etest = jnp.cos(pi * dnu) / (pi * x * tol) - fks = 1.0 - fhs = 0.25 - fk = 0.0 - ck = x + x + 2.0 - p1 = 0.0 - p2 = 1.0 + tol = jnp.finfo(jnp.float64).eps + max_iterations = 1000 - a_arr = jnp.zeros(160) - b_arr = jnp.zeros(160) + initial_numerator = v * v - 0.25 + initial_denominator = 2.0 * (z + 1.0) + initial_ratio = 1.0 / initial_denominator + initial_seq = -initial_numerator - def forward_body(i, carry): + def steeds_body(carry): ( - fk_val, - fks_val, - fhs_val, - ck_val, - p1_val, - p2_val, - a_arr_val, - b_arr_val, - k, - converged, + should_stop, + index, + partial_numerator, + partial_denominator, + denominator_ratio, + convergent_difference, + hypergeometric_ratio, + k_0, + k_1, + c, + q, + hypergeometric_sum, ) = carry - fk_new = fk_val + 1.0 - ak = (fhs_val - dnu2) / (fks_val + fk_new) - bk = ck_val / (fk_new + 1.0) - pt = p2_val - p2_new = bk * p2_val - ak * p1_val - p1_new = pt - - a_arr_new = jnp.where(converged, a_arr_val, a_arr_val.at[k].set(ak)) - b_arr_new = jnp.where(converged, b_arr_val, b_arr_val.at[k].set(bk)) - - ck_new = ck_val + 2.0 - fks_new = fks_val + fk_new + fk_new + 1.0 - fhs_new = fhs_val + fk_new + fk_new - - has_converged = (etest <= fk_new * jnp.abs(p1_new)) | converged - k_new = jnp.where(has_converged, k, k + 1) - + partial_numerator = partial_numerator - 2.0 * (index - 1.0) + c = jnp.where(should_stop, c, -c * partial_numerator / index) + next_k = (k_0 - partial_denominator * k_1) / partial_numerator + k_0 = jnp.where(should_stop, k_0, k_1) + k_1 = jnp.where(should_stop, k_1, next_k) + q = jnp.where(should_stop, q, q + c * next_k) + partial_denominator = partial_denominator + 2.0 + denominator_ratio = 1.0 / ( + partial_denominator + partial_numerator * denominator_ratio + ) + convergent_difference = jnp.where( + should_stop, + convergent_difference, + convergent_difference * (partial_denominator * denominator_ratio - 1.0), + ) + hypergeometric_ratio = jnp.where( + should_stop, + hypergeometric_ratio, + hypergeometric_ratio + convergent_difference, + ) + hypergeometric_sum = jnp.where( + should_stop, + hypergeometric_sum, + hypergeometric_sum + q * convergent_difference, + ) + index = index + 1.0 + should_stop = ( + jnp.abs(q * convergent_difference) < jnp.abs(hypergeometric_sum) * tol + ) | (index > max_iterations) return ( - fk_new, - fks_new, - fhs_new, - ck_new, - p1_new, - p2_new, - a_arr_new, - b_arr_new, - k_new, - has_converged, + should_stop, + index, + partial_numerator, + partial_denominator, + denominator_ratio, + convergent_difference, + hypergeometric_ratio, + k_0, + k_1, + c, + q, + hypergeometric_sum, ) - _, _, _, _, _, _, a_final, b_final, k_final, _ = jax.lax.fori_loop( - 0, 160, forward_body, (fk, fks, fhs, ck, p1, p2, a_arr, b_arr, 0, False) + def cond_fn(carry): + return ~carry[0] + + init = ( + jnp.array(False), + 2.0, + initial_numerator, + initial_denominator, + initial_ratio, + initial_ratio, + initial_ratio, + 0.0, + 1.0, + initial_seq, + initial_seq, + 1.0 - initial_numerator * initial_ratio, ) - - s = 1.0 - p1_back = 0.0 - p2_back = 1.0 - - def backward_body(i, carry): - s_val, p1_val, p2_val = carry - should_update = i < k_final - idx = k_final - 1 - i - pt = p2_val - p2_new = jnp.where( - should_update, (b_final[idx] * p2_val - p1_val) / a_final[idx], p2_val - ) - p1_new = jnp.where(should_update, pt, p1_val) - s_new = jnp.where(should_update, s_val + p2_new, s_val) - - return (s_new, p1_new, p2_new) - - s_final, p1_back_final, p2_back_final = jax.lax.fori_loop( - 0, 160, backward_body, (s, p1_back, p2_back) + result = jax.lax.while_loop(cond_fn, steeds_body, init) + hypergeometric_ratio = result[6] + hypergeometric_sum = result[11] + + # log(kve) = 0.5*log(pi/(2z)) - log(hypergeometric_sum) + log_kve = 0.5 * jnp.log(jnp.pi / (2.0 * z)) - jnp.log(hypergeometric_sum) + log_kvp1e = ( + log_kve + + jnp.log1p(2.0 * (v + z + initial_numerator * hypergeometric_ratio)) + - jnp.log(z) + - jnp.log(2.0) ) + return jnp.exp(log_kve), jnp.exp(log_kvp1e) - s1 = coef * (p2_back_final / s_final) - - def no_recursion(): - return s1 - - def with_recursion(): - s2 = s1 * (x + dnu + 0.5 - p1_back_final / p2_back_final) / x - ck_rec = (dnu + dnu + 2.0) / x - inu_adjusted = inu - 1 - def recursion_body(i, carry): - s1_val, s2_val, ck_val = carry - should_update = i < inu_adjusted - st = s2_val - s2_new = jnp.where(should_update, ck_val * s2_val + s1_val, s2_val) - s1_new = jnp.where(should_update, st, s1_val) - ck_new = jnp.where(should_update, ck_val + rx, ck_val) - return (s1_new, s2_new, ck_new) +def _olver_kve(v, z): + """Compute Kve(v, z) using Olver's uniform asymptotic expansion. - _, s2_final, _ = jax.lax.fori_loop(0, 50, recursion_body, (s1, s2, ck_rec)) - return s2_final - - return jax.lax.cond(inu == 0, no_recursion, with_recursion) + Valid for |v| >= 50. Returns exponentially scaled value: Kv(v,z)*exp(z). + """ + v_abs = jnp.abs(v) + w = z / v_abs + t = 1.0 / _sqrt1px2(w) + + divisor = v_abs + kve_sum = 1.0 + + # Evaluate the Olver polynomial terms using Horner's method + for i in range(len(_ASYMPTOTIC_OLVER_COEFFICIENTS)): + coeff = 0.0 + for c in _ASYMPTOTIC_OLVER_COEFFICIENTS[i]: + coeff = coeff * t + c + term = coeff / divisor + # For K_v, signs alternate: (-1)^i + kve_sum = kve_sum + (term if i % 2 == 1 else -term) + divisor = divisor * v_abs + + # log(kve) = 0.5*log(pi*t/(2*v_abs)) - v_abs*shared_prefactor + shared_prefactor = 1.0 / (_sqrt1px2(w) + w) + jnp.log(w) - jnp.log1p(1.0 / t) + log_k_prefactor = ( + 0.5 * jnp.log(jnp.pi * t / (2.0 * v_abs)) - v_abs * shared_prefactor + ) + log_kve = log_k_prefactor + jnp.log(kve_sum) + return jnp.exp(log_kve) -@jax.jit -def _bessel_kv_asymptotic(nu, x): - """Compute K_ν(x) using asymptotic expansion for large x (x > 17.0). - Enhanced version with 30 terms for better accuracy. +def _temme_kve(v, x): + """Compute Kve(v, x) using Temme's method for |v| < 50. - Reference: BesselK.cpp lines 799-832 (inside dbsknu) + Reduces to fractional order |u| <= 0.5, computes Kve(u, x) and Kve(u+1, x), + then uses forward recurrence to reach order v. + Returns exponentially scaled value: Kv(v,x)*exp(x). """ - sqrt_pi_2x = jnp.sqrt(jnp.pi / (2.0 * x)) - exp_neg_x = jnp.exp(-x) - - two_nu = nu + nu - fmu = two_nu * two_nu - ex = x * 8.0 - - s = 1.0 - ck = 1.0 - sqk = 1.0 - ak = 0.0 - dk = ex - - def body_fn(_, carry): - s_val, ck_val, dk_val, ak_val, sqk_val = carry - ck_new = ck_val * (fmu - sqk_val) / dk_val - s_new = s_val + ck_new - dk_new = dk_val + ex - ak_new = ak_val + 8.0 - sqk_new = sqk_val + ak_new - return (s_new, ck_new, dk_new, ak_new, sqk_new) + v = jnp.abs(v) + n = jnp.round(v) + u = v - n + + # Branchless: compute both methods with safe inputs, select with jnp.where + small_x = jnp.where(x <= 2.0, x, 0.1) + large_x = jnp.where(x > 2.0, x, 1000.0) + + temme_kue, temme_kuep1 = _temme_series_kve(u, small_x) + cf_kue, cf_kuep1 = _continued_fraction_kve(u, large_x) + + kue = jnp.where(x <= 2.0, temme_kue, cf_kue) + kuep1 = jnp.where(x <= 2.0, temme_kuep1, cf_kuep1) + + # Forward recurrence: K_{v+1}(z) = (2v/z)*K_v(z) + K_{v-1}(z) + # This recurrence is also satisfied by Kv*exp(z) (the exponentially scaled form). + def bessel_recurrence(carry): + index, kve, kvep1 = carry + next_kvep1 = 2.0 * (u + index) * kvep1 / x + kve + kve = jnp.where(index > n, kve, kvep1) + kvep1 = jnp.where(index > n, kvep1, next_kvep1) + return (index + 1.0, kve, kvep1) + + def recurrence_cond(carry): + index = carry[0] + return index <= n + + _, kve, _ = jax.lax.while_loop( + recurrence_cond, bessel_recurrence, (1.0, kue, kuep1) + ) + return kve - s_final, _, _, _, _ = jax.lax.fori_loop(0, 30, body_fn, (s, ck, dk, ak, sqk)) - return sqrt_pi_2x * exp_neg_x * s_final +def _kve_core(nu, x): + """Core dispatcher for Kve(nu, x) = Kv(nu, x) * exp(x). + Branchless: computes both Olver and Temme with safe dummy inputs, + selects based on |nu| >= 50. + """ + nu = jnp.abs(nu) -@jax.jit -def _bessel_kv_fractional(nu, x): - """Compute K_ν(x) for fractional ν using SLATEC algorithms. + # Safe inputs: avoid invalid regions for each method + small_nu = jnp.where(nu < 50.0, nu, 0.1) + large_nu = jnp.where(nu >= 50.0, nu, 1000.0) - Decision tree based on the C++ dbesk/dbsknu implementation: - - ν ≥ 35: Uniform asymptotic expansion (large ν) - - x ≤ 2: Power series expansion - - 2 < x ≤ 17: Miller's algorithm - - x > 17: Asymptotic expansion for large x + olver_result = _olver_kve(large_nu, x) + temme_result = _temme_kve(small_nu, x) - Reference: BesselK.cpp lines 62-215, 677-1036 - """ - branch_index = jnp.where( - nu >= 35.0, 0, jnp.where(x <= 2.0, 1, jnp.where(x <= 17.0, 2, 3)) - ) - return jax.lax.switch( - branch_index, - [ - lambda nu, x: _bessel_kv_asymptotic_large_nu(nu, x), - lambda nu, x: _bessel_kv_small_x(nu, x), - lambda nu, x: _bessel_kv_miller(nu, x), - lambda nu, x: _bessel_kv_asymptotic(nu, x), - ], - nu, - x, - ) + return jnp.where(nu >= 50.0, olver_result, temme_result) @jax.jit @@ -749,70 +778,21 @@ def kn(n, x): def _kv_scalar(nu, x): - """Scalar implementation of K_ν(x). Inputs must be scalar (0-d arrays).""" - nu = 1.0 * nu # promote to float + """Scalar implementation of K_ν(x) using TFP-ported Temme + Olver algorithms.""" + nu = 1.0 * nu x = 1.0 * x + nu = jnp.abs(nu) # K_{-v} = K_v - # K_{-nu}(x) = K_nu(x) - nu = jnp.abs(nu) - - nu_int = jnp.floor(nu).astype(int) - nu_frac = nu - nu_int - - is_half_integer = jnp.abs(nu_frac - 0.5) < 1e-10 - is_integer = nu_frac < 1e-10 - - def integer_order(nu_int, x): - k0 = _bessel_k0(x) - k1 = _bessel_k1(x) - return jnp.where( - nu_int == 0, - k0, - jnp.where(nu_int == 1, k1, _bessel_kn_recurrence(nu_int, x, k0, k1)), - ) - - def half_integer_order(nu_int, x): - sqrt_pi_2x = jnp.sqrt(jnp.pi / (2.0 * x)) - exp_neg_x = jnp.exp(-x) - inv_x = 1.0 / x - - p0 = 1.0 - p1 = 1.0 + inv_x - p2 = 1.0 + 3.0 * inv_x + 3.0 * inv_x**2 - p3 = 1.0 + 6.0 * inv_x + 15.0 * inv_x**2 + 15.0 * inv_x**3 - p4 = 1.0 + 10.0 * inv_x + 45.0 * inv_x**2 + 105.0 * inv_x**3 + 105.0 * inv_x**4 - - poly = jnp.select( - [nu_int == 0, nu_int == 1, nu_int == 2, nu_int == 3], - [p0, p1, p2, p3], - default=p4, - ) - - return sqrt_pi_2x * exp_neg_x * poly - - def very_small_x_limit(nu, x): - return jnp.power(2.0, nu - 1.0) * jsp.gamma(nu) / jnp.power(x, nu) - - branch_index = jnp.where(is_integer, 0, jnp.where(is_half_integer, 1, 2)) - - def _branch_integer(nu, nu_int, x): - return integer_order(nu_int, x) - - def _branch_half_integer(nu, nu_int, x): - return half_integer_order(nu_int, x) - - def _branch_fractional(nu, nu_int, x): - return jnp.where( - x < 1e-10, very_small_x_limit(nu, x), _bessel_kv_fractional(nu, x) - ) + # Compute via exponentially scaled form for numerical stability + # Use a safe x for the core computation (avoid x=0 which causes issues) + safe_x = jnp.where(x > 0.0, x, 1.0) + kve = _kve_core(nu, safe_x) + result = kve * jnp.exp(-safe_x) - return jax.lax.switch( - branch_index, - [_branch_integer, _branch_half_integer, _branch_fractional], - nu, - nu_int, - x, - ) + # Edge cases + result = jnp.where(x == 0.0, jnp.inf, result) + result = jnp.where(x < 0.0, jnp.nan, result) + return result @implements(_galsim.bessel.kv) @@ -820,25 +800,8 @@ def _branch_fractional(nu, nu_int, x): def _kv_impl(nu, x): """Modified Bessel function of the second kind K_ν(x) - internal implementation. - Implementation strategy: - - Integer orders (ν = 0, 1, 2, ...): Pure JAX using Chebyshev series and recurrence - - Half-integer orders (ν = 0.5, 1.5, ...): Pure JAX using closed-form expressions - - Arbitrary fractional orders: Pure JAX using SLATEC algorithms - - All implementations are fully JIT-compatible and ported from the NETLIB SLATEC - library via the C++ GalSim reference implementation. - - Handles both scalar and array inputs. Internal algorithms are scalar; - array inputs are handled via jax.vmap. - - Args: - nu: Order (can be negative, integer, or fractional) - x: Argument (must be positive) - - Returns: - K_ν(x) - - Reference: BesselK.cpp in GalSim C++ source + Uses TFP-ported Temme + Olver algorithms for all orders. + Handles both scalar and array inputs via jax.vmap. """ nu = jnp.asarray(1.0 * nu) x = jnp.asarray(1.0 * x) @@ -855,45 +818,15 @@ def _kv_impl(nu, x): def kv(nu, x): """Modified Bessel function of the second kind K_ν(x) with custom gradients. - Implementation strategy: - - Integer orders (ν = 0, 1, 2, ...): Pure JAX using Chebyshev series and recurrence - - Half-integer orders (ν = 0.5, 1.5, ...): Pure JAX using closed-form expressions - - Arbitrary fractional orders: Pure JAX using SLATEC algorithms - - All implementations are fully JIT-compatible and ported from the NETLIB SLATEC - library via the C++ GalSim reference implementation. - - Custom gradients are implemented using analytical derivative formulas based on - Bessel function recurrence relations: + Uses TFP-ported Temme + Olver algorithms. Custom gradients via: ∂K_ν/∂x = -1/2 * (K_{ν-1}(x) + K_{ν+1}(x)) - This is derived from the modified Bessel recurrence relation: - K_{ν-1}(x) + K_{ν+1}(x) = -2 * K'_ν(x) - - Note: Gradients with respect to the order parameter ν are not supported and will - return zero. This follows the approach used in TensorFlow Probability, as gradients - with respect to the order are rarely needed in practice. - - Args: - nu: Order (can be negative, integer, or fractional) - x: Argument (must be positive) - - Returns: - K_ν(x) - - Reference: - - BesselK.cpp in GalSim C++ source - - TensorFlow Probability bessel.py for custom gradient approach - - Abramowitz & Stegun 9.6.26 for derivative recurrence relations + Gradient w.r.t. ν is not supported (returns zero). """ return _kv_impl(nu, x) def _kv_fwd(nu, x): - """Forward pass for kv with custom gradients. - - Computes K_ν(x) and saves K_{ν-1}(x) and K_{ν+1}(x) for use in the backward pass. - """ kv_val = _kv_impl(nu, x) kv_prev = _kv_impl(nu - 1.0, x) kv_next = _kv_impl(nu + 1.0, x) @@ -901,20 +834,8 @@ def _kv_fwd(nu, x): def _kv_bwd(residuals, g): - """Backward pass for kv with custom gradients. - - Uses the analytical derivative formula from Bessel recurrence relations: - ∂K_ν/∂x = -1/2 * (K_{ν-1}(x) + K_{ν+1}(x)) - - This formula comes from the modified Bessel function recurrence relation: - K_{ν-1}(x) + K_{ν+1}(x) = -2 * K'_ν(x) - - Gradient with respect to ν is not supported and returns zero. - """ nu, x, kv_prev, kv_next = residuals - # Analytical gradient: ∂K_ν/∂x = -1/2 * (K_{ν-1}(x) + K_{ν+1}(x)) grad_x = -0.5 * (kv_prev + kv_next) * g - # Gradient w.r.t. ν not supported (following TFP approach) grad_nu = jnp.zeros_like(nu) return (grad_nu, grad_x) From 900013af0c3fa10b26f0217bab757b9cc4d64a07 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Sat, 7 Feb 2026 12:08:53 +0100 Subject: [PATCH 09/11] adding licensing --- LICENSE | 18 ++++++++++++++++++ jax_galsim/bessel.py | 19 +++++++++++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/LICENSE b/LICENSE index 344c4a4e..61250865 100644 --- a/LICENSE +++ b/LICENSE @@ -22,6 +22,24 @@ lost profits, business interruption, or indirect special or consequential damages of any kind. +Code in bessel.py (fractional-order Bessel K functions) is derived from +TensorFlow Probability under this license: + +Copyright 2020 The TensorFlow Probability Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + Code in angle.py and celestial.py is based on LSSTDESC/Coord under this license: Copyright (c) 2013-2017 LSST Dark Energy Science Collaboration (DESC) diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index 7ccee187..f2f7c759 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -321,8 +321,23 @@ def k1_large(x): # ===================================================================== # Modified Bessel K_v(x) for fractional order -# Ported from TensorFlow Probability's bessel.py (Apache 2.0 License) -# Uses Temme's method (|v| < 50) + Olver's uniform asymptotic (|v| >= 50) +# +# The functions below (_sqrt1px2, _evaluate_temme_coeffs, +# _temme_series_kve, _continued_fraction_kve, _olver_kve, _temme_kve, +# _kve_core) and the _ASYMPTOTIC_OLVER_COEFFICIENTS constant are +# derived from TensorFlow Probability's bessel.py: +# https://github.com/tensorflow/probability +# +# Original copyright and license: +# Copyright 2020 The TensorFlow Probability Authors. +# Licensed under the Apache License, Version 2.0. +# +# Modifications from the original: +# - Ported from TensorFlow/TFP APIs to pure JAX (jax.numpy, jax.lax) +# - Removed I_v (bessel_ive) computation; only K_v is computed +# - Removed log-space output option +# - Removed negative-v correction for I_v +# - Simplified to scalar-only core (vectorization via jax.vmap) # ===================================================================== # Olver expansion polynomial coefficients (10 terms, up to 31 coefficients each) From adb064d1d873d4cd0d736d4b6ee50878c8d38c12 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Sat, 7 Feb 2026 14:31:01 +0100 Subject: [PATCH 10/11] Delete CLAUDE.md --- CLAUDE.md | 295 ------------------------------------------------------ 1 file changed, 295 deletions(-) delete mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index d8f2386c..00000000 --- a/CLAUDE.md +++ /dev/null @@ -1,295 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -JAX-GalSim is a JAX port of GalSim (Galaxy Image Simulation toolkit) that enables parallelized, GPU-accelerated, and differentiable galaxy image simulations. This is an early-stage project aiming to reimplement GalSim functionalities in pure JAX. - -**Key Design Principles:** -- Drop-in replacement for GalSim with a close API match -- Each function/feature is tested against the reference GalSim implementation -- This is a **subset** of GalSim (only includes functions with a reference implementation) -- Code should be readable and pip-installable without compilation - -**Current Status:** Early development phase (v0.0.1rc1). Not for scientific applications yet - use the reference GalSim implementation for production work. - -## Installation and Setup - -**Recommended:** Use a virtual environment to isolate dependencies: - -```bash -# Clone with submodules (required for tests) -git clone --recurse-submodules https://github.com/YOUR_USERNAME/JAX-GalSim -cd JAX-GalSim - -# Create and activate a virtual environment (recommended) -python -m venv venv -source venv/bin/activate # On Windows: venv\Scripts\activate - -# Install in editable mode -pip install -e . - -# Install development tools -pip install black pre-commit pytest -pre-commit install -``` - -## Testing - -```bash -# Run all tests (includes both GalSim reference tests and JAX-specific tests) -pytest - -# Test paths are configured in pytest.ini: -# - tests/GalSim/tests/ (reference GalSim tests) -# - tests/jax (JAX-specific tests like test_jitting.py) - -# Run specific test file -pytest tests/jax/test_jitting.py -``` - -## Code Formatting - -This project uses Black for code formatting: - -```bash -# Format all code -black . - -# Black excludes tests/GalSim/ directory (configured in .pre-commit-config.yaml) -``` - -**Important:** CI will fail if code is not formatted with Black. Use pre-commit hooks to automate this. - -## Architecture - -### Core Structure - -- `jax_galsim/` - Main package implementing JAX versions of GalSim objects - - `gsobject.py` - Base `GSObject` class that all galaxy profile objects inherit from - - `gsparams.py` - `GSParams` class for speed/accuracy trade-off parameters - - `gaussian.py`, `exponential.py` - Specific galaxy profile implementations - - `sum.py` - Composite objects (e.g., `Add`, `Sum`) - - `core/` - Core utilities (currently empty/minimal) - -- `tests/` - Test suite - - `GalSim/` - Git submodule containing the reference GalSim implementation for testing - - `jax/` - JAX-specific tests (e.g., JIT compilation tests) - -### JAX Pytree Registration - -All GSObject classes must be registered as JAX pytrees to support JIT compilation and automatic differentiation: - -```python -from jax.tree_util import register_pytree_node_class - -@register_pytree_node_class -class MyGSObject(GSObject): - def tree_flatten(self): - # Return (children, aux_data) where children are JAX arrays - # and aux_data contains static information - ... - - @classmethod - def tree_unflatten(cls, aux_data, children): - # Reconstruct object from aux_data and children - ... -``` - -### Documentation Pattern - -Avoid duplicating documentation. Use JAX's `_wraps` utility to inherit docs from GalSim: - -```python -from jax._src.numpy.util import _wraps -import galsim as _galsim - -@_wraps(_galsim.Gaussian) -@register_pytree_node_class -class Gaussian(GSObject): - ... - -# Or for functions with differences: -@_wraps(_galsim.Add, lax_description="Does not support `ChromaticObject` at this point.") -def Add(*args, **kwargs): - return Sum(*args, **kwargs) -``` - -The `lax_description` parameter documents any differences or limitations compared to GalSim. - -### GSObject Parameter Management - -GSObjects use a dual-parameter system: -- `_params` dict: Traced parameters (JAX arrays) that can be differentiated -- `_gsparams`: Static parameters (`GSParams` object) for numerical configurations - -Properties like `flux`, `sigma`, `half_light_radius` etc. are accessed via `self.params` dictionary. - -### Testing Against Reference GalSim - -Tests in `tests/GalSim/tests/` are from the reference GalSim implementation. JAX-GalSim objects are tested against these to ensure API compatibility and numerical accuracy. - -JAX-specific tests in `tests/jax/` verify JAX functionality like JIT compilation, differentiation, and pytree behavior. - -### How the Testing Infrastructure Works - -JAX-GalSim uses a **pytest hook system** to automatically run GalSim's test suite against JAX-GalSim implementations. This means you can reuse all of GalSim's existing tests without modification! - -#### The Mechanism - -**1. Import Replacement (`tests/conftest.py`)** - - The `pytest_pycollect_makemodule` hook intercepts test file loading - - Automatically replaces `import galsim` with `import jax_galsim` in all GalSim test files - - This happens transparently - no modification to GalSim test files needed! - -**2. Test Configuration (`tests/galsim_tests_config.yaml`)** - ```yaml - enabled_tests: - galsim: - - test_gaussian.py - - test_exponential.py - - "*" # Enable all GalSim tests - - allowed_failures: - - "module 'jax_galsim' has no attribute 'Airy'" - - "module 'jax_galsim.bessel' has no attribute 'j1'" - # ... list of expected failures for unimplemented features - ``` - - - `enabled_tests`: Lists which GalSim test files to run (`"*"` means all) - - `allowed_failures`: Error messages that won't fail the test suite (for tracking unimplemented features) - -**3. Test Execution Flow** - ``` - pytest tests/GalSim/tests/test_bessel.py - ↓ - pytest hook replaces: import galsim → import jax_galsim - ↓ - GalSim tests run against JAX-GalSim implementation - ↓ - Results compared with scipy.special / reference values - ↓ - PASS / FAIL / ALLOWED FAILURE - ``` - -#### Enabling Tests for New Functions - -When you implement a new function in JAX-GalSim, follow these steps to enable its tests: - -**Example: Adding `bessel.kn` function** - -1. **Implement the function** in `jax_galsim/bessel.py`: - ```python - from jax_galsim.core.utils import implements - import galsim as _galsim - - @implements(_galsim.bessel.kn) - @jax.jit - def kn(n, x): - """Modified Bessel function K_n for integer n""" - # ... implementation ... - return result - ``` - -2. **Remove from allowed_failures** in `tests/galsim_tests_config.yaml`: - ```yaml - allowed_failures: - # DELETE or comment out this line: - # - "module 'jax_galsim.bessel' has no attribute 'kn'" - ``` - -3. **Run the tests**: - ```bash - pytest tests/GalSim/tests/test_bessel.py::test_kn -v - ``` - -4. **Test outcomes**: - - **PASS**: Your implementation matches GalSim's accuracy ✅ - - **FAIL**: Numerical accuracy issues - fix your implementation - - **ERROR**: API mismatch - check function signature and behavior - -#### Finding Which Tests Will Run - -To see what GalSim tests exist for a module: - -```bash -# List all bessel tests -grep "^def test_" tests/GalSim/tests/test_bessel.py - -# Example output: -# def test_j0(): -# def test_j1(): -# def test_kn(): -# def test_kv(): -# ... etc -``` - -Each `test_*` function will automatically run against your JAX-GalSim implementation when enabled! - -#### Tracking Progress - -```bash -# Run all tests and see summary -pytest tests/GalSim/tests/ -v - -# Common output: -# ✅ 25 passed - Implementations working correctly -# ❌ 3 failed - Implementations with accuracy issues -# ⚠️ 100 allowed - Features not yet implemented -``` - -This gives you clear visibility into: -- What's working (passing tests) -- What needs fixing (failing tests) -- What's not implemented yet (allowed failures) - -#### Debugging Failed Tests - -When a test fails, pytest shows: -- **Expected values**: From GalSim/scipy -- **Actual values**: From your JAX-GalSim implementation -- **Tolerance**: Typically `rtol=1e-10` (10 decimal places) - -Example failure: -```python -AssertionError: -Not equal to tolerance rtol=1e-10 -ACTUAL: [18.24, 2.146, 45.04, ...] -DESIRED: [11.90, 2.146, 37.79, ...] -``` - -This tells you exactly which test cases have accuracy problems. - -## Contributing Workflow - -1. Fork and clone with `--recurse-submodules` -2. Create a feature branch: `git checkout -b descriptive-name` -3. Make changes and ensure tests pass: `pytest` -4. Format code: `black .` -5. Update `CHANGELOG.md` -6. Add BSD license header to new files -7. Squash commits if needed: `git rebase -i` -8. Open PR against `main` branch - -**Before submitting:** -- Ensure tests pass -- Code is Black-formatted -- PR is self-contained and focused -- New functionality has tests -- Branch is up-to-date with upstream `main` - -## Submodule Management - -The `tests/GalSim` directory is a git submodule pointing to the reference GalSim implementation. When tests fail to run: - -```bash -# Initialize/update submodules -git submodule update --init --recursive -``` - -## Documentation Style - -Follow NumPy/SciPy documentation format: https://numpydoc.readthedocs.io/en/latest/format.html - -Prefer using `_wraps` to inherit GalSim documentation rather than copy/pasting. From c92305ef76af5a6b5413f4580b289a0ba9065ccb Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Sat, 7 Feb 2026 15:46:00 +0100 Subject: [PATCH 11/11] remove non-necessary code --- jax_galsim/bessel.py | 436 ++++++--------------------------- tests/galsim_tests_config.yaml | 1 + 2 files changed, 75 insertions(+), 362 deletions(-) diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index f2f7c759..87d09291 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -1,124 +1,92 @@ import galsim as _galsim import jax import jax.numpy as jnp -import jax.scipy.special as jsp from jax_galsim.core.utils import implements -# Chebyshev series evaluation -# Ported from SLATEC dcsevl function in BesselJ.cpp lines 1666-1676 -@jax.jit -def _dcsevl(x, cs): - """Evaluate Chebyshev series. - - Evaluates the N-term Chebyshev series cs at x using Clenshaw's - recurrence algorithm. Only half the first coefficient is summed. - - Args: - x: Value at which to evaluate series (should be in [-1, 1]) - cs: Array of Chebyshev series coefficients - - Returns: - Evaluated series value - """ - n = len(cs) - b0 = jnp.array(0.0) - b1 = jnp.array(0.0) - b2 = jnp.array(0.0) - twox = 2.0 * jnp.squeeze(x) - - def body_fn(i, carry): - b0, b1, b2 = carry - b2 = b1 - b1 = b0 - b0 = twox * b1 - b2 + cs[n - 1 - i] - return (b0, b1, b2) - - b0, b1, b2 = jax.lax.fori_loop(0, n, body_fn, (b0, b1, b2)) - return 0.5 * (b0 - b2) - - # the code here for Si, f, g and _si_small_pade is taken from galsim/src/math/Sinc.cpp @jax.jit def _f_pade(x, x2): # fmt: off y = 1. / x2 - return ( - (1. + # noqa: E226 - y*(7.44437068161936700618e2 + # noqa: E226 - y*(1.96396372895146869801e5 + # noqa: E226 - y*(2.37750310125431834034e7 + # noqa: E226 - y*(1.43073403821274636888e9 + # noqa: E226 - y*(4.33736238870432522765e10 + # noqa: E226 - y*(6.40533830574022022911e11 + # noqa: E226 - y*(4.20968180571076940208e12 + # noqa: E226 - y*(1.00795182980368574617e13 + # noqa: E226 - y*(4.94816688199951963482e12 + # noqa: E226 - y*(-4.94701168645415959931e11))))))))))) # noqa: E226 - / (x*(1. + # noqa: E226 - y*(7.46437068161927678031e2 + # noqa: E226 - y*(1.97865247031583951450e5 + # noqa: E226 - y*(2.41535670165126845144e7 + # noqa: E226 - y*(1.47478952192985464958e9 + # noqa: E226 - y*(4.58595115847765779830e10 + # noqa: E226 - y*(7.08501308149515401563e11 + # noqa: E226 - y*(5.06084464593475076774e12 + # noqa: E226 - y*(1.43468549171581016479e13 + # noqa: E226 - y*(1.11535493509914254097e13))))))))))) # noqa: E226 + f = ( + (1. + # noqa: W504, E126, E226 + y*(7.44437068161936700618e2 + # noqa: W504, E126, E226 + y*(1.96396372895146869801e5 + # noqa: W504, E126, E226 + y*(2.37750310125431834034e7 + # noqa: W504, E126, E226 + y*(1.43073403821274636888e9 + # noqa: W504, E126, E226 + y*(4.33736238870432522765e10 + # noqa: W504, E126, E226 + y*(6.40533830574022022911e11 + # noqa: W504, E126, E226 + y*(4.20968180571076940208e12 + # noqa: W504, E126, E226 + y*(1.00795182980368574617e13 + # noqa: W504, E126, E226 + y*(4.94816688199951963482e12 + # noqa: W504, E126, E226 + y*(-4.94701168645415959931e11))))))))))) # noqa: W504, E126, E226 + / (x*(1. + # noqa: W504, E126, E226 + y*(7.46437068161927678031e2 + # noqa: W504, E126, E226 + y*(1.97865247031583951450e5 + # noqa: W504, E126, E226 + y*(2.41535670165126845144e7 + # noqa: W504, E126, E226 + y*(1.47478952192985464958e9 + # noqa: W504, E126, E226 + y*(4.58595115847765779830e10 + # noqa: W504, E126, E226 + y*(7.08501308149515401563e11 + # noqa: W504, E126, E226 + y*(5.06084464593475076774e12 + # noqa: W504, E126, E226 + y*(1.43468549171581016479e13 + # noqa: W504, E126, E226 + y*(1.11535493509914254097e13))))))))))) # noqa: W504, E126, E226 ) # fmt: on + return f @jax.jit def _g_pade(x, x2): # fmt: off y = 1. / x2 - return ( - y*(1. + # noqa: E226 - y*(8.1359520115168615e2 + # noqa: E226 - y*(2.35239181626478200e5 + # noqa: E226 - y*(3.12557570795778731e7 + # noqa: E226 - y*(2.06297595146763354e9 + # noqa: E226 - y*(6.83052205423625007e10 + # noqa: E226 - y*(1.09049528450362786e12 + # noqa: E226 - y*(7.57664583257834349e12 + # noqa: E226 - y*(1.81004487464664575e13 + # noqa: E226 - y*(6.43291613143049485e12 + # noqa: E226 - y*(-1.36517137670871689e12))))))))))) # noqa: E226 - / (1. + # noqa: E226 - y*(8.19595201151451564e2 + # noqa: E226 - y*(2.40036752835578777e5 + # noqa: E226 - y*(3.26026661647090822e7 + # noqa: E226 - y*(2.23355543278099360e9 + # noqa: E226 - y*(7.87465017341829930e10 + # noqa: E226 - y*(1.39866710696414565e12 + # noqa: E226 - y*(1.17164723371736605e13 + # noqa: E226 - y*(4.01839087307656620e13 + # noqa: E226 - y*(3.99653257887490811e13)))))))))) # noqa: E226 + g = ( + y*(1. + # noqa: W504, E126, E226 + y*(8.1359520115168615e2 + # noqa: W504, E126, E226 + y*(2.35239181626478200e5 + # noqa: W504, E126, E226 + y*(3.12557570795778731e7 + # noqa: W504, E126, E226 + y*(2.06297595146763354e9 + # noqa: W504, E126, E226 + y*(6.83052205423625007e10 + # noqa: W504, E126, E226 + y*(1.09049528450362786e12 + # noqa: W504, E126, E226 + y*(7.57664583257834349e12 + # noqa: W504, E126, E226 + y*(1.81004487464664575e13 + # noqa: W504, E126, E226 + y*(6.43291613143049485e12 + # noqa: W504, E126, E226 + y*(-1.36517137670871689e12))))))))))) # noqa: W504, E126, E226 + / (1. + # noqa: W504, E126, E226 + y*(8.19595201151451564e2 + # noqa: W504, E126, E226 + y*(2.40036752835578777e5 + # noqa: W504, E126, E226 + y*(3.26026661647090822e7 + # noqa: W504, E126, E226 + y*(2.23355543278099360e9 + # noqa: W504, E126, E226 + y*(7.87465017341829930e10 + # noqa: W504, E126, E226 + y*(1.39866710696414565e12 + # noqa: W504, E126, E226 + y*(1.17164723371736605e13 + # noqa: W504, E126, E226 + y*(4.01839087307656620e13 + # noqa: W504, E126, E226 + y*(3.99653257887490811e13)))))))))) # noqa: W504, E126, E226 ) # fmt: on + return g @jax.jit def _si_small_pade(x, x2): # fmt: off return ( - x*(1. + # noqa: E226 - x2*(-4.54393409816329991e-2 + # noqa: E226 - x2*(1.15457225751016682e-3 + # noqa: E226 - x2*(-1.41018536821330254e-5 + # noqa: E226 - x2*(9.43280809438713025e-8 + # noqa: E226 - x2*(-3.53201978997168357e-10 + # noqa: E226 - x2*(7.08240282274875911e-13 + # noqa: E226 - x2*(-6.05338212010422477e-16)))))))) # noqa: E226 - / (1. + # noqa: E226 - x2*(1.01162145739225565e-2 + # noqa: E226 - x2*(4.99175116169755106e-5 + # noqa: E226 - x2*(1.55654986308745614e-7 + # noqa: E226 - x2*(3.28067571055789734e-10 + # noqa: E226 - x2*(4.5049097575386581e-13 + # noqa: E226 - x2*(3.21107051193712168e-16))))))) # noqa: E226 + x*(1. + # noqa: W504, E126, E226 + x2*(-4.54393409816329991e-2 + # noqa: W504, E126, E226 + x2*(1.15457225751016682e-3 + # noqa: W504, E126, E226 + x2*(-1.41018536821330254e-5 + # noqa: W504, E126, E226 + x2*(9.43280809438713025e-8 + # noqa: W504, E126, E226 + x2*(-3.53201978997168357e-10 + # noqa: W504, E126, E226 + x2*(7.08240282274875911e-13 + # noqa: W504, E126, E226 + x2*(-6.05338212010422477e-16)))))))) # noqa: W504, E126, E226 + / (1. + # noqa: W504, E126, E226 + x2*(1.01162145739225565e-2 + # noqa: W504, E126, E226 + x2*(4.99175116169755106e-5 + # noqa: W504, E126, E226 + x2*(1.55654986308745614e-7 + # noqa: W504, E126, E226 + x2*(3.28067571055789734e-10 + # noqa: W504, E126, E226 + x2*(4.5049097575386581e-13 + # noqa: W504, E126, E226 + x2*(3.21107051193712168e-16))))))) # noqa: W504, E126, E226 ) # fmt: on @@ -136,191 +104,8 @@ def si(x): ) -# Modified Bessel K functions - ported from SLATEC in GalSim BesselK.cpp - - -@jax.jit -def _bessel_k0(x): - """Modified Bessel function K_0(x) for x > 0. - - Implements SLATEC dbesk0 using Chebyshev series for x <= 2 - and asymptotic expansion for x > 2. - - Reference: BesselK.cpp lines 253-284, 286-442 - """ - # Chebyshev coefficients for K_0 (small x) - # fmt: off - bk0cs = jnp.array([ - -0.0353273932339027687201140060063153, - 0.344289899924628486886344927529213, - 0.0359799365153615016265721303687231, - 0.00126461541144692592338479508673447, - 2.28621210311945178608269830297585e-5, - 2.53479107902614945730790013428354e-7, - 1.90451637722020885897214059381366e-9, - 1.03496952576336245851008317853089e-11, - 4.25981614279108257652445327170133e-14, - 1.3744654358807508969423832544e-16, - 3.57089652850837359099688597333333e-19, - 7.63164366011643737667498666666666e-22, - 1.36542498844078185908053333333333e-24, - 2.07527526690666808319999999999999e-27, - 2.7128142180729856e-30, - 3.08259388791466666666666666666666e-33, - ]) - - # Asymptotic coefficients for 2 < x <= 8 - ak0cs = jnp.array([ - -0.07643947903327941424082978270088, - -0.02235652605699819052023095550791, - 7.734181154693858235300618174047e-4, - -4.281006688886099464452146435416e-5, - 3.08170017386297474365001482666e-6, - -2.639367222009664974067448892723e-7, - 2.563713036403469206294088265742e-8, - -2.742705549900201263857211915244e-9, - 3.169429658097499592080832873403e-10, - -3.902353286962184141601065717962e-11, - 5.068040698188575402050092127286e-12, - -6.889574741007870679541713557984e-13, - 9.744978497825917691388201336831e-14, - -1.427332841884548505389855340122e-14, - 2.156412571021463039558062976527e-15, - -3.34965425514956277218878205853e-16, - 5.335260216952911692145280392601e-17, - -8.693669980890753807639622378837e-18, - 1.446404347862212227887763442346e-18, - ]) - - # Asymptotic coefficients for x > 8 - ak02cs = jnp.array([ - -0.01201869826307592239839346212452, - -0.009174852691025695310652561075713, - 1.444550931775005821048843878057e-4, - -4.013614175435709728671021077879e-6, - 1.567831810852310672590348990333e-7, - -7.77011043852173771031579975446e-9, - 4.611182576179717882533130529586e-10, - -3.158592997860565770526665803309e-11, - 2.435018039365041127835887814329e-12, - -2.074331387398347897709853373506e-13, - 1.925787280589917084742736504693e-14, - -1.927554805838956103600347182218e-15, - 2.062198029197818278285237869644e-16, - -2.341685117579242402603640195071e-17, - 2.805902810643042246815178828458e-18, - ]) - # fmt: on - - def k0_small(x): - xsml = jnp.sqrt(4.0 * jnp.finfo(jnp.float64).eps) - y = jnp.where(x > xsml, x * x, 0.0) - return -jnp.log(0.5 * x) * jsp.i0(x) - 0.25 + _dcsevl(0.5 * y - 1.0, bk0cs) - - def k0_medium(x): - return jnp.exp(-x) * ( - (_dcsevl((16.0 / x - 5.0) / 3.0, ak0cs) + 1.25) / jnp.sqrt(x) - ) - - def k0_large(x): - return jnp.exp(-x) * ((_dcsevl(16.0 / x - 1.0, ak02cs) + 1.25) / jnp.sqrt(x)) - - idx = jnp.where(x <= 2.0, 0, jnp.where(x <= 8.0, 1, 2)) - return jax.lax.switch(idx, [k0_small, k0_medium, k0_large], x) - - -@jax.jit -def _bessel_k1(x): - """Modified Bessel function K_1(x) for x > 0. - - Implements SLATEC dbesk1 using Chebyshev series for x <= 2 - and asymptotic expansion for x > 2. - - Reference: BesselK.cpp lines 480-514, 516-655 - """ - # Chebyshev coefficients for K_1 (small x) - # fmt: off - bk1cs = jnp.array([ - 0.025300227338947770532531120868533, - -0.35315596077654487566723831691801, - -0.12261118082265714823479067930042, - -0.0069757238596398643501812920296083, - -1.7302889575130520630176507368979e-4, - -2.4334061415659682349600735030164e-6, - -2.2133876307347258558315252545126e-8, - -1.4114883926335277610958330212608e-10, - -6.6669016941993290060853751264373e-13, - -2.4274498505193659339263196864853e-15, - -7.023863479386287597178379712e-18, - -1.6543275155100994675491029333333e-20, - -3.2338347459944491991893333333333e-23, - -5.3312750529265274999466666666666e-26, - -7.5130407162157226666666666666666e-29, - -9.1550857176541866666666666666666e-32, - ]) - - # Asymptotic coefficients for 2 < x <= 8 - ak1cs = jnp.array([ - 0.27443134069738829695257666227266, - 0.07571989953199367817089237814929, - -0.0014410515564754061229853116175625, - 6.6501169551257479394251385477036e-5, - -4.3699847095201407660580845089167e-6, - 3.5402774997630526799417139008534e-7, - -3.3111637792932920208982688245704e-8, - 3.4459775819010534532311499770992e-9, - -3.8989323474754271048981937492758e-10, - 4.7208197504658356400947449339005e-11, - -6.047835662875356234537359156289e-12, - 8.1284948748658747888193837985663e-13, - -1.1386945747147891428923915951042e-13, - 1.654035840846228232597294820509e-14, - -2.4809025677068848221516010440533e-15, - 3.8292378907024096948429227299157e-16, - -6.0647341040012418187768210377386e-17, - 9.8324256232648616038194004650666e-18, - -1.6284168738284380035666620115626e-18, - ]) - - # Asymptotic coefficients for x > 8 - ak12cs = jnp.array([ - 0.06379308343739001036600488534102, - 0.02832887813049720935835030284708, - -2.475370673905250345414545566732e-4, - 5.771972451607248820470976625763e-6, - -2.068939219536548302745533196552e-7, - 9.739983441381804180309213097887e-9, - -5.585336140380624984688895511129e-10, - 3.732996634046185240221212854731e-11, - -2.825051961023225445135065754928e-12, - 2.372019002484144173643496955486e-13, - -2.176677387991753979268301667938e-14, - 2.157914161616032453939562689706e-15, - -2.290196930718269275991551338154e-16, - 2.582885729823274961919939565226e-17, - -3.07675264126846318762109817344e-18, - ]) - # fmt: on - - def k1_small(x): - xsml = 2.0 * jnp.sqrt(jnp.finfo(jnp.float64).eps) - y = jnp.where(x > xsml, x * x, 0.0) - return jnp.log(0.5 * x) * jsp.i1(x) + (_dcsevl(0.5 * y - 1.0, bk1cs) + 0.75) / x - - def k1_medium(x): - return jnp.exp(-x) * ( - (_dcsevl((16.0 / x - 5.0) / 3.0, ak1cs) + 1.25) / jnp.sqrt(x) - ) - - def k1_large(x): - return jnp.exp(-x) * ((_dcsevl(16.0 / x - 1.0, ak12cs) + 1.25) / jnp.sqrt(x)) - - idx = jnp.where(x <= 2.0, 0, jnp.where(x <= 8.0, 1, 2)) - return jax.lax.switch(idx, [k1_small, k1_medium, k1_large], x) - - # ===================================================================== -# Modified Bessel K_v(x) for fractional order +# Modified Bessel K_v(x) # # The functions below (_sqrt1px2, _evaluate_temme_coeffs, # _temme_series_kve, _continued_fraction_kve, _olver_kve, _temme_kve, @@ -713,87 +498,8 @@ def _kve_core(nu, x): return jnp.where(nu >= 50.0, olver_result, temme_result) -@jax.jit -def _bessel_kn_recurrence(n, x, k0_val, k1_val): - """Compute K_n(x) for integer n >= 2 using forward recurrence. - - Uses the recurrence relation: - K_{n+1}(x) = K_{n-1}(x) + (2*n/x) * K_n(x) - - For n <= 5, uses direct computation without a loop. - For n > 5, uses fori_loop with 399 iterations. - - Args: - n: Integer order (n >= 2) - x: Argument value - k0_val: Pre-computed K_0(x) - k1_val: Pre-computed K_1(x) - - Returns: - K_n(x) - """ - - def small_n(): - k2 = k0_val + (2.0 / x) * k1_val - k3 = k1_val + (4.0 / x) * k2 - k4 = k2 + (6.0 / x) * k3 - k5 = k3 + (8.0 / x) * k4 - return jnp.select( - [n == 2, n == 3, n == 4, n == 5], [k2, k3, k4, k5], default=k5 - ) - - def large_n(): - def body_fn(i, carry): - k_prev, k_curr = carry - should_update = i < n - k_next = jnp.where(should_update, k_prev + (2.0 * i / x) * k_curr, k_curr) - k_prev_new = jnp.where(should_update, k_curr, k_prev) - return (k_prev_new, k_next) - - _, k_n = jax.lax.fori_loop(1, 400, body_fn, (k0_val, k1_val)) - return k_n - - return jax.lax.cond(n <= 5, small_n, large_n) - - -@implements(_galsim.bessel.kn) -@jax.jit -def kn(n, x): - """Modified Bessel function of the second kind K_n(x) for integer n. - - This is a convenience wrapper that uses the integer-order implementations - for K_0, K_1, and recurrence for higher orders. - - Args: - n: Integer order (can be negative, K_{-n} = K_n) - x: Argument (must be positive) - - Returns: - K_n(x) - """ - n = jnp.abs(jnp.asarray(n, dtype=int)) # K_{-n} = K_n - x = 1.0 * x # promote to float - - k0 = _bessel_k0(x) - k1 = _bessel_k1(x) - - idx = jnp.where(n == 0, 0, jnp.where(n == 1, 1, 2)) - return jax.lax.switch( - idx, - [ - lambda n, x, k0, k1: k0, - lambda n, x, k0, k1: k1, - lambda n, x, k0, k1: _bessel_kn_recurrence(n, x, k0, k1), - ], - n, - x, - k0, - k1, - ) - - def _kv_scalar(nu, x): - """Scalar implementation of K_ν(x) using TFP-ported Temme + Olver algorithms.""" + """Scalar implementation of K_v(x) using TFP-ported Temme + Olver algorithms.""" nu = 1.0 * nu x = 1.0 * x nu = jnp.abs(nu) # K_{-v} = K_v @@ -813,7 +519,7 @@ def _kv_scalar(nu, x): @implements(_galsim.bessel.kv) @jax.jit def _kv_impl(nu, x): - """Modified Bessel function of the second kind K_ν(x) - internal implementation. + """Modified Bessel function of the second kind K_v(x) - internal implementation. Uses TFP-ported Temme + Olver algorithms for all orders. Handles both scalar and array inputs via jax.vmap. @@ -831,12 +537,12 @@ def _kv_impl(nu, x): @jax.custom_vjp def kv(nu, x): - """Modified Bessel function of the second kind K_ν(x) with custom gradients. + """Modified Bessel function of the second kind K_v(x) with custom gradients. Uses TFP-ported Temme + Olver algorithms. Custom gradients via: - ∂K_ν/∂x = -1/2 * (K_{ν-1}(x) + K_{ν+1}(x)) + dK_v/dx = -1/2 * (K_{v-1}(x) + K_{v+1}(x)) - Gradient w.r.t. ν is not supported (returns zero). + Gradient w.r.t. v is not supported (returns zero). """ return _kv_impl(nu, x) @@ -858,11 +564,17 @@ def _kv_bwd(residuals, g): kv.defvjp(_kv_fwd, _kv_bwd) +@jax.jit +def _R(z, num, denom): + return jnp.polyval(num, z) / jnp.polyval(denom, z) + + @jax.jit def _evaluate_rational(z, num, denom): - return jnp.polyval(num[::-1], z) / jnp.polyval(denom[::-1], z) + return _R(z, num[::-1], denom[::-1]) +# jitted & vectorized version _v_rational = jax.jit(jax.vmap(_evaluate_rational, in_axes=(0, None, None))) diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 30f6d9ab..0efc2e1b 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -67,6 +67,7 @@ allowed_failures: - "module 'jax_galsim.bessel' has no attribute 'yn'" - "module 'jax_galsim.bessel' has no attribute 'yv'" - "module 'jax_galsim.bessel' has no attribute 'iv'" + - "module 'jax_galsim.bessel' has no attribute 'kn'" - "module 'jax_galsim.bessel' has no attribute 'j0_root'" - "module 'jax_galsim.bessel' has no attribute 'gammainc'" - "module 'jax_galsim.bessel' has no attribute 'sinc'"