feat: MLA prefill per-group FP8 fused output#7
Closed
carlyou wants to merge 2 commits into
Closed
Conversation
carlyou
commented
Jun 17, 2026
c6b9687 to
e753d50
Compare
carlyou
commented
Jun 17, 2026
Comment on lines
+417
to
+420
| # Per-group quant writes scales into output_scales in place and returns | ||
| # it as an extra value; the caller already holds that buffer. | ||
| if len(result) == 3: | ||
| out, _, softmax_lse = result |
Owner
Author
There was a problem hiding this comment.
in this case, should we:
- always require the output_scales tensor input
- never return it for a consistent response format?
this applies to both here and interface on FA side.
Owner
Author
There was a problem hiding this comment.
🤖 Claude (Opus 4.8) — Done (both points): output_scales is now input-only and the return is a consistent (out, lse) on both sides. FA _flash_attn_fwd / flash_attn_func / flash_attn_varlen_func no longer append the scales (caller reads its preallocated buffer), which also removed the len(grad_outputs)==3 special-case in FA backward; the vLLM wrapper drops the len(result)==3 branch. Validated on B200: FA per-group 44/44, vLLM unit 20/20. (FA pin → 65525e7.)
e753d50 to
c79c02f
Compare
e1a3b1d to
5302244
Compare
Wire per-group (block) dynamic FP8 fused attention output through the FA4 MLA prefill backend, the per-group analogue of the static fused-output path. When the o_proj activation quant is per-group FP8 (kFp8Dynamic128Sym / kFp8Dynamic64Sym), the prefill kernel writes FP8 directly with per-group scales instead of running a standalone post-quant kernel. Scales are produced in the layout the downstream GEMM consumes: DeepGEMM's column-major / TMA-aligned UE8M0 (power-of-two) scales, or plain row-major fp32 for the Triton fallback. forward_impl only takes the fused path when the requested scale layout is self-consistent (col-major == ue8m0 == tma-aligned). Temporarily pins vllm-flash-attn at the per-group FP8 output branch to run CI pre-merge; revert once the flash-attention change lands. Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
5302244 to
045b1d0
Compare
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Warning
Draft / do not merge. Depends on the companion flash-attention change
(vllm-project/flash-attention#151). This PR temporarily pins
vllm_flash_attnat a fork branch to run CI against the new FA commit; the pin must bereverted to the upstream commit before merge.
Purpose
Wire per-group (block) dynamic FP8 fused attention output through the FA4 MLA prefill backend — the per-group analogue of the static fused-output path. When the o_proj activation quant is per-group FP8 (
kFp8Dynamic128Sym/kFp8Dynamic64Sym), the prefill kernel writes FP8 directly with per-group scales (DeepGEMM column-major / UE8M0 / TMA-aligned, or plain row-major) instead of running a standalone post-quant kernel.Test Plan
tests/v1/attention/test_mla_prefill_quant_output.py—supports_quant_outputgating + GPU equivalence.tests/compile/passes/test_mla_attn_quant_fusion.py— adds a prefillforward_mhaper-group case that exercises the fused path (asserts the FA prefill call receivesoutput_scalesand the fused FP8 output matches a bf16 reference).benchmarks/attention_benchmarks— adds a per-group FP8-output mode (--fp8-output-pergroup+configs/mla_fa4_fp8_output_pergroup.yaml).Test Result
B200 (SM100):
test_mla_prefill_quant_output.py20 passed;test_mla_attn_quant_fusion.py9 passed (incl. the new prefill case).Benchmark Result
benchmark.py --config configs/mla_fa4_fp8_output_pergroup.yamlon B200 (DeepSeek-V2-Lite dims, FA4 prefill + DeepGEMM E8M0).post_quant= bf16 attention + standalone per-group quant;fused= FA4 writes FP8 directly. The delta is the post-quant kernel the fused path removes (a small fraction of the full MLA prefill forward).Essential Elements of an Effective PR Description Checklist