Skip to content

Porting Bessel kv function implementation to reduce dependency on tfp#159

Merged
beckermr merged 14 commits into
mainfrom
jax_version
Feb 7, 2026
Merged

Porting Bessel kv function implementation to reduce dependency on tfp#159
beckermr merged 14 commits into
mainfrom
jax_version

Conversation

@EiffL

@EiffL EiffL commented Feb 6, 2026

Copy link
Copy Markdown
Member

This pull request introduces a new contributor guide for Claude, updates dependencies, and improves test coverage for JAX-GalSim. The most significant changes include adding a detailed CLAUDE.md file to help contributors and AI code assistants, updating the JAX dependency version, and enabling tests for the bessel.kn function by removing it from the list of allowed failures.

Contributor documentation and onboarding:

  • Added a comprehensive CLAUDE.md file with project overview, installation instructions, code formatting guidelines, architecture notes, testing infrastructure, and contribution workflow. This file is designed to help both human contributors and Claude code assistant understand and contribute to the project efficiently.

Dependency updates:

  • Updated the JAX dependency in pyproject.toml to require version >=0.7.0 instead of <0.7.0, ensuring compatibility with newer JAX releases.

Testing improvements:

  • Enabled testing for the bessel.kn function by removing it from the allowed_failures list in tests/galsim_tests_config.yaml, indicating that this function is now implemented and expected to pass its tests.

@EiffL EiffL linked an issue Feb 6, 2026 that may be closed by this pull request
@beckermr

beckermr commented Feb 6, 2026

Copy link
Copy Markdown
Collaborator

You won't be able to use the new Jax constraints until we fix the rng

@beckermr beckermr left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be careful directly porting the functions from galsim. Tensorflow has some special casing of derivatives for some special functions to maintain numerical stability.

@EiffL

EiffL commented Feb 6, 2026

Copy link
Copy Markdown
Member Author

Yes, very fair, this is a WIP, I'm relearning a bit how everything works

@codspeed-hq

codspeed-hq Bot commented Feb 6, 2026

Copy link
Copy Markdown

Merging this PR will degrade performance by 36.81%

⚡ 1 improved benchmark
❌ 5 regressed benchmarks
✅ 25 untouched benchmarks

⚠️ Please fix the performance issues or acknowledge them on CodSpeed.

Performance Changes

Benchmark BASE HEAD Efficiency
test_benchmarks_lanczos_interp[xval-conserve_dc-run] 721.9 µs 1,142.5 µs -36.81%
test_benchmark_spergel_kvalue[compile] 688.3 ms 563.2 ms +22.21%
test_benchmarks_lanczos_interp[xval-no_conserve_dc-run] 677.3 µs 958.1 µs -29.31%
test_benchmark_spergel_xvalue[run] 24 s 33.6 s -28.45%
test_benchmark_spergel_xvalue[compile] 28.8 s 38.4 s -25.02%
test_benchmark_spergel_calcfluxrad[run] 1.1 ms 1.5 ms -26.05%

Comparing jax_version (acf2ad6) with main (40845dd)

Open in CodSpeed

@EiffL

EiffL commented Feb 7, 2026

Copy link
Copy Markdown
Member Author

loool, not really an improvement in efficiency

@EiffL EiffL changed the title Adding missing bessel functions to enable tfp-free install with latest jax versions Porting Bessel kv function implementation to reduce dependency on tfp Feb 7, 2026
@beckermr

beckermr commented Feb 7, 2026

Copy link
Copy Markdown
Collaborator

I wouldn't worry about the benchmarks. There is a lot of variability even for the same code version and compile times are not as critical.

@EiffL EiffL marked this pull request as ready for review February 7, 2026 13:31
@EiffL

EiffL commented Feb 7, 2026

Copy link
Copy Markdown
Member Author

Ok, looks like these modifications are on par with the main branch in terms of speed, and this removes entirely the dependency on TFP.

@beckermr beckermr left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify relative to TFP what happens for orders 0 and 1?

Do they use special branches for those too? If so, do they use the same algorithms as the galsim code you ported?

@EiffL

EiffL commented Feb 7, 2026

Copy link
Copy Markdown
Member Author

In TFP there are no special cases for integer order:
https://github.com/tensorflow/probability/blob/65f265c62bb1e2d15ef3e25104afb245a6d52429/tensorflow_probability/python/math/bessel.py#L1106-L1132

Here, we don't actually need the _bessel_k0 and _bessel_k1 functions for bessel kv but I also added kn while I was at it. I think these implementations are probably more efficient than using kv for integer order.
This being said, we could remove all of this code, kn is not actually needed anywhere so far in jax-galsim.

@EiffL

EiffL commented Feb 7, 2026

Copy link
Copy Markdown
Member Author

ok, I cleaned things up. Only adding the ported code from TFP, remove additional implementation on kn functions. It's cleaner, we'll add them back when and if they are needed

@EiffL EiffL requested a review from beckermr February 7, 2026 15:30

@beckermr beckermr left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!

@beckermr beckermr merged commit 16e5e32 into main Feb 7, 2026
5 checks passed
@beckermr beckermr deleted the jax_version branch February 7, 2026 19:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

tfp and jax version issue

2 participants