[AIROCMLIR-707] Fix split-kv attention masking and sweep RMS for attention configs#2371
[AIROCMLIR-707] Fix split-kv attention masking and sweep RMS for attention configs#2371bogdan-petkovic wants to merge 33 commits into
Conversation
Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>
There was a problem hiding this comment.
Pull request overview
This PR aims to make split-KV attention performance sweeps reliable by fixing NaN production in the split-KV kernel softmax update path, aligning rocmlir-gen’s host-side split-KV masking/combine behavior with causal/prefix-causal semantics, and updating the attention sweep driver’s default RMS threshold when none is explicitly provided.
Changes:
- Add
-inf - (-inf)guards in split-KV softmax state updates to prevent NaNs in fully-masked/empty KV partitions. - Update rocmlir-gen split-KV validity masking to support per-row causal masking and do split-KV combine math in f32 for fp16/bf16 storage.
- Change attention sweep default
-RMS_thresholdinjection to0.005(for attention sweeps without an explicit threshold).
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp |
Prevent NaNs in split-KV softmax updates by guarding -inf - (-inf) cases. |
mlir/tools/rocmlir-gen/rocmlir-gen.cpp |
Rework split-KV valid-split masking for causal cases and do f32 combine for fp16/bf16. |
mlir/utils/performance/parameterSweeps.py |
Set a new default attention sweep RMS threshold (0.005) when none is specified. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Claude auto-review skipped (fork PR)This PR is from a fork, so the label-triggered Claude review can't run -- A maintainer can still produce a review one of two ways:
The |
|
Reviewed 4 files in PR 2371 (split-KV attention correctness + sweep tolerance + memory-aware prefilter). Kernel-side |
|
Reviewed 4 changed files (split-KV attention masking/iter-math kernel fixes, host-reference per-row valid-split + f32 combine, and sweep RMS / OOM prefilter policy). 1 new inline finding (Minor). 2 prior Claude findings (unmapped-dtype silent bypass in attentionSweeps.py; unreachable itersPerBlock==0 guard in rocmlir-gen.cpp) remain present and unchanged with no developer reply — suppressed per dedup rules. Verdict: COMMENT. |
There was a problem hiding this comment.
Verdict: COMMENT · New findings: 0 (0 Critical, 0 Major, 0 Minor)
Scope
Re-review of split-KV attention masking + sweep RMS fixes across the kernel rewrite (GridwiseGemmToBlockwise.cpp), the host reference generator (rocmlir-gen.cpp), and the attention sweep driver (attentionSweeps.py, parameterSweeps.py).
Findings
No new blocking issues found in this revision. The three previously-flagged minor issues (silent dtype bypass in the splitKV extra-bytes estimator, unreachable itersPerBlock == 0 ternary guard, and the possibility of appending two -RMS_threshold entries when both branches fire) are still present in the diff but were already raised on prior runs; the existing inline threads remain visible.
Notes
- The
-inf - (-inf)NaN guards inexpSubstractMaxFromGemm0/updateRowSumare now correctly factored intoselectZeroIfBothNegInfwith loop-invariantnegInfF/zeroFconstants hoisted outside the loop body, addressing the earlier shadowing concern. - The per-row vs per-batch-head
validSplitKVlayout split, theusePerRowMask = causalMasking || !prefixOffset.empty()predicate, and thecurrSeqLenclamp removal all line up with how the kernel treats prefix-causal as causal. - f32 combine for fp16/bf16 in
computeFinalAttentionStageis gated by storage type and the cast lambdas correctly no-op when types match. - Sweep default
-RMS_thresholdpolicy is dtype-conditioned and configs that set their own threshold are left untouched.
CI status
No non-self failing checks in /tmp/pr/checks.json; Jenkins / ml-ci / parameter sweeps / MIGraphX are still in-progress or pending. Build and Test, Python format/lint, and Python performance script tests have passed.
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
There was a problem hiding this comment.
Verdict: APPROVE -- submitted as COMMENT (automated reviews are advisory) · New findings: 0 (0 Critical, 0 Major, 0 Minor)
Scope
Split-KV attention correctness + sweep tolerance fixes across GridwiseGemmToBlockwise.cpp (kernel: NaN guards in softmax updates, ceil-div iteration math, 0/0 guard in scaleFinalOutput), rocmlir-gen.cpp (host reference: per-row causal/prefix-causal valid-split masking, f32 combine for narrow floats, updated assertions), attentionSweeps.py (device-memory-aware split-KV prefilter, separate filter-count buckets, CLI override), and parameterSweeps.py (dtype-aware default RMS thresholds + widened band for saturated softmax).
Findings
No new blocking issues found on this revision.
Notes
All three prior Claude inline comments have been addressed in a44bf2224:
- Unmapped-dtype gap in
_estimate_splitkv_extra_bytesis now caught by a module-loadassert all(dt in DTYPE_BYTES for dt in DATA_TYPES_ATTENTION). - The unreachable
itersPerBlock == 0 ? 0 : ...guard incomputeValidSplitKVwas removed and replaced with an explicitassert(mPerBlock > 0, ...)at function entry, making the invariant the contract relies on. test_configinparameterSweeps.pyis now structured as an if/elif chain (saturated-softmax case first, then dtype-specific defaults), so exactly one-RMS_thresholdis appended and precedence is explicit in Python rather than relying oncl::opt<double>last-wins semantics.
Spot-checked: hoisted negInfF/zeroF constants are correctly outside the loop body in both expSubstractMaxFromGemm0 and updateRowSum; selectZeroIfBothNegInf helper centralizes the NaN guard; computeFinalAttentionStage assertion now accepts both per-batch-head and per-(batch-head, query-row) validSplitKV layouts; createMaskSplitKV correctly broadcasts either layout.
CI status
All non-self checks in /tmp/pr/checks.json are passing (Python performance script tests, Python format/lint, detect).
| Type ldGemm0OutElemType = getElementTypeOrSelf(gemm0Out.getType()); | ||
| Value ldGemm0Out = InBoundsLoadOp::create( | ||
| rewriter, loc, ldGemm0OutElemType, gemm0Out, gemm0OutCoords); | ||
| Value ldGemm0Out = InBoundsLoadOp::create(rewriter, loc, gemm0OutElemType, |
There was a problem hiding this comment.
can you add some lit tests and E2E tests that would exercise this path ?
umangyadav
left a comment
There was a problem hiding this comment.
Would be good to have a seperate PRs for each bug/fixes. Okay to land this PR as it is
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Motivation
Attention performance sweeps were failing on split-KV configurations (
split_kv > 1) with RMS validation errors, NaNs, OOM crashes, or invalid results. The failures showed up across causal masking, GQA, KV-cache stylecurrent_seqlen, and mixed dtypes (f16,bf16,i8), including cases withtrans_qand bias.The goal is to make split-KV attention sweeps reliable without weakening the default kernel verifier policy. Fixes target:
scaleFinalOutput) for smallseq_len_kwith largesplit_kv.rocmlir-genthat did not match split-KV behavior under causal masking.rocmlir-gen's default-RMS_thresholdfor fp16/bf16 without an explicit override stays0.001. Only the attention sweep driver gets a separate default when the config does not set a threshold.Technical Details
Kernel path: guard
-inf - (-inf)in split-KV softmax updatesIn
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp, split-KV attention rewrites update row-wise softmax state across KV partitions. When a partition has no valid keys, both the running row max and the partition max can be-inf, soexp2(score - max)andexp2(old_max - new_max)can see-inf - (-inf)and become NaN. That poisons row sums and downstream combine.The change detects the case where both operands are
-infand uses0beforeexp2, so empty or fully masked partitions contribute zero instead of NaN. This is a correctness fix on the kernel path, not a verifier relaxation.Kernel path: non-causal split-KV iteration math +
scaleFinalOutput0/0 guardThe non-causal / non-KV-cache split-KV branch in
GridwiseAttentionAccelRewritePatterncomputed per-split iterations asgemm0M / (gemm0MPerBlock * splitKV)using truncating integer division. Whengemm0MBlocks < splitKV(smallseq_len_kwith largesplit_kv), this evaluates to0for every split, so every split-block skips the softmax loop entirely. The kernel then divides the (zero) output bysum = 0inscaleFinalOutput, producing NaNs that propagate through the host combine stage.This is fixed by:
endtogemm0MBlocksso trailing splits wherestart >= gemm0MBlocksbecome cleanly empty.0/0guard inscaleFinalOutput: when a row's sum is exactly zero, the per-split output stores0instead ofNaN. The host combine stage already tolerates-infmax, but only when each per-split partial output is finite.Host reference: causal valid-split masking
In
mlir/tools/rocmlir-gen/rocmlir-gen.cpp,computeValidSplitKV()used a causal branch that forcedcurrSeqLen = 0, which made every split invalid and did not reflect per-query causal reach. Causal masking now builds a per-(batch-head, query-row) valid split count from each row's effective key length (includingprefix_offsetwhen present). Non-causal configs keep the per-batch-head path. TheusePerRowMaskpredicate also triggers for non-emptyprefix_offset, mirroring how the kernel treats prefix-causal as causal.createMaskSplitKV()accepts either per-batch-head or per-(batch-head, query-row)validSplitKVlayouts and broadcasts the threshold tensor accordingly.computeFinalAttentionStage()still masks invalid splits on both the partial output and LSE tensors before the split-KV combine; its size assertion now accepts either layout.Host reference: f32 combine for narrow floats
For
f16andbf16, the split-KV combine stage (reduce-max, exp, weighted sum, normalization) runs inf32and casts the final combined result back to storage type. That reduces accumulation error in the reference path when comparing against the GPU kernel during sweeps.Sweep policy: default attention RMS 0.005
In
mlir/utils/performance/parameterSweeps.py, attention sweeps without an explicit-RMS_thresholdnow append-RMS_threshold 0.005for all attention dtypes (bf16keeps0.01). The band covers observed sweep disagreement (including i8 withtrans_qand highsplit_kv) while staying tighter than the old bf16-only default.Configs that set their own
-RMS_thresholdare unchanged.Sweep policy: widen RMS band for unscaled large-head_dim attention
Sampled configs with
with_attn_scale=Falseandhead_dim_qk > 64saturate softmax (|QK| ~ O(sqrt(d))collapses to near one-hot), so CPU vs GPU float-arithmetic ordering insideexp/accumulate dominates the diff (observed RMS up to ~6% in bf16, ~1% in f16/i8) independent ofsplit_kv. For this regime only,test_configwidens-RMS_thresholdto0.15so the verifier still catches NaN/crash regressions but no longer false-fails on known float associativity in saturated softmax. All other configs are unaffected.Sweep generator: device-memory-aware split-KV prefilter (folds in draft #2366)
In
mlir/utils/performance/attentionSweeps.py, the sweep generator now estimates extra split-KV temporary storage for each sampled shape and rejects samples above a budget before generating MLIR. The budget defaults todeviceMem / 8, clamped to[1 GiB, 8 GiB], with a1.5 GiBfallback if the HIP query fails. A--splitkv-extra-bytes-limitCLI override is available. Filter-out reasons are now tracked separately (MAX_TOKENSvssplitKV extra-storage) and reported cumulatively across initial and refill batches.No compiler/verifier logic changes for this part; scope is limited to sweep sampling / filtering behavior.
Test Plan
PrAttention{F16,BF16,F32}E2E configs (GPU), Python format/lint, the Python performance script tests, and theAttentionSweepsjob.gfx1201(16 GiB) with all changes in place. 6/13 PASS on the kernel path. Of the remaining 7: 6 are correctly filtered out by the new split-KV prefilter (they were OOM-bound on a 16 GiB GPU becausesplit_kv = 128× largeg × num_heads_q × head_dim_vblew past VRAM), and 1 (split_kv=1, bf16 trans-all) trips only the per-elementrelDiffcheck on near-zero outputs while passing both RMS and abs-diff — out of scope here and tracked separately.Test Coverage
New and updated tests pin each fix above:
Lit — kernel path (IR lowering)
mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlirandtoblockwise_attention_accel_lowering.mlir(updated): add explicitCHECKs for the-inf - (-inf)softmax guard and thescaleFinalOutput0/0guard, so the guarded select is verified in the emitted IR insteadof the old unguarded store.
mlir/test/Dialect/Rock/gridwise_attention_splitkv_noncausal_lowering.mlir(new): pins the non-causal split-KV ceil-division +
endclamp iterationmath, covering the
gemm0MBlocks < splitKVcase.Lit — host reference (
rocmlir-gen)mlir/test/rocmlir-gen/attention-splitkv-host-perrow-mask.mlir(new):pins the per-(batch-head, query-row) valid-split mask under causal +
prefix_offset.mlir/test/rocmlir-gen/attention-splitkv-host-f32-combine.mlir(new):pins the f16/bf16 → f32 combine and cast-back in the reference path.
E2E — GPU numeric (run in PR CI)
mlir/test/e2e/PrAttention{F16,BF16,F32}.toml(new configs): non-causalsplit-KV (basic, trailing/padded splits, GQA) and causal split-KV with
seq_len_q > 1(per-row mask, GQA,prefix_offset), exercising theguards and combine on real hardware.
The sweep-policy (
RMSband) and split-KV memory prefilter changes remaincovered by the existing
mlir/utils/performancepytest suite and theAttentionSweepsjob.Test Result
Submission Checklist