Skip to content

[Torch][LinalgExt] Lower flex_attention masks before online_attention#24426

Open
keshavvinayak01 wants to merge 12 commits into
iree-org:mainfrom
keshavvinayak01:users/keshavvinayak01/scalarize-online-attention-mask-region
Open

[Torch][LinalgExt] Lower flex_attention masks before online_attention#24426
keshavvinayak01 wants to merge 12 commits into
iree-org:mainfrom
keshavvinayak01:users/keshavvinayak01/scalarize-online-attention-mask-region

Conversation

@keshavvinayak01
Copy link
Copy Markdown
Contributor

@keshavvinayak01 keshavvinayak01 commented May 9, 2026

Lower flex_attention mask_mod by evaluating the mask function over broadcastable index tensors and passing the resulting tensor mask into online_attention.

This is done to fix existing GPU compilation failure when I tested flex_attention with causal masks.

The old lowering built Q/K/V indices with iree_linalg_ext.index -> tensor.from_elements. After decomposition
this left a mask-update linalg.generic in the QK path that did not vectorize cleanly, introduced private-memory between QK and the reductions, and lead to the source layout being null downstream in VectorLayoutAnalysis

@keshavvinayak01 keshavvinayak01 force-pushed the users/keshavvinayak01/scalarize-online-attention-mask-region branch from c3dcb86 to f418178 Compare May 11, 2026 06:25
Lower flex_attention mask_mod callbacks with broadcasted tensor index operands and pass the resulting mask as the online_attention mask operand. This avoids emitting rank-0 tensor payloads inside the score region while leaving score_mod on the existing in-region path.

Co-authored-by: Codex <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 force-pushed the users/keshavvinayak01/scalarize-online-attention-mask-region branch from f418178 to 0fcaf28 Compare May 11, 2026 07:02
keshavvinayak01 and others added 3 commits May 11, 2026 12:50
Use torch-mlir's shared broadcast shape helper when cloning mask_mod ops with tensor index operands. This keeps the flex_attention mask materialization path aligned with existing Torch broadcast handling.\n\nCo-authored-by: Codex <noreply@openai.com>

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Add a reduced generic-vectorization test for the broadcastable index tensor shape emitted by flex_attention mask_mod lowering.

Co-authored-by: Codex <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 marked this pull request as ready for review May 11, 2026 09:12
Copy link
Copy Markdown
Contributor

@sommerlukas sommerlukas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation looks OK to me, just a few nits.

Would be good to get eyes on this from somebody with more experience in this area.

Co-authored-by: Codex <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Copy link
Copy Markdown
Contributor

@sommerlukas sommerlukas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation LGTM, please wait for another review from someone with experience in this area, e.g., @IanWood1 or @rsuderman.

Copy link
Copy Markdown
Collaborator

@rsuderman rsuderman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need more tests. So far the additions are extremely minimal despite a substantial change to how the lowering works.

Factor the single-block mask callback clone/remap logic into a helper, add match-failure diagnostics for mask lowering failures, and extend lit coverage for mask broadcasting and GQA mask lowering.

Co-authored-by: GPT-5 <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 requested a review from rsuderman May 13, 2026 05:50
keshavvinayak01 and others added 2 commits May 13, 2026 17:05
Run a module-level barrier after the flex_attention conversion pass so callback functions are not rewritten by later per-function Torch conversions before their users are processed.

Add a full torch-to-iree test where mask_mod appears before the flex_attention user, covering the previous bool mask legalization failure.

Co-authored-by: GPT-5 Codex <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Avoid making the module-level barrier comment specific to flex_attention; the barrier separates unstructured conversion from later per-function Torch conversions generally.

Co-authored-by: GPT-5 Codex <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 force-pushed the users/keshavvinayak01/scalarize-online-attention-mask-region branch from c0f2974 to e23f612 Compare May 13, 2026 17:30
Copy link
Copy Markdown
Collaborator

@rsuderman rsuderman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at how other passes handle failure. I see a lot of return failure(); which can make it difficult for future developers to debug what went wrong.

Use notifyMatchFailure for flex_attention callback and preprocessing rejects so failed rewrites produce actionable diagnostics instead of silent pattern failure.

Co-authored-by: Codex <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 requested a review from zjgarvey as a code owner May 14, 2026 17:09
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 requested a review from rsuderman May 14, 2026 21:39
TorchInput::createConvertTorchUnstructuredToLinalgExtPass());
// Keep this as a module-level barrier so unstructured conversions finish
// across the module before later per-function Torch conversions run.
pm.addPass(createCanonicalizerPass());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a red flag to me. Why are you needing a canonicalization pass here? Typically we note adding passes as cocnerning.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the canonicalizer as a module-level barrier because I was worried about the nested per-function pipeline converting a flex_attention callback function before the function containing the hop_flex_attention user had lowered and inlined it.

But in the actual lowering emitted top down from torch's invokation, we're safe for now.

Remove the canonicalizer barrier that was only used to force module ordering between flex attention callbacks and their users. Also drop the torch-to-iree test that covered that artificial ordering case.

Co-authored-by: OpenAI Codex <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 requested a review from rsuderman May 18, 2026 18:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants