Skip to content

Sync with upstream#134

Merged
LucasWilkinson merged 47 commits into
vllm-project:mainfrom
MatthewBonanni:merge_upstream
May 12, 2026
Merged

Sync with upstream#134
LucasWilkinson merged 47 commits into
vllm-project:mainfrom
MatthewBonanni:merge_upstream

Conversation

@MatthewBonanni

@MatthewBonanni MatthewBonanni commented Apr 27, 2026

Copy link
Copy Markdown
Member

NOTE: Does NOT include Dao-AILab#2448 (explictly reverted). This is because it introduces another template instantiation which doubles the compiled kernel variants for the non-causal split-KV path

drisspg and others added 30 commits April 1, 2026 15:14
stack-info: PR: Dao-AILab#2414, branch: drisspg/stack/32
* [CK_TILE] Fix NaN for FMHA BWD When seq_q=0

* Add regression test for NaN in dK/dV with zero-length Q subsequence in flash attention BWD

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Add additional assertions for gradients in test_flash_attn_bwd_varlen_seqq_zero

---------

Co-authored-by: Ding, Yi <yi.ding@amd.com>
Co-authored-by: Yi DING <andy-ding@outlook.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…o-AILab#2393)

* Add FA4 CI: GitHub Actions workflow with Apptainer on B200 self-hosted runner

- GitHub Actions workflow (ci.yml) triggering on push to main
- Two-pass test strategy: kernel compilation (FakeTensorMode) + GPU execution
- Pulls Docker image from Docker Hub (togethercomputer/training-performance)
  and converts to Apptainer SIF, cached by image tag on the runner
- CI_WORK_DIR repo variable for configurable large-disk path on the runner
- Shared Python driver (tools/ci/run_fa4_ci.py) used by both CI and local runs
- Docker image build scripts and Apptainer SIF definition (tools/ci/)
- CI_SETUP.md setup guide covering runner registration, secrets, and migration

* Upgrade CI to CUDA 13.0 (cu130) torch nightly for B200/SM100 support

* Address CI review: security, shell quoting, and config cleanup

- Add permissions: contents: read to ci.yml to restrict GITHUB_TOKEN scope
- Pin FA4_IMAGE by digest for reproducibility; bump to flash-attn-cu13.0-26.04.01
- Move FA4_TEST_FILTER to ci.yml and thread through action inputs
- Fix shell quoting in run_fa4_ci.py (shlex.quote/join for all paths and env values)
- Drop venv fallback mode; script is now apptainer-only with clear error on missing FA4_SIF
- Drive Dockerfile deps from pyproject.toml via full cute/ copy (single source of truth)
- Widen Docker build context to repo root so pyproject.toml is accessible
- Add pytest-xdist to pyproject.toml dev extra (was missing, needed for -n flag)

* Fix stale --venv flag in test_ci_local.sh; add compile-workers input to gpu-test action

- Remove --venv arg from test_ci_local.sh (dropped in 360e6d7, caused immediate argparse failure)
- Add compile-workers input to action.yml (default 64) so Pass 1 compilation runs in parallel

* [CI] Add FA4 CI: SIF cache cleanup and trim setup docs

- Delete stale SIF files after pulling a new image to prevent unbounded
  disk growth on the self-hosted runner
- Replace AI/CI_SETUP.md with a lean tools/ci/README.md co-located with
  the CI scripts; drop one-time runner setup steps, keep maintainer-relevant
  bits (credentials, image update, test expansion, FA2 isolation)
* CI: fix ruff format, Apptainer pull, add cu129/cu130 auto-selection by CUDA version, trigger on main and johnson/ci-fix branches

* CI: trigger on main and ci-fix branches
* Initial FA-2 aiter Triton Windows build support

* minimize diff

* bump commit

* bump commit

* minimize diff

* bump commit

* bump aiter submodule

* bump aiter submodule to merged Dao-AILab#2433

* fix: guard distributed.py fallbacks with hasattr for Windows

---------

Co-authored-by: 0xDELUXA <djernovevo@gmail.com>
…ao-AILab#2433)

The SM100 2-CTA backward kernel does not properly handle block_sparse_tensors.
When block sparsity is combined with 2-CTA mode, the kernel hits an assertion:
  'AssertionError: 2-CTA mode does not support block sparsity'

