Skip to content

feat: MLA prefill per-group FP8 fused output#7

Closed
carlyou wants to merge 2 commits into
mainfrom
feat--fa4-pergroup-fp8-output
Closed

feat: MLA prefill per-group FP8 fused output#7
carlyou wants to merge 2 commits into
mainfrom
feat--fa4-pergroup-fp8-output

Conversation

@carlyou

@carlyou carlyou commented Jun 17, 2026

Copy link
Copy Markdown
Owner

Warning

Draft / do not merge. Depends on the companion flash-attention change
(vllm-project/flash-attention#151). This PR temporarily pins
vllm_flash_attn at a fork branch to run CI against the new FA commit; the pin must be
reverted to the upstream commit before merge.

Purpose

Wire per-group (block) dynamic FP8 fused attention output through the FA4 MLA prefill backend — the per-group analogue of the static fused-output path. When the o_proj activation quant is per-group FP8 (kFp8Dynamic128Sym / kFp8Dynamic64Sym), the prefill kernel writes FP8 directly with per-group scales (DeepGEMM column-major / UE8M0 / TMA-aligned, or plain row-major) instead of running a standalone post-quant kernel.

Test Plan

  • tests/v1/attention/test_mla_prefill_quant_output.pysupports_quant_output gating + GPU equivalence.
  • tests/compile/passes/test_mla_attn_quant_fusion.py — adds a prefill forward_mha per-group case that exercises the fused path (asserts the FA prefill call receives output_scales and the fused FP8 output matches a bf16 reference).
  • benchmarks/attention_benchmarks — adds a per-group FP8-output mode (--fp8-output-pergroup + configs/mla_fa4_fp8_output_pergroup.yaml).

Test Result

B200 (SM100): test_mla_prefill_quant_output.py 20 passed; test_mla_attn_quant_fusion.py 9 passed (incl. the new prefill case).

Benchmark Result

benchmark.py --config configs/mla_fa4_fp8_output_pergroup.yaml on B200 (DeepSeek-V2-Lite dims, FA4 prefill + DeepGEMM E8M0). post_quant = bf16 attention + standalone per-group quant; fused = FA4 writes FP8 directly. The delta is the post-quant kernel the fused path removes (a small fraction of the full MLA prefill forward).

                       Attention Benchmark Results
┏━━━━━━━┳━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓
┃ Batch ┃         ┃ Batch ┃ post_quant ┃ post_quant ┃    fused ┃   fused ┃
┃ Spec  ┃ Type    ┃  Size ┃   Time (s) ┃    vs Best ┃ Time (s) ┃ vs Best ┃
┡━━━━━━━╇━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩
│ q512  │ prefill │     1 │   0.000031 │     100.9% │ 0.000030 │  100.0% │
│ q1k   │ prefill │     1 │   0.000045 │     101.0% │ 0.000044 │  100.0% │
│ q2k   │ prefill │     1 │   0.000073 │     103.2% │ 0.000071 │  100.0% │
│ q4k   │ prefill │     1 │   0.000136 │     102.2% │ 0.000133 │  100.0% │
│ q8k   │ prefill │     1 │   0.000386 │     100.0% │ 0.000387 │  100.4% │
│ 2q4k  │ prefill │     2 │   0.000283 │     107.0% │ 0.000265 │  100.0% │
│ 4q4k  │ prefill │     4 │   0.000556 │     103.7% │ 0.000537 │  100.0% │
│ 8q4k  │ prefill │     8 │   0.001113 │     104.3% │ 0.001068 │  100.0% │
└───────┴─────────┴───────┴────────────┴────────────┴──────────┴─────────┘
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR
  • The test plan
  • The test results

Comment thread cmake/external_projects/vllm_flash_attn.cmake Outdated
@carlyou carlyou force-pushed the feat--fa4-pergroup-fp8-output branch 3 times, most recently from c6b9687 to e753d50 Compare June 17, 2026 23:00
Comment thread vllm/model_executor/layers/attention/mla_attention.py
Comment on lines +417 to +420
# Per-group quant writes scales into output_scales in place and returns
# it as an extra value; the caller already holds that buffer.
if len(result) == 3:
out, _, softmax_lse = result

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case, should we:

  1. always require the output_scales tensor input
  2. never return it for a consistent response format?
    this applies to both here and interface on FA side.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Claude (Opus 4.8) — Done (both points): output_scales is now input-only and the return is a consistent (out, lse) on both sides. FA _flash_attn_fwd / flash_attn_func / flash_attn_varlen_func no longer append the scales (caller reads its preallocated buffer), which also removed the len(grad_outputs)==3 special-case in FA backward; the vLLM wrapper drops the len(result)==3 branch. Validated on B200: FA per-group 44/44, vLLM unit 20/20. (FA pin → 65525e7.)

@carlyou carlyou force-pushed the feat--fa4-pergroup-fp8-output branch from e753d50 to c79c02f Compare June 17, 2026 23:38
@carlyou carlyou changed the title ci: pin vllm-flash-attn at FA4 per-group FP8 output branch feat: MLA prefill per-group FP8 fused output Jun 17, 2026
@carlyou carlyou force-pushed the feat--fa4-pergroup-fp8-output branch 4 times, most recently from e1a3b1d to 5302244 Compare June 18, 2026 06:51
Wire per-group (block) dynamic FP8 fused attention output through the FA4
MLA prefill backend, the per-group analogue of the static fused-output path.
When the o_proj activation quant is per-group FP8 (kFp8Dynamic128Sym /
kFp8Dynamic64Sym), the prefill kernel writes FP8 directly with per-group
scales instead of running a standalone post-quant kernel.

Scales are produced in the layout the downstream GEMM consumes: DeepGEMM's
column-major / TMA-aligned UE8M0 (power-of-two) scales, or plain row-major
fp32 for the Triton fallback. forward_impl only takes the fused path when the
requested scale layout is self-consistent (col-major == ue8m0 == tma-aligned).

Temporarily pins vllm-flash-attn at the per-group FP8 output branch to run CI
pre-merge; revert once the flash-attention change lands.

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
@carlyou carlyou force-pushed the feat--fa4-pergroup-fp8-output branch from 5302244 to 045b1d0 Compare June 18, 2026 07:00
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
@carlyou carlyou closed this Jun 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant