Sync upstream#141
Merged
Merged
Conversation
* [ROCm Windows] fix triton requirement * pin triton-windows>=3.6.0
* varlen block-sparsity for forward Squashed forward-path varlen support: extends BlockSparseTensors usage to [num_heads, total_m_blocks] / [num_heads, total_n_blocks] layouts, threads cu_seqlens / cu_total_m_blocks / cu_total_n_blocks through the kernel and compute_block_sparsity, and routes through get_curr_blocksparse_tensors and get_total_block_count for shape-aware indexing. * rename cu_total_n_blocks to cu_block_idx_offsets; move cu_total_m_blocks/cu_block_idx_offsets into BlockSparseTensors instead of threading them as standalone parameters; drop the two <tensor>[-1].item() syncs in normalize_block_sparse_config
…ao-AILab#2515) num_splits_heuristic divides num_SMs by total_mblocks, which collapses to 0 when seqlen_q == 0 or batch_size == 0 (e.g. CUDA graph padding or empty microbatches). The existing seqlen_k == 0 early-exit in _flash_attn_fwd does not cover these cases. - Extend the early-exit to also cover total_q == 0, using the same zero-output / -inf-LSE contract. total_q is batch_size * seqlen_q (dense) or q.shape[0] (varlen), so a single predicate handles both code paths. - Add a defensive total_mblocks == 0 guard inside num_splits_heuristic itself so the function is safe in isolation. - Add regression tests covering dense (batch=0, seqlen_q=0) and varlen (total_q=0) paths under both causal and non-causal masks. Fixes Dao-AILab#2503.
* split out varlen batch search into utils * more descriptive name
stack-info: PR: Dao-AILab#2536, branch: drisspg/stack/38
) Summary: Extract the inline `AttentionMask` construction in `FlashAttentionForwardSm100` and `FlashAttentionBackwardSm100` into an overridable `_generate_attention_mask_cls` method. This allows subclasses to inject a custom `AttentionMask` without modifying the base kernel code. For example, a custom attention kernel can override the mask to add a `causal_q_divisor` field for scaling the `row_idx` value. ``` class CustomAttentionMask(AttentionMask): causal_q_divisor: cutlass.Constexpr[int] = 1 @cute.jit def apply_mask_sm100(self, acc_S, m_block, n_block, ...): # Custom causal logic using causal_q_divisor row_idx = (tScS_t2r[0][0] + m_block * self.tile_m) // self.causal_q_divisor ... class CustomFlashAttentionForwardSm100(FlashAttentionForwardSm100): def __init__(self, *args, causal_q_divisor=1, **kwargs): super().__init__(*args, **kwargs) self.causal_q_divisor = causal_q_divisor def _generate_attention_mask_cls(self, window_size_left, window_size_right): return partial( CustomAttentionMask, self.m_block_size, self.n_block_size, window_size_left=window_size_left, window_size_right=window_size_right, bottom_right=self.is_bottom_right, causal_q_divisor=self.causal_q_divisor, ) ``` Test Plan: ``` $ pytest tests/cute/test_flash_attn_fast.py -v ================ 240 passed, 4139 warnings in 984.24s (0:16:24) ================ ``` Reviewers: Subscribers: Tasks: Tags:
…10 (Dao-AILab#2532) * Fix: Remove misleading py_limited_api=cp39 wheel tag for PyTorch extension * Implement dynamic ABI tagging for PyTorch versions Add dynamic ABI tag based on PyTorch version for correct and improved naming of the wheel. * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * Update Python version requirements based on torch metadata * Refactor setup.py for dynamic ABI and CUDA settings Refactor dynamic ABI tag and Python version requirements based on installed PyTorch version and streamline CUDA extension arguments. * Update CUDAExtension compile arguments Restored some accidentally removed content * Update setup.py * Updated setup.py minor fix: cleaned up the comments * Brought back Py_LIMITED_API flag to CUDA extension compilation * Minor fix * Update setup.py for Python version requirements Updated the wheel tag to cp310 and python_requires=">=3.10". --------- Co-authored-by: aw920h <alien@alien.alien> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
* vectorized mask mod application for existing mask mod signatures * add vectorized mask mod examples, get vectorized evaluation and application working * guard sm80/90/120 against mask_vec_size > 2 * thread mask_vec_size thru sm80/90/120 kernel * Small tweaks coverign sm90 * Small tweaks coverign sm90 --------- Co-authored-by: drisspg <drisspguessous@gmail.com>
…b#2572) (Dao-AILab#2590) * Fix bwd postprocess 2CTA gating to include sm_11x The 2CTA gating in flash_bwd_postprocess.py used `arch // 10 == 10`, which only matches SM 10.x (B100/B200/B300) and misses SM 11.x (Thor). The rest of the codebase (e.g. interface.py:549, 563, 834) consistently gates Blackwell-family 2CTA features as `arch // 10 in [10, 11]`. Bring the two postprocess sites in line with that convention. Flagged by @jayhshah in Dao-AILab#2572 follow-up discussion. * Include sm_110 in interface.py Blackwell-family heuristics Three sites in interface.py gate Blackwell-family behavior using `arch // 10 == 10`, which appears inconsistent with the rest of the file's `arch // 10 in [10, 11]` convention (used at lines 549, 563, 834, 974, 1035, etc.): - L533: `q_stage` heuristic for Blackwell forward - L579: `use_dedicated_hd256_kernel` (forward) - L1335: `use_dedicated_hd256_kernel` (backward) The dispatch in `_flash_attn_fwd` already routes both sm_10x and sm_11x through the same `FlashAttentionForwardSm100` / MLA classes, so these gates likely should treat them the same. NOTE FOR REVIEWERS: I'm not certain these are all oversight vs. intentional SM100-only paths. If any of them is intentional, please flag so I can revert just that hunk. The FP8 assert at L480 is left untouched on purpose — its error message reads as deliberate. * Apply ruff format to flash_bwd_sm100.py Pre-existing format drift surfaced by pre-commit. Not in the cute_exclude pattern, so it gets auto-fixed when other files in flash_attn/cute/ are touched in the same commit chain.
* Use is_family_of for sm_90 and sm_103 arch checks Follow-up to Dao-AILab#2572 — apply the same is_family_of pattern to the two remaining range-style arch checks for consistency: - flash_fwd_sm90.py:69 (SM 9.x assert) - flash_fwd_sm100.py:195 (is_sm103 flag) Same semantic narrowing as Dao-AILab#2572: bare-base SMs (sm_90, sm_103) are excluded. These kernels rely on wgmma / UMMA / 2CTA paths that require the a/f PTX variant anyway, so bare-base targets could not compile. * Clarify is_sm103 forward-inclusive semantics is_family_of(sm_103f) also matches any future sm_10x with x > 3, not just sm_103a/f. This was raised in PR review (@ocss884) — adding an inline comment clarifying that this forward-inclusive behavior is intentional: the flag gates ex2 emulation, sm_103 (B300) has fast hardware ex2, and later Blackwell variants in the same family are assumed to inherit it. No code-behavior change.
* Bump aiter submodule commit Co-authored-by: sstamenk <170634954+sstamenk@users.noreply.github.com> * Bump aiter submodule to 3b2e6f48ce97e1d494e8b3f1af5c65f74e304b28 (vllm-project#2) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: sstamenk <170634954+sstamenk@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: sstamenk <170634954+sstamenk@users.noreply.github.com>
…ao-AILab#2594) * Clamp kv_stage to avoid SMEM overflow for small head_dims on SM100 Fixes Dao-AILab#2591. The unbounded formula at flash_fwd_sm100.py:335 ignores per-stage state (mbarriers, sScale, pipeline counters) and yields kv_stage values that overflow the sm_100a 227 KB SMEM cap when head_dim_padded=16 (head_dim in {8, ..., 16}). Repro: hd=8/16 + seqlen >= 256 + bf16 fails with cudaErrorInvalidValue ("launch shared memory exceeds current GPU arch sm_100a allowed. Allocated: 233472 bytes. Max: 232448 bytes."). Clamp kv_stage at 32. Surgical to the broken case: the unbounded formula maxes at 26 stages for head_dim_padded >= 32, and the 2CTA gate at interface.py:572 restricts 2CTA to hd_padded in {128, 192} (both no-op), so the clamp only fires at hd_padded in {8, 16}. Verified across 24 configs (hd in {8,16,32,64,96,128} x causal in {T,F} x seqlen in {128,2048}) on B200 with max_err vs torch SDPA <= 0.0078. * Add test_flash_attn_small_head_dim regression test The main test_flash_attn_output parametrizes d over {64, 96, 128, 192, 256} and never exercises head_dim < 64, even though _validate_head_dims accepts head_dim >= 8 for sm_100/110. That coverage gap let the SMEM-overflow bug in Dao-AILab#2591 slip through. This focused test covers d in {8, 16, 32} x causal x seqlen in {128, 2048}. The seqlen=2048 cases push q_stage 1->2 (the actual bug trigger); the seqlen=128 cases also exercise the q_stage=1 boundary that fits on main today but is structurally adjacent. d=32 serves as a canary against any future tighter kv_stage clamp regressing it.
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
LucasWilkinson
approved these changes
May 29, 2026
This was referenced May 30, 2026
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
7 tasks
cerisier
pushed a commit
to zml/flash-attention
that referenced
this pull request
Jun 24, 2026
* Fused Bwd (vllm-project#137) * Fused with Good perf and stride fixed Fix fused bugs isolate failing case fix bug bring back test cases rm split impl in fused use exp2 is global variable now try oom fix save make fused the default limit to reproduce failure return default to split fix head size bug use exp2 back to true * new grid * BLK_SLICE_FACTOR = 1 * add tflops * new commit * test in parrallel * strides added by jusson * disable alibi * fix bugs again * default to fused * add bwd options for varlen * backend filter * default to jingning and batch 4 * best fwd config * fix TRITON_PRINT_AUTOTUNING flag bug * tune * Tuning fwd prefill * add if else * use flag * Minor mask fix * FLIP GRID * use best config for default * print when autotuning * test bfloat16 * fix k and v stride bugs * skip bfloat16 * test kvpacked * disable internal tests * pick default config based on arch * Add alibi in the new bwd kernel (vllm-project#139) * enable alibi for jinging kernel enable alibi for jinging kernel match * save bad configs * fix alibi and causal bug * disable autotune by default * auto tune when benching is good * set best config * remove env var * Update amd_tests.yml * upgrad to triton==3.3.0 * increase shm * use 64 x 64 for now * save * handle 1d alibi * Add fp8 to fused kernel (vllm-project#140) * fp8 stuff find test case compute delta fp8 basic fp8 config passing non causal path works * isolate bad case * fix fp8 bug * didnot fix fp8 bug * back to failing test * fp8 tests passing * skip * skip ref tests --------- Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com> * head, seq, batch (vllm-project#141) * Fix keys (vllm-project#144) * save * rm keys * fix keys * use GHA_RENDER_DEVICES * normal docker * Pad LSE (vllm-project#148) * add round multiple * fix fwd * backward fix * use rounded lse flag * passing ROUNDED_LSE * default is new rounded mode * rename to fused_atmoics and fused_no_atomics * add test for torch_compile * add varlen torch compile test * add old one kernel for ref * fix varlen mismatch bug * fix shape issue in varlen but mismatch * sync torch compile kernel launch * simple varlen test * add debug code * rm old * ignore old impls * DEBUG flag works in interface only * ref uses the righ shape for lse * rm oldest bwd kernel * fix typo * fix varlen bug * fix bug. Get info from q for now * simple shape and stride checkout * add more tests * test kvcache * kvcache safe * match case * fix segfault due to bad return_softmax * run bench * run seperate for the main functions * just output benchmark * default csv format and time stamp files * non verbsoe bench * Sliding Window Forward (vllm-project#151) * Compress SWA work test case set up debug inputs add fwd ref one mask ref fwd first pass save ref doesnot work for bigger seqlens save new version some causal cases failing found bad cases working new attn new atten works new attn_fwd works reorg n_extra_tokens use seqlen_delta_qk ref fwd works add sliding window to bwd ref test kvcache decode ref work with everything except sliding window add debug code for 12 failing sliding window cases for decode attention_decode_forward_ref_impl mostly works except for alibi fix alibi in attention_decode_forward_ref_impl ref works with normal, varlen & kvcache move stuff around figure out masking old attn inner two inner functions remove load_fn do Lk - Lq like ref unify IS_CAUSAL code in epilogue clean up add args rm inference stuff simplify compute_masking simpler compute mask stub out returning front masking variables remove pointer pass compute ptrs inloop compute block min and max window stub inside inner mask loop trying to use attn_fwd_mask causes issues fix compiler bug when front masking gen specifc types add sliding window and debug statements use identity for v add more taste cases add comments save use k_max_token for clarity disable debug configs basic NON-CAUSAL SLIDING WINDOW non causal sliding window works on the all the shapes non sliding window working in fwd clean up fused bwd seperate old fwd_prefill move configs to utils.py * fix bwd ref bug * skip local cases so that fa output * no sliding window causal green * add backward test skip for sliding window * clean reduce in fwd_kvcache. no is_CASUAL branching * add kvcache masking * kvcache working * fix some bugs in test.py * clean up * Fix Device Segfault (vllm-project#152) * Compress segfault work fix backward segfault rework offset ignore .profile ignore .analysis save * assert the kernel launch device and tensor devices are the same * fix failing asserts * add asserts to fwd * Fix SDMASK bug * Log triton, torch and fa version * Fix fp8 import issues * fix docs (vllm-project#154) * Sliding Window block classification logic (vllm-project#155) * add aiter code * remove aiter stuff * sliding window non causal masking works * causal and sliding window block masking * extract common * clean up typo * helper for swa * ignore .amd * fix last block bug * Enable FA V3 (Dao-AILab#157) * Compress PA work narrow pa test ref works on most cases inplace ref with new_kv inplace paged attention add pa ref save pa basic paged works save fix swa + causal in pa. Also new_kv only on pa path passing build fa v3 import interface from fa v3 copy fa tests use v3 api clean up rename to match old test support different head sizes remove fp8 basisc passing v3 cases test_flash_attn_varlen_output v3 working isolate bad case for kvcache case passing save use decode is seqused/ cacheseql is given use decode if not varlen basci kvcache v3 working kvcache enable more cases detect kvcache case if seqused_q is non and sequese_k is not None skip failing test find fp8 failing case mha fp8 works fix fp8 MQA/GQA bug clean up more clean up clean up more don't need fp8 dead code remove train code with fp8 stuff fp8 working in kvcache paged + fp8 seems to be working new_kv allowed * clean up * skip hopper race test * clean up more * fix paged + alibi * similar inner paged api * unify _attn_fwd_inner * AITER integration (Dao-AILab#159) * clean up v2 interface * assert fp8 scale shapes * rotary working * move rotary to impl layers * remove einops * enable rotarry in v3 * create interface * fix descale assert * unify bwd * lint from aiter * clean fp8 api * add api change * assert shapes for v2 * remove ref and bench.py * remove metadata class and clean up * bwd_prefill * one bwd.py * rename * lint * add bwd_change (Dao-AILab#156) * Tune FP8 Perf (Dao-AILab#160) * check cu count for gfx942 * create get_cu_count * update repo root * update forward tune * clean up load * use float8_e4m3fnuz * save * show bwd mode * recommend fp8 * use torch.float32 for fp8 kernel * add both best fp16 and fp8 config * tune fp8 backward * descale factors should be b, hk * fp8 bwd working on all primus configs * tune bwd configs * fa v3 tests passing * better warning * clean up bwd launcher * v3 passing * tune more * improve perf * clean up * lint * clean * start tuning gfx950 * tune non causal path * fix bug * save * Skip configs where BLOCK_M2 % BLOCK_N2 != 0 * skip more * stop tuning * fix varlen bug * fix dropout & causal/swa segfault * update the to machine new changes * save * fix more bugs * remove random seed * clean up * update readme * print tensor stats for debug * disable sliding window tests * add rdna configs * fix k partial bug * fix block_size_n bug * fix type check bug --------- Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com> Co-authored-by: Tianxing Wu <tianxing.wu@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Sync with upstream