diff --git a/benchmarks/attention_benchmarks/benchmark.py b/benchmarks/attention_benchmarks/benchmark.py index 9860d4b2d1c2..46925ae9d761 100644 --- a/benchmarks/attention_benchmarks/benchmark.py +++ b/benchmarks/attention_benchmarks/benchmark.py @@ -509,6 +509,12 @@ def main(): "standalone static-FP8 quant, true = FA4 writes FP8 directly. " "Default: both.", ) + parser.add_argument( + "--fp8-output-pergroup", + action="store_true", + help="Use per-group (block, 1x128) dynamic FP8 for the FP8-output " + "comparison instead of per-tensor static.", + ) # Batch specifications parser.add_argument( @@ -647,6 +653,8 @@ def main(): args.fp8_output_scale = yaml_config.get("fp8_output_scale", None) if args.fuse_quant_op is None: args.fuse_quant_op = yaml_config.get("fuse_quant_op", None) + if not args.fp8_output_pergroup: + args.fp8_output_pergroup = yaml_config.get("fp8_output_pergroup", False) # Check for special modes args.mode = yaml_config.get("mode", None) @@ -823,10 +831,12 @@ def main(): if fp8_output_scale is not None: decode_backend = backends[0] fuse_variants = args.fuse_quant_op or [False, True] + pergroup = getattr(args, "fp8_output_pergroup", False) label_of = {False: "post_quant", True: "fused"} console.print( f"[yellow]FP8 output comparison @ scale={fp8_output_scale} " - f"(prefill=fa4, decode impl={decode_backend})[/]" + f"({'per-group' if pergroup else 'static'}, prefill=fa4, " + f"decode impl={decode_backend})[/]" ) fp8_results = [] total = len(fuse_variants) * len(args.batch_specs) @@ -842,15 +852,18 @@ def main(): num_kv_heads=args.num_kv_heads, block_size=args.block_size, device=args.device, - repeats=args.repeats, - warmup_iters=args.warmup_iters, + warmup_ms=args.warmup_ms, + ncu_profile=args.ncu_profile, profile_memory=args.profile_memory, kv_cache_dtype=args.kv_cache_dtype, use_cuda_graphs=args.cuda_graphs, prefill_backend="fa4", ) result = run_benchmark( - config, output_scale=fp8_output_scale, fuse_quant_op=fuse + config, + output_scale=fp8_output_scale, + fuse_quant_op=fuse, + output_pergroup=pergroup, ) label = label_of[fuse] labeled_config = replace(result.config, backend=label) diff --git a/benchmarks/attention_benchmarks/configs/mla_fa4_fp8_output_pergroup.yaml b/benchmarks/attention_benchmarks/configs/mla_fa4_fp8_output_pergroup.yaml new file mode 100644 index 000000000000..f3cd4d6127e2 --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/mla_fa4_fp8_output_pergroup.yaml @@ -0,0 +1,45 @@ +# MLA prefill per-group FP8-output microbenchmark (FA4). +# Compares the fused per-group FP8 write against bf16 attention + a standalone +# per-group (block, 1x128) FP8 quant; the delta is the post-quant kernel the fused +# path removes. DeepSeek-V2-Lite dims; FA4 needs SM100/110. +# +# Usage: +# python benchmark.py --config configs/mla_fa4_fp8_output_pergroup.yaml + +description: "MLA prefill FA4 fused per-group FP8 output vs post-quant" + +model: + name: "deepseek-v2-lite" + num_layers: 27 + num_q_heads: 16 + num_kv_heads: 1 + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + block_size: 128 + +# Pure prefill (q_len == kv_len) so every token goes through forward_mha. +batch_specs: + - "q512" + - "q1k" + - "q2k" + - "q4k" + - "q8k" + - "2q4k" + - "4q4k" + - "8q4k" + +# Only used to construct the MLA impl; the pure-prefill specs skip decode. +decode_backends: + - CUTLASS_MLA + +# Sweep the two FP8 write paths (prefill backend is fixed to fa4); per-group scales. +fp8_output_scale: 0.1 +fuse_quant_op: [false, true] +fp8_output_pergroup: true + +device: "cuda:0" +repeats: 50 +warmup_iters: 10 diff --git a/benchmarks/attention_benchmarks/mla_runner.py b/benchmarks/attention_benchmarks/mla_runner.py index c9b3fb29bb92..5be1e054d7a0 100644 --- a/benchmarks/attention_benchmarks/mla_runner.py +++ b/benchmarks/attention_benchmarks/mla_runner.py @@ -710,6 +710,7 @@ def _run_single_benchmark( kv_cache_dtype: str | None = None, output_scale: float | None = None, fuse_quant_op: bool = False, + output_pergroup: bool = False, ) -> BenchmarkResult: """ Run a single benchmark iteration. @@ -832,27 +833,32 @@ def _run_single_benchmark( ) # Prefill FP8 output: fused (kernel writes e4m3) vs separate post-quant. + # ``output_pergroup`` picks per-group (1x128) dynamic FP8 vs per-tensor static. prefill_fp8_output = None prefill_output_scale = None + prefill_output_scales = None prefill_quant_op = None if has_prefill and output_scale is not None: + from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 + from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform + out_t = prefill_inputs["output"] prefill_output_scale = torch.tensor( [output_scale], device=device, dtype=torch.float32 ) if fuse_quant_op: prefill_fp8_output = torch.empty_like( - prefill_inputs["output"], dtype=current_platform.fp8_dtype() + out_t, dtype=current_platform.fp8_dtype() ) + if output_pergroup: + num_groups = out_t.shape[-1] // 128 + prefill_output_scales = torch.empty( + out_t.shape[0], num_groups, dtype=torch.float32, device=device + ) + elif output_pergroup: + prefill_quant_op = QuantFP8(static=False, group_shape=GroupShape(1, 128)) else: - from vllm.model_executor.layers.quantization.input_quant_fp8 import ( - QuantFP8, - ) - from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, - ) - prefill_quant_op = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) fused_output = output_scale is not None and fuse_quant_op @@ -871,14 +877,21 @@ def forward_fn(): metadata, prefill_inputs["k_scale"], prefill_fp8_output if fused_output else prefill_inputs["output"], - prefill_output_scale if fused_output else None, + output_scale=( + prefill_output_scale + if fused_output and not output_pergroup + else None + ), + output_scales=( + prefill_output_scales if fused_output and output_pergroup else None + ), ) if fused_output: out = prefill_fp8_output elif prefill_quant_op is not None: - out, _ = prefill_quant_op( - prefill_inputs["output"], prefill_output_scale - ) + # Per-group is dynamic (own scales); static takes the scale. + qargs = () if output_pergroup else (prefill_output_scale,) + out, _ = prefill_quant_op(prefill_inputs["output"], *qargs) results.append(out) return results[0] if len(results) == 1 else tuple(results) @@ -922,6 +935,7 @@ def _run_mla_benchmark_batched( prefill_backend: str | None = None, output_scale: float | None = None, fuse_quant_op: bool = False, + output_pergroup: bool = False, ) -> list[BenchmarkResult]: """ Unified batched MLA benchmark runner for all backends. @@ -1063,6 +1077,7 @@ def _run_mla_benchmark_batched( kv_cache_dtype=kv_cache_dtype, output_scale=output_scale, fuse_quant_op=fuse_quant_op, + output_pergroup=output_pergroup, ) results.append(result) @@ -1092,6 +1107,7 @@ def run_mla_benchmark( prefill_backend: str | None = None, output_scale: float | None = None, fuse_quant_op: bool = False, + output_pergroup: bool = False, ) -> BenchmarkResult | list[BenchmarkResult]: """ Unified MLA benchmark runner for all backends. @@ -1144,6 +1160,7 @@ def run_mla_benchmark( prefill_backend=prefill_backend, output_scale=output_scale, fuse_quant_op=fuse_quant_op, + output_pergroup=output_pergroup, ) # Return single result or list based on input diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index ea7ac544b9d7..56f31d08d8f6 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,8 +38,10 @@ if(VLLM_FLASH_ATTN_SRC_DIR) else() FetchContent_Declare( vllm-flash-attn - GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 803020a8fa15407871341d41eba4919ade2ee1ee + # TEMP: point at carlyou/flash-attention per-group FP8 output branch to run CI + # against it pre-merge. Revert to vllm-project/flash-attention once the FA change lands. + GIT_REPOSITORY https://github.com/carlyou/flash-attention.git + GIT_TAG 3d1ac23b3eca212d37ef50c169e13bb0bd88adc0 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/tests/compile/passes/test_mla_attn_quant_fusion.py b/tests/compile/passes/test_mla_attn_quant_fusion.py index 0a38ffca483a..4de6896ed79f 100644 --- a/tests/compile/passes/test_mla_attn_quant_fusion.py +++ b/tests/compile/passes/test_mla_attn_quant_fusion.py @@ -32,6 +32,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.cutlass import ( CutlassFp8BlockScaledMMKernel, ) +from vllm.model_executor.kernels.linear.scaled_mm.triton import ( + TritonFp8BlockScaledMMKernel, +) from vllm.model_executor.layers.attention import MLAAttention from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization.fp8 import Fp8Config @@ -46,6 +49,7 @@ ) from vllm.platforms import current_platform from vllm.v1.attention.backend import AttentionMetadata +from vllm.v1.attention.backends.mla.prefill.flash_attn import FlashAttnPrefillBackend from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.kv_cache_interface import MLAAttentionSpec @@ -131,18 +135,19 @@ def __init__( device=self.device, ) - def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: + def build_attn_metadata( + self, batch_size: int, query_len: int = 1 + ) -> AttentionMetadata: """Initialize MLA attention metadata. - NOTE: Uses decode-only batch (query_len=1 per request). The prefill - (forward_mha) path is not separately tested here because it requires - FlashAttention availability and different input tensor shapes. The - quant logic in forward_impl is identical for both paths — it quantizes - the full output[:num_actual_toks] buffer after both forward_mha and - forward_mqa have written their results. + ``query_len == 1`` is a decode-only batch (forward_mqa). ``query_len > 1`` + is a pure-prefill batch (seq_len == query_len, no context) which routes + through forward_mha — needed to exercise the fused per-group output path. """ - batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size) + batch_spec = BatchSpec( + seq_lens=[query_len] * batch_size, query_lens=[query_len] * batch_size + ) common_attn_metadata = create_common_attn_metadata( batch_spec, self.block_size, self.device, arange_block_indices=True ) @@ -288,6 +293,8 @@ class TestMLAAttentionFp8GroupQuantPatternModel(MLAAttentionQuantPatternModel): is_checkpoint_fp8_serialized=True, weight_block_size=[128, 128], ) + # o_proj block-scaled MM kernel; subclasses override to change the scale layout. + block_kernel = CutlassFp8BlockScaledMMKernel def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -303,15 +310,15 @@ def __init__(self, *a, **kw): self.weight_block_size = [128, 128] super().__init__(*a, **kw) - # Force CutlassFp8BlockScaledMMKernel to ensure the graph uses - # per_token_group_fp8_quant (not the deepgemm packed variant). + # Force a block-scaled kernel that emits per_token_group_fp8_quant (not the + # deepgemm packed variant) so the fusion pattern matches. self.block_fp8_linear = _BlockFP8Layer( weight_shape=(self.output_dim, self.output_dim), activation_quant_key=self.quant_key, weight_quant_key=weight_quant_key, input_dtype=self.dtype, device=device, - force_kernel=CutlassFp8BlockScaledMMKernel, + force_kernel=self.block_kernel, ) w = kwargs.get("w") @@ -341,6 +348,16 @@ def forward( return self.block_fp8_linear(attn_output) +class TestMLAAttentionFp8GroupQuantPatternModelTriton( + TestMLAAttentionFp8GroupQuantPatternModel +): + """Per-group FP8 with the Triton block-scaled o_proj, which produces plain row-major + (non-ue8m0, non-col-major) scales. That layout satisfies the FA4 fused-output gate + (col == ue8m0 == tma), so the prefill forward_mha fused path engages.""" + + block_kernel = TritonFp8BlockScaledMMKernel + + def is_nvfp4_supported(): return current_platform.has_device_capability(100) @@ -585,3 +602,130 @@ def test_mla_attention_quant_pattern( # Check numerical correctness torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2) + + +@pytest.mark.skipif( + not current_platform.is_cuda() or not current_platform.has_device_capability(100), + reason="FA4 fused per-group output requires Blackwell (SM100/SM110).", +) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("query_len", [32]) +def test_mla_prefill_pergroup_fused_output( + batch_size: int, + query_len: int, + dist_init, + monkeypatch, + use_fresh_inductor_cache, +): + """Exercise the FA4 *prefill* per-group FP8 fused-output path (forward_mha). + + ``test_mla_attention_quant_pattern`` is decode-only (forward_mqa -> post-quant). A + prefill batch routes through ``forward_mha``, and the Triton block-scaled o_proj + (plain row-major scales) makes ``forward_impl``'s layout gate engage the fused path. + + Asserts (a) the fused path ran — the FA prefill call received ``output_scales`` — + and (b) the fused fp8 output is correct: dequantizing it matches a bf16 reference + attention on the same inputs within fp8 per-group tolerance. The downstream o_proj + GEMM's *consumption* of the scales is a separate concern that this synthetic harness + can't set up faithfully (the block-scaled GEMM needs real weight processing), so the + GEMM result itself is not asserted. + """ + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + num_heads, qk_nope, qk_rope, v_head_dim, kv_lora = 16, 128, 64, 128, 512 + qk_head_dim = qk_nope + qk_rope + num_tokens = batch_size * query_len + dtype = torch.bfloat16 + device = torch.device(f"{DEVICE_TYPE}:0") + torch.set_default_dtype(dtype) + torch.manual_seed(42) + + model_config = ModelConfig( + model="deepseek-ai/DeepSeek-V3", max_model_len=2048, dtype=dtype + ) + vllm_config = VllmConfig( + model_config=model_config, + scheduler_config=SchedulerConfig( + max_num_seqs=1024, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, custom_ops=["+quant_fp8"] + ), + cache_config=CacheConfig(cache_dtype="auto"), + attention_config=AttentionConfig(backend=AttentionBackendEnum.TRITON_MLA), + ) + vllm_config.compilation_config.pass_config = PassConfig( + fuse_attn_quant=True, eliminate_noops=True + ) + + q = torch.randn(num_tokens, num_heads, qk_head_dim, dtype=dtype, device=device) + kv_c_normed = torch.randn(num_tokens, kv_lora, dtype=dtype, device=device) + k_pe = torch.randn(num_tokens, 1, qk_rope, dtype=dtype, device=device) + torch._dynamo.mark_dynamic(q, 0) + torch._dynamo.mark_dynamic(kv_c_normed, 0) + torch._dynamo.mark_dynamic(k_pe, 0) + + # Spy on the FA prefill call: confirm the fused path ran (output_scales passed) and + # that the fused fp8 output dequantizes to a bf16 reference attention (same inputs). + orig_run = FlashAttnPrefillBackend.run_prefill_new_tokens + fused_ran = False + fa_err: dict = {} + + def spy_run(self, *args, out=None, output_scale=None, output_scales=None, **kwargs): + nonlocal fused_ran + if output_scales is None: + return orig_run(self, *args, out=out, output_scale=output_scale, **kwargs) + fused_ran = True + bf16_ref = orig_run(self, *args, out=None, output_scale=None, **kwargs) + ret = orig_run( + self, + *args, + out=out, + output_scale=output_scale, + output_scales=output_scales, + **kwargs, + ) + deq = out.float() * output_scales.float() + fa_err["abs"] = (deq - bf16_ref.float()).abs().max().item() + fa_err["amax"] = bf16_ref.float().abs().max().item() + return ret + + monkeypatch.setattr(FlashAttnPrefillBackend, "run_prefill_new_tokens", spy_run) + + with ( + set_current_vllm_config(vllm_config), + set_forward_context(attn_metadata=None, vllm_config=vllm_config), + ): + model = TestMLAAttentionFp8GroupQuantPatternModelTriton( + num_heads=num_heads, + qk_nope_head_dim=qk_nope, + qk_rope_head_dim=qk_rope, + v_head_dim=v_head_dim, + kv_lora_rank=kv_lora, + kv_cache_dtype=dtype, + device=device, + vllm_config=vllm_config, + ).to(device) + model(q, kv_c_normed, k_pe) # HACK: warmup, see #131044 + get_forward_context().attn_metadata = model.build_attn_metadata( + batch_size, query_len + ) + test_backend = TestBackend( + NoOpEliminationPass(vllm_config), + LazyInitPass(MLAAttnQuantFusionPass, vllm_config), + PostCleanupPass(vllm_config), + ) + # forward_impl -> forward_mha (prefill) -> FA fused output -> o_proj GEMM. + torch.compile(model, backend=test_backend, fullgraph=True)(q, kv_c_normed, k_pe) + + assert fused_ran, ( + "fused per-group prefill path was not exercised: the FA prefill call never " + "received output_scales (forward_mha fell back to the post-quant path)" + ) + # fp8-e4m3 per-group quant: max abs error is a few % of the block amax. + assert fa_err["abs"] <= 0.1 * fa_err["amax"], ( + f"fused per-group fp8 output diverges from bf16 attention: " + f"max_abs={fa_err['abs']:.3f} vs amax={fa_err['amax']:.3f}" + ) diff --git a/tests/v1/attention/test_mla_prefill_quant_output.py b/tests/v1/attention/test_mla_prefill_quant_output.py index d7659485aa9f..6a3931856748 100644 --- a/tests/v1/attention/test_mla_prefill_quant_output.py +++ b/tests/v1/attention/test_mla_prefill_quant_output.py @@ -16,6 +16,7 @@ import torch from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8Dynamic64Sym, kFp8Dynamic128Sym, kFp8StaticTensorSym, kNvfp4Dynamic, @@ -63,20 +64,22 @@ def _make_fa_backend(version: int | None, is_vllm_fa: bool): @pytest.mark.parametrize( ("version", "is_vllm_fa", "dc_major", "quant_key", "expected"), [ - # FA4 + vLLM-FA + Blackwell SM100/SM110 + static FP8 -> fused. + # FA4 + vLLM-FA + Blackwell SM100/SM110 + static / per-group FP8 -> fused. (4, True, 10, kFp8StaticTensorSym, True), (4, True, 11, kFp8StaticTensorSym, True), + (4, True, 10, kFp8Dynamic128Sym, True), + (4, True, 11, kFp8Dynamic64Sym, True), # Wrong compute capability (SM90 / SM120) -> not supported (#135). (4, True, 9, kFp8StaticTensorSym, False), (4, True, 12, kFp8StaticTensorSym, False), + (4, True, 9, kFp8Dynamic128Sym, False), # Not FA4. (3, True, 10, kFp8StaticTensorSym, False), (2, True, 10, kFp8StaticTensorSym, False), (None, True, 10, kFp8StaticTensorSym, False), # Upstream (ROCm) flash-attn, not vLLM-FA. (4, False, 10, kFp8StaticTensorSym, False), - # Quant keys not wired through FA4 yet. - (4, True, 10, kFp8Dynamic128Sym, False), + # NVFP4 not wired through FA4. (4, True, 10, kNvfp4Dynamic, False), ], ) diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 21e3215479fe..663210f659d3 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -687,6 +687,13 @@ def forward_impl( num_mqa_tokens = attn_metadata.num_decode_tokens num_mha_tokens = q.size(0) - num_mqa_tokens + is_pergroup = quant_key in (kFp8Dynamic128Sym, kFp8Dynamic64Sym) + # FA derives UE8M0 from the column-major scale layout, so it can only write the + # two self-consistent layouts: DeepGEMM (col-major + ue8m0 + tma-aligned) or + # plain row-major fp32. + pergroup_layout_ok = ( + bool(quant_col_major) == bool(quant_scale_ue8m0) == bool(quant_tma_aligned) + ) mha_use_quant_output = ( quant_key is not None and self.prefill_backend.supports_quant_output(quant_key) @@ -694,15 +701,18 @@ def forward_impl( and attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is None and self.impl.dcp_world_size <= 1 + and (not is_pergroup or pergroup_layout_ok) ) if num_mha_tokens > 0: if mha_use_quant_output: mha_output = quant_output mha_output_scale = output_scale + mha_output_scales = output_block_scale else: mha_output = output mha_output_scale = None + mha_output_scales = None self.impl.forward_mha( # type: ignore[attr-defined] q[num_mqa_tokens:], @@ -713,6 +723,11 @@ def forward_impl( self._k_scale, output=mha_output[num_mqa_tokens:num_actual_toks], output_scale=mha_output_scale, + output_scales=( + mha_output_scales[num_mqa_tokens:num_actual_toks] + if mha_output_scales is not None + else None + ), ) if num_mqa_tokens > 0: @@ -2274,6 +2289,7 @@ def forward_mha( k_scale: torch.Tensor, output: torch.Tensor, output_scale: torch.Tensor | None = None, + output_scales: torch.Tensor | None = None, ) -> None: assert attn_metadata.prefill is not None assert self.dcp_world_size != -1 @@ -2287,7 +2303,8 @@ def forward_mha( q = q.to(prefill_metadata.q_data_type) has_context = prefill_metadata.chunked_context is not None - assert output_scale is None or not has_context, ( + fused_output = output_scale is not None or output_scales is not None + assert not fused_output or not has_context, ( "Fused FP8 output is only wired for the non-chunked-context path" ) @@ -2308,10 +2325,19 @@ def forward_mha( return_softmax_lse=has_context, out=( output.view(-1, self.num_heads, self.v_head_dim) - if output_scale is not None + if fused_output else None ), output_scale=output_scale, + # (tokens, heads*groups) -> (tokens, heads, groups); unflatten (not view) + # preserves the column-major / TMA-aligned DeepGEMM scale strides. + output_scales=( + output_scales.unflatten( + -1, (self.num_heads, output_scales.shape[-1] // self.num_heads) + ) + if output_scales is not None + else None + ), ) if has_context: @@ -2341,8 +2367,8 @@ def forward_mha( suffix_lse=suffix_lse, prefill_tokens_with_context=prefill_metadata.chunked_context.prefill_tokens_with_context, ) - elif output_scale is None: - # With output_scale set, backend already wrote into `output` in place. + elif not fused_output: + # With fused output, the backend already wrote `output` in place. assert isinstance(output_prefill, torch.Tensor) output_prefill = output_prefill.flatten(start_dim=-2) output.copy_(output_prefill) diff --git a/vllm/v1/attention/backends/mla/prefill/base.py b/vllm/v1/attention/backends/mla/prefill/base.py index ff478aec4ad1..49cea0bb41cd 100644 --- a/vllm/v1/attention/backends/mla/prefill/base.py +++ b/vllm/v1/attention/backends/mla/prefill/base.py @@ -135,6 +135,7 @@ def run_prefill_new_tokens( return_softmax_lse: bool, out: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, + output_scales: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError diff --git a/vllm/v1/attention/backends/mla/prefill/flash_attn.py b/vllm/v1/attention/backends/mla/prefill/flash_attn.py index 24763378e66b..25bbc26dc437 100644 --- a/vllm/v1/attention/backends/mla/prefill/flash_attn.py +++ b/vllm/v1/attention/backends/mla/prefill/flash_attn.py @@ -9,6 +9,8 @@ import vllm.envs as envs from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8Dynamic64Sym, + kFp8Dynamic128Sym, kFp8StaticTensorSym, ) from vllm.platforms import current_platform @@ -98,7 +100,7 @@ def supports_quant_output(self, quant_key: "QuantKey") -> bool: and self._is_vllm_fa and device_capability is not None and device_capability[0] in (10, 11) - and quant_key == kFp8StaticTensorSym + and quant_key in (kFp8StaticTensorSym, kFp8Dynamic128Sym, kFp8Dynamic64Sym) ) def _flash_attn_varlen_diff_headdims( @@ -110,6 +112,7 @@ def _flash_attn_varlen_diff_headdims( softmax_scale: float | None = None, out: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, + output_scales: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: maybe_padded_v = v @@ -122,11 +125,12 @@ def _flash_attn_varlen_diff_headdims( kwargs["return_softmax_lse"] = return_softmax_lse kwargs["out"] = out kwargs["output_scale"] = output_scale + kwargs["output_scales"] = output_scales else: # ROCm leverages the upstream flash_attn, which takes a parameter # called "return_attn_probs" instead of return_softmax_lse kwargs["return_attn_probs"] = return_softmax_lse - assert out is None and output_scale is None + assert out is None and output_scale is None and output_scales is None if envs.VLLM_BATCH_INVARIANT: kwargs["num_splits"] = 1 @@ -161,6 +165,7 @@ def run_prefill_new_tokens( return_softmax_lse: bool, out: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, + output_scales: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: return self._flash_attn_varlen_diff_headdims( q=q, @@ -175,6 +180,7 @@ def run_prefill_new_tokens( return_softmax_lse=return_softmax_lse, out=out, output_scale=output_scale, + output_scales=output_scales, ) def run_prefill_context_chunk( diff --git a/vllm/v1/attention/backends/mla/prefill/flashinfer.py b/vllm/v1/attention/backends/mla/prefill/flashinfer.py index 557c16f97f01..4cd1aace36ec 100644 --- a/vllm/v1/attention/backends/mla/prefill/flashinfer.py +++ b/vllm/v1/attention/backends/mla/prefill/flashinfer.py @@ -199,6 +199,7 @@ def run_prefill_new_tokens( return_softmax_lse: bool, out: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, + output_scales: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self._prefill_main is not None diff --git a/vllm/v1/attention/backends/mla/prefill/tokenspeed_mla.py b/vllm/v1/attention/backends/mla/prefill/tokenspeed_mla.py index 1f041f37317f..e0fa478039d0 100644 --- a/vllm/v1/attention/backends/mla/prefill/tokenspeed_mla.py +++ b/vllm/v1/attention/backends/mla/prefill/tokenspeed_mla.py @@ -126,6 +126,7 @@ def run_prefill_new_tokens( return_softmax_lse: bool, out: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, + output_scales: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from tokenspeed_mla import tokenspeed_mla_prefill diff --git a/vllm/v1/attention/backends/mla/prefill/trtllm_ragged.py b/vllm/v1/attention/backends/mla/prefill/trtllm_ragged.py index 90f721272dc9..1597db420f0c 100644 --- a/vllm/v1/attention/backends/mla/prefill/trtllm_ragged.py +++ b/vllm/v1/attention/backends/mla/prefill/trtllm_ragged.py @@ -99,6 +99,7 @@ def run_prefill_new_tokens( return_softmax_lse: bool, out: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, + output_scales: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from flashinfer.prefill import trtllm_ragged_attention_deepseek diff --git a/vllm/vllm_flash_attn/flash_attn_interface.py b/vllm/vllm_flash_attn/flash_attn_interface.py index f0150033672e..832c05bb6f88 100644 --- a/vllm/vllm_flash_attn/flash_attn_interface.py +++ b/vllm/vllm_flash_attn/flash_attn_interface.py @@ -202,6 +202,7 @@ def flash_attn_varlen_func( num_splits: int = 0, # FA4 Only output_scale=None, + output_scales=None, # Version selector fa_version: int = DEFAULT_FA_VERSION, s_aux=None, @@ -275,8 +276,8 @@ def flash_attn_varlen_func( "seqused_k must be provided if block_table is provided" ) - assert output_scale is None or fa_version == 4, ( - f"Fused FP8 output (output_scale) is only supported by FA4, " + assert (output_scale is None and output_scales is None) or fa_version == 4, ( + f"Fused FP8 output (output_scale/output_scales) is only supported by FA4, " f"got fa_version={fa_version}" ) @@ -411,6 +412,8 @@ def flash_attn_varlen_func( mask_mod=mask_mod, aux_tensors=aux_tensors, output_scale=output_scale, + # Per-group scales are written in place into output_scales (input-only). + output_scales=output_scales, ) else: raise ValueError(f"Unsupported FA version: {fa_version}")