This fix adds block_sparse_tensors to the disable_2cta condition in the backward
path, forcing the 1-CTA kernel when block sparsity is active. The 1-CTA backward
kernel already supports block_sparse_tensors correctly.

Without this fix, any backward pass using block_sparse_tensors on SM100 (B200/GB200)
with head_dim >= 128 will crash with the above assertion.
…ao-AILab#2441)

* add mla 2cta with topk sparsity support

* add tma store O

* add clc option; performs worse than single tile

* enable clc for topk gather

* add producer tails

* add mla dsa to interface

* ruff format

* use tma store for varlen

* decouple sm stats from scale for smem

* add varlen tests

* credit monellz for kernel dump attributes utility

* add docstring for optional args, change default value of topk_indices_maybe_oob to None

* give default vals for new args in interface

* more rigorous tests; fix race condition on smem for rowmax

* add bandwidth calc and qv to benchmark script

* refactor interface per suggestions

* return more Nones for gradients
* fix outstanding ruff check and exclude flash_fwd_mla_sm100.py from ci

* add fmt comments for ruff
…ion (Dao-AILab#2461)

* Disable 2CTA forward non-causal on CUDA 12.9 to work around codegen regression

CUDA 12.9 has a codegen issue that causes ~18% slowdown for 2CTA
forward non-causal (hdim=128: 1280 vs 1542 TFLOPS). This is fixed
in CUDA 13.x. Auto-disable 2CTA when CUDA 12.9 is detected.

Users on CUDA 13.x are unaffected. The manual `FA_DISABLE_2CTA=1`
override continues to work regardless of CUDA version.

* Disable 2CTA forward non-causal on all CUDA 12.x (not just 12.9)
* fa4 benchmark and correctness test

* initial fa4 fp8 e4m3 support

* fix tmem p overlap/signaling for fp8

* kv_stage 4 for fp8

* fp8 e5m2 support

* update benchmark

* fix lint

* descale named tuple

* compile time gating

* fix lint

* fix rescale bug

* fp8 register tuning

* defensive restore default rescale threshold to 0, add fp8 override

* load effective descales helper

* uint8 workaround for fp8 note
Both flags override the kernel's internal heuristics so users can
benchmark with forced settings instead of editing the script. Defaults
are unchanged (num_splits=0 and pack_gqa=None both mean "kernel auto").

Useful for A/B comparisons. Example: MLA decode bs=32 seqlen_kv=65536
num_splits=2 on B200 gives 0.75 ms with --pack-gqa true and 50.32 ms
with --pack-gqa false -- a 67x gap that confirms pack_gqa is the
deciding factor for long-context MLA decode parallelism.

Note: for non-MLA GQA with num_splits>1, interface.py may still force
pack_gqa off regardless of --pack-gqa true (pending a separate fix).
For MLA (qv is not None), the flag is honored. No new code path is
exposed -- the flag only makes existing kernel options reachable
from the CLI.
…ab#2448 regression) (Dao-AILab#2476)

PR Dao-AILab#2448 added `int num_splits = 0` as a trailing positional arg of
`mha_varlen_fwd` in csrc/flash_attn/flash_api.cpp but did not update
the Python wrapper nor the pybind11 binding to expose that default.
Because the binding at the bottom of flash_api.cpp is just
`m.def("varlen_fwd", &mha_varlen_fwd, ...)` (no `py::arg("num_splits")
= 0`), pybind11 does not honour the C++ default value, so every call
from `_flash_attn_varlen_forward` now fails with:

    TypeError: varlen_fwd(): incompatible function arguments.

This patch plumbs `num_splits` through `_flash_attn_varlen_forward`
(and its fake counterpart, to keep the torch.custom_op schemas in
sync) and passes it to `flash_attn_gpu.varlen_fwd`, restoring the
previous behaviour and exposing the knob to Python callers.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Issue
-----
flash_attn_func crashes when K/V is built with physical seqlen dim 0
(e.g. vLLM CUDA-graph capture with an all-padding batch):
- causal=False: PTX IllegalInstruction (async CUDA error from kernel)
- causal=True:  host SIGFPE before kernel launch

Root cause
----------
seqlen_k == 0 violates two downstream invariants:

1. TMA descriptor invariant (K physical seqlen > 0).
   On non-causal, StaticPersistentTileScheduler passes host-side setup
   and launches the kernel; the first TMA load over the 0-length K
   tensor goes OOB -> PTX IllegalInstruction.

2. LPT L2-swizzle heuristic invariant (size_one_head > 0).
   On causal, SingleTileLPTScheduler.Params.create in tile_scheduler.py
   computes
       size_one_head = seqlen_k * (headdim + headdim_v) * element_size
   which is 0, then evaluates "size_l2 // size_one_head" -> host
   integer divide by zero (SIGFPE) before the kernel launches.

Fix
---
Early return in interface._flash_attn_fwd when seqlen_k == 0 — zero the
output, fill LSE with -inf, skip kernel launch. Guards both invariants
at the boundary before either scheduler runs.

Only affects the non-varlen path. Varlen is unchanged: its K tensor has
physical seqlen > 0, and per-batch empty slots are already handled
correctly by the kernel's fake-iteration path.

Regression test: tests/cute/test_flash_attn.py::test_flash_attn_seqlen_k_zero
covers both crash paths across seqlen_q in {1, 64, 128, 256} and d in
{128, 192}.

Co-authored-by: yunzhongOvO <lzy21@mails.tsinghua.edu.cn>
* Fix script

* Fix lower right causal bug and add clc parse viewer
Co-authored-by: wangziheng <wangziheng@bytedance.com>
…ao-AILab#2412)

* [Feat] Support flash-attention head_dim 256 in CuteDSL

This PR adds head_dim=256 support to the FA4 FlashAttention implementation built with the CUTLASS CUTE DSL.

* Forward: uses a 2-CTA design and introduces a new pipeline to better hide memory latency; includes a TMEM-based design for intermediate storage.
* Backward: uses a 2-kernel approach and a 2-CTA design for the backward path.

No API changes for existing head dimensions. But coding style should be adjusted step by step.

This feature is authored by Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms,
Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin.

* Fix ruff lint errors in head_dim=256 changes

Apply ruff check --fix and ruff format to bring the new hd256 files in
line with the project's pre-commit config (flash_attn/cute/*.py, minus
the excluded set in .pre-commit-config.yaml).

