Skip to content

[AIROCMLIR-707] Fix split-kv attention masking and sweep RMS for attention configs#2371

Open
bogdan-petkovic wants to merge 33 commits into
ROCm:developfrom
bogdan-petkovic:bogdan-petkovic/attn-splitkv-sweep-fix
Open

[AIROCMLIR-707] Fix split-kv attention masking and sweep RMS for attention configs#2371
bogdan-petkovic wants to merge 33 commits into
ROCm:developfrom
bogdan-petkovic:bogdan-petkovic/attn-splitkv-sweep-fix

Conversation

@bogdan-petkovic

@bogdan-petkovic bogdan-petkovic commented May 11, 2026

Copy link
Copy Markdown
Contributor

Motivation

Attention performance sweeps were failing on split-KV configurations (split_kv > 1) with RMS validation errors, NaNs, OOM crashes, or invalid results. The failures showed up across causal masking, GQA, KV-cache style current_seqlen, and mixed dtypes (f16, bf16, i8), including cases with trans_q and bias.

The goal is to make split-KV attention sweeps reliable without weakening the default kernel verifier policy. Fixes target:

  • GPU split-KV softmax updates that produced NaNs on fully masked or empty key partitions.
  • GPU split-KV iteration math that yielded zero iterations per split (and downstream 0/0 in scaleFinalOutput) for small seq_len_k with large split_kv.
  • Host reference masking and combine logic in rocmlir-gen that did not match split-KV behavior under causal masking.
  • Sweep validation tolerance so attention configs are checked against a realistic numeric band instead of failing on reference-path noise.
  • Sweep generator behavior so memory-heavy split-KV cases don't surface in CI as OOM-crashed sweep samples.

rocmlir-gen's default -RMS_threshold for fp16/bf16 without an explicit override stays 0.001. Only the attention sweep driver gets a separate default when the config does not set a threshold.

Technical Details

Kernel path: guard -inf - (-inf) in split-KV softmax updates
In mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp, split-KV attention rewrites update row-wise softmax state across KV partitions. When a partition has no valid keys, both the running row max and the partition max can be -inf, so exp2(score - max) and exp2(old_max - new_max) can see -inf - (-inf) and become NaN. That poisons row sums and downstream combine.

The change detects the case where both operands are -inf and uses 0 before exp2, so empty or fully masked partitions contribute zero instead of NaN. This is a correctness fix on the kernel path, not a verifier relaxation.

Kernel path: non-causal split-KV iteration math + scaleFinalOutput 0/0 guard
The non-causal / non-KV-cache split-KV branch in GridwiseAttentionAccelRewritePattern computed per-split iterations as gemm0M / (gemm0MPerBlock * splitKV) using truncating integer division. When gemm0MBlocks < splitKV (small seq_len_k with large split_kv), this evaluates to 0 for every split, so every split-block skips the softmax loop entirely. The kernel then divides the (zero) output by sum = 0 in scaleFinalOutput, producing NaNs that propagate through the host combine stage.

This is fixed by:

  • Using ceil-division to compute iterations per split, then clamping end to gemm0MBlocks so trailing splits where start >= gemm0MBlocks become cleanly empty.
  • A defensive 0/0 guard in scaleFinalOutput: when a row's sum is exactly zero, the per-split output stores 0 instead of NaN. The host combine stage already tolerates -inf max, but only when each per-split partial output is finite.

Host reference: causal valid-split masking
In mlir/tools/rocmlir-gen/rocmlir-gen.cpp, computeValidSplitKV() used a causal branch that forced currSeqLen = 0, which made every split invalid and did not reflect per-query causal reach. Causal masking now builds a per-(batch-head, query-row) valid split count from each row's effective key length (including prefix_offset when present). Non-causal configs keep the per-batch-head path. The usePerRowMask predicate also triggers for non-empty prefix_offset, mirroring how the kernel treats prefix-causal as causal.

createMaskSplitKV() accepts either per-batch-head or per-(batch-head, query-row) validSplitKV layouts and broadcasts the threshold tensor accordingly. computeFinalAttentionStage() still masks invalid splits on both the partial output and LSE tensors before the split-KV combine; its size assertion now accepts either layout.

Host reference: f32 combine for narrow floats
For f16 and bf16, the split-KV combine stage (reduce-max, exp, weighted sum, normalization) runs in f32 and casts the final combined result back to storage type. That reduces accumulation error in the reference path when comparing against the GPU kernel during sweeps.

Sweep policy: default attention RMS 0.005
In mlir/utils/performance/parameterSweeps.py, attention sweeps without an explicit -RMS_threshold now append -RMS_threshold 0.005 for all attention dtypes (bf16 keeps 0.01). The band covers observed sweep disagreement (including i8 with trans_q and high split_kv) while staying tighter than the old bf16-only default.

Configs that set their own -RMS_threshold are unchanged.

Sweep policy: widen RMS band for unscaled large-head_dim attention
Sampled configs with with_attn_scale=False and head_dim_qk > 64 saturate softmax (|QK| ~ O(sqrt(d)) collapses to near one-hot), so CPU vs GPU float-arithmetic ordering inside exp/accumulate dominates the diff (observed RMS up to ~6% in bf16, ~1% in f16/i8) independent of split_kv. For this regime only, test_config widens -RMS_threshold to 0.15 so the verifier still catches NaN/crash regressions but no longer false-fails on known float associativity in saturated softmax. All other configs are unaffected.

Sweep generator: device-memory-aware split-KV prefilter (folds in draft #2366)
In mlir/utils/performance/attentionSweeps.py, the sweep generator now estimates extra split-KV temporary storage for each sampled shape and rejects samples above a budget before generating MLIR. The budget defaults to deviceMem / 8, clamped to [1 GiB, 8 GiB], with a 1.5 GiB fallback if the HIP query fails. A --splitkv-extra-bytes-limit CLI override is available. Filter-out reasons are now tracked separately (MAX_TOKENS vs splitKV extra-storage) and reported cumulatively across initial and refill batches.

No compiler/verifier logic changes for this part; scope is limited to sweep sampling / filtering behavior.

Test Plan

  • Added targeted deterministic tests (lit + E2E) for each kernel/host fix; see Test Coverage below. PR CI runs the full build, these new lit tests, the new PrAttention{F16,BF16,F32} E2E configs (GPU), Python format/lint, the Python performance script tests, and the AttentionSweeps job.
  • Locally re-ran the original 8 failing May-3 sweep configs plus 5 newer May-14 sweep failures on gfx1201 (16 GiB) with all changes in place. 6/13 PASS on the kernel path. Of the remaining 7: 6 are correctly filtered out by the new split-KV prefilter (they were OOM-bound on a 16 GiB GPU because split_kv = 128 × large g × num_heads_q × head_dim_v blew past VRAM), and 1 (split_kv=1, bf16 trans-all) trips only the per-element relDiff check on near-zero outputs while passing both RMS and abs-diff — out of scope here and tracked separately.

Test Coverage

New and updated tests pin each fix above:

Lit — kernel path (IR lowering)

  • mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir and
    toblockwise_attention_accel_lowering.mlir (updated): add explicit
    CHECKs for the -inf - (-inf) softmax guard and the scaleFinalOutput
    0/0 guard, so the guarded select is verified in the emitted IR instead
    of the old unguarded store.
  • mlir/test/Dialect/Rock/gridwise_attention_splitkv_noncausal_lowering.mlir
    (new): pins the non-causal split-KV ceil-division + end clamp iteration
    math, covering the gemm0MBlocks < splitKV case.

Lit — host reference (rocmlir-gen)

  • mlir/test/rocmlir-gen/attention-splitkv-host-perrow-mask.mlir (new):
    pins the per-(batch-head, query-row) valid-split mask under causal +
    prefix_offset.
  • mlir/test/rocmlir-gen/attention-splitkv-host-f32-combine.mlir (new):
    pins the f16/bf16 → f32 combine and cast-back in the reference path.

E2E — GPU numeric (run in PR CI)

  • mlir/test/e2e/PrAttention{F16,BF16,F32}.toml (new configs): non-causal
    split-KV (basic, trailing/padded splits, GQA) and causal split-KV with
    seq_len_q > 1 (per-row mask, GQA, prefix_offset), exercising the
    guards and combine on real hardware.

The sweep-policy (RMS band) and split-KV memory prefilter changes remain
covered by the existing mlir/utils/performance pytest suite and the
AttentionSweeps job.

Test Result

  • PR CI
  • Nightly CI
  • Weekly CI

Submission Checklist

Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to make split-KV attention performance sweeps reliable by fixing NaN production in the split-KV kernel softmax update path, aligning rocmlir-gen’s host-side split-KV masking/combine behavior with causal/prefix-causal semantics, and updating the attention sweep driver’s default RMS threshold when none is explicitly provided.

Changes:

  • Add -inf - (-inf) guards in split-KV softmax state updates to prevent NaNs in fully-masked/empty KV partitions.
  • Update rocmlir-gen split-KV validity masking to support per-row causal masking and do split-KV combine math in f32 for fp16/bf16 storage.
  • Change attention sweep default -RMS_threshold injection to 0.005 (for attention sweeps without an explicit threshold).

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Prevent NaNs in split-KV softmax updates by guarding -inf - (-inf) cases.
mlir/tools/rocmlir-gen/rocmlir-gen.cpp Rework split-KV valid-split masking for causal cases and do f32 combine for fp16/bf16.
mlir/utils/performance/parameterSweeps.py Set a new default attention sweep RMS threshold (0.005) when none is specified.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlir/tools/rocmlir-gen/rocmlir-gen.cpp Outdated
Comment thread mlir/tools/rocmlir-gen/rocmlir-gen.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Outdated
Comment thread mlir/tools/rocmlir-gen/rocmlir-gen.cpp
bogdan-petkovic and others added 7 commits May 14, 2026 12:40
Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
@bogdan-petkovic bogdan-petkovic marked this pull request as ready for review May 15, 2026 12:18
@bogdan-petkovic bogdan-petkovic requested a review from causten as a code owner May 15, 2026 12:18
@umangyadav umangyadav added the claude-review Trigger automated PR review by claude[bot]; auto-removed after the run. label May 28, 2026
@rocmlir-pr-reviewer rocmlir-pr-reviewer Bot removed the claude-review Trigger automated PR review by claude[bot]; auto-removed after the run. label May 28, 2026
@rocmlir-pr-reviewer

Copy link
Copy Markdown

Claude auto-review skipped (fork PR)

This PR is from a fork, so the label-triggered Claude review can't run --
fork PRs don't have access to the LLM gateway secrets and can't be
reviewed by the standard label flow.

A maintainer can still produce a review one of two ways:

  1. Manually dispatch the review: open Actions → Claude Auto Review → Run workflow,
    enter 2371 as the PR number, and click Run workflow. The dispatch
    path runs from the default branch's context (so it has the gateway
    secrets) and posts review comments on this PR.
  2. Mirror the branch internally: push this PR's branch to a branch in
    this repo, open an internal PR from it, and apply the claude-review
    label on the internal PR.

