Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions benchmarks/attention_benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,12 @@ def main():
"standalone static-FP8 quant, true = FA4 writes FP8 directly. "
"Default: both.",
)
parser.add_argument(
"--fp8-output-pergroup",
action="store_true",
help="Use per-group (block, 1x128) dynamic FP8 for the FP8-output "
"comparison instead of per-tensor static.",
)

# Batch specifications
parser.add_argument(
Expand Down Expand Up @@ -647,6 +653,8 @@ def main():
args.fp8_output_scale = yaml_config.get("fp8_output_scale", None)
if args.fuse_quant_op is None:
args.fuse_quant_op = yaml_config.get("fuse_quant_op", None)
if not args.fp8_output_pergroup:
args.fp8_output_pergroup = yaml_config.get("fp8_output_pergroup", False)

# Check for special modes
args.mode = yaml_config.get("mode", None)
Expand Down Expand Up @@ -823,10 +831,12 @@ def main():
if fp8_output_scale is not None:
decode_backend = backends[0]
fuse_variants = args.fuse_quant_op or [False, True]
pergroup = getattr(args, "fp8_output_pergroup", False)
label_of = {False: "post_quant", True: "fused"}
console.print(
f"[yellow]FP8 output comparison @ scale={fp8_output_scale} "
f"(prefill=fa4, decode impl={decode_backend})[/]"
f"({'per-group' if pergroup else 'static'}, prefill=fa4, "
f"decode impl={decode_backend})[/]"
)
fp8_results = []
total = len(fuse_variants) * len(args.batch_specs)
Expand All @@ -842,15 +852,18 @@ def main():
num_kv_heads=args.num_kv_heads,
block_size=args.block_size,
device=args.device,
repeats=args.repeats,
warmup_iters=args.warmup_iters,
warmup_ms=args.warmup_ms,
ncu_profile=args.ncu_profile,
profile_memory=args.profile_memory,
kv_cache_dtype=args.kv_cache_dtype,
use_cuda_graphs=args.cuda_graphs,
prefill_backend="fa4",
)
result = run_benchmark(
config, output_scale=fp8_output_scale, fuse_quant_op=fuse
config,
output_scale=fp8_output_scale,
fuse_quant_op=fuse,
output_pergroup=pergroup,
)
label = label_of[fuse]
labeled_config = replace(result.config, backend=label)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# MLA prefill per-group FP8-output microbenchmark (FA4).
# Compares the fused per-group FP8 write against bf16 attention + a standalone
# per-group (block, 1x128) FP8 quant; the delta is the post-quant kernel the fused
# path removes. DeepSeek-V2-Lite dims; FA4 needs SM100/110.
#
# Usage:
# python benchmark.py --config configs/mla_fa4_fp8_output_pergroup.yaml

description: "MLA prefill FA4 fused per-group FP8 output vs post-quant"

model:
name: "deepseek-v2-lite"
num_layers: 27
num_q_heads: 16
num_kv_heads: 1
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
block_size: 128

# Pure prefill (q_len == kv_len) so every token goes through forward_mha.
batch_specs:
- "q512"
- "q1k"
- "q2k"
- "q4k"
- "q8k"
- "2q4k"
- "4q4k"
- "8q4k"

# Only used to construct the MLA impl; the pure-prefill specs skip decode.
decode_backends:
- CUTLASS_MLA

# Sweep the two FP8 write paths (prefill backend is fixed to fa4); per-group scales.
fp8_output_scale: 0.1
fuse_quant_op: [false, true]
fp8_output_pergroup: true

