[Torch][LinalgExt] Lower flex_attention masks before online_attention#24426
Conversation
c3dcb86 to
f418178
Compare
Lower flex_attention mask_mod callbacks with broadcasted tensor index operands and pass the resulting mask as the online_attention mask operand. This avoids emitting rank-0 tensor payloads inside the score region while leaving score_mod on the existing in-region path. Co-authored-by: Codex <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
f418178 to
0fcaf28
Compare
Use torch-mlir's shared broadcast shape helper when cloning mask_mod ops with tensor index operands. This keeps the flex_attention mask materialization path aligned with existing Torch broadcast handling.\n\nCo-authored-by: Codex <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Add a reduced generic-vectorization test for the broadcastable index tensor shape emitted by flex_attention mask_mod lowering. Co-authored-by: Codex <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
sommerlukas
left a comment
There was a problem hiding this comment.
Implementation looks OK to me, just a few nits.
Would be good to get eyes on this from somebody with more experience in this area.
Co-authored-by: Codex <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
sommerlukas
left a comment
There was a problem hiding this comment.
Implementation LGTM, please wait for another review from someone with experience in this area, e.g., @IanWood1 or @rsuderman.
rsuderman
left a comment
There was a problem hiding this comment.
You need more tests. So far the additions are extremely minimal despite a substantial change to how the lowering works.
Factor the single-block mask callback clone/remap logic into a helper, add match-failure diagnostics for mask lowering failures, and extend lit coverage for mask broadcasting and GQA mask lowering. Co-authored-by: GPT-5 <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Run a module-level barrier after the flex_attention conversion pass so callback functions are not rewritten by later per-function Torch conversions before their users are processed. Add a full torch-to-iree test where mask_mod appears before the flex_attention user, covering the previous bool mask legalization failure. Co-authored-by: GPT-5 Codex <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Avoid making the module-level barrier comment specific to flex_attention; the barrier separates unstructured conversion from later per-function Torch conversions generally. Co-authored-by: GPT-5 Codex <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
c0f2974 to
e23f612
Compare
rsuderman
left a comment
There was a problem hiding this comment.
Look at how other passes handle failure. I see a lot of return failure(); which can make it difficult for future developers to debug what went wrong.
Use notifyMatchFailure for flex_attention callback and preprocessing rejects so failed rewrites produce actionable diagnostics instead of silent pattern failure. Co-authored-by: Codex <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
| TorchInput::createConvertTorchUnstructuredToLinalgExtPass()); | ||
| // Keep this as a module-level barrier so unstructured conversions finish | ||
| // across the module before later per-function Torch conversions run. | ||
| pm.addPass(createCanonicalizerPass()); |
There was a problem hiding this comment.
This is a red flag to me. Why are you needing a canonicalization pass here? Typically we note adding passes as cocnerning.
There was a problem hiding this comment.
I added the canonicalizer as a module-level barrier because I was worried about the nested per-function pipeline converting a flex_attention callback function before the function containing the hop_flex_attention user had lowered and inlined it.
But in the actual lowering emitted top down from torch's invokation, we're safe for now.
Remove the canonicalizer barrier that was only used to force module ordering between flex attention callbacks and their users. Also drop the torch-to-iree test that covered that artificial ordering case. Co-authored-by: OpenAI Codex <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Lower flex_attention mask_mod by evaluating the mask function over broadcastable index tensors and passing the resulting tensor mask into online_attention.
This is done to fix existing GPU compilation failure when I tested flex_attention with causal masks.
The old lowering built Q/K/V indices with
iree_linalg_ext.index -> tensor.from_elements. After decompositionthis left a mask-update
linalg.genericin the QK path that did not vectorize cleanly, introduced private-memory between QK and the reductions, and lead to the source layout being null downstream inVectorLayoutAnalysis