Skip to content

Add CUDA ring FFT backend for SHT#1165

Draft
ssmmnn11 wants to merge 2 commits into
mainfrom
feat/cu_fft_interface
Draft

Add CUDA ring FFT backend for SHT#1165
ssmmnn11 wants to merge 2 commits into
mainfrom
feat/cu_fft_interface

Conversation

@ssmmnn11

@ssmmnn11 ssmmnn11 commented Jun 3, 2026

Copy link
Copy Markdown
Member

Comment thread training/docs/modules/losses.rst Outdated
Base automatically changed from feat/graphs_sht_make_callable to main June 8, 2026 13:54
Comment thread models/src/anemoi/models/layers/spectral_transforms.py Outdated
Comment thread models/src/anemoi/models/layers/spectral_transforms.py Outdated
Comment thread models/src/anemoi/models/layers/spectral_transforms.py Outdated
Comment thread models/src/anemoi/models/layers/spectral_helpers.py Outdated
Comment thread models/src/anemoi/models/layers/spectral_helpers.py Outdated
Comment thread models/src/anemoi/models/layers/spectral_helpers.py Outdated
@samhatfield

Copy link
Copy Markdown
Collaborator

@ssmmnn11 any tips for the JIT compilation? I tried using SphericalHarmonicTransform with use_cuda_ring_fft=True but I get a gargantuan quantity of compilation errors, e.g.

RuntimeError: Error building extension 'anemoi_ring_fft': [1/3] nvc++ -MMD -MF ring_fft.o.d -DTORCH_EXTENSION_NAME=anemoi_ring_fft -DTORCH_API_INCLUDE_EXTENSION_H -I/path/to/nvidia/25.3/Linux_aarch64/25.3/math_libs/include -isystem /path/to/python_envs/ag/lib/python3.12/site-packages/torch/include -isystem /path/to/python_envs/ag/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /path/to/nvidia/25.3/Linux_aarch64/25.3/compilers/include -isystem /path/to/python3/3.12.11-01/include/python3.12 -fPIC -std=c++17 -O3 -DANEMOI_RING_FFT_ENABLE_CUFFT -c /path/to/anemoi-core/models/src/anemoi/models/layers/cuda/ring_fft.cpp -o ring_fft.o
FAILED: [code=2] ring_fft.o
nvc++ -MMD -MF ring_fft.o.d -DTORCH_EXTENSION_NAME=anemoi_ring_fft -DTORCH_API_INCLUDE_EXTENSION_H -I/path/to/nvidia/25.3/Linux_aarch64/25.3/math_libs/include -isystem /path/to/python_envs/ag/lib/python3.12/site-packages/torch/include -isystem /path/to/python_envs/ag/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /path/to/nvidia/25.3/Linux_aarch64/25.3/compilers/include -isystem /path/to/python3/3.12.11-01/include/python3.12 -fPIC -std=c++17 -O3 -DANEMOI_RING_FFT_ENABLE_CUFFT -c /path/to/anemoi-core/models/src/anemoi/models/layers/cuda/ring_fft.cpp -o ring_fft.o
"/path/to/python_envs/ag/lib/python3.12/site-packages/torch/include/torch/headeronly/util/Half.h", line 85: error: identifier "float16_t" is undefined                                                                                        
    inline Half(float16_t value);                                                                                                                                                                                                                            
                ^                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                             
"/path/to/python_envs/ag/lib/python3.12/site-packages/torch/include/torch/headeronly/util/Half.h", line 86: error: expected an operator                                                                                                       
    inline operator float16_t() const;                                                                                                                                                                                                                       
                    ^                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                             
"/path/to/python_envs/ag/lib/python3.12/site-packages/torch/include/torch/headeronly/util/Half.h", line 103: error: no suitable conversion function from "const c10::Half" to "float" exists                                                  
    out << (float)value;  

@samhatfield samhatfield force-pushed the feat/cu_fft_interface branch 2 times, most recently from efc1853 to 6ed66b6 Compare June 26, 2026 14:50
@samhatfield

Copy link
Copy Markdown
Collaborator

I've rebased the main commit of this branch on top of main to simplify the history. At the same time I removed some functionality temporarily just to make my life a bit easier in reviewing. The CUDA backend is now off by default and must be activated by explicitly passing use_cuda_ring_fft=True. This is currently overridden by use_graphed_rfft=True but we can discuss later the priority of these two options, including how to allow the user to control them both through env vars. I've also removed the "grouped" transform option as that was just an idea, not necessarily relevant anymore now that we have better options.

@samhatfield samhatfield force-pushed the feat/cu_fft_interface branch from 6ed66b6 to d201895 Compare June 26, 2026 14:59
This currently fails, indicating an issue with the CUDA backend.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: To be triaged

Development

Successfully merging this pull request may close these issues.

2 participants