Manual fixes:
* mask.py: `Boolean(mask)` -> `cutlass.Boolean(mask)` (F821; other call
  sites in the file already use the qualified form).
* sm100_hd256_2cta_fmha_backward_dkdvkernel.py: drop duplicate
  `SM100_TMEM_CAPACITY_COLUMNS = 512` local definition that shadowed the
  import from tile_scheduler (F811); the values were identical.
* sm100_hd256_2cta_fmha_backward.py: both branches of the
  try/except ImportError imported the same two kernels once
  make_cotiled_copy/warp_reduction_sum were removed as unused; collapse
  to a single unconditional import.

Auto-fixes: 41 unused imports (F401) + 2 f-strings without placeholders
(F541) removed across sm100_hd256_2cta_fmha_{forward,backward,backward_dqkernel,
backward_dkdvkernel}.py, tile_scheduler.py, mask.py. ruff format
reformatted the 8 in-scope files touched by this PR.

Verified: `ruff check` and `ruff format --check` both clean on
flash_attn/cute/ (minus the pre-commit exclude list). Forward + varlen
smoke tests on B200 pass (150 passed, 35 skipped, 0 failed across
non-causal MHA, causal MHA, MQA/GQA, and varlen MHA at d=256).
Backward kernels not yet test-exercised; change is imports/whitespace
only and the kernels parse cleanly.

---------

Co-authored-by: Johnsonms <lizhaofu@gmail.com>
…Lab#2487)

Follow-up polish on the freshly-merged hd256 feature (Dao-AILab#2412), sourced
from Copilot AI review comments on the original PR.

interface.py: drop duplicate `from cutlass import Int32` (already imported
at line 17) and unused `from flash_attn.cute.mask import Sm100MaskEnum as
MaskEnum`, which is never referenced.

mask.py: remove two dead `tidx, tidy, tidx = cute.arch.thread_idx()` lines
in Sm100FusedMask.apply_mask and apply_mask_via_causal_local. Neither
`tidx` nor `tidy` is ever read in the function bodies; these calls are
leftover debug scaffolding (consistent with the commented-out
`cute.printf("tidx = ...")` lines nearby at 490/525/665).

test_flash_attn.py: drop the stray "/SM110" from two TODO comments. The
skip guard is `IS_SM100` only (capability major == 10), and the hd256
2CTA kernel path is only taken when `arch // 10 == 10` (interface.py:573,
1310), never on SM110 (major == 11).
)

- interface.py: remove extra space in is_deepseek_mla_absorbed_shape condition
- softmax.py: fix comment typo "my" -> "may" in apply_score_mod_inner
- mask.py: fix docstring typo "conventio" -> "convention" in backward mask
- flash_bwd.py: clarify why hdim_multiple_of=32 differs from fwd's 16

Co-authored-by: watt <watt@micous.com>
…ILab#2506)

* simplify blocksparse tensors interface in flash_attn_func

* remove blocksparse kwargs
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
stack-info: PR: Dao-AILab#2508, branch: drisspg/stack/35
Johnsonms and others added 11 commits April 27, 2026 15:11
…len} (Dao-AILab#2483)

Expand the two MLA-absorbed tests with a few low-risk axes and unblock
a stale nheads parametrize. Axes stay inside what the MLA-absorbed
kernel supports per the guards in `interface.py:633-648` (the existing
skips for softcap/learnable_sink/local/paged/fp8/split_kv/score_mod/
mask_mod under `qv` are untouched).

Non-varlen `test_flash_attn_mla_absorbed`:
  - Comment out the hardcoded `nheads = 128` override at line 1827
    so the `nheads [16, 128]` parametrize is actually respected. The
    override was unintentional and the `nheads=16` path is meant to
    check the kernel works under head sharding (it's already exercised
    that way in the varlen test). The existing
    `if kv_sparsity and nheads != 128: pytest.skip()` at line 1816
    is left in place — that's a separate intentional guard.
  - `gather_kv_length: [2048] -> [1024, 2048]` — exercise a second
    value on the `kv_sparsity=True` path.
  - Add seqlen `(4096, 4096)` — the table had `(2048, 2048)` and
    `(1, 8192)` but nothing with both Q and K in the mid-prefill
    range.
  - Assert `lse` is NaN-free. LSE is consumed by backward and by
    split-KV output combine; NaNs there aren't caught by the existing
    `out` diff check and silently corrupt any downstream use.

Varlen `test_flash_attn_mla_absorbed_varlen`:
  - `gather_kv_length: [2048] -> [1024, 2048]` (same rationale).
  - LSE NaN check, gated on `unpad_q`. The padded path can
    legitimately contain uninit tail beyond `seqused_q`, so the check
    only runs on the packed-unpad path where every LSE slot is live.

Context on a narrower scope than the first draft: tiny seqlens
((1,1), (1,3), (2,1)), `zero_lengths_q/k` with True, and
`add_unused_qkv=True` were tried but triggered order-of-magnitude
output divergence on MLA varlen (FA4 returning zeros where the
reference has real values; `Output max diff ~5.25` against a
`Pytorch max diff` of ~0.03). The combinations that fire it are
`add_unused_qkv=True` with short K (seqlen_k=1) or with
`kv_sparsity=True`, and `zero_lengths_k=True` across most unpad/
varlen_mode combos. Those look like real kernel issues orthogonal
to a coverage expansion, so they're left out of this PR and can be
filed separately.
stack-info: PR: Dao-AILab#2509, branch: drisspg/stack/36
…rformance gain) (Dao-AILab#2488)

* [hd256] Improve forward kernel with exp2 FMA emulation

Rebased cherry-pick of `e122e67` from `Johnsonms/exp2-emu-hd256` on top
of merged main (hd256 PR Dao-AILab#2412, `27b4eb9`). The original branch was
based on a pre-merge snapshot; the other five commits in that branch
were absorbed into the squash-merge, leaving this one novel change.

## Change

Replace a fraction of hardware `exp2` (SFU) instructions with a
polynomial FMA emulation (`ex2_emulation_2`) in the softmax P-tile
computation. The key insight: SM100's SFU throughput is a bottleneck
for hdim=256 due to the large tile size. By substituting 3 out of
every 4 `exp2` calls (`ex2_emu_freq=4`, `ex2_emu_res=3`) with packed
FMA polynomial approximation, we shift pressure onto the underutilized
FMA pipeline.

Additionally, the P write-slot acquisition is moved earlier to overlap
any pipeline stall with the `exp2` compute.

Kernel-only change; no API change. Backward is untouched.

## Validation (this PR vs `origin/main` @ `b21e204` — includes Dao-AILab#2412 hd256 base and Dao-AILab#2487 post-merge cleanup)

B200, bf16, hdim=256, MHA (32:32) and GQA (32:2), 3-run means, locked
clocks @ 1755 MHz, seqlens 4k..128k.

### FWD delta vs `origin/main` (TFLOPS mean, 3 runs each)

| seqlen | causal | MHA 32:32 | GQA 32:2 |
|-------:|:------:|:---------:|:--------:|
| 4k     |   F    |  -0.2%    |  +0.7%   |
| 8k     |   F    |  +0.4%    |  +0.3%   |
| 16k    |   F    |  +0.7%    |  -1.0%   |
| 32k    |   F    |  +2.3%    |  -0.3%   |
| 64k    |   F    |  **+5.4%** |  **+5.2%** |
| 128k   |   F    |  **+7.4%** |  **+7.3%** |
| 4k     |   T    |  +0.3%    |  +0.9%   |
| 8k     |   T    |  +0.8%    |  +0.9%   |
| 16k    |   T    |  +0.8%    |  +1.1%   |
| 32k    |   T    |  **+5.2%** |  -0.4%   |
| 64k    |   T    |  +1.1%    |  +1.1%   |
| 128k   |   T    |  **+3.4%** |  **+2.6%** |

- 19 of 24 cells positive; 4 slightly negative, all within the
  batch-quantization noise band that `origin/main` itself already
  showed in our 3-run regression sweep.
- Peak gain **+7.4%** (MHA 128k non-causal) — exactly where softmax
  SFU pressure is worst, consistent with the theory above.
- Averages: **MHA fwd +2.3%**, **GQA fwd +1.5%**.

### Correctness smoke

`pytest tests/cute/test_flash_attn.py::test_flash_attn_output -k
"256-False-0-0.0-False-False"` on B200:
**78 passed, 78 skipped, 0 failed** — identical pass/skip count to
`origin/main`.

## Caveat

Exp2 FMA emulation introduces small numerical differences vs hardware
`exp2`. The existing test tolerances accept the delta.

* [hd256] Wire ex2_emu params through _TUNING_CONFIG with tuned values

The exp2 emulation knobs (ex2_emu_freq, ex2_emu_res, ex2_emu_start_frg)
and softmax register counts for the hd256 forward kernel were hardcoded
in BlackwellFusedMultiHeadAttentionForward.__init__, invisible to the
central _TUNING_CONFIG table used by all other kernel configs.

- flash_fwd_sm100.py: add hd256 entries to _TUNING_CONFIG (causal and
  non-causal; always 2cta, no sm103 variant). New ex2_emu_res field is
  hd256-specific; existing entries are unaffected. hd256 uses a fixed
  num_regs_other=32 (not derived from the 512-budget formula).
- sm100_hd256_2cta_fmha_forward.py: replace hardcoded self.* assignments
  with a _TUNING_CONFIG lookup.

Tuned values (B200, bf16, locked clocks): freq=14, res=6, start_frg=0
for both causal and non-causal. The inner loop steps k by 2, so k%freq
only takes even values; freq=14/res=6 gives ~43% emulation (3 out of 7
even k%14 steps), replacing the previous 50:50 split (freq=4/res=3).
* SM90 FA4 QuACK 0.4 Compatibility

* Require QuACK>=0.4
Use a temporary response file when setuptools emits an oversized link.exe command so Windows builds with many object files can complete.

Made-with: Cursor
…AILab#2527)

flash_attn_varlen_func now takes qv as the 4th positional parameter, which
shifted cu_seqlens_q/k and max_seqlen_q/k by one slot in the existing
positional call and caused 132 AttributeError failures at interface.py:370
(`'int' object has no attribute 'shape'`). Switch the call to keyword args.
…AILab#2510)

* [Cute,Bwd,Sm90] Fix determinism for GQA, port Sm100 approach in

* add tests
…ao-AILab#2495)

- Add ex2_emu_res as a third sweep dimension for hd256 keys; skip Phase 2
  (num_regs_other is fixed for hd256).
- Upgrade clock handling to lock/unlock: setup_clocks() locks at startup
  and unlocks via atexit, with --lock-clocks/--no-lock-clocks flag.
- Fix nvidia-smi GPU targeting to respect CUDA_VISIBLE_DEVICES.
…ao-AILab#2497)

* [FA4][hd256] Coalesce LSE/dpsum per-K-iter loads in dkdv

Switch LSE and sum_OdO GMEM→SMEM loads from scatter indexing
(thread_idx*N + i) to warp-coalesced indexing (thread_idx + i*32).
Applied at all three load sites in the dkdv accumulation loop.

* [FA4][hd256] TMA bulk-store epilogue for dK/dV and dQ

dkdv: replace per-thread scattered GMEM stores with a cooperative
TMA bulk-store path. Both warp-groups write into a CTA-shared (64, 256)
SMEM staging tile aliased onto the dead sP+sdST buffers; WG 0 fires
4x(64, 64) cp.async.bulk stores. Per-thread store retained for varlen.

dq: single-stage TMA bulk store aliased onto the consumed sdO buffer.
Per-thread store retained for varlen.
…Lab#2489)

* [hd256] Add TMA paged KV support to SM100 2CTA forward kernel

Rebased cherry-pick of `49fe257` from `Johnsonms/paged-kv-hd256` on top
of merged main (hd256 PR Dao-AILab#2412, `27b4eb9` + post-merge cleanup Dao-AILab#2487,
`b21e204`). Original branch was based on a pre-merge snapshot; its
base commits were absorbed into the squash-merge.

## Change

Adds paged KV support to the SM100 hd256 2CTA forward kernel. The paged
path reuses the dense TMA load path — logical KV blocks are remapped to
physical page indices through the page table at load time, so each page
maps to exactly one TMA tile.

**Constraint:** `page_size` must equal `tile_n = 128`.

### `flash_attn/cute/sm100_hd256_2cta_fmha_forward.py`

- Conditional K/V tensor layout in `__call__`: dense
  `(s_k, d, ((h_r, h_k), b))` vs paged
  `(page_size, d, h_k, num_pages)` for K (and transposed for V).
- Conditional K/V TMA setup in the load warp: dense uses
  `domain_offset` + batch indexing; paged uses `head_kv` slicing and
  keeps `num_pages` as the outer mode for per-load `page_idx` lookup.
- Conditional per-load `page_idx`: K uses mode-2 subtile + mode-3 page;
  V uses mode-1 page.
- Plumb `mPageTable` + `max_seqlen_k` through the kernel signature.
  `seqlen_k` in each of the 4 warp sections now uses `max_seqlen_k`
  for the paged path.
- Store `qhead_per_kvhead` on `self` and derive `head_kv_coord` via
  integer divide (matches the `flash_fwd_sm100` convention for
  contiguous GQA grouping).
- Relax `mPageTable` / `paged_kv_non_tma` assertions.

### `flash_attn/cute/paged_kv.py`

- Extract `_flatten_smem_sm100` / `_copy_row_async` helpers from
  `load_KV` — pure refactor, no behavior change for existing callers.

### `tests/cute/test_flash_attn.py`

- `test_flash_attn_paged_hd256_sm100_tma`: bit-exact vs dense varlen
  reference + determinism check, parametrized over `seqlen_q`.
- `test_flash_attn_paged_hd256_sm100_tma_gqa`: same check for GQA with
  `nheads_kv in {2, 4, 8}` — exercises `qhead_per_kvhead > 1`, which
  a modulo-aliasing bug would fail.

## Validation (this PR vs `origin/main` @ `b21e204`)

### Correctness smoke

`pytest tests/cute/test_flash_attn.py` on B200, filter combines the
6 new paged tests with the existing d=256 dense subset:

```
-k "paged_hd256_sm100_tma or (test_flash_attn_output and 256-False-0-0.0-False-False)"
```

Result: **84 passed, 78 skipped, 0 failed** in 2 min — 78 from the
dense d=256 subset (identical pass/skip count to `origin/main`) and
**6 from the new `paged_hd256_sm100_tma[_gqa]` tests**.

### FWD perf delta vs `origin/main` (TFLOPS mean, 3 runs each)

B200, bf16, hdim=256, locked clocks @ 1755 MHz.

| seqlen | causal | MHA 32:32 | GQA 32:2 |
|-------:|:------:|:---------:|:--------:|
| 4k     |   F    |  +0.2%    |  +0.3%   |
| 8k     |   F    |   0.0%    |  -0.1%   |
| 16k    |   F    |  -0.2%    |   0.0%   |
| 32k    |   F    |  -1.0%    |  **+2.3%** |
| 64k    |   F    |  **+2.1%** |  +0.8%  |
| 128k   |   F    |  -0.8%    |  -1.7%   |
| 4k     |   T    |  +0.2%    |  +0.2%   |
| 8k     |   T    |  +0.1%    |  +0.1%   |
| 16k    |   T    |  +0.3%    |  +0.1%   |
| 32k    |   T    |  **+3.0%** |  -1.2%  |
| 64k    |   T    |  -0.3%    |  -0.1%   |
| 128k   |   T    |  +0.1%    |  -0.7%   |

- **22 of 24 cells within ±2%.**
- Two `> 2%` outliers are both **positive** and in the batch-
  quantization noise zone that `origin/main` itself showed run-spread
  in during our 3-run baseline sweep — not regressions.
- **Aggregated means: MHA +0.31%, GQA +0.00%.**
- Paged-KV path isn't exercised by `benchmark_attn.py` (which uses
  contiguous KV); dense-path perf parity is the regression-critical
  property and is preserved.

## Caveat

- **page_size == tile_n == 128 is a hard constraint.** Callers that
  want a different page size will need a separate path.
- The paged-KV path itself is correctness-tested by the two new
  `paged_hd256_sm100_tma` tests (bit-exact vs dense reference, with
  and without GQA). Perf of the paged path was not benchmarked.

* [hd256] Address review comments on TMA paged KV

- interface.py: assert max_seqlen_k % page_size == 0, page_table sized to
  exact seqlen, and page_table fully contiguous for hd256 paged path
- tests: add shuffled-page-table test; allclose for correctness checks
- paged_kv.py: trim _flatten_smem_sm100 docstring to one line
- sm100_hd256_2cta_fmha_forward.py: cut multi-line comment blocks

* [hd256] Prefetch page indices and eliminate redundant V page reads in TMA paged KV

K and V for the same KV block share the same physical page, so the
separate mPageTable read issued for V was always fetching the same
index already loaded for K.  Carry k_page_idx forward as
v_page_idx_prev and drop all V-side page-table reads.

Additionally, issue the next K page read immediately after K TMA
dispatch (while V TMA is being issued) so the ~25-cycle L2 latency
is hidden behind in-flight work.  Together these changes halve the
number of scalar GMEM page-table reads per kernel call.

NCU (B=4 S=8192 H=8 D=256):
  executed instructions  −0.4 %
  L2 elapsed cycles      −2.2 %  (overhead vs dense: +3.5 % → +1.2 %)

Benchmark — paged vs. dense latency overhead
GPU 0 locked 1965 MHz, non-causal, bf16, page_size=128:

  seqlen    B   before    after    delta
  ------   --   ------   ------   ------
    1024   32   +0.2 %   +0.4 %   −0.2 %
    2048   16   +0.4 %   +0.4 %    0.0 %
    4096    8   −0.1 %   +0.2 %   −0.3 %
    8192    4   +4.9 %   +1.8 %   −3.1 %
   16384    2   +7.7 %   +5.2 %   −2.5 %
   32768    1   +4.9 %   +0.4 %   −4.5 %
   65536    1   +0.9 %   −1.8 %   −2.7 %

No effect at short sequences (TMEM-bound); −2.5 to −4.5 % overhead
reduction at medium-to-long sequences where page-table reads were on
the producer warp's critical path.
@MatthewBonanni MatthewBonanni force-pushed the merge_upstream branch 2 times, most recently from 5b1f58d to f25ff07 Compare May 6, 2026 13:41
@LucasWilkinson LucasWilkinson merged commit bce2942 into vllm-project:main May 12, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.