The claude-review label has been removed from this PR. Re-applying it on a fork PR will just produce this same notice again.

Comment thread mlir/utils/performance/attentionSweeps.py
Comment thread mlir/tools/rocmlir-gen/rocmlir-gen.cpp
@rocmlir-pr-reviewer

Copy link
Copy Markdown

Reviewed 4 files in PR 2371 (split-KV attention correctness + sweep tolerance + memory-aware prefilter). Kernel-side -inf - (-inf) and 0/0 guards, ceil-div with end-clamp for non-causal split-KV iteration math, per-row causal validSplitKV layout, f32 host combine, and the new device-memory-aware sweep prefilter all look sound. Prior Copilot findings (per-row mask key, prefix currSeqLen clamp, zero/negInf hoisting, computeFinalAttentionStage assertion) appear addressed in the current diff. Posted 2 minor inline comments. Verdict: COMMENT.

Comment thread mlir/utils/performance/parameterSweeps.py
@rocmlir-pr-reviewer

Copy link
Copy Markdown

Reviewed 4 changed files (split-KV attention masking/iter-math kernel fixes, host-reference per-row valid-split + f32 combine, and sweep RMS / OOM prefilter policy). 1 new inline finding (Minor). 2 prior Claude findings (unmapped-dtype silent bypass in attentionSweeps.py; unreachable itersPerBlock==0 guard in rocmlir-gen.cpp) remain present and unchanged with no developer reply — suppressed per dedup rules. Verdict: COMMENT.

@rocmlir-pr-reviewer rocmlir-pr-reviewer Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verdict: COMMENT  ·  New findings: 0 (0 Critical, 0 Major, 0 Minor)


