Skip to content

feat: flex-attention backend#1161

Draft
cathalobrien wants to merge 4 commits into
mainfrom
feat/flex-attention
Draft

feat: flex-attention backend#1161
cathalobrien wants to merge 4 commits into
mainfrom
feat/flex-attention

Conversation

@cathalobrien

Copy link
Copy Markdown
Contributor

This PR adds a flex attention backend to the attention layers.

Flex attention has been available in pytorch since ~2.6 but it wasn't stable for our use case. Now it seems more stable (I tested aifs-single v2 with torch 2.11 and it worked fine).

The specific use case here was that a user running on CPU installed the triton cpu build and wanted to use flex attention rather than SDPA.

I have also enabled support for the flash attention v4 backend of flex attention by checking if flash attention v4 is installed and passing the right kernel options to flex attention so it can use it.

Performance is slightly slower then flash attention v2 (0.34 it/s vs 0.38 it/s), but much faster then sdpa.

I added a correctness test which compares against sdpa, and a small benchmark test which compares all 3 attention kernels standalone.

anemoi-core/models/tests/integration/attention/test_benchmark_attention.py::test_attention_backend_benchmark[forward-global] Attention benchmark: window_size=None, dtype=torch.float16, shape=(1, 16, 40320, 32)
  flex: 36.25 ms
  sdpa: 27.32 ms
  flash: 21.16 ms
PASSED
anemoi-core/models/tests/integration/attention/test_benchmark_attention.py::test_attention_backend_benchmark[forward-sliding_window] Attention benchmark: window_size=1120, dtype=torch.float16, shape=(1, 16, 40320, 32)
  flex: 4.88 ms
  sdpa: 112.95 ms
  flash: 1.95 ms
PASSED
anemoi-core/models/tests/integration/attention/test_benchmark_attention.py::test_attention_backend_benchmark[fowrward & backward-global] Attention benchmark: window_size=None, dtype=torch.float16, shape=(1, 16, 40320, 32)
  flex: 28.18 ms
  sdpa: 26.16 ms
  flash: 21.11 ms
PASSED
anemoi-core/models/tests/integration/attention/test_benchmark_attention.py::test_attention_backend_benchmark[fowrward & backward-sliding_window] Attention benchmark: window_size=1120, dtype=torch.float16, shape=(1, 16, 40320, 32)
  flex: 5.04 ms
  sdpa: 114.19 ms
  flash: 1.96 ms
PASSED

@github-project-automation github-project-automation Bot moved this to To be triaged in Anemoi-dev Jun 2, 2026
@github-actions github-actions Bot added models and removed models labels Jun 2, 2026
@cathalobrien cathalobrien changed the title flex attention implemention. supports sliding window and optional fla… feat/flex-attention backend Jun 2, 2026
@cathalobrien cathalobrien requested a review from HCookie June 2, 2026 14:45
@cathalobrien cathalobrien added models ATS Approval Not Needed No approval needed by ATS labels Jun 2, 2026
@cathalobrien cathalobrien changed the title feat/flex-attention backend (Feat):flex-attention backend Jun 2, 2026
@HCookie HCookie changed the title (Feat):flex-attention backend feat: flex-attention backend Jun 2, 2026
@HCookie HCookie moved this from To be triaged to Reviewers needed in Anemoi-dev Jun 2, 2026

@HCookie HCookie left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work, looks good to me


# Try import flash attention v4
# if this is avilable it can be used as a backend for flex attention which gives approx 2x performance
# One reason to use flex attention with the flash attewntion v4 backend, ratehr then using flash attention v4 directly, is

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# One reason to use flex attention with the flash attewntion v4 backend, ratehr then using flash attention v4 directly, is
# One reason to use flex attention with the flash attention v4 backend, rather then using flash attention v4 directly, is

# if this is avilable it can be used as a backend for flex attention which gives approx 2x performance
# One reason to use flex attention with the flash attewntion v4 backend, ratehr then using flash attention v4 directly, is
# flex attentions support for custom block masks.
# if flash attention is not available then the trion backend will be used for flex attention

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# if flash attention is not available then the trion backend will be used for flex attention
# if flash attention is not available then the triton backend will be used for flex attention

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ATS Approval Not Needed No approval needed by ATS models

Projects

Status: Reviewers needed

Development

Successfully merging this pull request may close these issues.

3 participants