device: "cuda:0"
repeats: 50
warmup_iters: 10
41 changes: 29 additions & 12 deletions benchmarks/attention_benchmarks/mla_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ def _run_single_benchmark(
kv_cache_dtype: str | None = None,
output_scale: float | None = None,
fuse_quant_op: bool = False,
output_pergroup: bool = False,
) -> BenchmarkResult:
"""
Run a single benchmark iteration.
Expand Down Expand Up @@ -832,27 +833,32 @@ def _run_single_benchmark(
)

# Prefill FP8 output: fused (kernel writes e4m3) vs separate post-quant.
# ``output_pergroup`` picks per-group (1x128) dynamic FP8 vs per-tensor static.
prefill_fp8_output = None
prefill_output_scale = None
prefill_output_scales = None
prefill_quant_op = None
if has_prefill and output_scale is not None:
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform

out_t = prefill_inputs["output"]
prefill_output_scale = torch.tensor(
[output_scale], device=device, dtype=torch.float32
)
if fuse_quant_op:
prefill_fp8_output = torch.empty_like(
prefill_inputs["output"], dtype=current_platform.fp8_dtype()
out_t, dtype=current_platform.fp8_dtype()
)
if output_pergroup:
num_groups = out_t.shape[-1] // 128
prefill_output_scales = torch.empty(
out_t.shape[0], num_groups, dtype=torch.float32, device=device
)
elif output_pergroup:
prefill_quant_op = QuantFP8(static=False, group_shape=GroupShape(1, 128))
else:
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
QuantFP8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)

prefill_quant_op = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)

fused_output = output_scale is not None and fuse_quant_op
Expand All @@ -871,14 +877,21 @@ def forward_fn():
metadata,
prefill_inputs["k_scale"],
prefill_fp8_output if fused_output else prefill_inputs["output"],
prefill_output_scale if fused_output else None,
output_scale=(
prefill_output_scale
if fused_output and not output_pergroup
else None
),
output_scales=(
prefill_output_scales if fused_output and output_pergroup else None
),
)
if fused_output:
out = prefill_fp8_output
elif prefill_quant_op is not None:
out, _ = prefill_quant_op(
prefill_inputs["output"], prefill_output_scale
)
# Per-group is dynamic (own scales); static takes the scale.
qargs = () if output_pergroup else (prefill_output_scale,)
out, _ = prefill_quant_op(prefill_inputs["output"], *qargs)
results.append(out)
return results[0] if len(results) == 1 else tuple(results)

Expand Down Expand Up @@ -922,6 +935,7 @@ def _run_mla_benchmark_batched(
prefill_backend: str | None = None,
output_scale: float | None = None,
fuse_quant_op: bool = False,
output_pergroup: bool = False,
) -> list[BenchmarkResult]:
"""
Unified batched MLA benchmark runner for all backends.
Expand Down Expand Up @@ -1063,6 +1077,7 @@ def _run_mla_benchmark_batched(
kv_cache_dtype=kv_cache_dtype,
output_scale=output_scale,
fuse_quant_op=fuse_quant_op,
output_pergroup=output_pergroup,
)
results.append(result)

Expand Down Expand Up @@ -1092,6 +1107,7 @@ def run_mla_benchmark(
prefill_backend: str | None = None,
output_scale: float | None = None,
fuse_quant_op: bool = False,
output_pergroup: bool = False,
) -> BenchmarkResult | list[BenchmarkResult]:
"""
Unified MLA benchmark runner for all backends.
Expand Down Expand Up @@ -1144,6 +1160,7 @@ def run_mla_benchmark(
prefill_backend=prefill_backend,
output_scale=output_scale,
fuse_quant_op=fuse_quant_op,
output_pergroup=output_pergroup,
)

# Return single result or list based on input
Expand Down
6 changes: 4 additions & 2 deletions cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ if(VLLM_FLASH_ATTN_SRC_DIR)
else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 803020a8fa15407871341d41eba4919ade2ee1ee
# TEMP: point at carlyou/flash-attention per-group FP8 output branch to run CI
# against it pre-merge. Revert to vllm-project/flash-attention once the FA change lands.
GIT_REPOSITORY https://github.com/carlyou/flash-attention.git
GIT_TAG 3d1ac23b3eca212d37ef50c169e13bb0bd88adc0
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
Loading
Loading