Scope

Re-review of split-KV attention masking + sweep RMS fixes across the kernel rewrite (GridwiseGemmToBlockwise.cpp), the host reference generator (rocmlir-gen.cpp), and the attention sweep driver (attentionSweeps.py, parameterSweeps.py).

Findings

No new blocking issues found in this revision. The three previously-flagged minor issues (silent dtype bypass in the splitKV extra-bytes estimator, unreachable itersPerBlock == 0 ternary guard, and the possibility of appending two -RMS_threshold entries when both branches fire) are still present in the diff but were already raised on prior runs; the existing inline threads remain visible.

Notes

  • The -inf - (-inf) NaN guards in expSubstractMaxFromGemm0 / updateRowSum are now correctly factored into selectZeroIfBothNegInf with loop-invariant negInfF / zeroF constants hoisted outside the loop body, addressing the earlier shadowing concern.
  • The per-row vs per-batch-head validSplitKV layout split, the usePerRowMask = causalMasking || !prefixOffset.empty() predicate, and the currSeqLen clamp removal all line up with how the kernel treats prefix-causal as causal.
  • f32 combine for fp16/bf16 in computeFinalAttentionStage is gated by storage type and the cast lambdas correctly no-op when types match.
  • Sweep default -RMS_threshold policy is dtype-conditioned and configs that set their own threshold are left untouched.

CI status

No non-self failing checks in /tmp/pr/checks.json; Jenkins / ml-ci / parameter sweeps / MIGraphX are still in-progress or pending. Build and Test, Python format/lint, and Python performance script tests have passed.

Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>

@rocmlir-pr-reviewer rocmlir-pr-reviewer Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verdict: APPROVE -- submitted as COMMENT (automated reviews are advisory)  ·  New findings: 0 (0 Critical, 0 Major, 0 Minor)


Scope

Split-KV attention correctness + sweep tolerance fixes across GridwiseGemmToBlockwise.cpp (kernel: NaN guards in softmax updates, ceil-div iteration math, 0/0 guard in scaleFinalOutput), rocmlir-gen.cpp (host reference: per-row causal/prefix-causal valid-split masking, f32 combine for narrow floats, updated assertions), attentionSweeps.py (device-memory-aware split-KV prefilter, separate filter-count buckets, CLI override), and parameterSweeps.py (dtype-aware default RMS thresholds + widened band for saturated softmax).

Findings

No new blocking issues found on this revision.

Notes

All three prior Claude inline comments have been addressed in a44bf2224:

  • Unmapped-dtype gap in _estimate_splitkv_extra_bytes is now caught by a module-load assert all(dt in DTYPE_BYTES for dt in DATA_TYPES_ATTENTION).
  • The unreachable itersPerBlock == 0 ? 0 : ... guard in computeValidSplitKV was removed and replaced with an explicit assert(mPerBlock > 0, ...) at function entry, making the invariant the contract relies on.
  • test_config in parameterSweeps.py is now structured as an if/elif chain (saturated-softmax case first, then dtype-specific defaults), so exactly one -RMS_threshold is appended and precedence is explicit in Python rather than relying on cl::opt<double> last-wins semantics.

Spot-checked: hoisted negInfF/zeroF constants are correctly outside the loop body in both expSubstractMaxFromGemm0 and updateRowSum; selectZeroIfBothNegInf helper centralizes the NaN guard; computeFinalAttentionStage assertion now accepts both per-batch-head and per-(batch-head, query-row) validSplitKV layouts; createMaskSplitKV correctly broadcasts either layout.

CI status

All non-self checks in /tmp/pr/checks.json are passing (Python performance script tests, Python format/lint, detect).

@bogdan-petkovic bogdan-petkovic self-assigned this Jun 1, 2026
Type ldGemm0OutElemType = getElementTypeOrSelf(gemm0Out.getType());
Value ldGemm0Out = InBoundsLoadOp::create(
rewriter, loc, ldGemm0OutElemType, gemm0Out, gemm0OutCoords);
Value ldGemm0Out = InBoundsLoadOp::create(rewriter, loc, gemm0OutElemType,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add some lit tests and E2E tests that would exercise this path ?

@umangyadav umangyadav left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to have a seperate PRs for each bug/fixes. Okay to land this PR as it is

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants