diff --git a/.github/actions/gpu-test/action.yml b/.github/actions/gpu-test/action.yml new file mode 100644 index 00000000000..e03680e41dc --- /dev/null +++ b/.github/actions/gpu-test/action.yml @@ -0,0 +1,67 @@ +name: GPU Test +description: Compile and run FA4 tests (pull SIF from Docker Hub, cache by tag) + +inputs: + test-filter: + description: pytest -k filter expression + required: false + default: "" + compile-workers: + description: parallel workers for Pass 1 kernel compilation + required: false + default: "64" + fa4_image_cu129: + description: Docker image for CUDA 12.9 (used when driver does not support CUDA 13.0) + required: true + fa4_image_cu130: + description: Docker image for CUDA 13.0 (used when driver supports CUDA 13.0) + required: true + +runs: + using: composite + steps: + - name: Select FA4 image based on CUDA version + shell: bash + run: | + # Read max supported CUDA version from nvidia-smi header, e.g. "CUDA Version: 12.9" + CUDA_VER=$(nvidia-smi | grep -oP "CUDA Version: \K[0-9]+\.[0-9]+") + CUDA_MAJOR=$(echo "$CUDA_VER" | cut -d. -f1) + echo "Detected max CUDA version: $CUDA_VER" + if [ "$CUDA_MAJOR" -ge 13 ]; then + echo "Using cu130 image" + echo "FA4_IMAGE=${{ inputs.fa4_image_cu130 }}" >> "$GITHUB_ENV" + else + echo "Using cu129 image" + echo "FA4_IMAGE=${{ inputs.fa4_image_cu129 }}" >> "$GITHUB_ENV" + fi + + - name: Pull FA4 SIF + shell: bash + run: | + CI_WORK_DIR="${CI_WORK_DIR:-/scratch/user/$USER}" + TAG=$(echo "$FA4_IMAGE" | tr '/: ' '---') + SIF="$CI_WORK_DIR/${TAG}.sif" + # Apptainer doesn't support tag@digest refs — strip the tag, keep digest only. + PULL_REF=$(echo "$FA4_IMAGE" | sed 's/:[^@]*@/@/') + echo "PULL_REF=$PULL_REF" + echo "SIF=$SIF" + mkdir -p "$CI_WORK_DIR/apptainer_cache" /tmp/apptainer_tmp + if [ ! -f "$SIF" ]; then + echo "Pulling $PULL_REF → $SIF" + APPTAINER_TMPDIR="/tmp/apptainer_tmp" \ + APPTAINER_CACHEDIR="$CI_WORK_DIR/apptainer_cache" \ + apptainer pull "$SIF" "docker://$PULL_REF" + else + echo "Using cached SIF: $SIF" + fi + # Remove stale SIFs from previous image versions to prevent unbounded disk growth. + find "$CI_WORK_DIR" -maxdepth 1 -name "*.sif" ! -name "$(basename "$SIF")" -delete + echo "FA4_SIF=$SIF" >> "$GITHUB_ENV" + + - name: Compile and run tests + shell: bash + run: | + python3 "$GITHUB_WORKSPACE/tools/ci/run_fa4_ci.py" \ + --repo-root "$GITHUB_WORKSPACE" \ + --test-filter "${{ inputs.test-filter }}" \ + --compile-workers "${{ inputs.compile-workers }}" diff --git a/.github/scripts/test_ci_local.sh b/.github/scripts/test_ci_local.sh new file mode 100755 index 00000000000..d767492d4e1 --- /dev/null +++ b/.github/scripts/test_ci_local.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../../" && pwd) + +python3 "$SCRIPT_DIR/tools/ci/run_fa4_ci.py" \ + --repo-root "$SCRIPT_DIR" \ + "$@" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000000..a992b677a11 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,45 @@ +name: CI + +on: + push: + branches: [main, ci-fix] + +permissions: + contents: read + +env: + CI_WORK_DIR: ${{ vars.CI_WORK_DIR || format('/scratch/user/{0}', github.actor) }} + FA4_TEST_FILTER: "1024-1024-128-True-0-0.0-False-False-False-mha-dtype0 or 1024-1024-128-False-0-0.0-False-False-False-mha-dtype0" + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install ruff + run: pip install ruff + - name: Ruff check + run: ruff check flash_attn/cute/ --extend-exclude "flash_attn/cute/flash_bwd.py,flash_attn/cute/flash_fwd.py,flash_attn/cute/flash_fwd_sm100.py,flash_attn/cute/interface.py" + - name: Ruff format + run: ruff format --check flash_attn/cute/ --exclude "flash_attn/cute/flash_bwd.py,flash_attn/cute/flash_fwd.py,flash_attn/cute/flash_fwd_sm100.py,flash_attn/cute/interface.py" + + fa4-correctness-and-benchmark: + strategy: + fail-fast: false + matrix: + gpu: [b200] + runs-on: [self-hosted, '${{ matrix.gpu }}'] + name: fa4-correctness-and-benchmark (${{ matrix.gpu }}) + timeout-minutes: 60 + steps: + - uses: actions/checkout@v4 + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - uses: ./.github/actions/gpu-test + with: + test-filter: ${{ env.FA4_TEST_FILTER }} + fa4_image_cu129: "togethercomputer/training-performance:flash-attn-cu12.9-26.03.25@sha256:304a5c3d2b3a75b151cd2a964cd26d444e0d8b5686d63943df13378c9705f943" + fa4_image_cu130: "togethercomputer/training-performance:flash-attn-cu13.0-26.04.01@sha256:56e50b056eb4d671410846c3483e843ee7bd0f5b13cb45b6f0d7eb8bd27694a5" diff --git a/.gitignore b/.gitignore index a4f1703b494..387a5f4535e 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,8 @@ var/ # Dev venv +agent_space/ +benchmarks/results/ # compile-time generated file flash_attn_config.py diff --git a/AGENTS.md b/AGENTS.md new file mode 120000 index 00000000000..681311eb9cf --- /dev/null +++ b/AGENTS.md @@ -0,0 +1 @@ +CLAUDE.md \ No newline at end of file diff --git a/AI/CLC_TRACE_DEBUG.md b/AI/CLC_TRACE_DEBUG.md new file mode 100644 index 00000000000..9f1502aa57a --- /dev/null +++ b/AI/CLC_TRACE_DEBUG.md @@ -0,0 +1,82 @@ +# CLC Trace Debugging + +Use this when you suspect the CLC work scheduler is making surprising tile assignment decisions and you want a raw scheduler trace from the current kernel. + +## Current trace format + +SM100 forward kernels emit one trace line per scheduler-warp query at `FA_LOG_LEVEL=3`: + +```text +[CLC] query sm= cta= (m_blk=,h=,b=,s=) valid=<0|1> +``` + +Current emit sites: +- `flash_attn/cute/flash_fwd_sm100.py` +- `flash_attn/cute/flash_fwd_mla_sm100.py` + +## How to capture a trace + +Important: +- `FA_LOG_LEVEL=3` is needed for the `[CLC] query ...` device-side prints. +- `FA_CLC=1` only requests CLC; the kernel may still fall back if the shape/features disable it. + +Minimal repro pattern: + +```bash +FA_LOG_LEVEL=3 FA_CLC=1 CUDA_VISIBLE_DEVICES=0 python - <<'PY' \ + > agent_space/clc_trace.log 2>&1 +import torch +from flash_attn.cute.interface import flash_attn_func + +torch.manual_seed(0) +q = torch.randn(1, 512, 16, 128, device='cuda', dtype=torch.bfloat16) +k = torch.randn(1, 512, 1, 128, device='cuda', dtype=torch.bfloat16) +v = torch.randn(1, 512, 1, 128, device='cuda', dtype=torch.bfloat16) +flash_attn_func(q, k, v, causal=True) +torch.cuda.synchronize() +PY +``` + +If you want the run to say explicitly whether CLC was selected, keep the host log prefix too: + +```text +[FA] TileScheduler=SingleTileLPTScheduler, scheduling_mode=CLC, USE_2CTA=False +``` + +## What to look for + +- `scheduling_mode=CLC` in host logs confirms the shape actually used the CLC path. +- `valid=1` means the returned work tile is valid. +- `valid=0` means the scheduler is exhausted for that CTA/scheduler warp query. +- `m_blk`, `h`, `b`, `s` are the logical work coordinates after the scheduler mapping. +- `cta` is the physical `blockIdx.x`; for clustered launches multiple CTAs may participate in the same logical tile. + +## Parse the trace + +A lightweight parser lives in `AI/parse_clc_log.py`. + +Text summary: + +```bash +python AI/parse_clc_log.py agent_space/clc_trace.log +``` + +HTML view: + +```bash +python AI/parse_clc_log.py agent_space/clc_trace.log --html -o agent_space/clc_trace.html +``` + +## Suggested workflow + +1. Reproduce the surprising case with `FA_LOG_LEVEL=3 FA_CLC=1`. +2. Save stdout/stderr to `agent_space/clc_trace.log`. +3. Run `AI/parse_clc_log.py` on that log to get a compact per-SM / per-CTA summary. +4. If the trace still looks suspicious, attach or paste that log in the investigation thread / agent notes. +5. Compare against the relevant mapping logic in `flash_attn/cute/tile_scheduler.py`. + +## Caveats + +- The trace is noisy and expensive; use a single small shape first. +- Because the print happens on scheduler queries, many lines may be terminal `valid=0` queries after work is exhausted. +- Dense noncausal and varlen MHA may intentionally fall back away from CLC depending on the current heuristic in `flash_attn/cute/interface.py`. diff --git a/AI/parse_clc_log.py b/AI/parse_clc_log.py new file mode 100644 index 00000000000..c1b94543bf4 --- /dev/null +++ b/AI/parse_clc_log.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import html +import json +import re +import sys +from collections import Counter, defaultdict +from dataclasses import asdict, dataclass +from pathlib import Path + +TRACE_RE = re.compile( + r"\[CLC\]\s+query\s+sm=(?P\d+)\s+cta=(?P\d+)\s+" + r"\(m_blk=(?P-?\d+),h=(?P-?\d+),b=(?P-?\d+),s=(?P-?\d+)\)\s+" + r"valid=(?P[01])" +) + + +@dataclass(frozen=True) +class TraceRow: + sm: int + cta: int + m_blk: int + h: int + b: int + s: int + valid: int + + +def parse_rows(text: str) -> list[TraceRow]: + rows: list[TraceRow] = [] + for line in text.splitlines(): + match = TRACE_RE.search(line) + if match is None: + continue + rows.append(TraceRow(**{key: int(value) for key, value in match.groupdict().items()})) + return rows + + +def summarize(rows: list[TraceRow]) -> dict: + by_sm: dict[int, list[TraceRow]] = defaultdict(list) + by_cta: dict[int, list[TraceRow]] = defaultdict(list) + tile_counter: Counter[tuple[int, int, int, int, int]] = Counter() + for row in rows: + by_sm[row.sm].append(row) + by_cta[row.cta].append(row) + tile_counter[(row.m_blk, row.h, row.b, row.s, row.valid)] += 1 + + def encode_group(grouped: dict[int, list[TraceRow]]) -> dict[str, dict]: + out: dict[str, dict] = {} + for key, group_rows in sorted(grouped.items()): + out[str(key)] = { + "count": len(group_rows), + "valid_count": sum(row.valid for row in group_rows), + "invalid_count": sum(1 - row.valid for row in group_rows), + "first": asdict(group_rows[0]), + "last": asdict(group_rows[-1]), + "unique_tiles": len({(r.m_blk, r.h, r.b, r.s, r.valid) for r in group_rows}), + } + return out + + top_tiles = [ + { + "tile": { + "m_blk": tile[0], + "h": tile[1], + "b": tile[2], + "s": tile[3], + "valid": tile[4], + }, + "count": count, + } + for tile, count in tile_counter.most_common(20) + ] + + return { + "rows": len(rows), + "valid_rows": sum(row.valid for row in rows), + "invalid_rows": sum(1 - row.valid for row in rows), + "unique_sms": len(by_sm), + "unique_ctas": len(by_cta), + "by_sm": encode_group(by_sm), + "by_cta": encode_group(by_cta), + "top_tiles": top_tiles, + } + + +def format_summary(summary: dict) -> str: + lines = [ + f"rows={summary['rows']} valid={summary['valid_rows']} invalid={summary['invalid_rows']}", + f"unique_sms={summary['unique_sms']} unique_ctas={summary['unique_ctas']}", + "top_tiles:", + ] + for entry in summary["top_tiles"][:10]: + tile = entry["tile"] + lines.append( + f" count={entry['count']:>4} tile=(m_blk={tile['m_blk']}, h={tile['h']}, b={tile['b']}, s={tile['s']}, valid={tile['valid']})" + ) + lines.append("by_sm:") + for sm, sm_summary in summary["by_sm"].items(): + first = sm_summary["first"] + last = sm_summary["last"] + lines.append( + f" sm={sm:>3} count={sm_summary['count']:>4} valid={sm_summary['valid_count']:>4} invalid={sm_summary['invalid_count']:>4} " + f"first=(cta={first['cta']},m_blk={first['m_blk']},h={first['h']},b={first['b']},s={first['s']},v={first['valid']}) " + f"last=(cta={last['cta']},m_blk={last['m_blk']},h={last['h']},b={last['b']},s={last['s']},v={last['valid']})" + ) + return "\n".join(lines) + + +def visualize_html(rows: list[TraceRow], summary: dict) -> str: + by_sm: dict[int, list[TraceRow]] = defaultdict(list) + for row in rows: + by_sm[row.sm].append(row) + + data = [ + { + "sm": sm, + "tiles": [ + { + "id": r.m_blk, + "type": "INIT" if idx == 0 else "PULL", + "valid": bool(r.valid), + "m": r.m_blk, + "h": r.h, + "b": r.b, + "s": r.s, + "cta": r.cta, + } + for idx, r in enumerate(chain) + ], + } + for sm, chain in sorted(by_sm.items()) + ] + + total_tiles = sum(len(d["tiles"]) for d in data) + valid_pulls = sum(1 for d in data for t in d["tiles"] if t["type"] == "PULL" and t["valid"]) + work_per_sm = [sum(1 for t in d["tiles"] if t["valid"]) for d in data] + histogram = defaultdict(int) + for work in work_per_sm: + histogram[work] += 1 + histogram_data = [{"work": k, "count": v} for k, v in sorted(histogram.items())] + work_stats = { + "min": min(work_per_sm) if work_per_sm else 0, + "max": max(work_per_sm) if work_per_sm else 0, + "mean": (sum(work_per_sm) / len(work_per_sm)) if work_per_sm else 0.0, + "std": ( + sum((w - sum(work_per_sm) / len(work_per_sm)) ** 2 for w in work_per_sm) / len(work_per_sm) + ) ** 0.5 if work_per_sm else 0.0, + } + + return f""" + + + + +CLC Work Distribution Viewer + + + +

CLC Work Distribution Viewer

+
+ query-trace mode + SMs: {len(data)} + Total queries: {total_tiles} + Valid pulls: {valid_pulls} + Invalid queries: {summary['invalid_rows']} +
+
+ + + Press Esc to clear +
+
+
First query on SM
+
Later query / pull
+
Invalid / exhausted
+
+
+

Work Distribution Histogram min={work_stats['min']}, max={work_stats['max']}, mean={work_stats['mean']:.1f}, std={work_stats['std']:.2f}

+
+
+
+
+
+

SM

+
+
+
+ + + +""" + + +def read_text(path: str | None) -> str: + if path is None or path == "-": + return sys.stdin.read() + return Path(path).read_text() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Parse FlashAttention CLC trace lines.") + parser.add_argument("logfile", nargs="?", default="-", help="Trace log path or - for stdin") + parser.add_argument("--json", action="store_true", help="Emit JSON summary") + parser.add_argument("--rows", action="store_true", help="Emit parsed rows as JSON") + parser.add_argument("--html", action="store_true", help="Emit HTML view") + parser.add_argument("-o", "--output", help="Output path for --html") + args = parser.parse_args() + + rows = parse_rows(read_text(args.logfile)) + if args.rows: + print(json.dumps([asdict(row) for row in rows], indent=2)) + return + + summary = summarize(rows) + if args.html: + html_text = visualize_html(rows, summary) + if args.output is not None: + Path(args.output).write_text(html_text) + else: + print(html_text) + return + if args.json: + print(json.dumps(summary, indent=2)) + else: + print(format_summary(summary)) + + +if __name__ == "__main__": + main() diff --git a/CLAUDE.md b/CLAUDE.md index 9f752d7e0e0..3b5f9672b77 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,6 +8,10 @@ FlashAttention-4 (FA4) — fast, memory-efficient exact attention kernels writte The repository also contains older generations (FA2 in top-level `csrc/`, FA3 in `hopper/`) but active development is on FA4 in `flash_attn/cute/`. +## Agent Scratch Space + +Use `agent_space/` for project-local scratch work such as lab notes, profiling outputs, temporary repro scripts, and experiment artifacts. Treat it as disposable workspace rather than product code. + ## Build & Install ```bash @@ -16,7 +20,7 @@ pip install flash-attn-4 pip install -e "flash_attn/cute[dev]" ``` -Dependencies: `nvidia-cutlass-dsl>=4.4.1`, `torch`, `einops`, `apache-tvm-ffi`, `quack-kernels>=0.2.10`. +Dependencies: `nvidia-cutlass-dsl>=4.4.1`, `torch`, `einops`, `apache-tvm-ffi`, `quack-kernels>=0.4.0`. ## Running Tests @@ -124,7 +128,7 @@ Env vars: `CUTE_CUBIN_PATH` (dump CUBIN/SASS), `CUTE_DSL_KEEP_PTX=1` (inspect PT ## Debugging GPU Kernels -See `AI/DEBUG_2CTA.md` for kernel hang/deadlock debugging (printf bisection, pipeline barrier analysis, 2CTA pitfalls). See `AI/RACECHECK_TMA_HAZARD.md` for `compute-sanitizer` false positives with `cp.async.bulk`. +See `AI/DEBUG_2CTA.md` for kernel hang/deadlock debugging (printf bisection, pipeline barrier analysis, 2CTA pitfalls). See `AI/RACECHECK_TMA_HAZARD.md` for `compute-sanitizer` false positives with `cp.async.bulk`. See `AI/CLC_TRACE_DEBUG.md` for visualization of CLC scheduling. Key tools: - `cute.printf` with thread guards (`tidx % 32 == 0`, `elect_one()`) for targeted output diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index 239dff46664..a8abdbd89d4 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -10,8 +10,12 @@ from einops import rearrange from flash_attn.cute.bench_utils import ( - flops, attention_ref, - cudnn_fwd_setup, cudnn_bwd_setup, + flops, + bandwidth_fwd_bytes, + bandwidth_bwd_bytes, + attention_ref, + cudnn_fwd_setup, + cudnn_bwd_setup, ) try: @@ -54,7 +58,8 @@ def _make_bwd_fn(fwd_fn, g, inputs): g_match = g[:out.shape[0]] if g.shape[0] != out.shape[0] else g # handle varlen def bwd_fn(): for x in inputs: - x.grad = None + if x is not None: + x.grad = None out.backward(g_match, retain_graph=True) return bwd_fn @@ -134,29 +139,34 @@ def setup_fa3(ctx): def setup_fa4(ctx): if flash_attn_func_python is None: return None, None - q, k, v, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["g"], ctx["causal"] + q, k, v, qv, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["qv"], ctx["g"], ctx["causal"] window_size, softcap = ctx["window_size"], ctx["softcap"] pack_gqa, deterministic = ctx["pack_gqa"], ctx["deterministic"] sinks = ctx["sinks"] + gather_kv_indices = ctx.get("gather_kv_indices") k_use = ctx.get("k_paged", k) if ctx["page_size"] is not None else k v_use = ctx.get("v_paged", v) if ctx["page_size"] is not None else v if ctx["varlen"]: qu = ctx["q_unpad"] ku = ctx.get("k_paged", ctx["k_unpad"]) if ctx["page_size"] is not None else ctx["k_unpad"] vu = ctx.get("v_paged", ctx["v_unpad"]) if ctx["page_size"] is not None else ctx["v_unpad"] + qvu = ctx["qv_unpad"] csq, csk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"] pt = ctx["page_table"] - fwd_fn = lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, page_table=pt, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa) + gather_kv_indices_unpad = ctx.get("gather_kv_indices_unpad") + fwd_fn = lambda: flash_attn_varlen_func_python(qu, ku, vu, qvu, csq, csk, page_table=pt, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa, gather_kv_indices=gather_kv_indices_unpad) else: - fwd_fn = lambda: flash_attn_func_python(q, k_use, v_use, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa) + fwd_fn = lambda: flash_attn_func_python(q, k_use, v_use, qv=qv, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa, gather_kv_indices=gather_kv_indices) bwd_fn = None if ctx["has_backward"] and ctx["dtype"] != torch.float8_e4m3fn: if ctx["varlen"]: qu, ku, vu, gu = ctx["q_unpad"], ctx["k_unpad"], ctx["v_unpad"], ctx["g_unpad"] + qvu = ctx["qv_unpad"] csq, csk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"] - bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, causal=causal, softcap=softcap, deterministic=deterministic), gu, [qu, ku, vu]) + gather_kv_indices_unpad = ctx.get("gather_kv_indices_unpad") + bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_python(qu, ku, vu, qvu, csq, csk, causal=causal, softcap=softcap, deterministic=deterministic, gather_kv_indices=gather_kv_indices_unpad), gu, [qu, ku, vu, qvu]) else: - bwd_fn = _make_bwd_fn(lambda: flash_attn_func_python(q, k, v, causal=causal, softcap=softcap, deterministic=deterministic), g, [q, k, v]) + bwd_fn = _make_bwd_fn(lambda: flash_attn_func_python(q, k, v, qv=qv, causal=causal, softcap=softcap, deterministic=deterministic, gather_kv_indices=gather_kv_indices), g, [q, k, v, qv]) return fwd_fn, bwd_fn @@ -215,6 +225,33 @@ def get_peak_flops(device_index: int = 0, dtype: torch.dtype = torch.bfloat16) - return peak +def get_peak_bandwidth(device_index: int = 0) -> float | None: + """Return peak HBM bandwidth in bytes/sec for the given device. Returns None if unknown.""" + _PEAK_BW = { + # Ampere + "A100": 2.0e12, + "A6000": 0.768e12, + # Ada Lovelace + "L40S": 0.864e12, + # Hopper + "H100 SXM": 3.35e12, + "H100 NVL": 3.35e12, + "H100 PCIe": 2.0e12, + "H200": 4.8e12, + "H20": 4.0e12, + # Blackwell + "GB200": 8.0e12, + "GB300": 8.0e12, + "B300": 8.0e12, + "B200": 8.0e12, + } + device_name = torch.cuda.get_device_name(device_index) + for key in sorted(_PEAK_BW, key=len, reverse=True): + if key.lower() in device_name.lower(): + return _PEAK_BW[key] + return None + + def parse_int_k(s): """Parse an integer with optional k/K suffix, e.g. '8k' -> 8192.""" s = s.strip().lower() @@ -263,11 +300,16 @@ def parse_args(): parser.add_argument('--causal', type=str.lower, choices=['true', 'false', 'both'], default='both', help='Causal mode (default: both)') parser.add_argument('--seqlen', type=csv_ints, default=[8192], - help='Sequence length(s), comma-separated. Supports k suffix, e.g. 1k,2k,8k') + help='KV sequence length(s), comma-separated. Supports k suffix, e.g. 1k,2k,8k') + parser.add_argument('--seqlen-q', type=csv_ints, default=None, + help='Q sequence length(s), comma-separated. Supports k suffix, e.g. 1,128,1k. ' + 'Defaults to matching --seqlen (i.e. seqlen_q == seqlen_kv). ' + 'If a single value is given it is broadcast across all --seqlen values; ' + 'otherwise the list must match the length of --seqlen or be length 1.') parser.add_argument('--total-seqlen', type=parse_int_k, default='32k', help='Total sequence length for batch sizing (default: 32k)') parser.add_argument('--batch-size', type=int, default=None, - help='Batch size (default: total_seqlen // seqlen)') + help='Batch size (default: total_seqlen // seqlen_kv)') parser.add_argument('--deterministic', action='store_true', default=False) parser.add_argument('--nheads', type=int, default=None, help='Number of Q heads (default: 32 for hdim<=64, 16 for hdim<=192, 8 for hdim>192)') @@ -277,6 +319,17 @@ def parse_args(): help='GQA ratio (nheads // nheads_kv). Ignored if --nheads-kv is set.') parser.add_argument('--backend', type=csv_strs, default=['all'], help='Which backends to benchmark, comma-separated (choices: all,standard,fa2,fa3,fa4,cudnn)') + parser.add_argument('--gather-kv', type=int, default=None, + help='kv sparsity length for MLA (hdim=64, hdim_v=512 only). ' + 'When set, passes random kv indices (without repeats) to FA4 and uses gather-kv as ' + 'the effective KV length for flops/bandwidth accounting.') + parser.add_argument('--num-splits', type=int, default=0, + help='Override kernel num_splits heuristic. 0 = auto (default). ' + '>1 forces SplitKV with that many splits.') + parser.add_argument('--pack-gqa', type=str.lower, choices=['auto', 'true', 'false'], default='auto', + help='Override kernel pack_gqa heuristic. auto = kernel default (default). ' + 'true/false force pack_gqa on or off. Note: for non-MLA GQA with ' + 'num_splits>1, interface.py may still force pack_gqa off pending a fix.') parser.add_argument('--warmup', type=int, default=5, help='Warmup iterations (default: 5)') parser.add_argument('--rep', type=int, default=10, @@ -284,6 +337,27 @@ def parse_args(): return parser.parse_args() +def resolve_seqlen_q_list(seqlen_list, seqlen_q_arg): + """Return a list of seqlen_q values parallel to seqlen_list. + + Rules: + - None → seqlen_q == seqlen_kv for every entry + - single value [x] → broadcast x across all entries + - matching length → use as-is + - mismatch → raise + """ + if seqlen_q_arg is None: + return list(seqlen_list) + if len(seqlen_q_arg) == 1: + return [seqlen_q_arg[0]] * len(seqlen_list) + if len(seqlen_q_arg) != len(seqlen_list): + raise ValueError( + f"--seqlen-q has {len(seqlen_q_arg)} values but --seqlen has " + f"{len(seqlen_list)}; they must match (or pass a single value to broadcast)." + ) + return list(seqlen_q_arg) + + def main(): args = parse_args() @@ -302,7 +376,9 @@ def main(): causal_vals = [False, True] seqlen_list = args.seqlen + seqlen_q_list = resolve_seqlen_q_list(seqlen_list, args.seqlen_q) varlen = args.varlen + gather_kv_length = args.gather_kv # Filter backends to those requested and available enabled = set(args.backend) @@ -317,6 +393,7 @@ def main(): dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype device = 'cuda' peak_flops = get_peak_flops(0, dtype=dtype) + peak_bw = get_peak_bandwidth(0) page_size = None softcap = 0.0 deterministic = args.deterministic @@ -336,25 +413,25 @@ def main(): has_qv = headdim == 64 and headdim_v == 512 sinks = None - num_splits = 0 + num_splits = args.num_splits window_size = (None, None) window_size_fa = (-1, -1) - pack_gqa = None + pack_gqa = None if args.pack_gqa == 'auto' else (args.pack_gqa == 'true') - for seqlen in seqlen_list: + for seqlen, seqlen_q in zip(seqlen_list, seqlen_q_list): batch_size = args.batch_size if args.batch_size is not None else max(1, args.total_seqlen // seqlen) - seqlen_q = seqlen q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=has_backward) - q, k, v = [x.detach().to(dtype).requires_grad_(has_backward) for x in [q, k, v]] + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=has_backward) if has_qv else None + q, k, v, qv = [x.detach().to(dtype).requires_grad_(has_backward) if x is not None else None for x in [q, k, v, qv]] g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) # Varlen tensors - q_unpad = k_unpad = v_unpad = g_unpad = cu_seqlens_q = cu_seqlens_k = None + q_unpad = k_unpad = v_unpad = qv_unpad = g_unpad = cu_seqlens_q = cu_seqlens_k = None if varlen: - q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) for x in [q, k, v]] + q_unpad, k_unpad, v_unpad, qv_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) if x is not None else None for x in [q, k, v, qv]] g_unpad = rearrange(g.detach(), "b s h d -> (b s) h d") cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen if page_size is None else None @@ -367,33 +444,50 @@ def main(): page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), "(b s) -> b s", s=seqlen // page_size) + # kv sparsity indices — only meaningful for MLA (hdim=64, hdim_v=512) + gather_kv_indices = gather_kv_indices_unpad = None + gather_kv_eff = None # effective KV length for this config + if gather_kv_length is not None and has_qv: + assert gather_kv_length <= seqlen, f"--gather_kv {gather_kv_length} > seqlen_kv {seqlen}" + gather_kv_indices = ( + torch.rand(batch_size, seqlen_q, gather_kv_length, device=device) + .argsort(dim=-1) + .to(torch.int32) + ) + if varlen: + gather_kv_indices_unpad = rearrange(gather_kv_indices, "b s t -> (b s) t") + gather_kv_eff = gather_kv_length + for causal in causal_vals: - cfg = (headdim, headdim_v, causal, seqlen, batch_size, nheads) + cfg = (headdim, headdim_v, causal, seqlen_q, seqlen, batch_size, nheads, nheads_kv, gather_kv_eff) # Build context dict shared by all backends ctx = dict( - q=q, k=k, v=v, g=g, causal=causal, + q=q, k=k, v=v, qv=qv, g=g, causal=causal, headdim=headdim, headdim_v=headdim_v, dtype=dtype, has_backward=has_backward, - varlen=varlen, q_unpad=q_unpad, k_unpad=k_unpad, v_unpad=v_unpad, g_unpad=g_unpad, + varlen=varlen, q_unpad=q_unpad, k_unpad=k_unpad, v_unpad=v_unpad, qv_unpad=qv_unpad, g_unpad=g_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, seqlen_q=seqlen_q, seqlen=seqlen, page_size=page_size, k_paged=k_paged, v_paged=v_paged, page_table=page_table, dropout_p=dropout_p, window_size=window_size, window_size_fa=window_size_fa, softcap=softcap, deterministic=deterministic, num_splits=num_splits, pack_gqa=pack_gqa, sinks=sinks, + gather_kv_indices=gather_kv_indices, gather_kv_indices_unpad=gather_kv_indices_unpad, ) for display_name, cli_name, setup_fn in active_backends: fwd_fn, bwd_fn = setup_fn(ctx) if fwd_fn is not None and has_forward: time.sleep(1.0) - print(f"Benchmarking {display_name} fwd, hdim={headdim}, seqlen={seqlen}, causal={causal}, {nheads=}, {nheads_kv=}") + gather_kv_str = f", gather_kv={gather_kv_eff}" if gather_kv_eff is not None else "" + print(f"Benchmarking {display_name} fwd, hdim={headdim}, seqlen_q={seqlen_q}, seqlen_kv={seqlen}{gather_kv_str}, causal={causal}, {nheads=}, {nheads_kv=}") ms = do_bench(fwd_fn, warmup=warmup, rep=rep) * 1e-3 time_f[cfg, display_name] = ms if bwd_fn is not None and has_backward: time.sleep(1.0) - print(f"Benchmarking {display_name} bwd, hdim={headdim}, seqlen={seqlen}, causal={causal}, {nheads=}, {nheads_kv=}, {deterministic=}") + gather_kv_str = f", gather_kv={gather_kv_eff}" if gather_kv_eff is not None else "" + print(f"Benchmarking {display_name} bwd, hdim={headdim}, seqlen_q={seqlen_q}, seqlen_kv={seqlen}{gather_kv_str}, causal={causal}, {nheads=}, {nheads_kv=}, {deterministic=}") ms = do_bench(bwd_fn, warmup=warmup, rep=rep) * 1e-3 time_b[cfg, display_name] = ms @@ -404,7 +498,13 @@ def main(): if not shown_backends: return - col_w = 20 if peak_flops is not None else 16 + col_w = 28 if peak_flops is not None else 16 + + # Determine whether any config has seqlen_q != seqlen_kv so we can show + # the extra column only when it's actually useful. + all_cfgs = sorted(set(k[0] for k in list(time_f) + list(time_b))) + show_seqlen_q_col = any(cfg[3] != cfg[4] for cfg in all_cfgs) # seqlen_q vs seqlen_kv + show_gather_kv_col = any(cfg[8] is not None for cfg in all_cfgs) for direction, times, flops_mult in [("FWD", time_f, 1.0), ("BWD", time_b, 2.5)]: if not times: @@ -413,8 +513,14 @@ def main(): if not configs: continue - col_label = "ms / TFLOPS / MFU%" if peak_flops is not None else "ms / TFLOPS" - header = f"{'hdim':>9} {'causal':>6} {'batch':>5} {'seqlen':>6}" + col_label = "ms / TFLOPS / MFU% / TBs / BW%" if peak_flops is not None else "ms / TFLOPS" + header = f"{'hdim':>9} {'causal':>6} {'batch':>5}" + if show_seqlen_q_col: + header += f" {'seqlen_q':>8} {'seqlen_kv':>9}" + else: + header += f" {'seqlen':>6}" + if show_gather_kv_col: + header += f" {'gather_kv':>6}" for b in shown_backends: header += f" {b:>{col_w}}" print(f"\n{'=' * len(header)}") @@ -424,18 +530,36 @@ def main(): print("-" * len(header)) for cfg in configs: - headdim, headdim_v, causal, seqlen, batch_size, nheads = cfg - nFLOPS = flops(batch_size, nheads, seqlen, seqlen, headdim, headdim_v, causal=causal) + headdim, headdim_v, causal, seqlen_q, seqlen, batch_size, nheads, nheads_kv, gather_kv_eff = cfg + has_qv = (headdim == 64 and headdim_v == 512) + seqlen_k_eff = gather_kv_eff if gather_kv_eff is not None else seqlen + nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen_k_eff, headdim, headdim_v, causal=causal, has_qv=has_qv) + dtype_bytes = 1 if dtype == torch.float8_e4m3fn else 2 + if direction == "FWD": + nbytes = bandwidth_fwd_bytes(batch_size, nheads, nheads_kv, seqlen_q, seqlen_k_eff, + headdim, headdim_v, dtype_bytes=dtype_bytes, has_qv=has_qv) + else: + nbytes = bandwidth_bwd_bytes(batch_size, nheads, nheads_kv, seqlen_q, seqlen_k_eff, + headdim, headdim_v, dtype_bytes=dtype_bytes) hdim_str = str(headdim) if headdim == headdim_v else f"{headdim}-{headdim_v}" - row = f"{hdim_str:>9} {str(causal):>6} {batch_size:>5} {seqlen:>6}" + row = f"{hdim_str:>9} {str(causal):>6} {batch_size:>5}" + if show_seqlen_q_col: + row += f" {seqlen_q:>8} {seqlen:>9}" + else: + row += f" {seqlen:>6}" + if show_gather_kv_col: + row += f" {str(gather_kv_eff) if gather_kv_eff is not None else '—':>6}" for b in shown_backends: t = times.get((cfg, b)) if t is not None: tflops = flops_mult * nFLOPS / t * 1e-12 + tbs = nbytes / t * 1e-12 # TB/s ms = t * 1e3 if peak_flops is not None: - mfu = flops_mult * nFLOPS / t / peak_flops * 100 - cell = f"{ms:.2f}/{tflops:.0f}/{mfu:.1f}%" + mfu = flops_mult * nFLOPS / t / peak_flops * 100 + bw_pct = (nbytes / t / peak_bw * 100) if peak_bw is not None else None + bw_str = f"/{bw_pct:.1f}%" if bw_pct is not None else "" + cell = f"{ms:.2f}/{tflops:.0f}/{mfu:.1f}%/{tbs:.2f}{bw_str}" else: cell = f"{ms:.2f}/{tflops:.0f}" row += f" {cell:>{col_w}}" @@ -445,4 +569,4 @@ def main(): if __name__ == '__main__': - main() + main() \ No newline at end of file diff --git a/benchmarks/clc_bench.py b/benchmarks/clc_bench.py new file mode 100644 index 00000000000..46ee55980eb --- /dev/null +++ b/benchmarks/clc_bench.py @@ -0,0 +1,808 @@ +#!/usr/bin/env python3 +"""CLC benchmark for dense, varlen, and block-sparse FA4 sweeps. + +Run with benchmark against the yaml sweep: + python benchmarks/clc_bench.py --config benchmarks/configs/clc.yaml + +Useful overrides: + --workers 64 # compile parallelism + --case_filter q16_kv4 # run matching cases only +""" +from __future__ import annotations + +import csv +import json +import math +import os +import statistics +import subprocess +import sys +import types +from contextlib import nullcontext +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from itertools import product +from pathlib import Path +from typing import Literal + +try: + from jsonargparse import CLI +except ImportError as exc: + raise SystemExit( + "Missing jsonargparse. Install it with " + "uv pip install jsonargparse pyyaml" + ) from exc + + +REPO_ROOT = Path(__file__).resolve().parents[1] +RESULTS_ROOT = REPO_ROOT / "benchmarks" / "results" / "clc" +CSV_FLOAT_DIGITS = 3 +BLOCK_SIZE_Q = 256 +BLOCK_SIZE_K = 128 +INTERNAL_REQUEST_ENV = "CLC_BENCH_INTERNAL_REQUEST" + +DTypeName = Literal["bfloat16", "float16"] + + +@dataclass(frozen=True) +class DenseSweep: + enabled: bool = True + batches: list[int] = field(default_factory=lambda: [1, 4, 8, 16, 32]) + seqlen_pairs: list[list[int]] = field( + default_factory=lambda: [[32, 8192], [1024, 1024], [2048, 2048], [4096, 4096], [8192, 8192], [16384, 16384]] + ) + head_dims: list[int | list[int]] = field(default_factory=lambda: [64, 96, 128, [192, 128]]) + head_pairs: list[list[int]] = field(default_factory=lambda: [[16, 16], [16, 8], [16, 4], [16, 2], [16, 1]]) + causal: bool | list[bool] = True + + +@dataclass(frozen=True) +class VarlenSweep: + enabled: bool = True + max_q_tokens: list[int] = field(default_factory=lambda: [2048, 4096, 8192, 16384, 32768]) + max_kv_tokens: list[int] = field(default_factory=lambda: [2048, 4096, 8192, 16384, 32768]) + batches: list[int] = field(default_factory=lambda: [4, 8, 16, 32]) + patterns: list[str] = field(default_factory=lambda: ["uniform", "longtail"]) + head_dims: list[int | list[int]] = field(default_factory=lambda: [64, 96, 128, [192, 128]]) + head_pairs: list[list[int]] = field(default_factory=lambda: [[16, 8], [16, 4], [16, 2], [16, 1]]) + causal: bool | list[bool] = False + + +@dataclass(frozen=True) +class BlockSparseSweep: + enabled: bool = False + batches: list[int] = field(default_factory=lambda: [1, 4, 8, 16, 32]) + seqlen_pairs: list[list[int]] = field( + default_factory=lambda: [[1024, 1024], [2048, 2048], [4096, 4096], [4096, 8192]] + ) + head_dims: list[int | list[int]] = field(default_factory=lambda: [64, 128, [192, 128]]) + head_pairs: list[list[int]] = field(default_factory=lambda: [[16, 16], [16, 4], [16, 1]]) + mask_names: list[str] = field(default_factory=lambda: ["block_diagonal"]) + sliding_window_sizes: list[int] = field(default_factory=lambda: [2048]) + + +@dataclass(frozen=True) +class Case: + name: str + mode: Literal["dense", "varlen", "block_sparse"] + q_heads: int + kv_heads: int + d: int + dv: int + causal: bool + batch: int | None = None + seqlen_q: int | None = None + seqlen_k: int | None = None + seqlens_q: list[int] | None = None + seqlens_k: list[int] | None = None + pattern: str = "" + mask_name: str = "" + window_size: int | None = None + + +def utc_timestamp() -> str: + return datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + + +def default_out_dir() -> Path: + return RESULTS_ROOT / utc_timestamp() + + +def head_pair_label(q_heads: int, kv_heads: int) -> str: + return f"q{q_heads}_kv{kv_heads}" + + +def token_label(value: int) -> str: + return f"{value // 1024}k" if value >= 1024 and value % 1024 == 0 else str(value) + + +def head_dim_label(d: int, dv: int) -> str: + return f"h{d}" if d == dv else f"h{d}_dv{dv}" + + +def head_dim_pairs(head_dims: list[int | list[int]]) -> list[tuple[int, int]]: + pairs: list[tuple[int, int]] = [] + invalid_pairs: list[int | list[int]] = [] + for dims in head_dims: + if isinstance(dims, int): + pairs.append((dims, dims)) + continue + if len(dims) == 1: + pairs.append((dims[0], dims[0])) + continue + if len(dims) == 2: + pairs.append((dims[0], dims[1])) + continue + invalid_pairs.append(dims) + if invalid_pairs: + raise ValueError(f"Expected d or [d] or [d, dv], got {invalid_pairs}") + return pairs + + +def dense_case_name(q_heads: int, kv_heads: int, causal: bool, d: int, dv: int, batch: int, seqlen_q: int, seqlen_k: int) -> str: + causal_name = "causal" if causal else "noncausal" + pair = head_pair_label(q_heads, kv_heads) + dims = head_dim_label(d, dv) + if seqlen_q == seqlen_k: + return f"{pair}_{causal_name}_{dims}_{token_label(seqlen_q)}_b{batch}" + return f"{pair}_{causal_name}_q{seqlen_q}_k{seqlen_k}_{dims}_b{batch}" + + +def varlen_case_name( + pattern: str, + q_heads: int, + kv_heads: int, + causal: bool, + d: int, + dv: int, + batch: int, + max_q_tokens: int, + max_kv_tokens: int, +) -> str: + causal_name = "causal" if causal else "noncausal" + pair = head_pair_label(q_heads, kv_heads) + dims = head_dim_label(d, dv) + return ( + f"varlen_{pattern}_{pair}_{causal_name}_{dims}_" + f"b{batch}_q{token_label(max_q_tokens)}_kv{token_label(max_kv_tokens)}" + ) + + +def normalize_lengths(weights: list[float], total_tokens: int) -> list[int]: + if total_tokens < len(weights): + raise ValueError(f"total_tokens={total_tokens} is smaller than batch={len(weights)}") + scaled = [weight / sum(weights) * total_tokens for weight in weights] + lengths = [max(1, int(math.floor(value))) for value in scaled] + delta = total_tokens - sum(lengths) + order = sorted( + range(len(weights)), + key=lambda idx: scaled[idx] - math.floor(scaled[idx]), + reverse=delta > 0, + ) + cursor = 0 + while delta != 0: + idx = order[cursor % len(order)] + if delta > 0: + lengths[idx] += 1 + delta -= 1 + elif lengths[idx] > 1: + lengths[idx] -= 1 + delta += 1 + cursor += 1 + return lengths + + +def pattern_weights(pattern: str, batch: int) -> list[float]: + match pattern: + case "uniform": + return [1.0] * batch + case "spiky": + return [32.0] + [1.0] * (batch - 1) + case "longtail": + return [float(batch - idx) for idx in range(batch)] + case "bimodal": + split = max(1, batch // 2) + return [8.0] * split + [1.0] * (batch - split) + case "staircase": + return [1.0 + idx for idx in range(batch)] + case "loss_shape": + base = [130, 1, 1, 1, 1674, 68, 157, 1, 1, 1, 1, 1, 1, 9, 1, 5] + if batch == len(base): + return [float(value) for value in base] + return [float(base[idx % len(base)]) for idx in range(batch)] + case _: + raise ValueError(f"Unsupported varlen pattern: {pattern}") + + +def bool_values(value: bool | list[bool]) -> list[bool]: + return [value] if isinstance(value, bool) else value + + +def generate_cases( + dense: DenseSweep, + varlen: VarlenSweep, + block_sparse: BlockSparseSweep, + case_filter: str = "", +) -> list[Case]: + cases: list[Case] = [] + if dense.enabled: + for batch, seqlen_pair, (d, dv), (q_heads, kv_heads), causal in product( + dense.batches, + dense.seqlen_pairs, + head_dim_pairs(dense.head_dims), + dense.head_pairs, + bool_values(dense.causal), + ): + seqlen_q, seqlen_k = seqlen_pair + cases.append( + Case( + name=dense_case_name(q_heads, kv_heads, causal, d, dv, batch, seqlen_q, seqlen_k), + mode="dense", + q_heads=q_heads, + kv_heads=kv_heads, + d=d, + dv=dv, + causal=causal, + batch=batch, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + ) + ) + if varlen.enabled: + for max_q_tokens, max_kv_tokens, batch, pattern, (d, dv), (q_heads, kv_heads), causal in product( + varlen.max_q_tokens, + varlen.max_kv_tokens, + varlen.batches, + varlen.patterns, + head_dim_pairs(varlen.head_dims), + varlen.head_pairs, + bool_values(varlen.causal), + ): + weights = pattern_weights(pattern, batch) + lengths_q = normalize_lengths(weights, max_q_tokens) + lengths_k = normalize_lengths(weights, max(batch, max_kv_tokens)) + cases.append( + Case( + name=varlen_case_name(pattern, q_heads, kv_heads, causal, d, dv, batch, max_q_tokens, max_kv_tokens), + mode="varlen", + q_heads=q_heads, + kv_heads=kv_heads, + d=d, + dv=dv, + causal=causal, + batch=batch, + seqlens_q=lengths_q, + seqlens_k=lengths_k, + pattern=pattern, + ) + ) + if block_sparse.enabled: + for batch, seqlen_pair, (d, dv), (q_heads, kv_heads), mask_name in product( + block_sparse.batches, + block_sparse.seqlen_pairs, + head_dim_pairs(block_sparse.head_dims), + block_sparse.head_pairs, + block_sparse.mask_names, + ): + seqlen_q, seqlen_k = seqlen_pair + if seqlen_q > seqlen_k: + continue + window_sizes = block_sparse.sliding_window_sizes if mask_name == "sliding_window" else [None] + for window_size in window_sizes: + window_label = f"_w{window_size}" if window_size is not None else "" + pair = head_pair_label(q_heads, kv_heads) + dims = head_dim_label(d, dv) + cases.append( + Case( + name=( + f"block_sparse_{mask_name}{window_label}_{pair}_" + f"{dims}_q{seqlen_q}_k{seqlen_k}_b{batch}" + ), + mode="block_sparse", + q_heads=q_heads, + kv_heads=kv_heads, + d=d, + dv=dv, + causal=False, + batch=batch, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + mask_name=mask_name, + window_size=window_size, + ) + ) + if case_filter: + needle = case_filter.lower() + cases = [case for case in cases if needle in case.name.lower()] + return cases + + +def compile_q_stage(case: Case) -> int: + max_seqlen_q = max(case.seqlens_q) if case.seqlens_q is not None else case.seqlen_q + qhead_per_kvhead = case.q_heads // case.kv_heads + return 2 if max_seqlen_q is not None and max_seqlen_q * qhead_per_kvhead > 128 else 1 + + + +def compile_signature(case: Case) -> tuple: + q_stage = compile_q_stage(case) + if case.mode == "block_sparse": + return ( + case.mode, + case.q_heads, + case.kv_heads, + case.d, + case.dv, + case.mask_name, + case.window_size, + q_stage, + ) + return case.mode, case.q_heads, case.kv_heads, case.d, case.dv, case.causal, q_stage + + +def select_compile_cases(cases: list[Case]) -> list[Case]: + selected: dict[tuple, Case] = {} + for case in cases: + selected.setdefault(compile_signature(case), case) + return list(selected.values()) + + +def benchmark_cuda_samples_in_microseconds(func, *args, **kwargs) -> list[float]: + num_iters = kwargs.pop("NUM_ITERS", 100) + warmup_iters = kwargs.pop("MEMORY_WARMUP_ITERS", 25) + is_vetted_benchmarking = kwargs.pop("IS_VETTED_BENCHMARKING", False) + from torch._inductor.runtime.benchmarking import benchmarker + + return [ + float(sample_ms) * 1e3 + for sample_ms in benchmarker.benchmark_gpu( + lambda: func(*args, **kwargs), + benchmark_iters=num_iters, + memory_warmup_iters=warmup_iters, + return_mode="all", + is_vetted_benchmarking=is_vetted_benchmarking, + ) + ] + + +def flash_attn_imports(): + if "flash_attn" not in sys.modules: + stub = types.ModuleType("flash_attn") + stub.__path__ = [str(REPO_ROOT / "flash_attn")] + sys.modules["flash_attn"] = stub + import torch + from torch._subclasses.fake_tensor import FakeTensorMode + from flash_attn.cute import utils as cute_utils + from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func + + return torch, FakeTensorMode, cute_utils, flash_attn_func, flash_attn_varlen_func + + +def block_sparse_imports(): + if "flash_attn" not in sys.modules: + stub = types.ModuleType("flash_attn") + stub.__path__ = [str(REPO_ROOT / "flash_attn")] + sys.modules["flash_attn"] = stub + if str(REPO_ROOT / "tests" / "cute") not in sys.path: + sys.path.insert(0, str(REPO_ROOT / "tests" / "cute")) + from flash_attn.cute.compute_block_sparsity import compute_block_sparsity + from mask_mod_definitions import get_mask_pair + + return compute_block_sparsity, get_mask_pair + + +def build_cu_seqlens(torch_mod, lengths: list[int]) -> torch_mod.Tensor: + cu_seqlens = torch_mod.zeros(len(lengths) + 1, device="cuda", dtype=torch_mod.int32) + cu_seqlens[1:] = torch_mod.tensor(lengths, device="cuda", dtype=torch_mod.int32).cumsum(0) + return cu_seqlens + + +def build_dense_inputs(torch_mod, flash_attn_func, case: Case, dtype, factory): + q = factory(case.batch, case.seqlen_q, case.q_heads, case.d, device="cuda", dtype=dtype) + k = factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + v = factory(case.batch, case.seqlen_k, case.kv_heads, case.dv, device="cuda", dtype=dtype) + return flash_attn_func, dict(q=q, k=k, v=v, causal=case.causal) + + +def build_varlen_inputs(torch_mod, flash_attn_varlen_func, case: Case, dtype, factory): + lengths_q = case.seqlens_q or [] + lengths_k = case.seqlens_k or lengths_q + total_q = sum(lengths_q) + total_k = sum(lengths_k) + q = factory(total_q, case.q_heads, case.d, device="cuda", dtype=dtype) + k = factory(total_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + v = factory(total_k, case.kv_heads, case.dv, device="cuda", dtype=dtype) + return flash_attn_varlen_func, dict( + q=q, + k=k, + v=v, + cu_seqlens_q=build_cu_seqlens(torch_mod, lengths_q), + cu_seqlens_k=build_cu_seqlens(torch_mod, lengths_k), + max_seqlen_q=max(lengths_q), + max_seqlen_k=max(lengths_k), + causal=case.causal, + ) + + +def build_block_sparse_compile_tensors(torch_mod, case: Case): + num_m_blocks = math.ceil((case.seqlen_q or 0) / BLOCK_SIZE_Q) + count_shape = (case.batch, 1, num_m_blocks) + index_shape = (*count_shape, 1) + return dict( + mask_block_cnt=torch_mod.zeros(count_shape, device="cuda", dtype=torch_mod.int32), + mask_block_idx=torch_mod.zeros(index_shape, device="cuda", dtype=torch_mod.int32), + full_block_cnt=torch_mod.zeros(count_shape, device="cuda", dtype=torch_mod.int32), + full_block_idx=torch_mod.zeros(index_shape, device="cuda", dtype=torch_mod.int32), + ) + + + +def build_block_sparse_inputs(torch_mod, flash_attn_func, case: Case, dtype, tensor_factory, fake_tensor: bool): + compute_block_sparsity, get_mask_pair = block_sparse_imports() + if case.mask_name in {"document", "ima"}: + raise ValueError(f"Aux-backed block-sparse masks are not supported by clc_bench.py: {case.mask_name}") + q = tensor_factory(case.batch, case.seqlen_q, case.q_heads, case.d, device="cuda", dtype=dtype) + k = tensor_factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + v = tensor_factory(case.batch, case.seqlen_k, case.kv_heads, case.dv, device="cuda", dtype=dtype) + cute_mask, _ = get_mask_pair( + case.mask_name, + seqlen_q=case.seqlen_q, + seqlen_k=case.seqlen_k, + window_size=case.window_size, + ) + if fake_tensor: + block_sparse_tensors = build_block_sparse_compile_tensors(torch_mod, case) + else: + _, sparse_tensors = compute_block_sparsity( + tile_m=BLOCK_SIZE_Q, + tile_n=BLOCK_SIZE_K, + batch_size=case.batch, + num_heads=1, + seqlen_q=case.seqlen_q, + seqlen_k=case.seqlen_k, + mask_mod=cute_mask, + aux_tensors=None, + device="cuda", + compute_full_blocks=True, + use_fast_sampling=False, + ) + block_sparse_tensors = dict( + mask_block_cnt=sparse_tensors.mask_block_cnt, + mask_block_idx=sparse_tensors.mask_block_idx, + full_block_cnt=sparse_tensors.full_block_cnt, + full_block_idx=sparse_tensors.full_block_idx, + ) + return flash_attn_func, dict( + q=q, + k=k, + v=v, + causal=False, + mask_mod=cute_mask, + **block_sparse_tensors, + block_size=(BLOCK_SIZE_Q, BLOCK_SIZE_K), + ) + + +def build_inputs(case: Case, dtype_name: DTypeName, fake_tensor: bool): + torch, FakeTensorMode, _, flash_attn_func, flash_attn_varlen_func = flash_attn_imports() + dtype = getattr(torch, dtype_name) + tensor_factory = torch.empty if fake_tensor else torch.randn + context = FakeTensorMode() if fake_tensor else nullcontext() + with context: + match case.mode: + case "block_sparse": + return build_block_sparse_inputs(torch, flash_attn_func, case, dtype, tensor_factory, fake_tensor) + case "dense": + return build_dense_inputs(torch, flash_attn_func, case, dtype, tensor_factory) + case "varlen": + return build_varlen_inputs(torch, flash_attn_varlen_func, case, dtype, tensor_factory) + + +def attended_pairs(seqlen_q: int, seqlen_k: int, causal: bool) -> float: + """Lower-right aligned causal: last query aligns with last key. + When M > N, only the bottom N query rows attend (triangle of size N), + so valid pairs = N*(N+1)/2, not the upper-left formula M*N - N*(N-1)/2. + """ + if not causal: + return float(seqlen_q * seqlen_k) + if seqlen_q <= seqlen_k: + return float(seqlen_q * (2 * seqlen_k - seqlen_q + 1) / 2) + return float(seqlen_k * (seqlen_k + 1) / 2) + + +def block_sparse_pairs(case: Case) -> float: + seqlen_q = case.seqlen_q or 0 + seqlen_k = case.seqlen_k or 0 + match case.mask_name: + case "block_diagonal": + total = 0 + for q_idx in range(seqlen_q): + block_start = (q_idx // BLOCK_SIZE_K) * BLOCK_SIZE_K + block_end = min(block_start + BLOCK_SIZE_K, seqlen_k) + total += max(0, block_end - block_start) + return float(total) + case "sliding_window": + window = case.window_size or 0 + offset = seqlen_k - seqlen_q + total = 0 + for q_idx in range(seqlen_q): + center = q_idx + offset + lower = max(0, center - window) + upper = min(seqlen_k - 1, center + window) + total += max(0, upper - lower + 1) + return float(total) + case _: + raise ValueError(f"Unsupported block-sparse FLOP mask: {case.mask_name}") + + +def fwd_flops(case: Case, kwargs: dict | None = None) -> float: + if case.mode == "dense": + return (case.batch or 0) * case.q_heads * 2 * attended_pairs( + case.seqlen_q or 0, + case.seqlen_k or 0, + case.causal, + ) * (case.d + case.dv) + if case.mode == "block_sparse": + num_pairs = (case.batch or 0) * block_sparse_pairs(case) + return case.q_heads * 2 * num_pairs * (case.d + case.dv) + lengths_q = case.seqlens_q or [] + lengths_k = case.seqlens_k or lengths_q + total = 0.0 + for seqlen_q, seqlen_k in zip(lengths_q, lengths_k): + total += case.q_heads * 2 * attended_pairs(seqlen_q, seqlen_k, case.causal) * (case.d + case.dv) + return total + + +def tflops(flop_count: float, time_us: float) -> float: + return 0.0 if time_us <= 0 else flop_count / time_us / 1e6 + + +def case_shape(case: Case) -> str: + match case.mode: + case "dense" | "block_sparse": + if case.seqlen_q == case.seqlen_k: + return token_label(case.seqlen_q or 0) + return f"q={token_label(case.seqlen_q or 0)} kv={token_label(case.seqlen_k or 0)}" + case "varlen": + lengths_q = case.seqlens_q or [] + lengths_k = case.seqlens_k or lengths_q + total_q = sum(lengths_q) + total_k = sum(lengths_k) + max_q = max(lengths_q, default=0) + max_k = max(lengths_k, default=0) + if total_q == total_k and max_q == max_k: + return f"total={token_label(total_q)} max={token_label(max_q)}" + return ( + f"q_total={token_label(total_q)} kv_total={token_label(total_k)} " + f"q_max={token_label(max_q)} kv_max={token_label(max_k)}" + ) + + +def case_metadata(case: Case) -> dict: + return { + "name": case.name, + "mode": case.mode, + "shape": case_shape(case), + "batch": case.batch, + "q_heads": case.q_heads, + "kv_heads": case.kv_heads, + "d": case.d, + "dv": case.dv, + "causal": case.causal, + "pattern": case.pattern, + "mask_name": case.mask_name, + "window_size": case.window_size, + } + + +def summarize_profile(case: Case, samples_off: list[float], samples_on: list[float], flop_count: float) -> dict: + mean_off = statistics.mean(samples_off) + mean_on = statistics.mean(samples_on) + paired_log_ratios = [math.log(off / on) for off, on in zip(samples_off, samples_on)] + mean_log_ratio = statistics.mean(paired_log_ratios) + stderr_log_ratio = ( + statistics.stdev(paired_log_ratios) / math.sqrt(len(paired_log_ratios)) + if len(paired_log_ratios) > 1 + else 0.0 + ) + ci95_low = math.exp(mean_log_ratio - 1.96 * stderr_log_ratio) + ci95_high = math.exp(mean_log_ratio + 1.96 * stderr_log_ratio) + return { + **case_metadata(case), + "samples_off_us": samples_off, + "samples_on_us": samples_on, + "mean_off_us": mean_off, + "mean_on_us": mean_on, + "median_off_us": statistics.median(samples_off), + "median_on_us": statistics.median(samples_on), + "mean_off_tflops": tflops(flop_count, mean_off), + "mean_on_tflops": tflops(flop_count, mean_on), + "speedup_on_vs_off": mean_off / mean_on, + "pct_change_on_vs_off": (mean_off / mean_on - 1.0) * 100.0, + "ci95_low_speedup": ci95_low, + "ci95_high_speedup": ci95_high, + "ci95_excludes_1x": ci95_low > 1.0 or ci95_high < 1.0, + } + + +def run_single_case( + case: Case, + clc: int, + fake_tensor: bool, + dtype_name: DTypeName, + bench_iters: int, + seed: int, +) -> dict: + os.environ["FA_CLC"] = str(clc) + os.environ["FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED"] = "1" + torch, _, cute_utils, _, _ = flash_attn_imports() + if not fake_tensor and not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for runtime profiling in clc_bench.py") + torch.manual_seed(seed) + fn, kwargs = build_inputs(case, dtype_name, fake_tensor) + cute_utils._fa_clc_enabled = bool(clc) + fn(**kwargs) + if fake_tensor: + return {"case": case.name, "clc": clc, "compiled": True} + torch.cuda.synchronize() + warmup_iters = min(25, max(10, bench_iters // 10)) + return { + "case": case.name, + "clc": clc, + "time_us": statistics.mean( + benchmark_cuda_samples_in_microseconds( + fn, + **kwargs, + NUM_ITERS=bench_iters, + MEMORY_WARMUP_ITERS=warmup_iters, + ) + ), + } + + +def run_single_subprocess(case: Case, clc: int, dtype_name: DTypeName, bench_iters: int, seed: int, script_path: Path) -> dict: + env = os.environ.copy() + env[INTERNAL_REQUEST_ENV] = json.dumps( + { + "case": asdict(case), + "clc": clc, + "fake_tensor": True, + "dtype_str": dtype_name, + "bench_iters": bench_iters, + "seed": seed, + } + ) + command = [sys.executable, str(script_path)] + try: + completed = subprocess.run(command, check=True, capture_output=True, text=True, env=env) + except subprocess.CalledProcessError as exc: + detail = (exc.stderr or exc.stdout).strip() + raise RuntimeError(f"Single-case compile failed for {case.name} clc={clc}:\n{detail}") from exc + for line in reversed(completed.stdout.splitlines()): + line = line.strip() + if line.startswith("{") and line.endswith("}"): + return json.loads(line) + raise RuntimeError(f"No JSON result found for {case.name} clc={clc}") + + +def run_compile(cases: list[Case], dtype_name: DTypeName, workers: int, bench_iters: int, seed: int, script_path: Path) -> list[dict]: + compile_cases = select_compile_cases(cases) + rows: list[dict] = [] + with ThreadPoolExecutor(max_workers=workers) as pool: + futures = { + pool.submit(run_single_subprocess, case, clc, dtype_name, bench_iters, seed, script_path): (case.name, clc) + for case in compile_cases + for clc in (0, 1) + } + for index, future in enumerate(as_completed(futures), start=1): + row = future.result() + print(f"[{index}/{len(futures)}] compiled {row['case']} clc={row['clc']}") + rows.append(row) + return sorted(rows, key=lambda row: (row["case"], row["clc"])) + + +def run_profile(cases: list[Case], dtype_name: DTypeName, profile_repeats: int, bench_iters: int, seed: int) -> list[dict]: + torch, _, cute_utils, _, _ = flash_attn_imports() + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for clc_bench.py") + torch.manual_seed(seed) + total_iters = profile_repeats * bench_iters + warmup_iters = min(25, max(10, total_iters // 10)) + rows: list[dict] = [] + for index, case in enumerate(cases, start=1): + fn, kwargs = build_inputs(case, dtype_name, fake_tensor=False) + samples: dict[int, list[float]] = {} + for clc in (0, 1): + cute_utils._fa_clc_enabled = bool(clc) + fn(**kwargs) + torch.cuda.synchronize() + samples[clc] = benchmark_cuda_samples_in_microseconds( + fn, + **kwargs, + NUM_ITERS=total_iters, + MEMORY_WARMUP_ITERS=warmup_iters, + ) + row = summarize_profile(case, samples[0], samples[1], fwd_flops(case, kwargs)) + print( + f"[{index}/{len(cases)}] {case.name}: " + f"off={row['mean_off_us']:.3f}us on={row['mean_on_us']:.3f}us " + f"speedup={row['speedup_on_vs_off']:.3f}x " + f"ci95=[{row['ci95_low_speedup']:.3f}, {row['ci95_high_speedup']:.3f}]" + ) + rows.append(row) + return rows + + +def round_scalar_row(row: dict) -> dict: + return { + key: round(value, CSV_FLOAT_DIGITS) if isinstance(value, float) else value + for key, value in row.items() + } + + +def write_csv(path: Path, rows: list[dict]) -> None: + if not rows: + return + path.parent.mkdir(parents=True, exist_ok=True) + scalar_rows = [ + round_scalar_row({key: value for key, value in row.items() if not isinstance(value, list)}) + for row in rows + ] + with path.open("w", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=list(scalar_rows[0].keys())) + writer.writeheader() + writer.writerows(scalar_rows) + + +def main( + out_dir: Path | None = None, + workers: int = 32, + profile_repeats: int = 3, + bench_iters: int = 64, + dtype_str: DTypeName = "bfloat16", + seed: int = 0, + case_filter: str = "", + dense: DenseSweep = DenseSweep(), + varlen: VarlenSweep = VarlenSweep(), + block_sparse: BlockSparseSweep = BlockSparseSweep(), +) -> None: + if (request_json := os.environ.get(INTERNAL_REQUEST_ENV)) is not None: + request = json.loads(request_json) + print( + json.dumps( + run_single_case( + Case(**request["case"]), + request["clc"], + request["fake_tensor"], + request["dtype_str"], + request["bench_iters"], + request["seed"], + ) + ) + ) + return + + os.environ["FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED"] = "1" + cases = generate_cases(dense, varlen, block_sparse, case_filter) + if not cases: + raise ValueError("No cases selected. Adjust the YAML sweep or case_filter.") + + run_dir = out_dir or default_out_dir() + print(f"cases={len(cases)}") + print(f"compile_cases={len(select_compile_cases(cases))}") + print(f"out_dir={run_dir}") + print(f"python={sys.executable}") + + script_path = Path(__file__).resolve() + run_compile(cases, dtype_str, workers, bench_iters, seed, script_path) + run_dir.mkdir(parents=True, exist_ok=True) + profile_rows = run_profile(cases, dtype_str, profile_repeats, bench_iters, seed) + profile_csv = run_dir / "profile.csv" + write_csv(profile_csv, profile_rows) + print("Profile written to:") + print(profile_csv) + + +if __name__ == "__main__": + CLI(main, as_positional=False) diff --git a/benchmarks/configs/clc.yaml b/benchmarks/configs/clc.yaml new file mode 100644 index 00000000000..94daf11770d --- /dev/null +++ b/benchmarks/configs/clc.yaml @@ -0,0 +1,35 @@ +dtype_str: bfloat16 +seed: 0 +workers: 64 +profile_repeats: 1 +bench_iters: 256 + +dense: + enabled: true + batches: [1, 4, 8, 16, 32] + seqlen_pairs: [[32, 8192], [2048, 2048], [4096, 4096], [8192, 8192], [16384, 16384]] + head_dims: [64, 96, 128, [192, 128]] + head_pairs: [[16, 16], [16, 8], [16, 4], [16, 1]] + causal: [true] + +varlen: + enabled: true + max_q_tokens: [2048, 4096, 8192, 16384, 32768] + max_kv_tokens: [2048, 4096, 8192, 16384, 32768] + batches: [4, 8, 16, 32] + # uniform: all sequences in the batch are similar length + # longtail: a few long sequences plus many shorter ones + patterns: [uniform, longtail] + head_dims: [64, 128, [192, 128]] + head_pairs: [[16, 8], [16, 4], [16, 1]] + causal: [false] + +block_sparse: + enabled: false + batches: [1, 4, 8, 16, 32] + seqlen_pairs: [[1024, 1024], [2048, 2048], [4096, 4096], [4096, 8192]] + head_dims: [64, 128, [192, 128]] + head_pairs: [[16, 16], [16, 4], [16, 1]] + # supported mask_names: block_diagonal, sliding_window + mask_names: [block_diagonal] + sliding_window_sizes: [2048] diff --git a/benchmarks/tune_ex2_emu.py b/benchmarks/tune_ex2_emu.py index a2a74422e5f..c21eac93cc5 100644 --- a/benchmarks/tune_ex2_emu.py +++ b/benchmarks/tune_ex2_emu.py @@ -10,6 +10,7 @@ Requires: the _TUNING_CONFIG dict in flash_attn/cute/flash_fwd_sm100.py. """ import argparse +import atexit import json import re import subprocess @@ -86,6 +87,98 @@ def detect_sm103(): print(f"GPU: {torch.cuda.get_device_name()}, SM{sm}, is_sm103={is_sm103}") return is_sm103 +def _get_gpu_selector(): + """Return the nvidia-smi GPU selector (-i argument) for the current device. + + Resolves CUDA_VISIBLE_DEVICES so that nvidia-smi targets the same physical + GPU that PyTorch is using. + """ + visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if visible: + entries = [e.strip() for e in visible.split(",") if e.strip()] + if entries: + idx = torch.cuda.current_device() if torch.cuda.is_available() else 0 + return entries[idx] if idx < len(entries) else entries[0] + if torch.cuda.is_available(): + return str(torch.cuda.current_device()) + return None + +def _nvidia_smi_cmd(*args): + """Build nvidia-smi command for the current GPU, prepending sudo when not root.""" + prefix = [] if os.geteuid() == 0 else ["sudo"] + cmd = prefix + ["nvidia-smi"] + selector = _get_gpu_selector() + if selector is not None: + cmd += ["-i", selector] + return cmd + list(args) + +def _query_clocks(): + """Return (current_mhz_str, max_mhz_str) or (None, None) on failure.""" + try: + result = subprocess.run( + _nvidia_smi_cmd( + "--query-gpu=clocks.current.graphics,clocks.max.graphics", + "--format=csv,noheader,nounits", + ), + capture_output=True, text=True, + ) + except OSError: + return None, None + if result.returncode != 0: + return None, None + lines = [ln.strip() for ln in result.stdout.strip().splitlines() if ln.strip()] + if not lines: + return None, None + fields = [f.strip() for f in lines[0].split(",")] + if len(fields) < 2 or not fields[0] or not fields[1]: + return None, None + return fields[0], fields[1] + +def lock_clocks(max_mhz): + """Lock GPU clocks to max_mhz. Returns True on success.""" + try: + result = subprocess.run( + _nvidia_smi_cmd("--lock-gpu-clocks", str(max_mhz)), + capture_output=True, text=True, + ) + except OSError as e: + print(f"WARNING: Could not lock GPU clocks ({e}).") + return False + if result.returncode == 0: + print(f"Locked GPU clocks to {max_mhz} MHz.") + return True + print(f"WARNING: Could not lock GPU clocks ({result.stderr.strip()}).") + return False + +def unlock_clocks(): + """Unlock GPU clocks (best-effort, called at exit).""" + try: + subprocess.run(_nvidia_smi_cmd("--reset-gpu-clocks"), capture_output=True) + except OSError: + pass + +def setup_clocks(do_lock): + """Query clock state; if do_lock, attempt to lock and register unlock at exit.""" + cur, max_clk = _query_clocks() + if cur is None: + return + if do_lock: + if cur == max_clk: + print(f"GPU clocks already at max ({max_clk} MHz).") + elif lock_clocks(max_clk): + atexit.register(unlock_clocks) + print("GPU clocks will be unlocked on exit.") + else: + lock_cmd = " ".join(_nvidia_smi_cmd("--lock-gpu-clocks", max_clk)) + print(f" To lock manually: {lock_cmd}") + else: + if cur != max_clk: + print(f"WARNING: GPU clocks not locked ({cur} MHz, max {max_clk} MHz).") + print(" Benchmark results may vary between runs.") + lock_cmd = " ".join(_nvidia_smi_cmd("--lock-gpu-clocks", max_clk)) + print(f" To lock: {lock_cmd}") + print() + def run_benchmark(causal_flag, headdim_str, seqlen, rep=20, warmup=10): """Run benchmark, return (ms, tflops, mfu) or (None, None, None). @@ -125,11 +218,15 @@ def parse_args(): p.add_argument("--seqlen", type=str, default="8192") p.add_argument("--rep", type=int, default=20) p.add_argument("--warmup", type=int, default=10) + p.add_argument("--lock-clocks", action=argparse.BooleanOptionalAction, default=True, + help="Lock GPU clocks before tuning (requires sudo); use --no-lock-clocks to warn only") return p.parse_args() def main(): args = parse_args() + setup_clocks(args.lock_clocks) original_src = read_file() + atexit.register(write_file, original_src) config = parse_tuning_config(original_src) hdim, hdim_v = parse_headdim(args.headdim) @@ -148,51 +245,81 @@ def main(): print(f"Tuning hdim={args.headdim}, hdim_padded={hdim_padded}, seqlen={args.seqlen}, is_sm103={is_sm103}") print(f"Keys to tune: {keys_to_tune}\n") - # ── Phase 1: ex2_emu_freq + ex2_emu_start_frg sweep ── + # ── Phase 1: ex2_emu_freq + ex2_emu_start_frg sweep (+ ex2_emu_res for hd256) ── freq_values = [0, 6, 8, 10, 12, 14, 16, 20, 24, 32] start_frg_values = [0, 1] + # hd256 inner loop steps k by 2, so k%freq only takes even values. + # Meaningful res values that produce distinct hw:emu ratios on this sweep grid: + # freq=6, res=3 → 50:50 | freq=8, res=6 → 25:75 | freq=8, res=4 → 50:50 (diff freq) + hd256_res_values = [3, 6, 4] for key in keys_to_tune: - use_2cta, is_causal, _, _ = key + use_2cta, is_causal, key_hdim, _ = key causal_flag = "true" if is_causal else "false" causal_label = "causal" if is_causal else "non-causal" cta_label = "2CTA" if use_2cta else "1CTA" + is_hd256 = key_hdim == 256 + res_values = hd256_res_values if is_hd256 else [None] print("=" * 70) print(f"Phase 1: ex2_emu sweep for {causal_label} ({cta_label})") print("=" * 70) - print(f"{'freq':>5} {'start':>6} {'ms':>8} {'tflops':>10} {'mfu':>8}") - print("-" * 45) + if is_hd256: + print(f"{'freq':>5} {'res':>4} {'start':>6} {'ms':>8} {'tflops':>10} {'mfu':>8}") + print("-" * 50) + else: + print(f"{'freq':>5} {'start':>6} {'ms':>8} {'tflops':>10} {'mfu':>8}") + print("-" * 45) best_freq = config[key]["ex2_emu_freq"] + best_res = config[key].get("ex2_emu_res", None) best_start = config[key]["ex2_emu_start_frg"] best_tflops = 0 for start_frg in start_frg_values: for freq in freq_values: - test_config = dict(config) - test_config[key] = {**config[key], "ex2_emu_freq": freq, "ex2_emu_start_frg": start_frg} - write_file(patch_config(original_src, test_config)) - try: - ms, tflops, mfu = run_benchmark(causal_flag, args.headdim, args.seqlen, args.rep, args.warmup) - if tflops is None: - print(f"{freq:>5} {start_frg:>6} ERROR") - continue - marker = " ***" if tflops > best_tflops else "" - print(f"{freq:>5} {start_frg:>6} {ms:>8.2f} {tflops:>10.0f} {mfu:>8.1f}{marker}") - if tflops > best_tflops: - best_tflops = tflops - best_freq = freq - best_start = start_frg - except Exception as e: - print(f"{freq:>5} {start_frg:>6} ERROR: {e}") - sys.stdout.flush() - - print(f"\n Best: freq={best_freq}, start_frg={best_start}, {best_tflops:.0f} TFLOPS") - config[key] = {**config[key], "ex2_emu_freq": best_freq, "ex2_emu_start_frg": best_start} + for res in res_values: + test_config = dict(config) + patch = {"ex2_emu_freq": freq, "ex2_emu_start_frg": start_frg} + if res is not None: + patch["ex2_emu_res"] = res + test_config[key] = {**config[key], **patch} + write_file(patch_config(original_src, test_config)) + try: + ms, tflops, mfu = run_benchmark(causal_flag, args.headdim, args.seqlen, args.rep, args.warmup) + if tflops is None: + if is_hd256: + print(f"{freq:>5} {res:>4} {start_frg:>6} ERROR") + else: + print(f"{freq:>5} {start_frg:>6} ERROR") + continue + marker = " ***" if tflops > best_tflops else "" + if is_hd256: + print(f"{freq:>5} {res:>4} {start_frg:>6} {ms:>8.2f} {tflops:>10.0f} {mfu:>8.1f}{marker}") + else: + print(f"{freq:>5} {start_frg:>6} {ms:>8.2f} {tflops:>10.0f} {mfu:>8.1f}{marker}") + if tflops > best_tflops: + best_tflops = tflops + best_freq = freq + best_res = res + best_start = start_frg + except Exception as e: + if is_hd256: + print(f"{freq:>5} {res:>4} {start_frg:>6} ERROR: {e}") + else: + print(f"{freq:>5} {start_frg:>6} ERROR: {e}") + sys.stdout.flush() + + if is_hd256: + print(f"\n Best: freq={best_freq}, res={best_res}, start_frg={best_start}, {best_tflops:.0f} TFLOPS") + config[key] = {**config[key], "ex2_emu_freq": best_freq, "ex2_emu_res": best_res, "ex2_emu_start_frg": best_start} + else: + print(f"\n Best: freq={best_freq}, start_frg={best_start}, {best_tflops:.0f} TFLOPS") + config[key] = {**config[key], "ex2_emu_freq": best_freq, "ex2_emu_start_frg": best_start} # ── Phase 2: Register count sweep (softmax, correction; other = 512 - 2*softmax - correction) ── + # hd256 skipped: its num_regs_other=32 is fixed and the 512-budget formula does not apply. reg_combos = [] for softmax in [176, 184, 192, 200]: @@ -202,7 +329,7 @@ def main(): reg_combos.append((softmax, correction, other)) for key in keys_to_tune: - use_2cta, is_causal, _, _ = key + use_2cta, is_causal, key_hdim, _ = key causal_flag = "true" if is_causal else "false" causal_label = "causal" if is_causal else "non-causal" cta_label = "2CTA" if use_2cta else "1CTA" @@ -210,6 +337,11 @@ def main(): print("\n" + "=" * 70) print(f"Phase 2: Register sweep for {causal_label} ({cta_label})") print("=" * 70) + + if key_hdim == 256: + print(" Skipping: hd256 uses fixed num_regs_other=32; 512-budget formula does not apply.") + continue + print(f"{'softmax':>8} {'corr':>6} {'other':>6} {'ms':>8} {'tflops':>10} {'mfu':>8}") print("-" * 55) @@ -255,10 +387,8 @@ def main(): val_parts = ", ".join(f'"{k}": {json.dumps(v)}' for k, v in config[key].items()) print(f" {key!r}: {{{val_parts}}},") - print(f"\nTo apply, update _TUNING_CONFIG in {KERNEL_FILE}") - # Restore original write_file(original_src) print("Restored original file.") diff --git a/flash_attn/cute/README.md b/flash_attn/cute/README.md index 653f7b1cee2..c7f1b32ebd0 100644 --- a/flash_attn/cute/README.md +++ b/flash_attn/cute/README.md @@ -27,6 +27,7 @@ out = flash_attn_func(q, k, v, causal=True) ```sh git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention -pip install -e "flash_attn/cute[dev]" +pip install -e "flash_attn/cute[dev]" # CUDA 12.x +pip install -e "flash_attn/cute[dev,cu13]" # CUDA 13.x (e.g. B200) pytest tests/cute/ ``` diff --git a/flash_attn/cute/__init__.py b/flash_attn/cute/__init__.py index 1b84363b63d..be32e149b85 100644 --- a/flash_attn/cute/__init__.py +++ b/flash_attn/cute/__init__.py @@ -7,19 +7,11 @@ except PackageNotFoundError: __version__ = "0.0.0" -import cutlass.cute as cute - from .interface import ( flash_attn_func, flash_attn_varlen_func, ) -from flash_attn.cute.cute_dsl_utils import cute_compile_patched - -# Patch cute.compile to optionally dump SASS -cute.compile = cute_compile_patched - - __all__ = [ "flash_attn_func", "flash_attn_varlen_func", diff --git a/flash_attn/cute/bench_utils.py b/flash_attn/cute/bench_utils.py index 45cbcf1af36..f6ad96d7c4f 100644 --- a/flash_attn/cute/bench_utils.py +++ b/flash_attn/cute/bench_utils.py @@ -13,7 +13,15 @@ def flops( - batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None) + batch, + nheads, + seqlen_q, + seqlen_k, + headdim, + headdim_v, + causal=False, + window_size=(None, None), + has_qv=False, ): if causal: avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 @@ -35,7 +43,37 @@ def flops( else torch.full_like(row_idx, seqlen_k - 1) ) avg_seqlen = (col_right - col_left + 1).float().mean().item() - return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) + eff_headdim = headdim + headdim_v if has_qv else headdim + return batch * nheads * 2 * seqlen_q * avg_seqlen * (eff_headdim + headdim_v) + + +# ── Bandwidth calculation ──────────────────────────────────────────────────── + + +def bandwidth_fwd_bytes( + batch, nheads, nheads_kv, seqlen_q, seqlen_k, headdim, headdim_v, dtype_bytes=2, has_qv=False +): + """HBM traffic for one attention pass: read Q,K,V + write O.""" + q = batch * nheads * seqlen_q * headdim + qv = batch * nheads * seqlen_q * headdim_v if has_qv else 0 + k = batch * nheads_kv * seqlen_k * headdim + v = batch * nheads_kv * seqlen_k * headdim_v + o = batch * nheads * seqlen_q * headdim_v + return (q + qv + k + v + o) * dtype_bytes + + +def bandwidth_bwd_bytes( + batch, nheads, nheads_kv, seqlen_q, seqlen_k, headdim, headdim_v, dtype_bytes=2 +): + """HBM traffic for one attention pass: read Q,K,V,dO + write dQ,dK,dV.""" + q = batch * nheads * seqlen_q * headdim + k = batch * nheads_kv * seqlen_k * headdim + v = batch * nheads_kv * seqlen_k * headdim_v + do = batch * nheads * seqlen_q * headdim_v + dq = q + dk = k + dv = v + return (q + k + v + do + dq + dk + dv) * dtype_bytes # ── Reference attention ───────────────────────────────────────────────────── diff --git a/flash_attn/cute/benchmark_flash_attention_fp8.py b/flash_attn/cute/benchmark_flash_attention_fp8.py new file mode 100644 index 00000000000..c79e7687237 --- /dev/null +++ b/flash_attn/cute/benchmark_flash_attention_fp8.py @@ -0,0 +1,434 @@ +# Benchmark FP8 attention for FA4 (CuTe-DSL) on SM100. +# +# Run (recommended): +# python -m flash_attn.cute.benchmark_flash_attention_fp8 +# +# Notes: +# - This is intended to be used while bringing up FP8 support for SM100. +# - FP8 correctness depends on descales + max-offset scaling being implemented in the SM100 kernel. +# This script optionally checks output vs a BF16 PyTorch baseline on dequantized FP8 inputs. +# +# Adapted from: `hopper/benchmark_flash_attention_fp8.py` + +from __future__ import annotations + +import argparse +import inspect +import math +import time +from typing import Iterable + +import torch +from einops import rearrange + +from flash_attn.cute.benchmark import benchmark_forward +from flash_attn.cute.interface import _flash_attn_fwd as flash_attn_cute_fwd + +try: + import cudnn +except ImportError: + cudnn = None + + +def _torch_float8_dtype(name: str) -> torch.dtype: + if name in ("fp8", "fp8_e4m3", "fp8_e4m3fn"): + return torch.float8_e4m3fn + if name in ("fp8_e5m2", "fp8_e5m2fn"): + return torch.float8_e5m2 + raise ValueError(f"Unsupported fp8 dtype name: {name}") + + +def _parse_int_list(csv: str) -> list[int]: + out: list[int] = [] + for part in csv.split(","): + part = part.strip() + if not part: + continue + out.append(int(part)) + return out + + +def attention_pytorch(qkv: torch.Tensor, causal: bool) -> torch.Tensor: + """ + qkv: (batch, seqlen, 3, nheads, headdim) + out: (batch, seqlen, nheads, headdim) + """ + batch_size, seqlen, _, nheads, d = qkv.shape + q, k, v = qkv.unbind(dim=2) + q = rearrange(q, "b t h d -> (b h) t d") + k = rearrange(k, "b s h d -> (b h) d s") + softmax_scale = 1.0 / math.sqrt(d) + scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) + scores = rearrange( + torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), "(b h) t s -> b h t s", h=nheads + ) + if causal: + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1) + output = torch.einsum("bhts,bshd->bthd", attention, v) + return output.to(dtype=qkv.dtype) + + +def flops(batch: int, seqlen: int, headdim: int, nheads: int, causal: bool) -> int: + # Matches the hopper benchmark’s convention. + return 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) + + +def efficiency(flop: int, seconds: float) -> float: + return (flop / seconds / 1e12) if not math.isnan(seconds) else 0.0 + + +def time_fwd(fn, *args, repeats: int, **kwargs) -> float: + time.sleep(1) # reduce residual throttling effects between benchmarks + _, m = benchmark_forward(fn, *args, repeats=repeats, verbose=False, **kwargs) + return float(m.mean) + + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + if torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + if torch_type == torch.float32: + return cudnn.data_type.FLOAT + if torch_type == torch.int32: + return cudnn.data_type.INT32 + if torch_type == torch.int64: + return cudnn.data_type.INT64 + if torch_type == torch.float8_e4m3fn: + return cudnn.data_type.FP8_E4M3 + if torch_type == torch.float8_e5m2: + return cudnn.data_type.FP8_E5M2 + raise ValueError("Unsupported tensor data type.") + + +def cudnn_sdpa_fp8_setup(qkv: torch.Tensor, seqlen_q: int, seqlen_k: int, causal: bool): + """Minimal cudnn.fp8 sdpa runner (optional).""" + assert cudnn is not None, "cudnn python bindings not available" + b, _, _, nheads, headdim = qkv.shape + o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device) + o_gpu_transposed = torch.as_strided( + o_gpu, + [b, nheads, seqlen_q, headdim], + [nheads * seqlen_q * headdim, headdim, nheads * headdim, 1], + ) + amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) + amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) + + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(qkv.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + new_q = torch.as_strided( + qkv, + [b, nheads, seqlen_q, headdim], + [seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], + storage_offset=0, + ) + q = graph.tensor( + name="Q", + dim=list(new_q.shape), + stride=list(new_q.stride()), + data_type=convert_to_cudnn_type(qkv.dtype), + ) + + new_k = torch.as_strided( + qkv, + [b, nheads, seqlen_k, headdim], + [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], + storage_offset=nheads * headdim, + ) + k = graph.tensor( + name="K", + dim=list(new_k.shape), + stride=list(new_k.stride()), + data_type=convert_to_cudnn_type(qkv.dtype), + ) + + new_v = torch.as_strided( + qkv, + [b, nheads, seqlen_k, headdim], + [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], + storage_offset=nheads * headdim * 2, + ) + v = graph.tensor( + name="V", + dim=list(new_v.shape), + stride=list(new_v.stride()), + data_type=convert_to_cudnn_type(qkv.dtype), + ) + + def _scale_tensor(): + return graph.tensor(dim=[1, 1, 1, 1], stride=[1, 1, 1, 1], data_type=cudnn.data_type.FLOAT) + + default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda") + descale_q = _scale_tensor() + descale_k = _scale_tensor() + descale_v = _scale_tensor() + descale_s = _scale_tensor() + scale_s = _scale_tensor() + scale_o = _scale_tensor() + + o, _, amax_s, amax_o = graph.sdpa_fp8( + q=q, + k=k, + v=v, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_s=descale_s, + scale_s=scale_s, + scale_o=scale_o, + is_inference=True, + attn_scale=1.0 / math.sqrt(headdim), + use_causal_mask=causal, + name="sdpa", + ) + o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride()) + amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride()) + amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride()) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + variant_pack = { + q: new_q, + k: new_k, + v: new_v, + descale_q: default_scale_gpu, + descale_k: default_scale_gpu, + descale_v: default_scale_gpu, + descale_s: default_scale_gpu, + scale_s: default_scale_gpu, + scale_o: default_scale_gpu, + o: o_gpu_transposed, + amax_s: amax_s_gpu, + amax_o: amax_o_gpu, + } + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + def run(): + graph.execute(variant_pack, workspace) + return o_gpu + + return run + + +def _maybe_pass_descales(callable_, **kwargs): + sig = inspect.signature(callable_) + return {k: v for k, v in kwargs.items() if k in sig.parameters} + + +def main(argv: Iterable[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--repeats", type=int, default=30) + parser.add_argument("--dim", type=int, default=2048) + parser.add_argument("--headdims", default="64,128") + parser.add_argument("--dtype", default="fp8_e4m3fn") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--check", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable correctness checks vs BF16 PyTorch baseline.", + ) + parser.add_argument( + "--check-quantization-only", + action="store_true", + help="Check FP8 kernel vs dequantized-FP8 baseline (quantization error only).", + ) + parser.add_argument("--atol-bf16", type=float, default=0.10) + parser.add_argument("--rtol-bf16", type=float, default=0.10) + parser.add_argument("--atol-fp8", type=float, default=0.50) + parser.add_argument("--rtol-fp8", type=float, default=0.50) + parser.add_argument("--run-cudnn", action="store_true") + args = parser.parse_args(list(argv) if argv is not None else None) + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + major, minor = torch.cuda.get_device_capability() + if major != 10: + raise RuntimeError( + f"This benchmark is for SM100 (compute capability 10.x). Got {major}.{minor}." + ) + + torch.manual_seed(args.seed) + device = "cuda" + fp8_dtype = _torch_float8_dtype(args.dtype) + headdim_vals = _parse_int_list(args.headdims) + bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] + + methods = ["Pytorch", "FA4-CuTe-BF16", "FA4-CuTe-FP8"] + ( + ["cuDNN-FP8"] if args.run_cudnn and cudnn is not None else [] + ) + + fp8_failures = [] + + for headdim in headdim_vals: + for causal in (False, True): + for batch, seqlen in bs_seqlen_vals: + torch.cuda.empty_cache() + nheads = args.dim // headdim + if args.dim % headdim != 0: + raise ValueError(f"--dim must be divisible by headdim ({args.dim=} {headdim=})") + + q_bf16 = torch.randn( + batch, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16 + ) + k_bf16 = torch.randn( + batch, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16 + ) + v_bf16 = torch.randn( + batch, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16 + ) + qkv_bf16 = torch.stack([q_bf16, k_bf16, v_bf16], dim=2) + + times = {} + speeds = {} + + out_ref_bf16 = None + try: + out_ref_bf16 = attention_pytorch(qkv_bf16, causal=causal) # warmup / reference + t = time_fwd(attention_pytorch, qkv_bf16, causal=causal, repeats=args.repeats) + times["Pytorch"] = t + except RuntimeError as e: + if "out of memory" in str(e).lower(): + times["Pytorch"] = float("nan") + out_ref_bf16 = None + else: + raise + + # FA4 / CuTe BF16 baseline + try: + softmax_scale = headdim**-0.5 + out_fa4_bf16, _ = flash_attn_cute_fwd( + q_bf16, k_bf16, v_bf16, softmax_scale=softmax_scale, causal=causal + ) # warmup / compile + t = time_fwd( + flash_attn_cute_fwd, + q_bf16, + k_bf16, + v_bf16, + softmax_scale=softmax_scale, + causal=causal, + repeats=args.repeats, + ) + times["FA4-CuTe-BF16"] = t + if args.check and out_ref_bf16 is not None: + torch.testing.assert_close( + out_fa4_bf16, + out_ref_bf16, + atol=args.atol_bf16, + rtol=args.rtol_bf16, + ) + except Exception as e: + # Treat as fatal: BF16 kernel should be usable for basic sanity checking. + raise RuntimeError("FA4-CuTe BF16 baseline failed") from e + + # FA4 / CuTe FP8 + q_fp8 = q_bf16.to(fp8_dtype) + k_fp8 = k_bf16.to(fp8_dtype) + v_fp8 = v_bf16.to(fp8_dtype) + + # Placeholder descales (FA3-style: per-(batch, kv_head)). + q_descale = torch.ones(batch, nheads, device=device, dtype=torch.float32) + k_descale = torch.ones(batch, nheads, device=device, dtype=torch.float32) + v_descale = torch.ones(batch, nheads, device=device, dtype=torch.float32) + + # Optional: FP8 reference baseline (dequantized FP8 -> PyTorch) for quantization-error-only checks + out_ref_fp8 = None + if args.check and args.check_quantization_only: + try: + # Dequantize FP8 inputs back to BF16 (applying descales) + q_ref_fp8 = (q_fp8.to(torch.bfloat16) * q_descale[:, None, :, None]).to( + torch.bfloat16 + ) + k_ref_fp8 = (k_fp8.to(torch.bfloat16) * k_descale[:, None, :, None]).to( + torch.bfloat16 + ) + v_ref_fp8 = (v_fp8.to(torch.bfloat16) * v_descale[:, None, :, None]).to( + torch.bfloat16 + ) + qkv_ref_fp8 = torch.stack([q_ref_fp8, k_ref_fp8, v_ref_fp8], dim=2) + out_ref_fp8 = attention_pytorch(qkv_ref_fp8, causal=causal) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + out_ref_fp8 = None + else: + raise + + fa4_kwargs = dict(softmax_scale=softmax_scale, causal=causal) + fa4_kwargs.update( + _maybe_pass_descales( + flash_attn_cute_fwd, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + ) + + try: + # Warmup/compile (will raise until FP8 is implemented) + out_fa4_fp8, _ = flash_attn_cute_fwd(q_fp8, k_fp8, v_fp8, **fa4_kwargs) + t = time_fwd( + flash_attn_cute_fwd, + q_fp8, + k_fp8, + v_fp8, + repeats=args.repeats, + **fa4_kwargs, + ) + times["FA4-CuTe-FP8"] = t + if args.check: + # Choose baseline: quantization-only (dequantized FP8) or full (BF16) + if args.check_quantization_only: + ref_baseline = out_ref_fp8 + else: + ref_baseline = out_ref_bf16 + + if ref_baseline is not None: + torch.testing.assert_close( + out_fa4_fp8, + ref_baseline, + atol=args.atol_fp8, + rtol=args.rtol_fp8, + ) + except Exception as e: + fp8_failures.append((causal, headdim, batch, seqlen, repr(e))) + times["FA4-CuTe-FP8"] = float("nan") + + if args.run_cudnn and cudnn is not None: + qkv_fp8 = qkv_bf16.to(fp8_dtype) + runner = cudnn_sdpa_fp8_setup(qkv_fp8, seqlen, seqlen, causal=causal) + _ = runner() # warmup + t = time_fwd(lambda: runner(), repeats=args.repeats) + times["cuDNN-FP8"] = t + + print(f"### causal={causal}, headdim={headdim}, batch={batch}, seqlen={seqlen} ###") + for method in methods: + t = times.get(method, float("nan")) + speeds[method] = efficiency(flops(batch, seqlen, headdim, nheads, causal), t) + if math.isnan(t): + print(f"{method} fwd: (skipped)") + else: + print(f"{method} fwd: {speeds[method]:.2f} TFLOPs/s, {t * 1e3:.3f} ms") + if math.isnan(times.get("FA4-CuTe-FP8", float("nan"))): + print("FA4-CuTe-FP8 status: FAILED") + + if fp8_failures: + print(f"\nFP8 failures: {len(fp8_failures)} (showing first 5)") + for causal, headdim, batch, seqlen, err in fp8_failures[:5]: + print(f"- causal={causal} headdim={headdim} batch={batch} seqlen={seqlen}: {err}") + return 1 + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index 720778027b2..4caadce864a 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -10,6 +10,24 @@ import flash_attn.cute.mma_sm100_desc as sm100_desc +def _tcgen05_mma_kind(op: cute.nvgpu.tcgen05.mma.MmaOp) -> str: + if isinstance(op, tcgen05.mma.MmaF16BF16Op): + return "f16" + if isinstance(op, tcgen05.mma.MmaTF32Op): + return "tf32" + if isinstance(op, tcgen05.mma.MmaI8Op): + return "i8" + if isinstance(op, tcgen05.mma.MmaFP8Op): + return "f8f6f4" + if isinstance(op, tcgen05.mma.MmaMXF8Op): + return "mxf8f6f4" + if isinstance(op, tcgen05.mma.MmaMXF4Op): + return "mxf4" + if isinstance(op, tcgen05.mma.MmaMXF4NVF4Op): + return "mxf4nvf4" + raise TypeError(f"Unsupported tcgen05 MMA op kind: {type(op).__name__}") + + @cute.jit def gemm_w_idx( tiled_mma: cute.TiledMma, @@ -108,6 +126,7 @@ def gemm_ptx( sA_layout = sA.layout if sA is not None else None sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + kind = _tcgen05_mma_kind(op) if const_expr(not is_ts): sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( @@ -177,7 +196,7 @@ def gemm_ptx( f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t" f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" + f"tcgen05.mma.cta_group::1.kind::{kind} [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" "}\n", "r,r,r,r", has_side_effects=True, @@ -198,7 +217,7 @@ def gemm_ptx( ".reg .b64 smem_desc_b;\n\t" f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" + f"tcgen05.mma.cta_group::1.kind::{kind} [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" "}\n", "r,r,r,r", has_side_effects=True, @@ -223,6 +242,7 @@ def gemm_ptx_loop( sA_layout = sA.layout if sA is not None else tCrA.layout sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + kind = _tcgen05_mma_kind(op) if const_expr(not is_ts): sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( @@ -310,14 +330,14 @@ def gemm_ptx_loop( f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) ) @@ -351,14 +371,14 @@ def gemm_ptx_loop( f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) ) @@ -394,6 +414,7 @@ def gemm_ptx_partial( sA_layout = sA.layout if sA is not None else tCrA.layout sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + kind = _tcgen05_mma_kind(op) if const_expr(not is_ts): sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( @@ -477,7 +498,7 @@ def gemm_ptx_partial( f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" @@ -486,7 +507,7 @@ def gemm_ptx_partial( f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) for k in range(1, cute.size(tCrA.shape[2])) ) @@ -554,7 +575,7 @@ def gemm_ptx_partial( f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" @@ -562,7 +583,7 @@ def gemm_ptx_partial( f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range( 1, @@ -575,7 +596,7 @@ def gemm_ptx_partial( ( f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range(split_arrive_idx, cute.size(tCrA.shape[2])) ) @@ -613,6 +634,7 @@ def gemm_ptx_partial1( assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + kind = _tcgen05_mma_kind(op) if const_expr(not is_ts): smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( @@ -706,14 +728,14 @@ def gemm_ptx_partial1( f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $4, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + "".join( ( f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" ) for k in range(1, cute.size(tCrA.shape[2])) ) @@ -751,13 +773,13 @@ def gemm_ptx_partial1( f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + "".join( ( f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" ) for k in range(1, cute.size(tCrA.shape[2])) ) @@ -783,6 +805,7 @@ def gemm_ptx_precomputed( mbar_phase: Optional[Int32] = None, zero_init: bool | Boolean = False, cta_group: int = 1, + kind: str = "f16", ) -> None: # acc_tmem_addr += acc_offset is_ts = const_expr(smem_desc_base_a is None) @@ -842,7 +865,7 @@ def gemm_ptx_precomputed( f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" @@ -851,7 +874,7 @@ def gemm_ptx_precomputed( f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) for k in range(1, num_k_tile) ) @@ -911,7 +934,7 @@ def gemm_ptx_precomputed( f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" @@ -919,7 +942,7 @@ def gemm_ptx_precomputed( f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range( 1, @@ -933,7 +956,7 @@ def gemm_ptx_precomputed( # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range(num_k_tile // 4 * 3, num_k_tile) ) @@ -1019,6 +1042,7 @@ def gemm_ptx_precomputed_varname( smem_offset: int, zero_init: bool | Boolean = False, cta_group: int = 1, + kind: str = "f16", ) -> None: is_ts = False num_k_tile = cute.size(tCrB_layout.shape[2]) @@ -1067,7 +1091,7 @@ def gemm_ptx_precomputed_varname( ) + "setp.ne.b32 p, $1, 0;\n\t" # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t" + "".join( ( # f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" @@ -1077,7 +1101,7 @@ def gemm_ptx_precomputed_varname( # f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\n\t" # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t" ) for k in range(1, num_k_tile) ) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index f21013891b4..422da2b66a0 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -137,3 +137,20 @@ def get_n_block_min_before_local_mask( n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q n_idx_left = n_idx - self.window_size_left return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.tile_n)) + + @cute.jit + def get_n_block_max_for_m_block( + self, + seqlen_info: SeqlenInfoQK, + m_block: Int32, + n_block_global_max: Int32, + ) -> Int32: + if const_expr(self.is_causal or self.window_size_right is not None): + m_idx_max = (m_block + 1) * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx_right = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + if const_expr(self.window_size_right is not None): + n_idx_right += self.window_size_right + return min(n_block_global_max, cute.ceil_div(n_idx_right, self.tile_n)) + return n_block_global_max diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index 52cb7e06044..d664b16dc64 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -188,7 +188,7 @@ def produce_block_sparse_loads( must be converted to unpacked for sparse tensor indexing. """ - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) @@ -332,7 +332,7 @@ def consume_block_sparse_loads( must be converted to unpacked for sparse tensor indexing. """ - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) @@ -552,7 +552,7 @@ def produce_block_sparse_loads_sm100( """ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] @@ -629,7 +629,7 @@ def get_total_block_count( ): m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors if const_expr(full_block_cnt is not None): return ( mask_block_cnt[batch_idx, head_idx, m_block_sparse] @@ -667,6 +667,8 @@ def handle_block_sparse_empty_tile_correction_sm100( o_corr_consumer_phase: Int32, corr_epi_producer_phase: Int32, softmax_scale_log2: Float32, + max_offset: Float32, + max_offset_scale: Float32, mO_cur: Optional[cute.Tensor] = None, gO: Optional[cute.Tensor] = None, gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, @@ -706,10 +708,11 @@ def handle_block_sparse_empty_tile_correction_sm100( if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0): if row_max_value == -Float32.inf: row_max_value = sink_val * (LOG2_E / softmax_scale_log2) - row_sum_value = Float32(1.0) + row_sum_value = max_offset_scale else: row_sum_value = row_sum_value + cute.math.exp2( - sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True + sink_val * LOG2_E - row_max_value * softmax_scale_log2 + max_offset, + fastmath=True, ) if tidx < m_block_size: scale_row_idx = tidx + stage * m_block_size @@ -726,6 +729,8 @@ def handle_block_sparse_empty_tile_correction_sm100( if const_expr(gmem_tiled_copy_O is None): pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase) + + gO_stage = gO[None, None, stage] if const_expr(gO is not None) else None correction_epilogue( thr_mma_pv, tOtO[None, None, None, stage], @@ -736,7 +741,7 @@ def handle_block_sparse_empty_tile_correction_sm100( Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs sO[None, None, stage], mO_cur, - gO[None, None, stage], + gO_stage, gmem_tiled_copy_O, ) if const_expr(gmem_tiled_copy_O is None): @@ -775,7 +780,7 @@ def softmax_block_sparse_sm100( warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] @@ -790,8 +795,6 @@ def softmax_block_sparse_sm100( total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt if total_block_cnt == 0: - # See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. - # pipeline_sm_stats.producer_commit_w_index(stage_idx) sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx) else: if curr_mask_block_cnt > 0: @@ -902,7 +905,7 @@ def get_total_q_block_count_bwd( m_block_max: int = 0, ): """Count total tile iterations for given n_block (KV tile) in backward.""" - q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors + q_block_cnt, _, full_block_cnt, _, *_ = blocksparse_tensors total = q_block_cnt[batch_idx, head_idx, n_block] if const_expr(full_block_cnt is not None): total = total + full_block_cnt[batch_idx, head_idx, n_block] @@ -1046,7 +1049,7 @@ def get_block_sparse_iteration_info_bwd( Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count). """ - q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] @@ -1170,7 +1173,7 @@ def produce_block_sparse_q_loads_bwd_sm90( Returns updated (producer_state_Q, producer_state_dO). """ - q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] @@ -1265,7 +1268,7 @@ def consume_block_sparse_mma_bwd_sm90( Returns updated (consumer_state_Q, consumer_state_dO). """ - q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] @@ -1391,7 +1394,7 @@ def dQaccum_store_block_sparse_bwd_sm90( Iterates partial blocks first, then full blocks, matching producer/consumer order. """ - q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index 3fad8c9f491..4a5726b7493 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -19,9 +19,13 @@ class BlockSparseTensors(NamedTuple): mask_block_idx: cute.Tensor full_block_cnt: cute.Tensor | None full_block_idx: cute.Tensor | None + dq_write_order: cute.Tensor | None = None + dq_write_order_full: cute.Tensor | None = None def __new_from_mlir_values__(self, values): if len(values) == 2: + values = (*values, None, None, None, None) + elif len(values) == 4: values = (*values, None, None) return BlockSparseTensors(*values) @@ -32,6 +36,138 @@ class BlockSparseTensorsTorch(NamedTuple): full_block_cnt: torch.Tensor | None = None full_block_idx: torch.Tensor | None = None block_size: tuple[int, int] | None = None + dq_write_order: torch.Tensor | None = None + dq_write_order_full: torch.Tensor | None = None + spt: bool | None = None + + +def _ordered_to_dense_simple( + num_blocks: torch.Tensor, + indices: torch.Tensor, + num_cols: int, +) -> torch.Tensor: + """Convert ordered sparse representation to dense binary matrix. + + Args: + num_blocks: [B, H, num_rows] count of valid entries per row + indices: [B, H, num_rows, max_entries] column indices (valid entries packed left) + num_cols: total number of columns + + Returns: + dense: [B, H, num_rows, num_cols] binary int32 matrix + """ + B, H, num_rows, max_entries = indices.shape + device = indices.device + dense = torch.zeros(B, H, num_rows, num_cols + 1, dtype=torch.int32, device=device) + col_range = torch.arange(max_entries, device=device) + valid = col_range[None, None, None, :] < num_blocks[:, :, :, None] + safe_indices = torch.where(valid, indices.long(), num_cols) + row_idx = torch.arange(num_rows, device=device)[None, None, :, None].expand_as(indices) + b_idx = torch.arange(B, device=device)[:, None, None, None].expand_as(indices) + h_idx = torch.arange(H, device=device)[None, :, None, None].expand_as(indices) + dense[b_idx, h_idx, row_idx, safe_indices] = 1 + return dense[:, :, :, :num_cols] + + +def compute_dq_write_order( + fwd_mask_cnt: torch.Tensor, + fwd_mask_idx: torch.Tensor, + fwd_full_cnt: torch.Tensor | None, + fwd_full_idx: torch.Tensor | None, + bwd_mask_cnt: torch.Tensor, + bwd_mask_idx: torch.Tensor, + bwd_full_cnt: torch.Tensor | None, + bwd_full_idx: torch.Tensor | None, + spt: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Compute dQ write-order metadata for deterministic block-sparse backward. + + For each (n_block, i) in the backward iteration, computes the semaphore + lock value: the rank of n_block in the combined (partial + full) sorted + contributor list for the target m_block. + + Lock values are assigned in ascending n_block order (or descending if spt=True) + to guarantee deadlock-freedom with the CTA scheduling order. + + Args: + fwd_mask_cnt: [B, H, num_m_blocks] partial contributor counts per m_block + fwd_mask_idx: [B, H, num_m_blocks, max_kv] partial contributor n_block indices (ascending) + fwd_full_cnt: [B, H, num_m_blocks] full contributor counts per m_block (optional) + fwd_full_idx: [B, H, num_m_blocks, max_kv] full contributor n_block indices (optional) + bwd_mask_cnt: [B, H, num_n_blocks] partial iteration counts per n_block + bwd_mask_idx: [B, H, num_n_blocks, max_q] partial iteration m_block indices + bwd_full_cnt: [B, H, num_n_blocks] full iteration counts per n_block (optional) + bwd_full_idx: [B, H, num_n_blocks, max_q] full iteration m_block indices (optional) + spt: if True, reverse ordering (highest n_block gets lock_value=0) + + Returns: + (dq_write_order, dq_write_order_full): tensors parallel to bwd_mask_idx + and bwd_full_idx respectively, containing lock values. + """ + device = fwd_mask_idx.device + B, H, num_m, max_kv_partial = fwd_mask_idx.shape + _, _, num_n, max_q_partial = bwd_mask_idx.shape + + has_full = fwd_full_cnt is not None and fwd_full_idx is not None + + dense_partial = _ordered_to_dense_simple(fwd_mask_cnt, fwd_mask_idx, num_n) + if has_full: + dense_full = _ordered_to_dense_simple(fwd_full_cnt, fwd_full_idx, num_n) + dense = (dense_partial + dense_full).clamp(max=1) + else: + dense = dense_partial + + cumsum = dense.cumsum(dim=-1) + rank_table = (cumsum - dense).to(torch.int32) + + if spt: + total_per_m = cumsum[:, :, :, -1:] + rank_table = (total_per_m - 1 - rank_table).to(torch.int32) + + def _gather_write_order(bwd_idx, bwd_cnt): + b_i = torch.arange(B, device=device)[:, None, None, None].expand_as(bwd_idx) + h_i = torch.arange(H, device=device)[None, :, None, None].expand_as(bwd_idx) + n_i = torch.arange(bwd_idx.shape[2], device=device)[None, None, :, None].expand_as(bwd_idx) + m_vals = bwd_idx.long().clamp(0, num_m - 1) + return rank_table[b_i, h_i, m_vals, n_i].to(torch.int32) + + dq_write_order = _gather_write_order(bwd_mask_idx, bwd_mask_cnt) + + dq_write_order_full = None + if has_full and bwd_full_cnt is not None and bwd_full_idx is not None: + dq_write_order_full = _gather_write_order(bwd_full_idx, bwd_full_cnt) + + return dq_write_order, dq_write_order_full + + +def compute_dq_write_order_from_block_mask( + block_mask, + spt: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = block_mask.as_tuple() + return compute_dq_write_order( + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + spt=spt, + ) def get_sparse_q_block_size( @@ -110,6 +246,25 @@ def _check_and_expand_block( return expanded_cnt, expanded_idx +def _check_and_expand_metadata_tensor( + name: str, + tensor: torch.Tensor | None, + expected_shape: Tuple[int, ...], + context: str | None, + hint: str | Callable[[], str] | None, + device: torch.device, +) -> torch.Tensor | None: + if tensor is None: + return None + if tensor.dtype != torch.int32: + raise ValueError(f"{name} must have dtype torch.int32") + if tensor.device != device: + raise ValueError(f"{name} must be on the same device as block sparse tensors") + if not tensor.is_cuda: + raise ValueError(f"{name} must live on CUDA") + return _expand_sparsity_tensor(tensor, expected_shape, name, context, hint) + + def get_block_sparse_expected_shapes( batch_size: int, num_head: int, @@ -279,12 +434,37 @@ def normalize_block_sparse_tensors( if full_cnt is not None and mask_cnt.device != full_cnt.device: raise ValueError("All block sparse tensors must be on the same device") + dq_write_order = _check_and_expand_metadata_tensor( + "dq_write_order", + tensors.dq_write_order, + tuple(mask_idx.shape), + context, + hint, + mask_cnt.device, + ) + dq_write_order_full = _check_and_expand_metadata_tensor( + "dq_write_order_full", + tensors.dq_write_order_full, + tuple(full_idx.shape) if full_idx is not None else expected_index_shape, + context, + hint, + mask_cnt.device, + ) + spt = tensors.spt + if spt is not None and not isinstance(spt, bool): + raise ValueError("spt must be a bool when provided") + if spt is not None and dq_write_order is None: + raise ValueError("spt requires dq_write_order to be provided") + return BlockSparseTensorsTorch( mask_block_cnt=mask_cnt, mask_block_idx=mask_idx, full_block_cnt=full_cnt, full_block_idx=full_idx, block_size=tensors.block_size, + dq_write_order=dq_write_order, + dq_write_order_full=dq_write_order_full, + spt=spt, ) @@ -316,6 +496,8 @@ def get_block_sparse_broadcast_pattern( tensors.mask_block_idx, tensors.full_block_cnt, tensors.full_block_idx, + tensors.dq_write_order, + tensors.dq_write_order_full, ): if tensor is not None: patterns.append(get_broadcast_dims(tensor)) @@ -423,30 +605,21 @@ def to_cute_block_sparse_tensors( """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi""" if not is_block_sparsity_enabled(tensors): return None - - ( - mask_block_cnt, - mask_block_idx, - full_block_cnt, - full_block_idx, - *_, - ) = tensors - - ( - mask_block_cnt_tensor, - mask_block_idx_tensor, - ) = [ + mask_block_cnt_tensor, mask_block_idx_tensor = [ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) - for t in (mask_block_cnt, mask_block_idx) + for t in (tensors.mask_block_cnt, tensors.mask_block_idx) ] - ( - full_block_cnt_tensor, - full_block_idx_tensor, - ) = [ + full_block_cnt_tensor, full_block_idx_tensor = [ + to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) + if t is not None + else None + for t in (tensors.full_block_cnt, tensors.full_block_idx) + ] + dq_write_order_tensor, dq_write_order_full_tensor = [ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) if t is not None else None - for t in (full_block_cnt, full_block_idx) + for t in (tensors.dq_write_order, tensors.dq_write_order_full) ] return BlockSparseTensors( @@ -454,6 +627,8 @@ def to_cute_block_sparse_tensors( mask_block_idx_tensor, full_block_cnt_tensor, full_block_idx_tensor, + dq_write_order_tensor, + dq_write_order_full_tensor, ) diff --git a/flash_attn/cute/cache_utils.py b/flash_attn/cute/cache_utils.py index 3fca0579d98..658a8d5b656 100644 --- a/flash_attn/cute/cache_utils.py +++ b/flash_attn/cute/cache_utils.py @@ -1,7 +1,6 @@ # Manage Ahead-of-Time (AOT) compiled kernels import fcntl import hashlib -import logging import os import pickle import sys @@ -18,6 +17,7 @@ import cutlass.cute as cute import tvm_ffi from cutlass.cutlass_dsl import JitCompiledFunction +from flash_attn.cute.fa_logging import fa_log # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. @@ -30,17 +30,6 @@ CompileKeyType: TypeAlias = tuple[Hashable, ...] CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function -logger = logging.getLogger(__name__) -_handler = logging.StreamHandler() -_handler.setFormatter( - logging.Formatter( - "%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" - ) -) -logger.addHandler(_handler) -logger.setLevel(logging.DEBUG) - - # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1" @@ -226,13 +215,13 @@ def _try_load_from_storage(self, key: CompileKeyType) -> bool: label=sha256_hex, ): if obj_path.exists(): - logger.debug("Loading compiled function from disk: %s", obj_path) + fa_log(1, f"Loading compiled function from disk: {obj_path}") m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True) fn = getattr(m, self.EXPORT_FUNCTION_PREFIX) JITCache.__setitem__(self, key, fn) return True else: - logger.debug("Cache miss on disk for key hash %s", sha256_hex) + fa_log(1, f"Cache miss on disk for key hash {sha256_hex}") return False def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None: @@ -247,14 +236,14 @@ def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) - obj_path = self.cache_path / f"{sha256_hex}.o" if obj_path.exists(): # Another process already exported. - logger.debug("Skipping export, already on disk: %s", obj_path) + fa_log(1, f"Skipping export, already on disk: {obj_path}") return - logger.debug("Exporting compiled function to disk: %s", obj_path) + fa_log(1, f"Exporting compiled function to disk: {obj_path}") fn.export_to_c( object_file_path=str(obj_path), function_name=self.EXPORT_FUNCTION_PREFIX, ) - logger.debug("Successfully exported compiled function to disk: %s", obj_path) + fa_log(1, f"Successfully exported compiled function to disk: {obj_path}") def _key_to_hash(self, key: CompileKeyType) -> str: return hashlib.sha256(pickle.dumps(key)).hexdigest() @@ -266,7 +255,7 @@ def clear(self) -> None: """ Not only clear the in-memory cache. Also purge persistent compilation cache. """ - logger.debug("Clearing persistent cache at %s", self.cache_path) + fa_log(1, f"Clearing persistent cache at {self.cache_path}") super().clear() for child in self.cache_path.iterdir(): child.unlink() @@ -285,8 +274,8 @@ def get_jit_cache(name: str | None = None) -> JITCache: path = get_cache_path() / _compute_source_fingerprint() if name: path = path / name - logger.debug("Creating persistent JIT cache at %s", path) + fa_log(1, f"Creating persistent JIT cache at {path}") return JITPersistentCache(path) else: - logger.debug("Persistent cache disabled, using in-memory JIT cache") + fa_log(1, "Persistent cache disabled, using in-memory JIT cache") return JITCache() diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py index a2dd98e41d2..69e8309a028 100644 --- a/flash_attn/cute/compute_block_sparsity.py +++ b/flash_attn/cute/compute_block_sparsity.py @@ -54,7 +54,7 @@ def __call__( seqlen_k: Int32, aux_tensors: Optional[list] = None, ): - self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx = blocksparse_tensors + self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx, *_ = blocksparse_tensors if const_expr(self.compute_full_blocks): assert self.full_cnt is not None and self.full_idx is not None, ( @@ -366,7 +366,14 @@ def compute_block_sparsity( ) compute_block_sparsity.compile_cache[compile_key]( - blocksparse_tensors_torch[:4], + ( + blocksparse_tensors_torch.mask_block_cnt, + blocksparse_tensors_torch.mask_block_idx, + blocksparse_tensors_torch.full_block_cnt, + blocksparse_tensors_torch.full_block_idx, + blocksparse_tensors_torch.dq_write_order, + blocksparse_tensors_torch.dq_write_order_full, + ), seqlen_q, seqlen_k, aux_tensors, diff --git a/flash_attn/cute/cute_dsl_utils.py b/flash_attn/cute/cute_dsl_utils.py index 79ebd9df6cf..6dfad6606ef 100644 --- a/flash_attn/cute/cute_dsl_utils.py +++ b/flash_attn/cute/cute_dsl_utils.py @@ -1,9 +1,7 @@ # Copyright (c) 2025, Tri Dao. -import os -import pathlib from typing import Tuple -from functools import partial, lru_cache +from functools import lru_cache import torch @@ -28,6 +26,8 @@ torch.float16: cutlass.Float16, torch.bfloat16: cutlass.BFloat16, torch.float32: cutlass.Float32, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e5m2: cutlass.Float8E5M2, } @@ -41,27 +41,6 @@ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: return torch.cuda.get_device_capability(device) -def load_cubin_module_data_patched(cubin_data, filepath): - pathlib.Path(filepath).write_bytes(cubin_data) - return load_cubin_module_data_og(cubin_data) - - -def cute_compile_patched(*args, **kwargs): - """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set.""" - cubin_path = os.getenv("CUTE_CUBIN_PATH", None) - if cubin_path is not None: - cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( - load_cubin_module_data_patched, filepath=cubin_path - ) - output = cute_compile_og(*args, **kwargs) - if cubin_path is not None: - cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og - if extract is not None: - sass = extract(cubin_path, None) - pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) - return output - - def assume_strides_aligned(t): """Assume all strides except the last are divisible by 128 bits. @@ -82,7 +61,20 @@ def assume_tensor_aligned(t): def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True): """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1.""" - tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi) + # NOTE: torch 2.9.1 doesn't support fp8 via DLPack but 2.11.0 nightly does + # currently export raw bytes as uint8 and tell cutlass correct type + # can directly export as fp8 when torch supports it + if t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + tensor = from_dlpack( + t.view(torch.uint8).detach(), + assumed_align=assumed_align, + enable_tvm_ffi=enable_tvm_ffi, + ) + tensor.element_type = ( + cutlass.Float8E4M3FN if t.dtype == torch.float8_e4m3fn else cutlass.Float8E5M2 + ) + else: + tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi) if fully_dynamic: return tensor.mark_layout_dynamic() if leading_dim == -1: @@ -127,3 +119,38 @@ def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: patterns are not interchangeable. """ return tuple(s == 0 for s in tensor.stride()) + + +# credit: monellz (https://github.com/NVIDIA/cutlass/issues/2658#issuecomment-3630564264) +def dump_kernel_attributes(compiled_kernel): + from cuda.bindings import driver + from cutlass.utils import HardwareInfo + import torch + + device_id = torch.cuda.current_device() + hardware_info = HardwareInfo(device_id=device_id) + cubin_data = compiled_kernel.artifacts.CUBIN + assert cubin_data is not None, "cubin_data is None, need '--keep-cubin' option when compiling" + cuda_library = hardware_info._checkCudaErrors( + driver.cuLibraryLoadData(cubin_data, None, None, 0, None, None, 0) + ) + kernels = hardware_info._checkCudaErrors(driver.cuLibraryEnumerateKernels(1, cuda_library)) + kernel = hardware_info._checkCudaErrors(driver.cuKernelGetFunction(kernels[0])) + # more metrics: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g5e92a1b0d8d1b82cb00dcfb2de15961b + local_size_bytes = hardware_info._checkCudaErrors( + driver.cuFuncGetAttribute( + driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, + kernel, + ) + ) + num_regs = hardware_info._checkCudaErrors( + driver.cuFuncGetAttribute( + driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NUM_REGS, + kernel, + ) + ) + + print("--- Kernel Info ---") + print(f"local_size_bytes: {local_size_bytes}") + print(f"num_regs: {num_regs}") + print("--- End Kernel Info ---") diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 824abdda139..81c8ac68bd9 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -46,6 +46,8 @@ def __init__( AtomLayoutNdKV: int = 8, AtomLayoutMdQ: int = 1, V_in_regs: bool = False, + score_mod: cutlass.Constexpr | None = None, + score_mod_bwd: cutlass.Constexpr | None = None, ): """Initializes the configuration for a flash attention v2 kernel. @@ -63,7 +65,8 @@ def __init__( :param is_causal: is causal """ self.dtype = dtype - # padding head_dim to a multiple of 16 as k_block_size + # padding head_dim to a multiple of 32 (stricter than fwd's 16) due to + # backward kernel register layout requirements for dQ/dK/dV accumulation hdim_multiple_of = 32 self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) head_dim_v = head_dim_v if head_dim_v is not None else head_dim @@ -90,6 +93,8 @@ def __init__( self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_mma_warps and SdP_swapAB and not dKV_swapAB self.V_in_regs = V_in_regs self.share_QV_smem = V_in_regs + self.score_mod = score_mod + self.score_mod_bwd = score_mod_bwd @staticmethod def can_implement( @@ -377,7 +382,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, @@ -430,7 +434,7 @@ def __call__( tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) - softmax_scale_log2 = softmax_scale * math.log2(math.e) + softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod) self.kernel( mQ, mK, @@ -773,6 +777,7 @@ def kernel( smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params, load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum, m_block_max=m_block_max, + softmax_scale=softmax_scale, softmax_scale_log2=softmax_scale_log2, ) @@ -861,6 +866,7 @@ def compute_one_m_block( load_Q_LSE: Callable, load_dO_dPsum: Callable, m_block_max: cutlass.Int32, + softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, mask_fn: Optional[Callable] = None, ): @@ -890,13 +896,24 @@ def load_dO_next(): smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, swap_AB=self.SdP_swapAB, ) + acc_S_pre = cute.make_fragment_like(acc_S) + acc_S_pre.store(acc_S.load()) tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0]) cute.autovec_copy( smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE ) + acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) + acc_S_pre_mn = layout_utils.reshape_acc_to_mn(acc_S_pre) + if cutlass.const_expr(self.score_mod is not None): + for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True): + acc_S_mn[r, None].store( + self.score_mod( + acc_S_mn[r, None].load() * softmax_scale, + 0, 0, 0, 0, None, [], + ) + ) if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) - acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) bidx = 0 # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE) @@ -926,7 +943,14 @@ def load_dO_next(): # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum) for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True): - acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r])) + grad_val = acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r]) + if cutlass.const_expr(self.score_mod_bwd is not None): + grad_val = self.score_mod_bwd( + grad_val, + acc_S_pre_mn[r, None].load() * softmax_scale, + 0, 0, 0, 0, None, [], + ) + acc_dP_mn[r, None].store(grad_val) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) rP = cute.make_fragment_like(acc_S, self.dtype) rP.store(acc_S.load().to(self.dtype)) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index d93ea5cc50b..8142def5ebb 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -43,6 +43,7 @@ def __init__( head_dim_v: int, tile_m: int = 128, num_threads: int = 256, + use_padded_offsets: bool = True, ): """ All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension @@ -64,6 +65,7 @@ def __init__( self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded self.num_threads = num_threads + self.use_padded_offsets = use_padded_offsets @staticmethod def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: @@ -241,6 +243,14 @@ def kernel( work_tile = tile_scheduler.initial_work_tile_info() m_block, head_idx, batch_idx, _ = work_tile.tile_idx + # This kernel is launched with use_pdl=True, so the GPU may start executing it in + # "prologue" mode while the previous stream kernel is still running. We must wait + # before touching any upstream GMEM outputs (mO, mdO, mLSE); otherwise we risk + # reading a partially-written dout, which silently corrupts dpsum = sum(O * dO) and + # propagates to dQ/dK via dS = P * (dP - dpsum). + if const_expr(self.use_pdl): + cute.arch.griddepcontrol_wait() + if work_tile.is_valid_tile: # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. @@ -250,9 +260,15 @@ def kernel( ) mO_cur = seqlen.offset_batch(mO, batch_idx, dim=0)[None, head_idx, None] mdO_cur = seqlen.offset_batch(mdO, batch_idx, dim=0)[None, head_idx, None] - mPdPsum_cur = seqlen.offset_batch(mPdPsum, batch_idx, dim=2, padded=True)[ - None, head_idx - ] + # Stats buffers (dpsum/lse_log2) are always consumed with padded q-offsets + # on the generic backward path (mdQaccum is present). Keep dedicated hd256 + # behavior controlled by self.use_padded_offsets. + stats_use_padded_offsets = self.use_padded_offsets + if const_expr(mdQaccum is not None): + stats_use_padded_offsets = True + mPdPsum_cur = seqlen.offset_batch( + mPdPsum, batch_idx, dim=2, padded=stats_use_padded_offsets + )[None, head_idx] headdim_v = mO_cur.shape[cute.rank(mO_cur) - 1] seqlen_q = seqlen.seqlen seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) @@ -330,7 +346,11 @@ def kernel( # Clear dQaccum if const_expr(mdQaccum is not None): mdQaccum_cur = seqlen.offset_batch( - mdQaccum, batch_idx, dim=2, padded=True, multiple=self.head_dim_padded + mdQaccum, + batch_idx, + dim=2, + padded=True, + multiple=self.head_dim_padded, )[None, head_idx] blkdQaccum_shape = (self.tile_m * self.head_dim_padded,) gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,)) @@ -341,9 +361,9 @@ def kernel( cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum) if const_expr(mLSE is not None): - mLSElog2_cur = seqlen.offset_batch(mLSElog2, batch_idx, dim=2, padded=True)[ - None, head_idx - ] + mLSElog2_cur = seqlen.offset_batch( + mLSElog2, batch_idx, dim=2, padded=stats_use_padded_offsets + )[None, head_idx] gLSElog2 = cute.local_tile(mLSElog2_cur, (self.tile_m,), (m_block,)) LOG2_E = math.log2(math.e) if tidx < seqlen_q_rounded - m_block * self.tile_m: diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index e06cd811fc6..9184ddeb029 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -58,6 +58,7 @@ def __init__( tile_n: int = 128, is_persistent: bool = False, deterministic: bool = False, + spt: Optional[bool] = None, cluster_size: int = 1, use_2cta_instrs: bool = False, score_mod: cutlass.Constexpr | None = None, @@ -116,6 +117,7 @@ def __init__( self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = False self.deterministic = deterministic + self.spt_override = spt # Score mod and mask mod support self.score_mod = score_mod @@ -456,7 +458,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, @@ -706,7 +707,11 @@ def __call__( TileScheduler = SingleTileLPTBwdScheduler else: TileScheduler = SingleTileScheduler - self.spt = (self.is_causal or self.is_local) and self.deterministic + if const_expr(self.spt_override is None): + self.spt = (self.is_causal or self.is_local) and self.deterministic + else: + assert self.spt_override is not None + self.spt = self.spt_override and self.deterministic tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks cute.size(mQ.shape[2]), # num_heads = num_query_heads @@ -2990,8 +2995,17 @@ def compute_loop( # prefetch_LSE = not self.is_causal prefetch_LSE = False - # some tiles might be empty due to block sparsity + + curr_q_cnt = Int32(0) + curr_q_idx = None + curr_full_cnt = Int32(0) + curr_full_idx = None + loop_count = m_block_max - m_block_min + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max + ) if const_expr(self.use_block_sparsity): + assert blocksparse_tensors is not None ( curr_q_cnt, curr_q_idx, @@ -3007,17 +3021,14 @@ def compute_loop( m_block_max=m_block_max, ) process_tile = loop_count > Int32(0) - else: - process_tile = ( - const_expr(not self.is_local and not self.is_varlen_q) - or m_block_min < m_block_max - ) - loop_count = m_block_max - m_block_min # Mainloop # Block sparsity: iterate over sparse m_block count and derive actual m_block # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly. for iter_idx in cutlass.range(loop_count, unroll=1): + m_block = m_block_min + iter_idx + m_block_oob = False + is_full_block = False if const_expr(self.use_block_sparsity): m_block, is_full_block = get_m_block_from_iter_bwd( iter_idx, @@ -3029,10 +3040,6 @@ def compute_loop( m_block_max=m_block_max, ) m_block_oob = m_block >= m_block_max - else: - m_block = m_block_min + iter_idx - m_block_oob = False - is_full_block = False # Prefetch 1 stage of LSE pipeline_LSE.consumer_wait(consumer_state_LSE) tSrLSE_s2r = cute.make_fragment(tScS_t2r[None, 0, 0, 0].shape, Float32) @@ -3413,6 +3420,38 @@ def compute_loop( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + @cute.jit + def _dq_semaphore_lock_value( + self, + iter_idx: Int32, + curr_q_cnt: Int32, + curr_dq_write_order: Optional[cute.Tensor], + curr_dq_write_order_full: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], + block_info: BlockInfo, + seqlen, + m_block: Int32, + n_block: Int32, + n_block_global_max: Int32, + ) -> Int32: + lock_value = n_block + if const_expr(self.spt): + n_block_max_for_m_block = block_info.get_n_block_max_for_m_block( + seqlen, m_block, n_block_global_max + ) + lock_value = n_block_max_for_m_block - 1 - n_block + if const_expr(self.use_block_sparsity): + assert blocksparse_tensors is not None + if const_expr(blocksparse_tensors.dq_write_order is not None): + sparse_iter = iter_idx // self.subtile_factor + if sparse_iter < curr_q_cnt: + assert curr_dq_write_order is not None + lock_value = curr_dq_write_order[sparse_iter] + else: + assert curr_dq_write_order_full is not None + lock_value = curr_dq_write_order_full[sparse_iter - curr_q_cnt] + return lock_value + @cute.jit def dQacc_reduce( self, @@ -3485,13 +3524,24 @@ def dQacc_reduce( ) if const_expr(self.deterministic): + assert mdQ_semaphore is not None mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] - # delay_semaphore_release = self.is_causal and not self.tile_hdim == 192 - delay_semaphore_release = not self.tile_hdim == 192 + delay_semaphore_release = not self.tile_hdim == 192 and not self.use_block_sparsity + n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n) - # some tiles might be empty due to block sparsity + curr_q_cnt = Int32(0) + curr_q_idx = None + curr_full_cnt = Int32(0) + curr_full_idx = None + curr_dq_write_order = None + curr_dq_write_order_full = None + loop_count = m_block_max - m_block_min + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max + ) if const_expr(self.use_block_sparsity): + assert blocksparse_tensors is not None ( curr_q_cnt, curr_q_idx, @@ -3507,17 +3557,25 @@ def dQacc_reduce( m_block_max=m_block_max, ) process_tile = loop_count > Int32(0) - else: - process_tile = ( - const_expr(not self.is_local and not self.is_varlen_q) - or m_block_min < m_block_max - ) - loop_count = m_block_max - m_block_min + if const_expr(self.deterministic and self.use_block_sparsity): + assert blocksparse_tensors is not None + if const_expr(blocksparse_tensors.dq_write_order is not None): + assert blocksparse_tensors.dq_write_order is not None + curr_dq_write_order = blocksparse_tensors.dq_write_order[ + batch_idx, head_idx, n_block, None + ] + if const_expr(blocksparse_tensors.dq_write_order_full is not None): + assert blocksparse_tensors.dq_write_order_full is not None + curr_dq_write_order_full = blocksparse_tensors.dq_write_order_full[ + batch_idx, head_idx, n_block, None + ] # dQacc_reduce mainloop # Block sparsity: iterate over sparse m_block count and derive actual m_block # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly. for iter_idx in cutlass.range(loop_count, unroll=1): + m_block = m_block_min + iter_idx + m_block_oob_upper = False if const_expr(self.use_block_sparsity): m_block, _ = get_m_block_from_iter_bwd( iter_idx, @@ -3528,10 +3586,7 @@ def dQacc_reduce( subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) - if m_block_max > 0: - m_block = cutlass.min(m_block, m_block_max - 1) - else: - m_block = m_block_min + iter_idx + m_block_oob_upper = m_block >= m_block_max pipeline_dQ.consumer_wait(dQ_consumer_state) # TMEM -> RMEM tdQrdQ_t2r = cute.make_fragment(tdQrdQ_t2r_shape, Float32) @@ -3542,6 +3597,8 @@ def dQacc_reduce( pipeline_dQ.consumer_release(dQ_consumer_state) dQ_consumer_state.advance() + if m_block_max > 0: + m_block = cutlass.min(m_block, m_block_max - 1) gdQaccum_cur = gdQaccum[None, None, m_block] tdQrdQ_shape = ( @@ -3559,22 +3616,28 @@ def dQacc_reduce( cute.arch.fence_view_async_shared() # semaphore acquire if const_expr(self.deterministic and stage == 0): - if const_expr(self.spt): - _, n_block_max_for_m_block = block_info.get_n_block_min_max( - seqlen, m_block + if not m_block_oob_upper: + lock_value = self._dq_semaphore_lock_value( + iter_idx, + curr_q_cnt, + curr_dq_write_order, + curr_dq_write_order_full, + blocksparse_tensors, + block_info, + seqlen, + m_block, + n_block_cta_group, + n_block_global_max, + ) + barrier.wait_eq( + mdQ_semaphore_cur[(m_block, None)].iterator, + tidx, + cta_rank_in_cluster, + lock_value, ) - lock_value = n_block_max_for_m_block - 1 - n_block_cta_group - else: - lock_value = n_block_cta_group - barrier.wait_eq( - mdQ_semaphore_cur[(m_block, None)].iterator, - tidx, - cta_rank_in_cluster, - lock_value, - ) self.reduce_sync_barrier.arrive_and_wait() # Copy from shared memory to global memory - if is_tma_warp: + if is_tma_warp and not m_block_oob_upper: with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, smem_idx].iterator, @@ -3583,20 +3646,12 @@ def dQacc_reduce( ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(self.sdQaccum_stage - 1, read=read_flag) + elif is_tma_warp: + # Drain pending TMA stores so SMEM buffers are safe to reuse + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() dQ_tma_store_producer_state.advance() - # Directly add to gmem, much slower - # tdQgdQ = thr_copy_dQaccum_r2s.partition_D(gdQaccum[None, stage, m_block]) - # assert cute.size(tdQrdQ_r2s) == cute.size(tdQgdQ) - # for i in cutlass.range(cute.size(tdQrdQ_r2s) // 4, unroll_full=True): - # copy_utils.atomic_add_fp32x4( - # tdQrdQ_r2s[4 * i], - # tdQrdQ_r2s[4 * i + 1], - # tdQrdQ_r2s[4 * i + 2], - # tdQrdQ_r2s[4 * i + 3], - # utils.elem_pointer(tdQgdQ, 4 * i), - # ) - # semaphore release for prior m_block + if const_expr(self.deterministic and stage == 0 and delay_semaphore_release): if m_block > m_block_min: barrier.arrive_inc( @@ -3618,12 +3673,13 @@ def dQacc_reduce( # NOTE: arrive_inc calls red_release which issues membar if const_expr(self.deterministic and not delay_semaphore_release): if const_expr(self.sdQaccum_stage > 1 and not self.tile_hdim == 192): - if is_tma_warp: + if is_tma_warp and not m_block_oob_upper: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() - barrier.arrive_inc( - mdQ_semaphore_cur[m_block, None].iterator, tidx, cta_rank_in_cluster, 1 - ) + if not m_block_oob_upper: + barrier.arrive_inc( + mdQ_semaphore_cur[m_block, None].iterator, tidx, cta_rank_in_cluster, 1 + ) if process_tile: if is_tma_warp: @@ -3639,7 +3695,10 @@ def dQacc_reduce( ) if const_expr( - self.deterministic and not self.spt and block_info.window_size_left is not None + self.deterministic + and not self.spt + and not self.use_block_sparsity + and block_info.window_size_left is not None ): m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m) for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1): @@ -3864,6 +3923,7 @@ def epilogue_dK_or_dV_tma( deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1 if const_expr(deterministic_KV): + assert mdKV_semaphore is not None mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] if const_expr(not self.dKV_postprocess): diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index f724b5a11e3..2e420924e92 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -350,7 +350,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, @@ -403,6 +402,15 @@ def _qkv_transpose(t): if const_expr(self.deterministic): assert mdQ_semaphore is not None mdQ_semaphore = layout_utils.select(mdQ_semaphore, mode=[2, 3, 1, 0]) + if const_expr(self.deterministic and self.qhead_per_kvhead > 1): + assert mdK_semaphore is not None + assert mdV_semaphore is not None + mdK_semaphore, mdV_semaphore = [ + layout_utils.select(t, mode=[2, 3, 1, 0]) for t in (mdK_semaphore, mdV_semaphore) + ] + else: + mdK_semaphore = None + mdV_semaphore = None self.num_mma_threads = tiled_mma_SdP.size assert self.num_mma_threads + 128 == self.num_threads @@ -599,6 +607,8 @@ def _qkv_transpose(t): blocksparse_tensors, qhead_per_kvhead_divmod, mdQ_semaphore, + mdK_semaphore, + mdV_semaphore, window_size_left, window_size_right, ).launch( @@ -652,6 +662,8 @@ def kernel( blocksparse_tensors: Optional[BlockSparseTensors] = None, qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, mdQ_semaphore: Optional[cute.Tensor] = None, + mdK_semaphore: Optional[cute.Tensor] = None, + mdV_semaphore: Optional[cute.Tensor] = None, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, ): @@ -788,6 +800,8 @@ def kernel( tiled_mma_dQ, mdK, mdV, + mdK_semaphore, + mdV_semaphore, mdQaccum, sQ, sK, @@ -1093,6 +1107,8 @@ def mma( tiled_mma_dQ: cute.TiledMma, mdK: cute.Tensor, mdV: cute.Tensor, + mdK_semaphore: Optional[cute.Tensor], + mdV_semaphore: Optional[cute.Tensor], mdQaccum: cute.Tensor, sQ: cute.Tensor, sK: cute.Tensor, @@ -1201,7 +1217,6 @@ def mma( tiled_mma_SdP, sP_cpy, tidx, - self.arch, transpose=self.SdP_swapAB, position_independent=True, major_mode_size=mms_PdS, @@ -1211,7 +1226,6 @@ def mma( tiled_mma_SdP, sdS_cpy, tidx, - self.arch, transpose=self.SdP_swapAB, position_independent=True, major_mode_size=mms_PdS, @@ -1391,6 +1405,8 @@ def mma( head_idx, batch_idx, qhead_per_kvhead_divmod, + mdK_semaphore, + mdV_semaphore, ) else: # KV tile with zero Q blocks produces no dK/dV; write zeros. @@ -1414,6 +1430,8 @@ def mma( head_idx, batch_idx, qhead_per_kvhead_divmod, + mdK_semaphore, + mdV_semaphore, ) tile_scheduler.advance_to_next_work() @@ -1618,6 +1636,8 @@ def epilogue_dKV( head_idx: Int32, batch_idx: Int32, qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, + mdK_semaphore: Optional[cute.Tensor] = None, + mdV_semaphore: Optional[cute.Tensor] = None, ): epi_barrier = cutlass.pipeline.NamedBarrier( barrier_id=int(NamedBarrierBwd.Epilogue), num_threads=self.num_mma_threads @@ -1645,7 +1665,6 @@ def epilogue_dKV( tiled_mma_dV, sdV, tidx, - self.arch, transpose=self.dKV_swapAB, position_independent=True, ) @@ -1653,7 +1672,6 @@ def epilogue_dKV( tiled_mma_dK, sdK, tidx, - self.arch, transpose=self.dKV_swapAB, position_independent=True, ) @@ -1674,11 +1692,18 @@ def epilogue_dKV( store_dK() cute.arch.cp_async_bulk_commit_group() else: + deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1 sdKaccum_shape0 = self.tile_n * self.tile_hdim // self.num_wg_mma sdVaccum_shape0 = self.tile_n * self.tile_hdimv // self.num_wg_mma sdKaccum_layout = cute.make_layout((sdKaccum_shape0, self.num_wg_mma)) sdVaccum_layout = cute.make_layout((sdVaccum_shape0, self.num_wg_mma)) head_idx_kv = head_idx // qhead_per_kvhead_divmod + if const_expr(deterministic_KV): + assert mdK_semaphore is not None + assert mdV_semaphore is not None + mdK_semaphore_cur = mdK_semaphore[n_block, None, head_idx_kv, batch_idx] + mdV_semaphore_cur = mdV_semaphore[n_block, None, head_idx_kv, batch_idx] + lock_value = head_idx % self.qhead_per_kvhead mdKaccum_cur = seqlen.offset_batch_K( mdK, batch_idx, dim=2, padded=True, multiple=self.tile_hdim )[None, head_idx_kv] @@ -1702,7 +1727,10 @@ def epilogue_dKV( tdKsdKaccum = thr_copy_dKVaccum_r2s.partition_D(sdKaccum) tdVsdVaccum = thr_copy_dKVaccum_r2s.partition_D(sdVaccum) - cute.arch.cp_async_bulk_wait_group(0, read=True) + read_flag = const_expr(not deterministic_KV) + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + if const_expr(deterministic_KV): + barrier.wait_eq(mdK_semaphore_cur.iterator, tidx, 0, lock_value) epi_barrier.arrive_and_wait() tdKrdKaccum_flat = cute.make_tensor(acc_dK.iterator, tdKsdKaccum.shape) cute.autovec_copy(tdKrdKaccum_flat, tdKsdKaccum) @@ -1718,7 +1746,10 @@ def epilogue_dKV( ) cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + if const_expr(deterministic_KV): + barrier.arrive_inc(mdK_semaphore_cur.iterator, tidx, 0, 1) + barrier.wait_eq(mdV_semaphore_cur.iterator, tidx, 0, lock_value) epi_barrier.arrive_and_wait() tdVrdVaccum_flat = cute.make_tensor(acc_dV.iterator, tdVsdVaccum.shape) cute.autovec_copy(tdVrdVaccum_flat, tdVsdVaccum) @@ -1733,6 +1764,9 @@ def epilogue_dKV( self.tma_copy_bytes["dVacc"] // self.num_wg_mma, ) cute.arch.cp_async_bulk_commit_group() + if const_expr(deterministic_KV): + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + barrier.arrive_inc(mdV_semaphore_cur.iterator, tidx, 0, 1) @cute.jit def dQaccum_store( diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 33031d0f2cc..ecca680bc91 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -27,7 +27,7 @@ from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.pack_gqa import PackGQA @@ -1179,8 +1179,8 @@ def load_V_next(): m_block, acc_S, n_block, - seqlen, softmax_scale=softmax.softmax_scale, + seqlen=seqlen, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) @@ -1219,6 +1219,40 @@ def load_K_next(): ) # if const_expr(self.num_stages > 1): # load_K_next() + @cute.jit + def apply_score_mod( + self, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + acc_S, + n_block, + softmax_scale, + seqlen, + aux_tensors: Optional[list] = None, + fastdiv_mods=None, + ): + # Prepare index tensor + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS) + tScS = thr_mma_qk.partition_C(cS) + + apply_score_mod_inner( + acc_S, + tScS, + self.score_mod, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info=seqlen, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility diff --git a/flash_attn/cute/flash_fwd_mla_sm100.py b/flash_attn/cute/flash_fwd_mla_sm100.py new file mode 100644 index 00000000000..2987b4c0460 --- /dev/null +++ b/flash_attn/cute/flash_fwd_mla_sm100.py @@ -0,0 +1,3441 @@ +import math +import time +from functools import partial +from typing import Callable, Optional + +import torch +import torch.utils.benchmark as benchmark + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int64, Int32, Uint32, Boolean, const_expr +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.runtime import from_dlpack +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.utils import ClcDynamicPersistentTileScheduler + +from quack import copy_utils + +from flash_attn.cute.pack_gqa import pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.mask import AttentionMask +import flash_attn.cute.blackwell_helpers as fa_sm100_utils +from flash_attn.cute.softmax import SoftmaxSm100 +from flash_attn.cute.tile_scheduler import ( + ClcState, + SchedulingMode, + TileSchedulerArguments, + TileSchedulerProtocol, + SingleTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + ParamsBase, +) +from flash_attn.cute.fa_logging import fa_log, fa_printf +from flash_attn.cute.utils import smid + +from flash_attn.cute.topk_gather_kv import CpasyncGatherKVManager + +from flash_attn.cute.testing import attention_ref + +from flash_attn.cute.named_barrier import NamedBarrierFwdSm100_MLA2CTA + +from flash_attn.cute.cute_dsl_utils import dump_kernel_attributes + + +class FlashAttentionMLAForwardSm100: + def __init__( + self, + is_causal: bool = False, + use_cpasync_load_KV: bool = False, + topk_length: int = 2048, + is_topk_gather: bool = True, + pack_gqa: bool = False, + qhead_per_kvhead: int = 1, + nheads_kv: int = 1, + hdim: int = 64, + hdimv: int = 512, + is_varlen_q: bool = False, + disable_bitmask: bool = False, + use_clc_scheduler: bool = True, + ): + self.is_causal = is_causal + self.is_local = False + self.pack_gqa = pack_gqa + self.qhead_per_kvhead = qhead_per_kvhead + self.nheads_kv = nheads_kv + self.is_varlen_q = is_varlen_q + self.use_tma_O = True + self.use_cpasync_load_KV = use_cpasync_load_KV + self.use_tma_KV = not use_cpasync_load_KV + self.topk_length = topk_length + self.is_topk_gather = is_topk_gather + if is_topk_gather: + assert pack_gqa + assert qhead_per_kvhead == 128, "require MQA 128 for DSA path" + assert use_cpasync_load_KV + # user-provided option if topk indices guaranteed in bounds + self.disable_bitmask = disable_bitmask + + # ==== tile scheduler ==== + self.is_persistent = False + self.use_clc_scheduler = use_clc_scheduler and not is_varlen_q + self.sched_stages = 1 + self.scheduling_mode = ( + SchedulingMode.CLC if self.use_clc_scheduler else SchedulingMode.STATIC + ) + + if const_expr(is_varlen_q): + self.TileScheduler = SingleTileVarlenScheduler + elif self.use_clc_scheduler: + self.TileScheduler = SingleTileLPTScheduler + else: + self.TileScheduler = SingleTileScheduler + + fa_log( + 1, + f"TileScheduler={self.TileScheduler.__name__}, scheduling_mode={self.scheduling_mode.name}", + ) + + # ==== thread info ==== + self.num_softmax_threads = 128 + self.num_epilogue_threads = 128 + self.num_load_threads = 32 + self.num_mma_threads = 32 + self.num_empty_threads = 32 if use_cpasync_load_KV else 64 + self.num_relay_threads = 32 if use_cpasync_load_KV else 0 + self.num_cpasync_load_threads = 128 if use_cpasync_load_KV else 0 + self.num_threads = ( + self.num_softmax_threads + + self.num_epilogue_threads + + self.num_load_threads + + self.num_mma_threads + + self.num_empty_threads + + self.num_relay_threads + + self.num_cpasync_load_threads + ) + self.num_warps = self.num_threads // 32 + assert self.num_warps == 12 or self.num_warps == 16 + self.softmax_warp_indices = (0, 1, 2, 3) + self.epilogue_warp_indices = (4, 5, 6, 7) + self.load_warp_id = 8 + self.mma_warp_id = 9 + self.clc_scheduler_warp_id = 10 + self.relay_warp_id = 11 + self.empty_warp_ids = tuple( + w + for w, active in [ + (self.relay_warp_id, not use_cpasync_load_KV), + (self.clc_scheduler_warp_id, not self.use_clc_scheduler), + ] + if active + ) + self.cpasync_load_warp_indices = (12, 13, 14, 15) + + # ==== register usage ==== + if self.num_warps == 16: + self.num_regs_load = 80 + self.num_regs_mma = 80 + self.num_regs_softmax = 208 + self.num_regs_epilogue = 128 + self.num_regs_cpasync = 96 if self.use_cpasync_load_KV else 0 + self.num_regs_other = 48 + else: + self.num_regs_load = 168 - 40 + self.num_regs_mma = 168 - 40 + self.num_regs_softmax = 168 + 80 + self.num_regs_epilogue = 168 - 40 + self.num_regs_cpasync = 0 + self.num_regs_other = 48 + + self.num_regs_per_thread = 168 if self.num_warps == 12 else 128 + self.num_regs_total = 504 if self.num_warps == 12 else 512 + + assert ( + self.num_regs_mma + + self.num_regs_softmax + + self.num_regs_epilogue + + self.num_regs_cpasync + <= self.num_regs_total + ) + + # ==== 2cta info ==== + self.use_2cta_instrs = True + self.cta_group = tcgen05.CtaGroup.TWO + self.cta_group_size = 2 + self.cluster_shape_mn = (2, 1) + self.cluster_shape_mnk = (2, 1, 1) + + # ==== problem shape info ==== + self.hdim = hdim + self.hdimv = hdimv + self.cta_tile_m = 64 + self.cluster_tile_m = self.cta_group_size * self.cta_tile_m + self.tile_n = 128 + assert ( + pack_gqa is False + or self.cluster_tile_m % qhead_per_kvhead == 0 + or qhead_per_kvhead % self.cluster_tile_m == 0 + ) + self.num_hdimv_splits = 2 # split hdimv in half for our Qv @ V^T and P @ V mmas. + assert hdimv % 32 == 0 + assert self.topk_length % (self.tile_n * 2) == 0 or not self.is_topk_gather + self.epi_tile = (self.cta_tile_m, self.hdimv // self.num_hdimv_splits) + + # ==== MMA info ==== + self.mma_tiler_QK = ( + self.cluster_tile_m, + self.tile_n, + self.hdim, + ) + self.mma_tiler_QviVi = ( + self.cluster_tile_m, + self.tile_n, + self.hdimv // self.num_hdimv_splits, + ) + self.mma_tiler_PVti = ( + self.cluster_tile_m, + self.hdimv // self.num_hdimv_splits, + self.tile_n, + ) + self.major_mode_Q = tcgen05.OperandMajorMode.K + self.major_mode_Qvi = tcgen05.OperandMajorMode.K + self.major_mode_K = tcgen05.OperandMajorMode.K + self.major_mode_Vi = tcgen05.OperandMajorMode.K + self.major_mode_Vti = tcgen05.OperandMajorMode.MN + self.major_mode_P = tcgen05.OperandMajorMode.K + self.operand_source_Q = tcgen05.OperandSource.SMEM + self.operand_source_Qvi = tcgen05.OperandSource.SMEM + self.operand_source_P = tcgen05.OperandSource.SMEM + + # ==== pipeline info ==== + self.num_stages_Q = 1 + self.num_stages_K = 1 + self.num_stages_Qvi = 1 + self.num_stages_Vi = 2 + self.num_stages_S = 2 + self.num_stages_P = 1 + self.num_stages_Oi = 1 + self.num_stages_sm_stats = 2 + self.num_stages_bitmask = 4 + assert self.num_stages_S == 2, "mainloops expect 2 stages for S" + + # ==== dtype info ==== + self.dtype_acc = Float32 + + # ==== TMEM info ==== + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + self.tmem_cols_S = self.tile_n // self.cta_group_size + self.tmem_cols_Oi = (self.hdimv // self.num_hdimv_splits) // self.cta_group_size + self.tmem_offset_S = [ + self.tmem_cols_S * stage for stage in range(self.num_stages_S) + ] # allocate 64 TMEM columns for each stage of S + self.tmem_offset_O0 = self.tmem_cols_S * self.num_stages_S + self.tmem_offset_O1 = self.tmem_offset_O0 + self.tmem_cols_Oi + self.tmem_offsets_O = [self.tmem_offset_O0, self.tmem_offset_O1] + self.total_tmem = self.tmem_offset_O1 + self.tmem_cols_Oi + assert self.total_tmem <= self.tmem_alloc_cols, ( + f"Total TMEM columns allocated {self.total_tmem} exceeds capacity {self.tmem_alloc_cols}" + ) + + def _get_shared_storage_cls(self): + self.buffer_align_bytes = 1024 + + def smem_struct_align(dtype, staged_layout): + return cute.struct.Align[ + cute.struct.MemRange[dtype, cute.cosize(staged_layout)], + self.buffer_align_bytes, + ] + + def mbar_struct(num_stages): + return cute.struct.MemRange[Int64, 2 * num_stages] + + (sQ_struct, sK_struct, sQv0_struct, sQv1_struct, sV0_struct, sV1_struct, sP_struct) = ( + smem_struct_align(dtype, layout) + for dtype, layout in [ + (self.dtype_Q, self.sQ_layout_staged), + (self.dtype_K, self.sK_layout_staged), + (self.dtype_Qv, self.sQvi_layout_staged), + (self.dtype_Qv, self.sQvi_layout_staged), + (self.dtype_V, self.sVi_layout_staged), + (self.dtype_V, self.sVi_layout_staged), + (self.dtype_P, self.sP_layout_staged), + ] + ) + sStats_struct = cute.struct.MemRange[Float32, cute.cosize(self.sStats_layout)] + sScale_struct = cute.struct.MemRange[Float32, cute.cosize(self.sScale_layout)] + sBitmask_struct = cute.struct.MemRange[Uint32, cute.cosize(self.sBitmask_layout)] + + ( + mbar_ptr_Q_struct, + mbar_ptr_K_struct, + mbar_ptr_Qv0_struct, + mbar_ptr_Qv1_struct, + mbar_ptr_V0_struct, + mbar_ptr_V1_struct, + mbar_ptr_S_struct, + mbar_ptr_P_struct, + mbar_ptr_O0_struct, + mbar_ptr_O1_struct, + mbar_sm_stats_struct, + mbar_bitmask_struct, + ) = ( + mbar_struct(n) + for n in [ + self.num_stages_Q, + self.num_stages_K, + self.num_stages_Qvi, + self.num_stages_Qvi, + self.num_stages_Vi, + self.num_stages_Vi, + self.num_stages_S, + self.num_stages_P, + self.num_stages_Oi, + self.num_stages_Oi, + self.num_stages_sm_stats, + self.num_stages_bitmask, + ] + ) + mbar_ptr_tmem_dealloc_struct = Int64 + tmem_holding_buf_struct = Int32 + + self.sched_stages = 1 + clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 + clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + + @cute.struct + class SharedStorage: + mbar_ptr_Q: mbar_ptr_Q_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_Qv0: mbar_ptr_Qv0_struct + mbar_ptr_Qv1: mbar_ptr_Qv1_struct + mbar_ptr_V0: mbar_ptr_V0_struct + mbar_ptr_V1: mbar_ptr_V1_struct + mbar_ptr_S: mbar_ptr_S_struct + mbar_ptr_P: mbar_ptr_P_struct + mbar_ptr_O0: mbar_ptr_O0_struct + mbar_ptr_O1: mbar_ptr_O1_struct + mbar_ptr_K_cpasync: mbar_ptr_K_struct + mbar_ptr_V0_cpasync: mbar_ptr_V0_struct + mbar_ptr_V1_cpasync: mbar_ptr_V1_struct + mbar_ptr_sm_stats: mbar_sm_stats_struct + mbar_ptr_bitmask: mbar_bitmask_struct + mbar_ptr_tmem_dealloc: mbar_ptr_tmem_dealloc_struct + tmem_holding_buf: tmem_holding_buf_struct + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] + clc_response: cute.struct.MemRange[Int32, clc_response_size] + sO_empty_mbar_ptr: cutlass.Int64 + + sRowMax: sStats_struct + sRowSum: sStats_struct + sScale: sScale_struct + sBitmask: sBitmask_struct + sQv0: sQv0_struct + sQv1: sQv1_struct + sQ: sQ_struct + sK: sK_struct + sV0: sV0_struct + sV1: sV1_struct + sP: sP_struct + + # print("smem bytes = ", SharedStorage.size_in_bytes()) + + return SharedStorage + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mQv: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + mLSE: Optional[cute.Tensor], # (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + softmax_scale: Float32, + mCuSeqlensQ: Optional[cute.Tensor] = None, # (b + 1) + mCuSeqlensK: Optional[cute.Tensor] = None, # (b + 1) + mSeqUsedQ: Optional[cute.Tensor] = None, # (b) + mSeqUsedK: Optional[cute.Tensor] = None, # (b) + mIndexTopk: Optional[ + cute.Tensor + ] = None, # (b, s_q, topk) or (total_q, topk) if there is cu_seqlens_q + mPageTable: Optional[cute.Tensor] = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, + ): + # ==== asserts for unimplemented features ==== + assert mPageTable is None, "page table tbd for MLA" + + # ==== dtype info ==== + self.dtype_Q = mQ.element_type + self.dtype_K = mK.element_type + self.dtype_Qv = mQv.element_type + self.dtype_V = mV.element_type + self.dtype_P = mV.element_type + self.dtype_O = mO.element_type + + # ==== Prepare Tensors ==== + new_stride = lambda mX: ( + *(cute.assume(s, divby=128 // mX.element_type.width) for s in mX.stride[:-1]), + mX.stride[-1], + ) + mQ, mQv, mK, mV, mO = [ + cute.make_tensor(mX.iterator, cute.make_layout(mX.shape, stride=new_stride(mX))) + for mX in (mQ, mQv, mK, mV, mO) + ] + + # (b, s, h, d) -> (s, d, h, b) or + # (total, h, d) -> (total, d, h) or + # (num_pages, page_size, h_k, d) -> (page_size, d, h_k, num_pages) + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mQ, mQv, mO = [ + cute.make_tensor(mX.iterator, cute.select(mX.layout, mode=QO_layout_transpose)) + for mX in (mQ, mQv, mO) + ] + mK, mV = [ + cute.make_tensor(mX.iterator, cute.select(mX.layout, mode=KV_layout_transpose)) + for mX in (mK, mV) + ] + # (s_k, dv, h_k, b) -> (dv, s_k, h_k, b) or + # (total_k, dv, h_k) -> (dv, total_k, h_k) + V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] + mVt = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose)) + # (b, h, s_q) -> (s_q, h, b) or (h, total_q) -> (total_q, h) + # (b, s_q, topk) -> (topk, s_q, b) or (total_q, topk) -> (topk, total_q) + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE, mIndexTopk = ( + cute.make_tensor(t.iterator, cute.select(t.layout, mode=LSE_layout_transpose)) + if t is not None + else None + for t in (mLSE, mIndexTopk) + ) + topk_length_dynamic = mIndexTopk.shape[0] if mIndexTopk is not None else None + + self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO) + + mO_og = mO + if const_expr(self.pack_gqa): + mQ, mQv, mO = [ + pack_gqa_layout(mX, self.qhead_per_kvhead, self.nheads_kv, head_idx=2) + for mX in (mQ, mQv, mO) + ] + if const_expr(mLSE is not None): + mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, self.nheads_kv, head_idx=1) + + def split_hdimv(m, dim: int): + """Re-tile mode `dim` of tensor `m` from hdimv into (hdimv//S, S), + and return (slice0, slice1) where slice_i selects chunk i.""" + S = self.num_hdimv_splits + chunk = self.hdimv // S + split_shape = (*m.shape[:dim], (chunk, S), *m.shape[dim + 1 :]) + split_stride = (*m.layout.stride[:dim], (1, chunk), *m.layout.stride[dim + 1 :]) + split = cute.make_tensor(m.iterator, cute.make_layout(split_shape, stride=split_stride)) + ndim = len(split.shape) + slices = [ + split[(*([None] * dim), (None, i), *([None] * (ndim - dim - 1)))] for i in range(S) + ] + return slices + + # (seqlen_q, hdimv//2, nheads, batch) or (total_q, hdimv//2, nheads) + mQv0, mQv1 = split_hdimv(mQv, dim=1) + mV0, mV1 = split_hdimv(mV, dim=1) + # (hdimv//2, seqlen_k, nheads_k, batch) or (hdimv//2, total_k, nheads_k) + mVt0, mVt1 = split_hdimv(mVt, dim=0) + + # ==== Prepare MMAs ==== + # (local_var, dtype_a, major_a, major_b, mma_tiler, operand_source_a) + # fmt: off + _mma_specs = [ + ("tiled_mma_QK", self.dtype_Q, self.major_mode_Q, self.major_mode_K, self.mma_tiler_QK, self.operand_source_Q), + ("tiled_mma_QviVi", self.dtype_Qv, self.major_mode_Qvi, self.major_mode_Vi, self.mma_tiler_QviVi, self.operand_source_Qvi), + ("tiled_mma_PVti", self.dtype_P, self.major_mode_P, self.major_mode_Vti, self.mma_tiler_PVti, self.operand_source_P), + ] + tiled_mma_QK, tiled_mma_QviVi, tiled_mma_PVti = ( + sm100_utils.make_trivial_tiled_mma( + dtype_a, major_a, major_b, self.dtype_acc, self.cta_group, mma_tiler[:2], operand_source_a, + ) + for _, dtype_a, major_a, major_b, mma_tiler, operand_source_a in _mma_specs + ) + # fmt: on + + # ==== Prepare SMEM layouts and TMAs ==== + # (attr, make_fn, tiled_mma, mma_tiler, dtype, num_stages) + # fmt: off + _smem_layout_specs = [ + ("sQ_layout", sm100_utils.make_smem_layout_a, tiled_mma_QK, self.mma_tiler_QK, self.dtype_Q, self.num_stages_Q), + ("sK_layout", sm100_utils.make_smem_layout_b, tiled_mma_QK, self.mma_tiler_QK, self.dtype_K, self.num_stages_K), + ("sQvi_layout", sm100_utils.make_smem_layout_a, tiled_mma_QviVi, self.mma_tiler_QviVi, self.dtype_Qv, self.num_stages_Qvi), + ("sVi_layout", sm100_utils.make_smem_layout_b, tiled_mma_QviVi, self.mma_tiler_QviVi, self.dtype_V, self.num_stages_Vi), + ("sVti_layout", sm100_utils.make_smem_layout_b, tiled_mma_PVti, self.mma_tiler_PVti, self.dtype_V, self.num_stages_Vi), + ("sP_layout", sm100_utils.make_smem_layout_a, tiled_mma_PVti, self.mma_tiler_PVti, self.dtype_P, self.num_stages_P), + ] + for attr, make_fn, tiled_mma, mma_tiler, dtype, num_stages in _smem_layout_specs: + ab_kwarg = "a_dtype" if make_fn is sm100_utils.make_smem_layout_a else "b_dtype" + staged = make_fn( + tiled_mma=tiled_mma, + mma_tiler_mnk=mma_tiler, + num_stages=num_stages, + **{ab_kwarg: dtype}, + ) + setattr(self, f"{attr}_staged", staged) + setattr(self, attr, cute.select(staged, mode=[0, 1, 2])) + # fmt: on + + self.sStats_layout = cute.make_layout((self.cta_tile_m, self.cta_group_size)) + self.sScale_layout = cute.make_layout((self.cta_tile_m, self.num_stages_sm_stats)) + self.sBitmask_layout = cute.make_layout((self.tile_n // 32, self.num_stages_bitmask)) + + # fmt: off + for attr, dtype, layout in [ + ("tma_copy_bytes_Q", self.dtype_Q, self.sQ_layout), + ("tma_copy_bytes_K", self.dtype_K, self.sK_layout), + ("tma_copy_bytes_Qvi", self.dtype_Qv, self.sQvi_layout), + ("tma_copy_bytes_Vi", self.dtype_V, self.sVi_layout), + ]: + setattr(self, attr, cute.size_in_bytes(dtype, layout) * self.cta_group_size) + # fmt: on + + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(self.cta_group) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_QK.thr_id.shape,) + ) + cta_shape = cta_layout_vmnk.shape + + def make_tma(make_fn, mX, smem_layout, mma_tiler, tiled_mma): + return make_fn(tma_load_op, mX, smem_layout, mma_tiler, tiled_mma, cta_shape) + + A, B = cute.nvgpu.make_tiled_tma_atom_A, cute.nvgpu.make_tiled_tma_atom_B + + # (atom_name, tensor_name, make_fn, m, smem_layout, mma_tiler, tiled_mma, kv_only) + # fmt: off + _tma_specs = [ + ("tma_atom_Q", "tma_tensor_Q", A, mQ, self.sQ_layout, self.mma_tiler_QK, tiled_mma_QK, False), + ("tma_atom_Qv0", "tma_tensor_Qv0", A, mQv0, self.sQvi_layout, self.mma_tiler_QviVi, tiled_mma_QviVi, False), + ("tma_atom_Qv1", "tma_tensor_Qv1", A, mQv1, self.sQvi_layout, self.mma_tiler_QviVi, tiled_mma_QviVi, False), + ("tma_atom_K", "tma_tensor_K", B, mK, self.sK_layout, self.mma_tiler_QK, tiled_mma_QK, True), + ("tma_atom_V0", "tma_tensor_V0", B, mV0, self.sVi_layout, self.mma_tiler_QviVi, tiled_mma_QviVi, True), + ("tma_atom_V1", "tma_tensor_V1", B, mV1, self.sVi_layout, self.mma_tiler_QviVi, tiled_mma_QviVi, True), + ("tma_atom_Vt0", "tma_tensor_Vt0", B, mVt0, self.sVti_layout, self.mma_tiler_PVti, tiled_mma_PVti, True), + ("tma_atom_Vt1", "tma_tensor_Vt1", B, mVt1, self.sVti_layout, self.mma_tiler_PVti, tiled_mma_PVti, True), + ] + _tmas = {} + for atom_name, tensor_name, make_fn, m, smem_layout, mma_tiler, tiled_mma, kv_only in _tma_specs: + _tmas[atom_name], _tmas[tensor_name] = ( + make_tma(make_fn, m, smem_layout, mma_tiler, tiled_mma) + if const_expr(not kv_only or self.use_tma_KV) + else (None, None) + ) + + (tma_atom_Q, tma_tensor_Q, + tma_atom_Qv0, tma_tensor_Qv0, + tma_atom_Qv1, tma_tensor_Qv1, + tma_atom_K, tma_tensor_K, + tma_atom_V0, tma_tensor_V0, + tma_atom_V1, tma_tensor_V1, + tma_atom_Vt0, tma_tensor_Vt0, + tma_atom_Vt1, tma_tensor_Vt1) = _tmas.values() + # fmt: on + + # ==== Set up Oi smem -> gmem tma store ==== + + self.overlap_sO_sV = True + if const_expr(self.overlap_sO_sV): + num_stages_sO = self.num_hdimv_splits * self.num_stages_Vi + else: + num_stages_sO = self.num_hdimv_splits + sO_layout = sm100_utils.make_smem_layout_epi( + self.dtype_O, self.o_layout, self.epi_tile, num_stages_sO + ) + self.ragged_tma_O = ( + self.use_tma_O + and self.is_varlen_q + and self.pack_gqa + and self.cta_tile_m % self.qhead_per_kvhead == 0 + ) + make_tiled_tma_atom_fn = ( + partial(make_packgqa_tiled_tma_atom, qhead_per_kvhead=self.qhead_per_kvhead, head_idx=2) + if const_expr(self.ragged_tma_O) + else cpasync.make_tiled_tma_atom + ) + if const_expr(self.use_tma_O): + mO_tma = mO_og if const_expr(self.ragged_tma_O) else mO + if const_expr(self.ragged_tma_O): + mO_tma = copy_utils.create_ragged_tensor_for_tma( + mO_tma, ragged_dim=0, ptr_shift=True + ) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + tma_atom_O, tma_tensor_O = make_tiled_tma_atom_fn( + tma_store_op, mO_tma, cute.select(sO_layout, mode=[0, 1]), self.epi_tile + ) + else: + tma_atom_O = None + tma_tensor_O = None + + # ==== Set up Oi rmem -> gmem copy ==== + universal_copy_bits = 128 + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype_O, + num_bits_per_copy=universal_copy_bits, + ) + thread_layout_O_r2g = cute.make_layout((64, 2), stride=(1, 64)) + value_layout_O_r2g = cute.make_layout( + (1, self.hdimv // self.num_hdimv_splits // self.cta_group_size) + ) + tiled_copy_O_r2g = cute.make_tiled_copy_tv( + atom=atom_universal_copy, + thr_layout=thread_layout_O_r2g, + val_layout=value_layout_O_r2g, + ) + + # ==== Allocate shared memory ==== + SharedStorage = self._get_shared_storage_cls() + + # ==== Tile scheduler ==== + + TileScheduler = self.TileScheduler + + tile_sched_args = TileSchedulerArguments( + num_block=cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tile_m), + num_head=cute.size(mQ.shape[2]), + num_batch=cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), + num_splits=1, # todo: split_kv + seqlen_k=cute.size(mK.shape[0]), # todo: page table + headdim=self.hdim, + headdim_v=self.hdimv, + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=( + self.cta_tile_m, + self.tile_n, + ), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.dtype_K.width // 8, + is_persistent=self.is_persistent, + # lpt=self.is_causal or self.is_local, + lpt=False, + is_split_kv=False, + cluster_shape_mn=self.cluster_shape_mn, + use_cluster_idx=False, + ) + tile_sched_params = TileScheduler.to_underlying_arguments( + tile_sched_args, scheduling_mode=self.scheduling_mode + ) + self.tile_scheduler_cls = TileScheduler + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + fa_printf(1, "grid = {}", grid_dim) + + # ==== Named Barrier ==== + self.cpasync_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.Cpasync), + num_threads=self.num_cpasync_load_threads, + ) + self.softmax_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.Softmax), + num_threads=self.num_softmax_threads, + ) + self.epi_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.Epilogue), + num_threads=self.num_epilogue_threads, + ) + # softmax -> correction + self.sm_stats_barrier_full = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.SoftmaxStatsFull), + num_threads=self.num_softmax_threads + self.num_epilogue_threads, + ) + self.sm_stats_barrier_empty = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.SoftmaxStatsEmpty), + num_threads=self.num_softmax_threads + self.num_epilogue_threads, + ) + + LOG2_E = math.log2(math.e) + softmax_scale_log2 = softmax_scale * LOG2_E + + # ==== Launch kernel ==== + self.kernel( + tma_tensor_Q, + tma_tensor_Qv0, + tma_tensor_Qv1, + tma_tensor_K if self.use_tma_KV else mK, + tma_tensor_V0 if self.use_tma_KV else mV0, + tma_tensor_V1 if self.use_tma_KV else mV1, + tma_tensor_Vt0 if self.use_tma_KV else mVt0, + tma_tensor_Vt1 if self.use_tma_KV else mVt1, + tma_tensor_O if self.use_tma_O else mO, + mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + mIndexTopk, + tma_atom_Q, + tma_atom_Qv0, + tma_atom_Qv1, + tma_atom_K, + tma_atom_V0, + tma_atom_V1, + tma_atom_Vt0, + tma_atom_Vt1, + tma_atom_O, + tiled_copy_O_r2g, + self.sQ_layout_staged, + self.sK_layout_staged, + self.sQvi_layout_staged, + self.sVi_layout_staged, + self.sVti_layout_staged, + self.sP_layout_staged, + self.sStats_layout, + self.sScale_layout, + self.sBitmask_layout, + sO_layout, + tiled_mma_QK, + tiled_mma_QviVi, + tiled_mma_PVti, + softmax_scale, + softmax_scale_log2, + topk_length_dynamic, + tile_sched_params, + SharedStorage, + ).launch( + grid=grid_dim, + block=( + self.num_threads, + 1, + 1, + ), + cluster=self.cluster_shape_mnk, + smem=SharedStorage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mQv0: cute.Tensor, + mQv1: cute.Tensor, + mK: cute.Tensor, + mV0: cute.Tensor, + mV1: cute.Tensor, + mVt0: cute.Tensor, + mVt1: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + mIndexTopk: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_Qv0: cute.CopyAtom, + tma_atom_Qv1: cute.CopyAtom, + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V0: Optional[cute.CopyAtom], + tma_atom_V1: Optional[cute.CopyAtom], + tma_atom_Vt0: Optional[cute.CopyAtom], + tma_atom_Vt1: Optional[cute.CopyAtom], + tma_atom_O: Optional[cute.CopyAtom], + tiled_copy_O_r2g: cute.TiledCopy, + sQ_layout_staged: cute.ComposedLayout, + sK_layout_staged: cute.ComposedLayout, + sQvi_layout_staged: cute.ComposedLayout, + sVi_layout_staged: cute.ComposedLayout, + sVti_layout_staged: cute.ComposedLayout, + sP_layout_staged: cute.ComposedLayout, + sStats_layout: cute.Layout, + sScale_layout: cute.Layout, + sBitmask_layout: cute.Layout, + sO_layout: cute.ComposedLayout, + tiled_mma_QK: cute.TiledMma, + tiled_mma_QviVi: cute.TiledMma, + tiled_mma_PVti: cute.TiledMma, + softmax_scale: Float32, + softmax_scale_log2: Float32, + topk_length_dynamic: Optional[Int32], + tile_sched_params: ParamsBase, + SharedStorage: cutlass.Constexpr[Callable], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_QK.thr_id.shape,) + ) + + cta_m_block, head_idx, batch_idx = cute.arch.block_idx() + cluster_m_block = cta_m_block // self.cta_group_size + mma_tile_coord_v = cta_m_block % cute.size(tiled_mma_QK.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + + # ==== Allocate SMEM ==== + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # ==== TMEM stuff ==== + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.TmemPtr), + num_threads=self.num_mma_threads + self.num_softmax_threads + self.num_epilogue_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + is_two_cta=self.use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.mbar_ptr_tmem_dealloc, + ) + + # ==== Prefetch TMA descriptors ==== + if warp_idx == self.load_warp_id: + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_Qv0) + cpasync.prefetch_descriptor(tma_atom_Qv1) + if const_expr(self.use_tma_KV): + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V0) + cpasync.prefetch_descriptor(tma_atom_V1) + cpasync.prefetch_descriptor(tma_atom_Vt0) + cpasync.prefetch_descriptor(tma_atom_Vt1) + if const_expr(self.use_tma_O): + cpasync.prefetch_descriptor(tma_atom_O) + + # ==== Construct pipelines ==== + tma_warp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + mma_warp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + sm_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_softmax_threads) + epi_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_epilogue_threads) + sm_threads_cluster = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.num_softmax_threads * self.cta_group_size + ) + epi_threads_cluster = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.num_epilogue_threads * self.cta_group_size + ) + + TmaUmma = pipeline.PipelineTmaUmma + AsyncUmma = pipeline.PipelineAsyncUmma + UmmaAsync = pipeline.PipelineUmmaAsync + Async = pipeline.PipelineAsync + + def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None): + return cls.create( + barrier_storage=mbar_ptr.data_ptr(), + num_stages=num_stages, + producer_group=producer, + consumer_group=consumer, + defer_sync=True, + **({"cta_layout_vmnk": cta_layout_vmnk} if cls is not Async else {}), + **({"tx_count": tx_count} if tx_count is not None else {}), + ) + + # Unconditional pipelines + # fmt: off + pipeline_Q = make_pipeline(TmaUmma, storage.mbar_ptr_Q, self.num_stages_Q, tma_warp, mma_warp, self.tma_copy_bytes_Q) + pipeline_Qv0 = make_pipeline(TmaUmma, storage.mbar_ptr_Qv0, self.num_stages_Qvi, tma_warp, mma_warp, self.tma_copy_bytes_Qvi) + pipeline_Qv1 = make_pipeline(TmaUmma, storage.mbar_ptr_Qv1, self.num_stages_Qvi, tma_warp, mma_warp, self.tma_copy_bytes_Qvi) + pipeline_S = make_pipeline(UmmaAsync, storage.mbar_ptr_S, self.num_stages_S, mma_warp, sm_threads_cluster) + pipeline_P = make_pipeline(AsyncUmma, storage.mbar_ptr_P, self.num_stages_P, sm_threads_cluster, mma_warp) + pipeline_O0 = make_pipeline(UmmaAsync, storage.mbar_ptr_O0, self.num_stages_Oi, mma_warp, epi_threads_cluster) + pipeline_O1 = make_pipeline(UmmaAsync, storage.mbar_ptr_O1, self.num_stages_Oi, mma_warp, epi_threads_cluster) + pipeline_sm_stats = make_pipeline(Async, storage.mbar_ptr_sm_stats, self.num_stages_sm_stats, sm_threads, epi_threads) + + # K/V pipelines: type and producer depend on use_tma_KV + if const_expr(self.use_tma_KV): + pipeline_K = make_pipeline(TmaUmma, storage.mbar_ptr_K, self.num_stages_K, tma_warp, mma_warp, self.tma_copy_bytes_K) + pipeline_V0 = make_pipeline(TmaUmma, storage.mbar_ptr_V0, self.num_stages_Vi, tma_warp, mma_warp, self.tma_copy_bytes_Vi) + pipeline_V1 = make_pipeline(TmaUmma, storage.mbar_ptr_V1, self.num_stages_Vi, tma_warp, mma_warp, self.tma_copy_bytes_Vi) + pipeline_K_cpasync = pipeline_V0_cpasync = pipeline_V1_cpasync = pipeline_bitmask = None + else: + cpasync_load_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_cpasync_load_threads) + relay_warps_cluster = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.cta_group_size) + relay_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_relay_threads) + + pipeline_K = make_pipeline(AsyncUmma, storage.mbar_ptr_K, self.num_stages_K, relay_warps_cluster, mma_warp) + pipeline_V0 = make_pipeline(AsyncUmma, storage.mbar_ptr_V0, self.num_stages_Vi, relay_warps_cluster, mma_warp) + pipeline_V1 = make_pipeline(AsyncUmma, storage.mbar_ptr_V1, self.num_stages_Vi, relay_warps_cluster, mma_warp) + pipeline_K_cpasync = make_pipeline(Async, storage.mbar_ptr_K_cpasync, self.num_stages_K, cpasync_load_threads, relay_threads) + pipeline_V0_cpasync = make_pipeline(Async, storage.mbar_ptr_V0_cpasync, self.num_stages_Vi, cpasync_load_threads, relay_threads) + pipeline_V1_cpasync = make_pipeline(Async, storage.mbar_ptr_V1_cpasync, self.num_stages_Vi, cpasync_load_threads, relay_threads) + pipeline_bitmask = ( + make_pipeline(Async, storage.mbar_ptr_bitmask, self.num_stages_bitmask, cpasync_load_threads, sm_threads) + if const_expr(self.is_topk_gather and not self.disable_bitmask) else None + ) + # fmt: on + + sO_empty_mbar_ptr = None + if const_expr(self.use_tma_O and self.overlap_sO_sV): + sO_empty_mbar_ptr = storage.sO_empty_mbar_ptr + if warp_idx == 0: + cute.arch.mbarrier_init(sO_empty_mbar_ptr, 1) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + + # ==== Get SMEM tensors ==== + # fmt: off + sQ, sK, sQv0, sQv1, sV0, sV1, sVt0, sVt1, sP = ( + store.get_tensor(layout.outer, swizzle=layout.inner) + for store, layout in [ + (storage.sQ, sQ_layout_staged), + (storage.sK, sK_layout_staged), + (storage.sQv0, sQvi_layout_staged), + (storage.sQv1, sQvi_layout_staged), + (storage.sV0, sVi_layout_staged), + (storage.sV1, sVi_layout_staged), + (storage.sV0, sVti_layout_staged), # sVt0 reuses sV0 storage + (storage.sV1, sVti_layout_staged), # sVt1 reuses sV1 storage + (storage.sP, sP_layout_staged), + ] + ) + # fmt: on + sRowMax = storage.sRowMax.get_tensor(sStats_layout) + sRowSum = storage.sRowSum.get_tensor(sStats_layout) + sScale = storage.sScale.get_tensor(sScale_layout) + sBitmask = None + if const_expr(self.is_topk_gather): + sBitmask = storage.sBitmask.get_tensor(sBitmask_layout) + + if const_expr(self.overlap_sO_sV): + sO_iterator = sV0.iterator + assert cute.cosize(sO_layout) <= cute.cosize(sVi_layout_staged) * self.num_hdimv_splits + else: + sO_iterator = sQv0.iterator + assert cute.cosize(sO_layout) <= cute.cosize(sQvi_layout_staged) * self.num_hdimv_splits + sO = cute.make_tensor( + cute.recast_ptr(sO_iterator, sO_layout.inner, self.dtype_O), sO_layout.outer + ) + + # ==== Get thread MMAs and accumulator fragments ==== + thr_mma_QK = tiled_mma_QK.get_slice(mma_tile_coord_v) + thr_mma_QviVi = tiled_mma_QviVi.get_slice(mma_tile_coord_v) + thr_mma_PVti = tiled_mma_PVti.get_slice(mma_tile_coord_v) + + acc_shape_QK = thr_mma_QK.partition_shape_C(self.mma_tiler_QK[:2]) + tStS = thr_mma_QK.make_fragment_C(cute.append(acc_shape_QK, self.num_stages_S)) + + acc_shape_PVi = thr_mma_PVti.partition_shape_C(self.mma_tiler_PVti[:2]) + tO0tO0 = thr_mma_PVti.make_fragment_C(acc_shape_PVi) + tO1tO1 = thr_mma_PVti.make_fragment_C(acc_shape_PVi) + tO0tO0 = cute.make_tensor(tO0tO0.iterator + self.tmem_offset_O0, tO0tO0.layout) + tO1tO1 = cute.make_tensor(tO1tO1.iterator + self.tmem_offset_O1, tO1tO1.layout) + + block_info = BlockInfo( + self.cta_tile_m * self.cta_group_size, + self.tile_n, + is_causal=self.is_causal, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + seqlen_k_static=mK.shape[0], + tile_m=self.cta_tile_m, + tile_n=self.tile_n, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + ) + AttentionMaskCls = partial( + AttentionMask, + self.cta_tile_m * self.cta_group_size, + self.tile_n, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + + if const_expr(self.use_clc_scheduler): + clc_response_ptr = storage.clc_response.data_ptr() + clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() + + clc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_clc_consumer_warps_per_cta = self.num_threads // cute.arch.WARP_SIZE + num_clc_consumer_warps = num_clc_consumer_warps_per_cta * self.cta_group_size + clc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps + ) + clc = ClcState.create( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_scheduler_cls.clc_problem_shape(tile_sched_params), + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + ), + pipeline=pipeline.PipelineClcFetchAsync.create( + barrier_storage=clc_mbar_ptr, + num_stages=self.sched_stages, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ), + consumer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.sched_stages + ), + producer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.sched_stages + ), + ) + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, clc=clc) + else: + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + assert isinstance(tile_scheduler, TileSchedulerProtocol), ( + f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}" + ) + + pipeline.pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + if const_expr(self.use_clc_scheduler): + if warp_idx == self.clc_scheduler_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_other) + if is_leader_cta: + self.clc_scheduler_warp(tile_scheduler) + else: + self.empty_warp(tile_scheduler) + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i] and warp_idx != self.clc_scheduler_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_other) + self.empty_warp(tile_scheduler) + else: + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i]: + cute.arch.setmaxregister_decrease(self.num_regs_other) + + if const_expr(self.use_cpasync_load_KV): + if warp_idx == self.relay_warp_id: + if const_expr(self.num_regs_load < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_load) + self.relay( + pipeline_K, + pipeline_V0, + pipeline_V1, + pipeline_K_cpasync, + pipeline_V0_cpasync, + pipeline_V1_cpasync, + sO_empty_mbar_ptr, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + + if warp_idx in self.cpasync_load_warp_indices: + if const_expr(self.num_regs_cpasync < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_cpasync) + self.load_cpasync( + mIndexTopk, + mK, + mV0, + mV1, + mVt0, + mVt1, + sK, + sV0, + sV1, + sVt0, + sVt1, + sBitmask, + pipeline_K, + pipeline_V0, + pipeline_V1, + pipeline_K_cpasync, + pipeline_V0_cpasync, + pipeline_V1_cpasync, + pipeline_bitmask, + sO_empty_mbar_ptr, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + + if warp_idx == self.load_warp_id: + if const_expr(self.num_regs_load < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_load) + self.load( + mQ, + mK, + mQv0, + mQv1, + mV0, + mV1, + mVt0, + mVt1, + sQ, + sK, + sQv0, + sQv1, + sV0, + sV1, + sVt0, + sVt1, + tma_atom_Q, + tma_atom_K, + tma_atom_Qv0, + tma_atom_Qv1, + tma_atom_V0, + tma_atom_V1, + tma_atom_Vt0, + tma_atom_Vt1, + pipeline_Q, + pipeline_K, + pipeline_Qv0, + pipeline_Qv1, + pipeline_V0, + pipeline_V1, + sO_empty_mbar_ptr, + thr_mma_QK, + thr_mma_QviVi, + thr_mma_PVti, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + + if warp_idx == self.mma_warp_id: + if const_expr(self.num_regs_mma < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_mma) + # ==== Allocate TMEM ==== + tmem.allocate(self.tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.dtype_acc) + self.mma( + sQ, + sK, + sQv0, + sQv1, + sV0, + sV1, + sVt0, + sVt1, + sP, + tiled_mma_QK, + tiled_mma_QviVi, + tiled_mma_PVti, + pipeline_Q, + pipeline_K, + pipeline_Qv0, + pipeline_Qv1, + pipeline_V0, + pipeline_V1, + pipeline_S, + pipeline_P, + pipeline_O0, + pipeline_O1, + sO_empty_mbar_ptr, + is_leader_cta, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + + if warp_idx in self.softmax_warp_indices: + cute.arch.setmaxregister_increase(self.num_regs_softmax) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.dtype_acc) + self.softmax_loop( + softmax_scale, + softmax_scale_log2, + mLSE, + sRowMax, + sRowSum, + sScale, + sBitmask, + sP, + tStS, + thr_mma_QK, + pipeline_S, + pipeline_P, + pipeline_sm_stats, + pipeline_bitmask, + sO_empty_mbar_ptr, + AttentionMaskCls, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + tmem_alloc_barrier.arrive() + + if warp_idx in self.epilogue_warp_indices: + if const_expr(self.num_regs_epilogue < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_epilogue) + elif const_expr(self.num_regs_epilogue > self.num_regs_per_thread): + cute.arch.setmaxregister_increase(self.num_regs_epilogue) + + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.dtype_acc) + self.correction_loop( + softmax_scale_log2, + mO, + mLSE, + tma_atom_O, + sRowMax, + sRowSum, + sScale, + sO, + tO0tO0, + tO1tO1, + pipeline_O0, + pipeline_O1, + pipeline_sm_stats, + sO_empty_mbar_ptr, + tiled_copy_O_r2g, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + tmem_alloc_barrier.arrive() + + @cute.jit + def clc_scheduler_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + tile_scheduler.prefetch_next_work() + work_tile = tile_scheduler.advance_to_next_work() + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE: + fa_printf( + 3, + "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", + smid(), + cute.arch.block_idx()[0], + work_tile.tile_idx[0], + work_tile.tile_idx[1], + work_tile.tile_idx[2], + work_tile.tile_idx[3], + work_tile.is_valid_tile, + ) + tile_scheduler.producer_tail() + + @cute.jit + def empty_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_tile = tile_scheduler.advance_to_next_work() + + @cute.jit + def relay( + self, + pipeline_K: pipeline.PipelineAsyncUmma, + pipeline_V0: pipeline.PipelineAsyncUmma, + pipeline_V1: pipeline.PipelineAsyncUmma, + pipeline_K_cpasync: pipeline.PipelineAsync, + pipeline_V0_cpasync: pipeline.PipelineAsync, + pipeline_V1_cpasync: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== Make pipeline states ==== + # pipeline_{K,V0,V1} producer + # pipeline_{K,V0,V1}_cpasync consumer + producer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_K + ) + producer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + producer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + consumer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_K + ) + consumer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Vi + ) + consumer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Vi + ) + relay_K_fn = partial(self.relay_inner, pipeline_K_cpasync, pipeline_K) + relay_V0_fn = partial(self.relay_inner, pipeline_V0_cpasync, pipeline_V0) + relay_V1_fn = partial(self.relay_inner, pipeline_V1_cpasync, pipeline_V1) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, + cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + + # ==== Prologue ==== + # relay K, V0, V1 + consumer_state_K, producer_state_K = relay_K_fn(consumer_state_K, producer_state_K) + consumer_state_V0, producer_state_V0 = relay_V0_fn(consumer_state_V0, producer_state_V0) + consumer_state_V1, producer_state_V1 = relay_V1_fn(consumer_state_V1, producer_state_V1) + + # ==== Mainloop ==== + for _ in cutlass.range(num_n_blocks - 1, unroll=2): + # relay K, V0, V1, Vt0, Vt1 + consumer_state_K, producer_state_K = relay_K_fn(consumer_state_K, producer_state_K) + for _ in cutlass.range_constexpr(2): + consumer_state_V0, producer_state_V0 = relay_V0_fn( + consumer_state_V0, producer_state_V0 + ) + consumer_state_V1, producer_state_V1 = relay_V1_fn( + consumer_state_V1, producer_state_V1 + ) + + # ==== Epilogue === + # relay Vt0, Vt1 + consumer_state_V0, producer_state_V0 = relay_V0_fn(consumer_state_V0, producer_state_V0) + consumer_state_V1, producer_state_V1 = relay_V1_fn(consumer_state_V1, producer_state_V1) + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + + pipeline_K.producer_tail(producer_state_K) + pipeline_V0.producer_tail(producer_state_V0) + pipeline_V1.producer_tail(producer_state_V1) + + @cute.jit + def relay_inner( + self, + pipeline_cpasync: pipeline.PipelineAsync, + pipeline_mma: pipeline.PipelineAsyncUmma, + consumer_state: pipeline.PipelineState, + producer_state: pipeline.PipelineState, + ): + pipeline_cpasync.consumer_wait(consumer_state) + with cute.arch.elect_one(): + pipeline_mma.producer_commit(producer_state) + consumer_state.advance() + producer_state.advance() + return consumer_state, producer_state + + @cute.jit + def load_cpasync( + self, + mIndexTopk: cute.Tensor, + mK: cute.Tensor, + mV0: cute.Tensor, + mV1: cute.Tensor, + mVt0: cute.Tensor, + mVt1: cute.Tensor, + sK: cute.Tensor, + sV0: cute.Tensor, + sV1: cute.Tensor, + sVt0: cute.Tensor, + sVt1: cute.Tensor, + sBitmask: Optional[cute.Tensor], + pipeline_K: pipeline.PipelineAsyncUmma, + pipeline_V0: pipeline.PipelineAsyncUmma, + pipeline_V1: pipeline.PipelineAsyncUmma, + pipeline_K_cpasync: pipeline.PipelineAsync, + pipeline_V0_cpasync: pipeline.PipelineAsync, + pipeline_V1_cpasync: pipeline.PipelineAsync, + pipeline_bitmask: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== cpasync load warpgroup ==== + # Description: loads tiles of K, V, V0, V1 from gmem to smem using cpasync + # produces: K, V, V0, V1, bitmask + # consumes: - + + # TODO: use cpasync for non-topk paged attn + assert sBitmask is not None, "cpasync load meant to be used with topk gather" + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + tidx = cute.arch.thread_idx()[0] % self.num_cpasync_load_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % ( + self.num_cpasync_load_threads // 32 + ) + + # ==== Make pipeline states ==== + # producer: acquire PipelineAsyncUmma <- mma + # producer: commit PipelineAsync -> relay + producer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_K + ) + producer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + producer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + if const_expr(not self.disable_bitmask): + producer_state_bitmask = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, + stages=self.num_stages_bitmask, + ) + if const_expr(self.use_tma_O): + producer_phase_O = Int32(1) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, + cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + num_n_block_groups = cute.ceil_div(num_n_blocks, self.num_stages_S) + + # cluster_m_block == m_idx under MQA 128 assumption + m_idx = cluster_m_block + if const_expr(not seqlen.has_cu_seqlens_q): + mIndexTopk_cur = mIndexTopk[None, m_idx, batch_idx] + else: + offset_q = seqlen.offset_q + mIndexTopk_cur = mIndexTopk[None, m_idx + offset_q] + + if const_expr(self.is_causal): + seqlen_k_limit = m_idx + 1 + seqlen.seqlen_k - seqlen.seqlen_q + else: + seqlen_k_limit = seqlen.seqlen_k + cpasync_gather_kv_manager = CpasyncGatherKVManager.create( + mIndexTopk_cur, + sBitmask, + cta_rank_in_cluster, + tidx, + warp_idx, + self.topk_length, + seqlen_k_limit, + self.tile_n, + self.hdim, + self.hdimv, + self.num_hdimv_splits, + self.num_cpasync_load_threads, + mK.element_type, + self.cta_group_size, + pipeline_bitmask, + self.num_stages_bitmask, + self.cpasync_barrier, + self.disable_bitmask, + ) + + # (seqlen_k, hdim) or (seqlen_k, hdimv//2) + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mV0_cur = seqlen.offset_batch_K(mV0, batch_idx, dim=3)[None, None, head_idx_kv] + mV1_cur = seqlen.offset_batch_K(mV1, batch_idx, dim=3)[None, None, head_idx_kv] + # (hdimv//2, seqlen_k) + if const_expr(not seqlen.has_cu_seqlens_k): + mVt0_cur = mVt0[None, None, head_idx_kv, batch_idx] + mVt1_cur = mVt1[None, None, head_idx_kv, batch_idx] + else: + mVt0_cur = cute.domain_offset((0, seqlen.offset_k), mVt0[None, None, head_idx_kv]) + mVt1_cur = cute.domain_offset((0, seqlen.offset_k), mVt1[None, None, head_idx_kv]) + # (hdimv//4, seqlen_k) + hdimv_split_per_cta = self.hdimv // self.num_hdimv_splits // self.cta_group_size + mVt0_cur = cute.tiled_divide(mVt0_cur, (hdimv_split_per_cta,))[ + None, cta_rank_in_cluster, None + ] + mVt1_cur = cute.tiled_divide(mVt1_cur, (hdimv_split_per_cta,))[ + None, cta_rank_in_cluster, None + ] + + load_K = partial( + self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_K, + pipeline_K_cpasync, + sK, + False, + "K", + mK_cur, + ) + load_V0 = partial( + self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_V0, + pipeline_V0_cpasync, + sV0, + False, + "V", + mV0_cur, + ) + load_V1 = partial( + self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_V1, + pipeline_V1_cpasync, + sV1, + False, + "V", + mV1_cur, + ) + load_Vt0 = partial( + self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_V0, + pipeline_V0_cpasync, + sVt0, + True, + "V", + mVt0_cur, + ) + load_Vt1 = partial( + self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_V1, + pipeline_V1_cpasync, + sVt1, + True, + "V", + mVt1_cur, + ) + + # gather KV path processes n_blocks in increasing order + n_block = 0 + + # ==== Prologue ==== + # K, V0, V1 + cpasync_gather_kv_manager.load_index_topk(n_block, transpose=False) + producer_state_K = load_K(producer_state_K) + producer_state_V0 = load_V0(producer_state_V0) + producer_state_V1 = load_V1(producer_state_V1) + if const_expr(not self.disable_bitmask): + producer_state_bitmask = cpasync_gather_kv_manager.compute_bitmask( + producer_state_bitmask + ) + + if const_expr(self.use_tma_O and self.overlap_sO_sV): + cute.arch.mbarrier_wait(sO_empty_mbar_ptr, phase=producer_phase_O) + producer_phase_O ^= 1 + + # ==== Mainloop ==== + for n_block_group in cutlass.range(num_n_block_groups - 1, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + n_block = n_block_group * self.num_stages_S + stage + # K, V0, V1 + cpasync_gather_kv_manager.load_index_topk(n_block + 1, transpose=False) + producer_state_K = load_K(producer_state_K) + producer_state_V0 = load_V0(producer_state_V0) + producer_state_V1 = load_V1(producer_state_V1) + if const_expr(not self.disable_bitmask): + producer_state_bitmask = cpasync_gather_kv_manager.compute_bitmask( + producer_state_bitmask + ) + # Vt0, Vt1 + cpasync_gather_kv_manager.load_index_topk(n_block, transpose=True) + producer_state_V0 = load_Vt0(producer_state_V0) + producer_state_V1 = load_Vt1(producer_state_V1) + + # ==== Epilogue ==== + for stage in cutlass.range_constexpr(self.num_stages_S): + n_block = (num_n_block_groups - 1) * self.num_stages_S + stage + if const_expr(stage == 0): + # K, V0, V1 + cpasync_gather_kv_manager.load_index_topk(n_block + 1, transpose=False) + producer_state_K = load_K(producer_state_K) + producer_state_V0 = load_V0(producer_state_V0) + producer_state_V1 = load_V1(producer_state_V1) + if const_expr(not self.disable_bitmask): + producer_state_bitmask = cpasync_gather_kv_manager.compute_bitmask( + producer_state_bitmask + ) + + # Vt0, Vt1 + cpasync_gather_kv_manager.load_index_topk(n_block, transpose=True) + producer_state_V0 = load_Vt0(producer_state_V0) + producer_state_V1 = load_Vt1(producer_state_V1) + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + + pipeline_K_cpasync.producer_tail(producer_state_K) + pipeline_V0_cpasync.producer_tail(producer_state_V0) + pipeline_V1_cpasync.producer_tail(producer_state_V1) + if const_expr(not self.disable_bitmask): + pipeline_bitmask.producer_tail(producer_state_bitmask) + + @cute.jit + def cpasync_gather_load_KV( + self, + cpasync_gather_kv_manager: CpasyncGatherKVManager, + pipeline_mma: pipeline.PipelineAsyncUmma, + pipeline_cpasync: pipeline.PipelineAsync, + sX: cute.Tensor, + transpose: bool, + K_or_V: str, + mX: cute.Tensor, + producer_state: pipeline.PipelineState, + ): + stage, phase = producer_state.index, producer_state.phase + pipeline_mma.producer_acquire(producer_state) + cpasync_gather_kv_manager.load_X(mX, sX[None, None, None, stage], transpose, K_or_V) + cute.arch.cp_async_commit_group() + pipeline_cpasync.sync_object_full.arrive_cp_async_mbarrier(stage) + producer_state.advance() + return producer_state + + @cute.jit + def load( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mQv0: cute.Tensor, + mQv1: cute.Tensor, + mV0: cute.Tensor, + mV1: cute.Tensor, + mVt0: cute.Tensor, + mVt1: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sQv0: cute.Tensor, + sQv1: cute.Tensor, + sV0: cute.Tensor, + sV1: cute.Tensor, + sVt0: cute.Tensor, + sVt1: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_Qv0: cute.CopyAtom, + tma_atom_Qv1: cute.CopyAtom, + tma_atom_V0: cute.CopyAtom, + tma_atom_V1: cute.CopyAtom, + tma_atom_Vt0: cute.CopyAtom, + tma_atom_Vt1: cute.CopyAtom, + pipeline_Q: pipeline.PipelineAsync, + pipeline_K: pipeline.PipelineAsync, + pipeline_Qv0: pipeline.PipelineAsync, + pipeline_Qv1: pipeline.PipelineAsync, + pipeline_V0: pipeline.PipelineAsync, + pipeline_V1: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + thr_mma_QK: cute.ThrMma, + thr_mma_QviVi: cute.ThrMma, + thr_mma_PVti: cute.ThrMma, + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== Load warp ==== + # Description: loads tiles of Q, Qv, K, V, V0, V1 from gmem to smem using TMA + # produces: Q, Qv, K, V, V0, V1 + # consumes: - + + mQvs = [mQv0, mQv1] + mVs = [mV0, mV1] + mVts = [mVt0, mVt1] + + sQvs = [sQv0, sQv1] + sVs = [sV0, sV1] + sVts = [sVt0, sVt1] + + tma_atom_Qvs = [tma_atom_Qv0, tma_atom_Qv1] + tma_atom_Vs = [tma_atom_V0, tma_atom_V1] + tma_atom_Vts = [tma_atom_Vt0, tma_atom_Vt1] + + pipeline_Qvs = [pipeline_Qv0, pipeline_Qv1] + pipeline_Vs = [pipeline_V0, pipeline_V1] + + # ==== Make pipeline states ==== + producer_state_Q = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Q + ) + producer_state_Qv0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Qvi + ) + producer_state_Qv1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Qvi + ) + if const_expr(self.use_tma_KV): + producer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_K + ) + producer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + producer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + if const_expr(self.use_tma_O): + producer_phase_O = Int32(1) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, + cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + even_n_blocks = num_n_blocks % 2 == 0 and num_n_blocks > 0 + num_n_block_groups = cute.ceil_div(num_n_blocks, self.num_stages_S) + + # ==== Partition GMEM tensors ==== + # (seqlen_q, hdim or hdimv//2) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + mQvs_cur = [ + seqlen.offset_batch_Q(mQvs[split], batch_idx, dim=3)[None, None, head_idx] + for split in range(self.num_hdimv_splits) + ] + # (mma_tile_m, hdim or hdimv//2) + gQ = cute.local_tile( + mQ_cur, + (self.mma_tiler_QK[0], self.mma_tiler_QK[2]), + (cluster_m_block, 0), + ) + gQvs = [ + cute.local_tile( + mQvs_cur[split], + (self.mma_tiler_QviVi[0], self.mma_tiler_QviVi[2]), + (cluster_m_block, 0), + ) + for split in range(self.num_hdimv_splits) + ] + tSgQ = thr_mma_QK.partition_A(gQ) + tSgQvs = [ + thr_mma_QviVi.partition_A(gQvs[split]) for split in range(self.num_hdimv_splits) + ] + tQsQ, tQgQ = cpasync.tma_partition( + atom=tma_atom_Q, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sQ, 0, 3), + gmem_tensor=cute.group_modes(tSgQ, 0, 3), + ) + tQvsQvs, tQvgQvs = zip( + *[ + cpasync.tma_partition( + atom=tma_atom, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sQv, 0, 3), + gmem_tensor=cute.group_modes(tSgQv, 0, 3), + ) + for tma_atom, sQv, tSgQv in zip(tma_atom_Qvs, sQvs, tSgQvs) + ] + ) + + if const_expr(self.use_tma_KV): + # (seqlen_k, hdim) or (seqlen_k, hdimv//2) + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mVs_cur = [ + seqlen.offset_batch_K(mVs[split], batch_idx, dim=3)[None, None, head_idx_kv] + for split in range(self.num_hdimv_splits) + ] + # (hdimv//2, seqlen_k) + if const_expr(not seqlen.has_cu_seqlens_k): + mVts_cur = [ + mVts[split][None, None, head_idx_kv, batch_idx] + for split in range(self.num_hdimv_splits) + ] + else: + mVts_cur = [ + cute.domain_offset( + (0, seqlen.offset_k), mVts[split][None, None, head_idx_kv] + ) + for split in range(self.num_hdimv_splits) + ] + # (tile_n, hdim or hdimv//2, num_n_blocks) + gK = cute.local_tile( + mK_cur, + (self.mma_tiler_QK[1], self.mma_tiler_QK[2]), + (None, 0), + ) + gVs = [ + cute.local_tile( + mVs_cur[split], + (self.mma_tiler_QviVi[1], self.mma_tiler_QviVi[2]), + (None, 0), + ) + for split in range(self.num_hdimv_splits) + ] + # (hdim or hdimv//2, tile_n, num_n_blocks) + gVts = [ + cute.local_tile( + mVts_cur[split], + (self.mma_tiler_PVti[1], self.mma_tiler_PVti[2]), + (0, None), + ) + for split in range(self.num_hdimv_splits) + ] + tSgK = thr_mma_QK.partition_B(gK) + tSgVs = [ + thr_mma_QviVi.partition_B(gVs[split]) for split in range(self.num_hdimv_splits) + ] + tOgVts = [ + thr_mma_PVti.partition_B(gVts[split]) for split in range(self.num_hdimv_splits) + ] + tKsK, tKgK = cpasync.tma_partition( + atom=tma_atom_K, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sK, 0, 3), + gmem_tensor=cute.group_modes(tSgK, 0, 3), + ) + tVsVs, tVgVs = zip( + *[ + cpasync.tma_partition( + atom=tma_atom, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sV, 0, 3), + gmem_tensor=cute.group_modes(tSgV, 0, 3), + ) + for tma_atom, sV, tSgV in zip(tma_atom_Vs, sVs, tSgVs) + ] + ) + tVtsVts, tVtgVts = zip( + *[ + cpasync.tma_partition( + atom=tma_atom, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sVt, 0, 3), + gmem_tensor=cute.group_modes(tOgV, 0, 3), + ) + for tma_atom, sVt, tOgV in zip(tma_atom_Vts, sVts, tOgVts) + ] + ) + + load_Q = partial(self.load_inner, tma_atom_Q, tQgQ, tQsQ, pipeline_Q) + load_Qv = partial(self.load_inner, tma_atom_Qvs, tQvgQvs, tQvsQvs, pipeline_Qvs) + if const_expr(self.use_tma_KV): + load_K = partial(self.load_inner, tma_atom_K, tKgK, tKsK, pipeline_K) + load_V = partial(self.load_inner, tma_atom_Vs, tVgVs, tVsVs, pipeline_Vs) + load_Vt = partial(self.load_inner, tma_atom_Vts, tVtgVts, tVtsVts, pipeline_Vs) + + # ==== Load stationary operands ==== + + # copy Q, Qvi gmem -> smem + producer_state_Q = load_Q(producer_state_Q) + producer_state_Qv0 = load_Qv(producer_state_Qv0, split=0) + producer_state_Qv1 = load_Qv(producer_state_Qv1, split=1) + + if const_expr(self.use_tma_KV): + # ==== Prologue ==== + n_block_first = n_block_max - 1 + # copy K gmem -> smem + producer_state_K = load_K(producer_state_K, n_block=n_block_first) + # copy Vi gmem -> smem + producer_state_V0 = load_V(producer_state_V0, n_block=n_block_first, split=0) + producer_state_V1 = load_V(producer_state_V1, n_block=n_block_first, split=1) + + if const_expr(self.use_tma_O and self.overlap_sO_sV): + cute.arch.mbarrier_wait(sO_empty_mbar_ptr, phase=producer_phase_O) + producer_phase_O ^= 1 + + # ==== Main loop ==== + for n_block_group in cutlass.range(num_n_block_groups - 1, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + n_block = n_block_max - 1 - n_block_group * self.num_stages_S - stage + # copy K gmem -> smem + producer_state_K = load_K(producer_state_K, n_block=n_block - 1) + # copy Vi gmem -> smem + producer_state_V0 = load_V(producer_state_V0, n_block=n_block - 1, split=0) + producer_state_V1 = load_V(producer_state_V1, n_block=n_block - 1, split=1) + # copy Vti gmem -> smem + producer_state_V0 = load_Vt(producer_state_V0, n_block=n_block, split=0) + producer_state_V1 = load_Vt(producer_state_V1, n_block=n_block, split=1) + + # ==== Epilogue ==== + num_final_n_blocks = self.num_stages_S if even_n_blocks else self.num_stages_S - 1 + for stage in cutlass.range(num_final_n_blocks, unroll_full=True): + n_block = num_final_n_blocks - 1 - stage + if n_block > 0: + # copy K gmem -> smem + producer_state_K = load_K(producer_state_K, n_block=n_block - 1) + # copy Vi gmem -> smem + producer_state_V0 = load_V(producer_state_V0, n_block=n_block - 1, split=0) + producer_state_V1 = load_V(producer_state_V1, n_block=n_block - 1, split=1) + # copy Vti gmem -> smem + producer_state_V0 = load_Vt(producer_state_V0, n_block=n_block, split=0) + producer_state_V1 = load_Vt(producer_state_V1, n_block=n_block, split=1) + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + + pipeline_Q.producer_tail(producer_state_Q) + pipeline_Qv0.producer_tail(producer_state_Qv0) + pipeline_Qv1.producer_tail(producer_state_Qv1) + if const_expr(self.use_tma_KV): + pipeline_K.producer_tail(producer_state_K) + pipeline_V0.producer_tail(producer_state_V0) + pipeline_V1.producer_tail(producer_state_V1) + + @cute.jit + def load_inner( + self, + tma_atom: cute.CopyAtom, + tXgX: cute.Tensor, + tXsX: cute.Tensor, + load_pipeline: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + n_block: Optional[Int32] = None, + split: Optional[Int32] = None, + ): + stage = producer_state.index + if const_expr(split is not None): + tma_atom = tma_atom[split] + tXgX = tXgX[split] + tXsX = tXsX[split] + load_pipeline = load_pipeline[split] + if const_expr(n_block is not None): + tXgX = tXgX[(None, n_block)] + tXsX = tXsX[(None, stage)] + + load_pipeline.producer_acquire(producer_state) + tma_bar_ptr = load_pipeline.producer_get_barrier(producer_state) + cute.copy(tma_atom, tXgX, tXsX, tma_bar_ptr=tma_bar_ptr) + producer_state.advance() + return producer_state + + @cute.jit + def mma( + self, + sQ: cute.Tensor, + sK: cute.Tensor, + sQv0: cute.Tensor, + sQv1: cute.Tensor, + sV0: cute.Tensor, + sV1: cute.Tensor, + sVt0: cute.Tensor, + sVt1: cute.Tensor, + sP: cute.Tensor, + tiled_mma_QK: cute.TiledMma, + tiled_mma_QviVi: cute.TiledMma, + tiled_mma_PVti: cute.TiledMma, + pipeline_Q: pipeline.PipelineAsync, + pipeline_K: pipeline.PipelineAsync, + pipeline_Qv0: pipeline.PipelineAsync, + pipeline_Qv1: pipeline.PipelineAsync, + pipeline_V0: pipeline.PipelineAsync, + pipeline_V1: pipeline.PipelineAsync, + pipeline_S: pipeline.PipelineAsync, + pipeline_P: pipeline.PipelineAsync, + pipeline_O0: pipeline.PipelineAsync, + pipeline_O1: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + is_leader_cta: Boolean, + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== mma warp ==== + # Description: Computes Q @ K^T, Qv @ V^T, and P @ V + # Produces: S, O + # Consumes: Q, K, Qv, V, P + + pipelines_V = [pipeline_V0, pipeline_V1] + pipelines_Qv = [pipeline_Qv0, pipeline_Qv1] + pipelines_O = [pipeline_O0, pipeline_O1] + + sQvs = [sQv0, sQv1] + sVs = [sV0, sV1] + sVts = [sVt0, sVt1] + + # Set accumulate = True for Qv @ V^T since we are accumulating on the Q @ K^T result + tiled_mma_QviVi.set(tcgen05.Field.ACCUMULATE, True) + + # Operands for S = Q @ K^T + tSrQ = tiled_mma_QK.make_fragment_A(sQ) + tSrK = tiled_mma_QK.make_fragment_B(sK) + + # Operands for S += Qv @ V^T + tSrQvs = [ + tiled_mma_QviVi.make_fragment_A(sQvs[split]) for split in range(self.num_hdimv_splits) + ] + tSrVs = [ + tiled_mma_QviVi.make_fragment_B(sVs[split]) for split in range(self.num_hdimv_splits) + ] + + # Operands for Oi = P @ Vi + tOrP = tiled_mma_PVti.make_fragment_A(sP) + tOrVts = [ + tiled_mma_PVti.make_fragment_B(sVts[split]) for split in range(self.num_hdimv_splits) + ] + + # GEMM functions + gemm_QK = [ + partial( + fa_sm100_utils.gemm_ptx_partial, + tiled_mma_QK.op, + self.tmem_offset_S[stage], + tCrA=tSrQ[None, None, None, 0], + sA=sQ[None, None, None, 0], + zero_init=True, + cta_group=self.cta_group_size, + ) + for stage in range(self.num_stages_S) + ] + gemms_QvV = [ + [ + partial( + fa_sm100_utils.gemm_ptx_partial, + tiled_mma_QviVi.op, + self.tmem_offset_S[stage], + tCrA=tSrQvs[split][None, None, None, 0], + sA=sQvs[split][None, None, None, 0], + zero_init=False, + cta_group=self.cta_group_size, + ) + for stage in range(self.num_stages_S) + ] + for split in range(self.num_hdimv_splits) + ] + gemms_PVt = [ + partial( + fa_sm100_utils.gemm_ptx_partial, + tiled_mma_PVti.op, + self.tmem_offsets_O[split], + tOrP[None, None, None, 0], + sA=sP[None, None, None, 0], + cta_group=self.cta_group_size, + ) + for split in range(self.num_hdimv_splits) + ] + + consumer_state_Q = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Q + ) + consumer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_K + ) + consumer_state_Qv0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Qvi + ) + consumer_state_Qv1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Qvi + ) + consumer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Vi + ) + consumer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Vi + ) + producer_state_S = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_S + ) + consumer_state_P = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_P + ) + producer_state_O0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Oi + ) + producer_state_O1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Oi + ) + + mma_QK = partial(self.mma_inner, gemm_QK, pipeline_K, tSrK, sK) + mma_QvV = partial(self.mma_inner, gemms_QvV, pipelines_V, tSrVs, sVs) + mma_PVt = partial(self.mma_inner, gemms_PVt, pipelines_V, tOrVts, sVts) + + work_tile = tile_scheduler.initial_work_tile_info() + O_should_accumulate = False + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + # n_block_max = self.topk_length // self.tile_n + n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, + cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + even_n_blocks = num_n_blocks % 2 == 0 and num_n_blocks > 0 + num_n_block_groups = cute.ceil_div(num_n_blocks, self.num_stages_S) + + if is_leader_cta: + pipeline_Q.consumer_wait(consumer_state_Q) + pipeline_Qv0.consumer_wait(consumer_state_Qv0) + pipeline_Qv1.consumer_wait(consumer_state_Qv1) + + consumer_states_V = [consumer_state_V0, consumer_state_V1] + producer_states_O = [producer_state_O0, producer_state_O1] + + # ==== Prologue ==== + pipeline_S.producer_acquire(producer_state_S) + # S = Q @ K^T + consumer_state_K = mma_QK(consumer_state_K, stage=0) + # S += Qvi @ Vi^T + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_states_V[split] = mma_QvV( + consumer_states_V[split], stage=0, split=split + ) + pipeline_S.producer_commit(producer_state_S) + producer_state_S.advance() + + # ==== Mainloop ==== + for _ in cutlass.range(num_n_block_groups - 1, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + next_stage = const_expr((stage + 1) % self.num_stages_S) + pipeline_S.producer_acquire(producer_state_S) + # S = Q @ K^T + consumer_state_K = mma_QK(consumer_state_K, stage=next_stage) + # S += Qvi @ Vi^T + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_states_V[split] = mma_QvV( + consumer_states_V[split], stage=next_stage, split=split + ) + pipeline_S.producer_commit(producer_state_S) + producer_state_S.advance() + # Oi += P @ Vi + pipeline_P.consumer_wait(consumer_state_P) + for split in cutlass.range_constexpr(self.num_hdimv_splits): + producer_state_Oi = producer_states_O[split] + pipelines_O[split].producer_acquire(producer_state_Oi) + consumer_states_V[split] = mma_PVt( + consumer_states_V[split], + split=split, + zero_init=not O_should_accumulate, + ) + pipelines_O[split].producer_commit(producer_state_Oi) + producer_state_Oi.advance() + producer_states_O[split] = producer_state_Oi + pipeline_P.consumer_release(consumer_state_P) + consumer_state_P.advance() + O_should_accumulate = True + + # ==== Epilogue ==== + num_final_n_blocks = self.num_stages_S if even_n_blocks else self.num_stages_S - 1 + for stage in cutlass.range_constexpr(self.num_stages_S): + n_block = num_final_n_blocks - 1 - stage + if const_expr(stage == 0): + if n_block > 0: + pipeline_S.producer_acquire(producer_state_S) + # S = Q @ K^T + consumer_state_K = mma_QK(consumer_state_K, stage=stage + 1) + # S += Qvi @ Vi^T + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_states_V[split] = mma_QvV( + consumer_states_V[split], stage=stage + 1, split=split + ) + pipeline_S.producer_commit(producer_state_S) + producer_state_S.advance() + if n_block >= 0: + # Oi += P @ Vi + pipeline_P.consumer_wait(consumer_state_P) + for split in cutlass.range_constexpr(self.num_hdimv_splits): + producer_state_Oi = producer_states_O[split] + pipelines_O[split].producer_acquire(producer_state_Oi) + consumer_states_V[split] = mma_PVt( + consumer_states_V[split], + split=split, + zero_init=not O_should_accumulate, + ) + pipelines_O[split].producer_commit(producer_state_Oi) + producer_state_Oi.advance() + producer_states_O[split] = producer_state_Oi + pipeline_P.consumer_release(consumer_state_P) + consumer_state_P.advance() + O_should_accumulate = True + + consumer_state_V0, consumer_state_V1 = consumer_states_V + producer_state_O0, producer_state_O1 = producer_states_O + + pipeline_Q.consumer_release(consumer_state_Q) + + # if we overlap sOi with sQvi for tma store, need to acquire signal + if const_expr(self.use_tma_O and not self.overlap_sO_sV): + pipeline_O0.producer_tail(producer_state_O0.clone()) + pipeline_O1.producer_tail(producer_state_O1.clone()) + + pipeline_Qv0.consumer_release(consumer_state_Qv0) + pipeline_Qv1.consumer_release(consumer_state_Qv1) + consumer_state_Q.advance() + consumer_state_Qv0.advance() + consumer_state_Qv1.advance() + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + O_should_accumulate = False + + pipeline_S.producer_tail(producer_state_S) + pipeline_O0.producer_tail(producer_state_O0) + pipeline_O1.producer_tail(producer_state_O1) + + @cute.jit + def mma_inner( + self, + gemm, + load_pipeline, + tCrB, + sB, + consumer_state: pipeline.PipelineState, + stage: Optional[Int32] = None, + split: Optional[Int32] = None, + zero_init: Optional[bool] = None, + ): + if const_expr(split is not None): + gemm = gemm[split] + load_pipeline = load_pipeline[split] + tCrB = tCrB[split] + sB = sB[split] + if const_expr(stage is not None): + gemm = gemm[stage] + + smem_stage = consumer_state.index + tCrB_cur = tCrB[None, None, None, smem_stage] + sB_cur = sB[None, None, None, smem_stage] + + load_pipeline.consumer_wait(consumer_state) + if const_expr(zero_init is not None): + gemm(tCrB=tCrB_cur, sB=sB_cur, zero_init=zero_init) + else: + gemm(tCrB=tCrB_cur, sB=sB_cur) + load_pipeline.consumer_release(consumer_state) + consumer_state.advance() + return consumer_state + + @cute.jit + def softmax_loop( + self, + softmax_scale: Float32, + softmax_scale_log2: Float32, + mLSE: Optional[cute.Tensor], + sRowMax: cute.Tensor, + sRowSum: cute.Tensor, + sScale: cute.Tensor, + sBitmask: Optional[cute.Tensor], + sP: cute.Tensor, + tStS: cute.Tensor, + thr_mma_QK: cute.ThrMma, + pipeline_S: pipeline.PipelineAsync, + pipeline_P: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + pipeline_bitmask: Optional[pipeline.PipelineAsync], + sO_empty_mbar_ptr: Optional[cute.Pointer], + AttentionMaskCls: Callable, + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== softmax warpgroup ==== + # Description: computes softmax on S and writes the result to P + # Produces: P, softmax stats + # Consumes: S, bitmask (for topk sparsity) + + tidx = cute.arch.thread_idx()[0] % self.num_softmax_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % ( + self.num_softmax_threads // 32 + ) + + tSAcc = tStS[(None, None), 0, 0, 0] + tSAcc_staged = [tStS[(None, None), 0, 0, stage] for stage in range(self.num_stages_S)] + + cS = cute.make_identity_tensor(self.mma_tiler_QK[:2]) # (128, 128) + tScS = thr_mma_QK.partition_C(cS)[(None, None), 0, 0] # (64, 128) + + # S tmem -> rmem copy objects + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.dtype_acc, + ) + tmem_load_tiled = tcgen05.make_tmem_copy(tmem_load_atom, tSAcc) + tmem_load_thr = tmem_load_tiled.get_slice(tidx) + # S tmem -> rmem copy operands + tStS_t2r = tmem_load_thr.partition_S(tSAcc) # (((32, 32), 1), 1, 2) + tStS_t2r_staged = [ + tmem_load_thr.partition_S(tSAcc_staged[stage]) for stage in range(self.num_stages_S) + ] + tScS_t2r = tmem_load_thr.partition_D(tScS) + tSrS_t2r = cute.make_rmem_tensor(tScS_t2r.shape, self.dtype_acc) + + # P rmem -> smem copy objects + universal_copy_bits = 128 + smem_store_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype_P, + num_bits_per_copy=universal_copy_bits, + ) + smem_store_tiled = cute.make_tiled_copy_D(smem_store_atom, tmem_load_tiled) + smem_store_thr = smem_store_tiled.get_slice(tidx) + # P rmem -> smem copy operands + sP_slice = sP[None, None, None, 0] + sP_mn = cute.make_tensor( + sP_slice.iterator, + cute.make_layout( + ( + (sP_slice.shape[0][0], sP_slice.shape[1]), + (sP_slice.shape[0][1], sP_slice.shape[2]), + ), + stride=( + (sP_slice.stride[0][0], sP_slice.stride[1]), + (sP_slice.stride[0][1], sP_slice.stride[2]), + ), + ), + ) + sP_smem_view = smem_store_thr.partition_D(sP_mn) + + consumer_state_S = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_S + ) + producer_state_P = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_P + ) + producer_state_sm_stats = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_sm_stats + ) + consumer_state_bitmask = None + if const_expr(self.is_topk_gather and not self.disable_bitmask): + consumer_state_bitmask = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_bitmask + ) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, + cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + even_n_blocks = num_n_blocks % 2 == 0 and num_n_blocks > 0 + num_n_block_groups = cute.ceil_div(num_n_blocks, self.num_stages_S) + + mask = AttentionMaskCls(seqlen) + mask_fn = partial( + mask.apply_mask_sm100, + m_block=cluster_m_block, + thr_mma=thr_mma_QK, + thr_tmem_load=tmem_load_thr, + mask_causal=self.is_causal, + mask_local=self.is_local, + batch_idx=batch_idx, + head_idx=head_idx, + r2p=False, # TODO: fix r2p for 2cta + ) + disable_mask = self.disable_bitmask and self.is_topk_gather + + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=8.0 if const_expr(self.dtype_Q.width == 16) else 0.0, + softmax_scale=softmax_scale, + ) + softmax.reset() + + softmax_step_fn = partial( + self.softmax_step, + softmax, + sRowMax, + sScale, + sBitmask, + tStS_t2r_staged, + tSrS_t2r, + sP_smem_view, + tmem_load_thr, + smem_store_thr, + pipeline_S, + pipeline_P, + pipeline_sm_stats, + pipeline_bitmask, + tidx, + warp_idx, + ) + + ### first iteration ### + n_block = n_block_max - 1 + ( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + ) = softmax_step_fn( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + 0, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=True) + if not const_expr(disable_mask) + else None, + is_first=True, + ) + n_block -= 1 + + ### Separate iterations with causal masking + # note: For square mma tile, can mask at most 1 n_block_group + if const_expr((self.is_causal or self.is_local) and not self.is_topk_gather): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, cluster_m_block, n_block_min + ) + num_masked_n_blocks = n_block_max - 1 - n_block_min_causal_local_mask + num_masked_n_block_groups = min( + num_n_block_groups - 1, cute.ceil_div(num_masked_n_blocks, self.num_stages_S) + ) + num_n_block_groups -= num_masked_n_block_groups + for _ in cutlass.range(num_masked_n_block_groups, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + ( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + ) = softmax_step_fn( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + 1 - stage, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + n_block -= 1 + + ### Mainloop ### + for n_block_group in cutlass.range(num_n_block_groups - 1, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + ( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + ) = softmax_step_fn( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + 1 - stage, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False) + if const_expr(self.is_topk_gather and not self.disable_bitmask) + else None, + ) + n_block -= 1 + + ### last iteration if even ### + # always mask to simplify logic + if even_n_blocks: + ( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + ) = softmax_step_fn( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + 1, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False) + if not const_expr(disable_mask) + else None, + ) + n_block -= 1 + + # write row max and sum to smem + sRowSum[tidx % self.cta_tile_m, warp_idx // self.cta_group_size] = softmax.row_sum[0] + if const_expr(mLSE is not None): + if tidx < self.cta_tile_m: + sRowMax[tidx, 0] = softmax.row_max[0] + self.sm_stats_barrier_full.arrive() + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + self.sm_stats_barrier_empty.arrive_and_wait() + + pipeline_P.producer_tail(producer_state_P) + pipeline_sm_stats.producer_tail(producer_state_sm_stats) + + @cute.jit + def softmax_step( + self, + softmax: SoftmaxSm100, + sRowMax: cute.Tensor, + sScale: cute.Tensor, + sBitmask: Optional[cute.Tensor], + tStS_t2r_staged: cute.Tensor, + tSrS_t2r: cute.Tensor, + sP_smem_view: cute.Tensor, + tmem_load_thr: cute.CopyAtom, + smem_store_thr: cute.CopyAtom, + pipeline_S: pipeline.PipelineAsync, + pipeline_P: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + pipeline_bitmask: Optional[pipeline.PipelineAsync], + tidx: Int32, + warp_idx: Int32, + consumer_state_S: pipeline.PipelineState, + producer_state_P: pipeline.PipelineState, + producer_state_sm_stats: pipeline.PipelineState, + consumer_state_bitmask: Optional[pipeline.PipelineState], + stage: cutlass.Constexpr[Int32], + n_block: Int32, + mask_fn: Optional[Callable] = None, + is_first: Boolean = False, + ): + tSrP = cute.make_rmem_tensor(tSrS_t2r.shape, self.dtype_P) + rP_smem_view = smem_store_thr.retile(tSrP) + + pipeline_S.consumer_wait(consumer_state_S) + cute.copy(tmem_load_thr, tStS_t2r_staged[stage], tSrS_t2r) + cute.arch.fence_view_async_tmem_load() + pipeline_S.consumer_release(consumer_state_S) + + rBitmask = None + if const_expr(self.is_topk_gather and not self.disable_bitmask): + assert pipeline_bitmask is not None + assert consumer_state_bitmask is not None + pipeline_bitmask.consumer_wait(consumer_state_bitmask) + rBitmask = cute.make_rmem_tensor((self.tile_n // 64,), dtype=Uint32) + bitmask_col_offset = self.tile_n // 64 if warp_idx >= 2 else 0 + for i in cutlass.range_constexpr(cute.size(rBitmask)): + rBitmask[i] = sBitmask[bitmask_col_offset + i, consumer_state_bitmask.index] + + if const_expr(mask_fn is not None): + mask_fn(tSrS_t2r, n_block=n_block, rBitmask=rBitmask) + + # compute threadwise row_max + row_max = softmax.compute_row_max_local(tSrS_t2r.load(), is_first) + self.softmax_barrier.arrive_and_wait() + + # 2-thread reduce row_max through smem + assert self.cta_tile_m * self.cta_group_size == 128 + sRowMax[tidx % self.cta_tile_m, warp_idx // self.cta_group_size] = row_max + self.softmax_barrier.arrive_and_wait() + # must release after barrier sync + if const_expr(self.is_topk_gather and not self.disable_bitmask): + pipeline_bitmask.consumer_release(consumer_state_bitmask) + row_max0 = sRowMax[tidx % self.cta_tile_m, 0] + row_max1 = sRowMax[tidx % self.cta_tile_m, 1] + row_max = max(row_max0, row_max1) + + row_max, acc_scale = softmax.update_row_max_from_local(row_max, is_first) + + # note: acc_scales agree for paired threads + pipeline_sm_stats.producer_acquire(producer_state_sm_stats) + if warp_idx < self.cta_group_size: + sScale[tidx % self.cta_tile_m, producer_state_sm_stats.index] = acc_scale + pipeline_sm_stats.producer_commit(producer_state_sm_stats) + + # x -> scale_log2*x-rowmax + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + + # x -> exp2(x) + softmax.apply_exp2_convert(tSrS_t2r, tSrP) + + pipeline_P.producer_acquire(producer_state_P) + cute.copy(smem_store_thr, rP_smem_view, sP_smem_view) + cute.arch.fence_view_async_shared() + pipeline_P.producer_commit(producer_state_P) + + consumer_state_S.advance() + producer_state_P.advance() + producer_state_sm_stats.advance() + if const_expr(self.is_topk_gather and not self.disable_bitmask): + consumer_state_bitmask.advance() + + softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) + + return consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask + + @cute.jit + def correction_loop( + self, + softmax_scale_log2: Float32, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + tma_atom_O: Optional[cute.CopyAtom], + sRowMax: cute.Tensor, + sRowSum: cute.Tensor, + sScale: cute.Tensor, + sO: cute.Tensor, + tO0tO0: cute.Tensor, + tO1tO1: cute.Tensor, + pipeline_O0: pipeline.PipelineAsync, + pipeline_O1: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + tiled_copy_O_r2g: cute.TiledCopy, + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + ### ==== correction/epilogue warpgroup ==== + # Correction: copy scale smem -> rmem, copy O tmem -> rmem, rescale O, store O rmem -> tmem + # Epilogue: copy O tmem -> rmem, do final scaling of O, store O rmem -> gmem, + # optionally store LSE + # Produces: - + # Consumes: O, softmax stats + + tidx = cute.arch.thread_idx()[0] % self.num_epilogue_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % ( + self.num_epilogue_threads // 32 + ) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + leader_warp = warp_idx == 0 + + tO0tO0 = tO0tO0[(None, None), 0, 0] # (64, (128, 2)) + tO1tO1 = tO1tO1[(None, None), 0, 0] # (64, (128, 2)) + tOtOs = [tO0tO0, tO1tO1] + + # tuneable parameter + corr_tile_size = math.gcd(32, self.tmem_cols_Oi) + + tmem_load_atom_O = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.dtype_acc, + ) + tmem_store_atom_O = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.dtype_acc, + ) + thr_tmem_load_O = tcgen05.make_tmem_copy(tmem_load_atom_O, tO0tO0).get_slice(tidx) + thr_tmem_store_O = tcgen05.make_tmem_copy(tmem_store_atom_O, tO0tO0).get_slice(tidx) + + # ((32,1),1,4) + tOtOs_t2r = [ + thr_tmem_load_O.partition_S(tOtOs[split]) for split in range(self.num_hdimv_splits) + ] + tOtOs_r2t = [ + thr_tmem_store_O.partition_D(tOtOs[split]) for split in range(self.num_hdimv_splits) + ] + + cOi = cute.make_identity_tensor((self.cta_tile_m, self.hdimv // self.num_hdimv_splits)) + thr_tiled_copy_O_r2g = tiled_copy_O_r2g.get_slice(tidx) + tOicOi = thr_tiled_copy_O_r2g.partition_S(cOi) + + tOicOi_t2r = thr_tmem_load_O.partition_D(tOicOi[(None, None), 0, 0]) + + pipelines_O = [pipeline_O0, pipeline_O1] + + consumer_state_O0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Oi + ) + consumer_state_O1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Oi + ) + consumer_state_sm_stats = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_sm_stats + ) + + do_correction_rescale = partial( + self.correction_rescale, + thr_tmem_load_O, + thr_tmem_store_O, + tOicOi_t2r, + ) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, + cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + + consumer_states_O = [consumer_state_O0, consumer_state_O1] + + # acquire first signal and release immediately + pipeline_sm_stats.consumer_wait(consumer_state_sm_stats) + pipeline_sm_stats.consumer_release(consumer_state_sm_stats) + consumer_state_sm_stats.advance() + + for _ in cutlass.range(num_n_blocks - 1, unroll=1): + pipeline_sm_stats.consumer_wait(consumer_state_sm_stats) + scale = sScale[tidx % self.cta_tile_m, consumer_state_sm_stats.index] + should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 + pipeline_sm_stats.consumer_release(consumer_state_sm_stats) + consumer_state_sm_stats.advance() + + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_state_Oi = consumer_states_O[split] + pipelines_O[split].consumer_wait(consumer_state_Oi) + if should_rescale: + do_correction_rescale( + tOtOs_t2r[split], + tOtOs_r2t[split], + scale, + ) + pipelines_O[split].consumer_release(consumer_state_Oi) + consumer_state_Oi.advance() + consumer_states_O[split] = consumer_state_Oi + + # (seqlen_q, hdimv) + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=self.ragged_tma_O)[ + None, None, head_idx + ] + # (cta_tile_m, hdimv//2, 2) + gO = cute.local_tile( + mO_cur, + (self.cta_tile_m, self.hdimv // self.num_hdimv_splits), + (cta_m_block, None), + ) + tOgO = thr_tiled_copy_O_r2g.partition_D(gO) + # ((32, 1), 1, 4) + tOrOs_t2r = [ + cute.make_rmem_tensor(tOicOi_t2r.shape, self.dtype_acc) + for split in range(self.num_hdimv_splits) + ] + tOrOs_r2g_f32 = [ + thr_tiled_copy_O_r2g.retile(tOrOs_t2r[split]) + for split in range(self.num_hdimv_splits) + ] + tOrOs_r2g = [ + cute.make_rmem_tensor_like(tOrOs_r2g_f32[split], self.dtype_O) + for split in range(self.num_hdimv_splits) + ] + if const_expr(self.use_tma_O): + tOsO = thr_tiled_copy_O_r2g.partition_D(sO) + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, + 0, + cute.make_layout(1), + sO, + gO, + ) + + self.sm_stats_barrier_full.arrive_and_wait() + + row_sum0 = sRowSum[tidx % self.cta_tile_m, 0] + row_sum1 = sRowSum[tidx % self.cta_tile_m, 1] + row_sum = row_sum0 + row_sum1 + acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) + + self.sm_stats_barrier_empty.arrive() + + seqlen_q = ( + seqlen.seqlen_q + if const_expr(not self.pack_gqa) + else seqlen.seqlen_q * self.qhead_per_kvhead + ) + + # compute and store lse to gmem + if const_expr(mLSE is not None): + if const_expr(not seqlen.has_cu_seqlens_q): + mLSE_cur = mLSE[None, head_idx, batch_idx] + else: + lse_offset = ( + seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + ) + mLSE_cur = cute.domain_offset((lse_offset,), mLSE[None, head_idx]) + gLSE = cute.local_tile(mLSE_cur, (self.cta_tile_m,), (cta_m_block,)) + if tidx < self.cta_tile_m: + row_max = sRowMax[tidx, 0] + LN2 = math.log(2.0) + lse = ( + (row_max * softmax_scale_log2 + cute.math.log2(row_sum, fastmath=True)) + * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + if tidx < seqlen_q - cta_m_block * self.cta_tile_m: + gLSE[tidx] = lse + + row_idx = cta_m_block * self.cta_tile_m + tOicOi[0][0] + + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_state_Oi = consumer_states_O[split] + pipelines_O[split].consumer_wait(consumer_state_Oi) + # copy Oi tmem -> rmem + cute.copy( + thr_tmem_load_O, + tOtOs_t2r[split], + tOrOs_t2r[split], + ) + + # scale and downcast Oi + tOrOs_r2g[split].store((tOrOs_r2g_f32[split].load() * scale).to(self.dtype_O)) + + if const_expr(not self.use_tma_O): + # copy Oi rmem -> gmem + if row_idx < seqlen_q: + cute.copy( + thr_tiled_copy_O_r2g, + tOrOs_r2g[split], + tOgO[None, None, None, split], + ) + else: + # copy Oi rmem -> smem + if const_expr(self.overlap_sO_sV): + # last slot for Vti is always 1, 3 + sO_idx = 1 + 2 * split + else: + sO_idx = split + cute.copy( + thr_tiled_copy_O_r2g, + tOrOs_r2g[split], + tOsO[None, None, None, sO_idx], + ) + cute.arch.fence_view_async_shared() + self.epi_barrier.arrive_and_wait() + # tma store Oi smem -> gmem + if leader_warp: + store_O(src_idx=sO_idx, dst_idx=split) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(1 - split, read=True) + if const_expr(split == 1 and self.overlap_sO_sV): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(sO_empty_mbar_ptr) + + consumer_state_O0, consumer_state_O1 = consumer_states_O + + cute.arch.fence_view_async_tmem_load() + pipeline_O0.consumer_release(consumer_state_O0) + pipeline_O1.consumer_release(consumer_state_O1) + consumer_state_O0.advance() + consumer_state_O1.advance() + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + + @cute.jit + def correction_rescale( + self, + thr_tmem_load: cute.CopyAtom, + thr_tmem_store: cute.CopyAtom, + tOcO_t2r: cute.Tensor, + tOtO_t2r: cute.Tensor, + tOtO_r2t: cute.Tensor, + scale: Float32, + ): + tOrO_t2r_frg = cute.make_rmem_tensor_like(tOcO_t2r[None, None, 0], self.dtype_acc) + + for i in cutlass.range_constexpr(cute.size(tOtO_t2r, mode=[2])): + tOtO_t2r_cur = tOtO_t2r[None, None, i] + tOtO_r2t_cur = tOtO_r2t[None, None, i] + + cute.copy(thr_tmem_load, tOtO_t2r_cur, tOrO_t2r_frg) + for j in cutlass.range(0, cute.size(tOrO_t2r_frg), 2, unroll_full=True): + tOrO_t2r_frg[j], tOrO_t2r_frg[j + 1] = cute.arch.mul_packed_f32x2( + (tOrO_t2r_frg[j], tOrO_t2r_frg[j + 1]), (scale, scale) + ) + cute.copy(thr_tmem_store, tOrO_t2r_frg, tOtO_r2t_cur) + cute.arch.fence_view_async_tmem_store() + + +def test_mla_kernel( + seqlen_q=2048, + seqlen_k=2048, + topk_length=2048, + nheads=1, + batch=1, + iter=0, + compile_cache=dict(), + validate=True, + seed=0, + gather_kv=True, + pack_gqa=False, + is_causal=False, + varlen_q=False, + varlen_k=False, + disable_bitmask=False, +): + torch.manual_seed(seed) + hdim = 64 + hdimv = 512 + softmax_scale = 1.0 / math.sqrt(hdim + hdimv) + + nheads_kv = 1 + qhead_per_kvhead = nheads + + compile_key = ( + is_causal, + gather_kv, + topk_length if gather_kv else None, + pack_gqa, + qhead_per_kvhead, + nheads_kv, + varlen_q, + varlen_k, + disable_bitmask, + ) + if compile_key not in compile_cache: + total_q_dummy = batch * seqlen_q + total_k_dummy = batch * seqlen_k + + if varlen_q: + Q = torch.randn(total_q_dummy, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(total_q_dummy, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(total_q_dummy, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(nheads, total_q_dummy, dtype=torch.float32, device="cuda") + index_topk = ( + torch.rand(total_q_dummy, topk_length, device="cuda") + .argsort(dim=-1) + .to(torch.int32) + ) + cu_seqlens_q_dummy = torch.arange( + 0, (batch + 1) * seqlen_q, seqlen_q, dtype=torch.int32, device="cuda" + ) + else: + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(batch, nheads, seqlen_q, dtype=torch.float32, device="cuda") + index_topk = ( + torch.rand(batch, seqlen_q, topk_length, device="cuda") + .argsort(dim=-1) + .to(torch.int32) + ) + + if varlen_k: + K = torch.randn(total_k_dummy, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(total_k_dummy, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + cu_seqlens_k_dummy = torch.arange( + 0, (batch + 1) * seqlen_k, seqlen_k, dtype=torch.int32, device="cuda" + ) + else: + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + mLSE = from_dlpack(lse, assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) + if gather_kv: + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic( + leading_dim=index_topk.ndim - 1 + ) + else: + mIndexTopk = None + + compile_kwargs = dict(mIndexTopk=mIndexTopk) + if varlen_q: + compile_kwargs["mCuSeqlensQ"] = from_dlpack(cu_seqlens_q_dummy, assumed_align=4) + if varlen_k: + compile_kwargs["mCuSeqlensK"] = from_dlpack(cu_seqlens_k_dummy, assumed_align=4) + + kernel = cute.compile( + FlashAttentionMLAForwardSm100( + is_causal=is_causal, + use_cpasync_load_KV=gather_kv, + topk_length=topk_length if gather_kv else 2048, + is_topk_gather=gather_kv, + pack_gqa=pack_gqa, + qhead_per_kvhead=qhead_per_kvhead, + nheads_kv=nheads_kv, + is_varlen_q=varlen_q, + disable_bitmask=disable_bitmask, + ), + mQ, + mQv, + mK, + mV, + mO, + mLSE, + softmax_scale, + **compile_kwargs, + options="--keep-ptx --keep-cubin --generate-line-info", + ) + dump_kernel_attributes(kernel) + compile_cache[compile_key] = kernel + + # ================================================================ + # ---- Generate variable seqlens for this run ---- + if varlen_q: + torch.manual_seed(seed + 1000) + # When causal without varlen_k, every per-batch seqlen_q must not exceed seqlen_k. + max_seqlen_q = seqlen_k if (is_causal and not varlen_k) else seqlen_q + seqlens_q = torch.randint(1, max_seqlen_q + 1, (batch,), dtype=torch.int32) + cu_seqlens_q = torch.zeros(batch + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q[1:] = seqlens_q.cumsum(0).to(torch.int32).cuda() + total_q = cu_seqlens_q[-1].item() + else: + seqlens_q = torch.full((batch,), seqlen_q, dtype=torch.int32) + total_q = None # unused + + if varlen_k: + torch.manual_seed(seed + 2000) + # Each batch item must have at least topk_length keys so topk gather is valid. + min_seqlen_k = topk_length if gather_kv else 1 + seqlens_k = torch.randint(min_seqlen_k, seqlen_k + 1, (batch,), dtype=torch.int32) + # When causal, every batch item needs seqlens_k[b] >= seqlens_q[b]. + if is_causal: + seqlens_k = torch.maximum(seqlens_k, seqlens_q) + cu_seqlens_k = torch.zeros(batch + 1, dtype=torch.int32, device="cuda") + cu_seqlens_k[1:] = seqlens_k.cumsum(0).to(torch.int32).cuda() + total_k = cu_seqlens_k[-1].item() + else: + seqlens_k = torch.full((batch,), seqlen_k, dtype=torch.int32) + total_k = None # unused + + torch.manual_seed(seed) # restore main seed before drawing actual tensors + + # ---- Allocate Q / Qv / O / lse ---- + if varlen_q: + Q = torch.randn(total_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(total_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(total_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(nheads, total_q, dtype=torch.float32, device="cuda") + else: + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(batch, nheads, seqlen_q, dtype=torch.float32, device="cuda") + + # ---- Allocate K / V ---- + if varlen_k: + K = torch.randn(total_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(total_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + else: + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + + # ---- Generate index_topk with per-batch valid ranges when varlen_k ---- + # index_topk shape: (total_q, topk_length) if varlen_q else (batch, seqlen_q, topk_length) + if gather_kv: + topk_parts = [] + for b in range(batch): + sl_q_b = seqlens_q[b].item() + sl_k_b = seqlens_k[b].item() + # Draw topk_length unique indices from [0, sl_k_b) for each query in this batch item. + topk_b = ( + torch.rand(sl_q_b, sl_k_b, device="cuda") + .argsort(dim=-1)[..., :topk_length] + .to(torch.int32) + ) # (sl_q_b, topk_length), all < sl_k_b + topk_parts.append(topk_b) + + if varlen_q: + index_topk = torch.cat(topk_parts, dim=0) # (total_q, topk_length) + else: + index_topk = torch.stack(topk_parts, dim=0) # (batch, seqlen_q, topk_length) + else: + index_topk = None + + # ---- Reference computation (per-batch loop covers all four varlen combos) ---- + O_ref_list, O_pt_list, lse_ref_list, lse_pt_list = [], [], [], [] + for b in range(batch): + qs = cu_seqlens_q[b].item() if varlen_q else b * seqlen_q + qe = cu_seqlens_q[b + 1].item() if varlen_q else (b + 1) * seqlen_q + ks = cu_seqlens_k[b].item() if varlen_k else b * seqlen_k + ke = cu_seqlens_k[b + 1].item() if varlen_k else (b + 1) * seqlen_k + + Q_b = Q[qs:qe].unsqueeze(0) if varlen_q else Q[b : b + 1] # (1, sl_q, nheads, hdim) + Qv_b = Qv[qs:qe].unsqueeze(0) if varlen_q else Qv[b : b + 1] # (1, sl_q, nheads, hdimv) + K_b = K[ks:ke].unsqueeze(0) if varlen_k else K[b : b + 1] # (1, sl_k, nheads_kv, hdim) + V_b = V[ks:ke].unsqueeze(0) if varlen_k else V[b : b + 1] # (1, sl_k, nheads_kv, hdimv) + if gather_kv: + topk_b = index_topk[qs:qe].unsqueeze(0) if varlen_q else index_topk[b : b + 1] + else: + topk_b = None + + O_b, _, lse_b = attention_ref( + Q_b, K_b, V_b, qv=Qv_b, causal=is_causal, return_lse=True, gather_kv_indices=topk_b + ) + O_pt_b, _, lse_pt_b = attention_ref( + Q_b, + K_b, + V_b, + qv=Qv_b, + causal=is_causal, + upcast=False, + reorder_ops=True, + return_lse=True, + gather_kv_indices=topk_b, + ) + O_ref_list.append(O_b.squeeze(0)) + O_pt_list.append(O_pt_b.squeeze(0)) + lse_ref_list.append(lse_b.squeeze(0)) + lse_pt_list.append(lse_pt_b.squeeze(0)) + + cat_dim_o = 0 if (varlen_q) else 0 # always 0: leading token/batch dim + cat_dim_lse = -1 if (varlen_q) else -1 # always last: token dim + + if varlen_q: + O_ref = torch.cat(O_ref_list, dim=0) # (total_q, nheads, hdimv) + O_pt = torch.cat(O_pt_list, dim=0) + lse_ref = torch.cat(lse_ref_list, dim=-1) # (nheads, total_q) + lse_pt = torch.cat(lse_pt_list, dim=-1) + else: + O_ref = torch.stack(O_ref_list, dim=0) # (batch, seqlen_q, nheads, hdimv) + O_pt = torch.stack(O_pt_list, dim=0) + lse_ref = torch.stack(lse_ref_list, dim=0) # (batch, nheads, seqlen_q) + lse_pt = torch.stack(lse_pt_list, dim=0) + + rtol = 2 + atol = 2 * (O_ref + 0.3 - 0.3 - O_ref).abs().max().item() + + # ---- CuTe tensor wrappers ---- + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + mLSE = from_dlpack(lse, assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) + if index_topk is not None: + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic( + leading_dim=index_topk.ndim - 1 + ) + else: + mIndexTopk = None + + run_kwargs = dict(mIndexTopk=mIndexTopk) + if varlen_q: + run_kwargs["mCuSeqlensQ"] = from_dlpack(cu_seqlens_q, assumed_align=4) + if varlen_k: + run_kwargs["mCuSeqlensK"] = from_dlpack(cu_seqlens_k, assumed_align=4) + + # ---- Run kernel ---- + compile_cache[compile_key]( + mQ, + mQv, + mK, + mV, + mO, + mLSE, + softmax_scale, + **run_kwargs, + ) + + print(f"Pytorch max O diff: {(O_pt - O_ref).abs().max().item()}") + print(f"Pytorch mean O diff: {(O_pt - O_ref).abs().mean().item()}") + print(f"Max abs diff O, O_ref: {(O - O_ref).abs().max().item()}") + print(f"Mean abs diff O, O_ref: {(O - O_ref).abs().mean().item()}") + + # print(f"Pytorch LSE max diff: {(lse_pt - lse_ref).abs().max().item()}") + # print(f"Pytorch LSE mean diff: {(lse_pt - lse_ref).abs().mean().item()}") + # print(f"Max abs diff LSE: {(lse - lse_ref).abs().max().item()}") + # print(f"Mean abs diff LSE: {(lse - lse_ref).abs().mean().item()}") + + if validate: + assert (O - O_ref).abs().max().item() <= rtol * (O_pt - O_ref).abs().max().item() + atol + varlen_tag = "" + if varlen_q: + varlen_tag += f", total_q:{total_q}" + if varlen_k: + varlen_tag += f", total_k:{total_k}" + print( + f"batch:{batch:3d}, nheads:{nheads:3d}, seqlen_q:{seqlen_q:5d}, seqlen_k:{seqlen_k:5d}" + f"{varlen_tag}, iter:{iter:2d} PASSED" + ) + else: + print(mO) + print( + f"batch:{batch:3d}, nheads:{nheads:3d}, seqlen_q:{seqlen_q:5d}, seqlen_k:{seqlen_k:5d}" + f", iter:{iter:2d} RUN (NOT TESTING CORRECTNESS)" + ) + + return None + + +def timeit(fn, *args, **kwargs): + # Synchronize before timing + torch.cuda.synchronize() + + # Warmup + for _ in range(10): + fn(*args, **kwargs) + + # Benchmark using PyTorch's Timer + t = benchmark.Timer( + stmt="fn(*args, **kwargs)", globals={"fn": fn, "args": args, "kwargs": kwargs} + ) + + # Time it multiple runs + measurement = t.timeit(20) # 20 repeats + avg_time = measurement.mean # Average time in seconds + + time.sleep(1) + + return avg_time + + +def benchmark_mla_kernel( + batch=1, + seqlen_q=2048, + seqlen_k=2048, + topk_length=2048, + nheads=128, + hdim=64, + hdimv=512, + compile_cache=dict(), + gather_kv=True, + is_causal=False, + disable_bitmask=False, +): + assert hdim == 64, "hdim must be 64" + assert hdimv == 512, "hdimv must be 512" + + qhead_per_kvhead = nheads + nheads_kv = 1 + pack_gqa = True + softmax_scale = 1.0 / math.sqrt(hdim + hdimv) + + compile_key = ( + is_causal, + gather_kv, + topk_length if gather_kv else None, + pack_gqa, + qhead_per_kvhead, + nheads_kv, + disable_bitmask, + ) + if compile_key not in compile_cache: + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + index_topk = ( + torch.rand(batch, seqlen_q, topk_length, device="cuda").argsort(dim=-1).to(torch.int32) + ) + + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + if gather_kv: + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic( + leading_dim=index_topk.ndim - 1 + ) + else: + mIndexTopk = None + + mLSE = None + + kernel = cute.compile( + FlashAttentionMLAForwardSm100( + is_causal=is_causal, + use_cpasync_load_KV=gather_kv, + topk_length=topk_length if gather_kv else 2048, + is_topk_gather=gather_kv, + pack_gqa=pack_gqa, + qhead_per_kvhead=qhead_per_kvhead, + nheads_kv=nheads_kv, + disable_bitmask=disable_bitmask, + ), + mQ, + mQv, + mK, + mV, + mO, + mLSE, + softmax_scale, + mIndexTopk=mIndexTopk, + ) + compile_cache[compile_key] = kernel + + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + + index_topk = ( + torch.rand(batch, seqlen_q, topk_length, device="cuda").argsort(dim=-1).to(torch.int32) + ) + + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + if gather_kv: + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic( + leading_dim=index_topk.ndim - 1 + ) + else: + mIndexTopk = None + mLSE = None + + exec_time_in_s = timeit( + compile_cache[compile_key], + mQ, + mQv, + mK, + mV, + mO, + mLSE, + softmax_scale, + mIndexTopk=mIndexTopk, + ) + + seqlen_k_eff = topk_length if gather_kv else seqlen_k + + FLOPs = 2 * batch * nheads * seqlen_q * seqlen_k_eff * (hdim + 2 * hdimv) + if is_causal and not gather_kv: + FLOPs /= 2 + + TFLOPS = FLOPs / exec_time_in_s / 1e12 + + q_bytes = 2 * batch * nheads * seqlen_q * hdim + qv_bytes = 2 * batch * nheads * seqlen_q * hdimv + k_bytes = 2 * batch * nheads_kv * seqlen_k_eff * hdim + v_bytes = 2 * batch * nheads_kv * seqlen_k_eff * hdimv + o_bytes = 2 * batch * nheads * seqlen_q * hdimv + total_bytes = q_bytes + qv_bytes + k_bytes + v_bytes + o_bytes + TBs = total_bytes / exec_time_in_s / 1e12 + + print( + f"batch: {batch}, seqlen_q: {seqlen_q}, seqlen_k: {seqlen_k}, nheads: {nheads}, -> {exec_time_in_s * 1e3:.2f} ms, {TFLOPS:.2f} TFLOPS, {TBs:.2f} TBs" + ) + + +if __name__ == "__main__": + run_test = True + run_benchmark = True + gather_kv = False + is_causal = True + pack_gqa = True + topk_length = 2048 + varlen_q = False + varlen_k = False + disable_bitmask = True + validate = True + + if run_test: + if not gather_kv: + seqlen_q_test_values = range(1, 4002, 400) + seqlen_k_test_values = range(1, 4002, 400) + else: + seqlen_q_test_values = range(1, 1001, 200) + seqlen_k_test_values = range(topk_length, 9001, 2000) + seqlen_q_test_values = [1] + seqlen_k_test_values = [4096] + nheads_test_values = [128] + batch_test_values = [4] + test_configs = [ + ( + batch, + nheads, + seqlen_q, + seqlen_k, + ) + for batch in batch_test_values + for nheads in nheads_test_values + for seqlen_q in seqlen_q_test_values + for seqlen_k in seqlen_k_test_values + ] + iters_per_config = 1 + compile_cache = dict() + print("=" * 40) + print("Testing MLA Kernel") + print("=" * 40) + for config in test_configs: + batch, nheads, seqlen_q, seqlen_k = config + # if is_causal and seqlen_k < seqlen_q: + # continue + for iter in range(iters_per_config): + test_mla_kernel( + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + topk_length=topk_length, + nheads=nheads, + batch=batch, + iter=iter, + compile_cache=compile_cache, + validate=validate, + seed=0, + gather_kv=gather_kv, + pack_gqa=pack_gqa, + is_causal=is_causal, + varlen_q=varlen_q, + varlen_k=varlen_k, + disable_bitmask=disable_bitmask, + ) + if run_benchmark: + if gather_kv: + seqlen_q_benchmark_values = [1] + seqlen_k_benchmark_values = [8192 * 2] + nheads_benchmark_values = [128] + batch_benchmark_values = [512] + else: + seqlen_q_benchmark_values = [1] + seqlen_k_benchmark_values = [8192 * 2] + nheads_benchmark_values = [128] + batch_benchmark_values = [512] + seqlen_q_benchmark_values = [4096] + seqlen_k_benchmark_values = [4096] + nheads_benchmark_values = [16] + batch_benchmark_values = [8] + benchmark_configs = [ + ( + batch, + nheads, + seqlen_q, + seqlen_k, + ) + for batch in batch_benchmark_values + for nheads in nheads_benchmark_values + for seqlen_q in seqlen_q_benchmark_values + for seqlen_k in seqlen_k_benchmark_values + ] + compile_cache = dict() + print("=" * 40) + print("Benchmarking MLA Kernel") + print("=" * 40) + for config in benchmark_configs: + batch, nheads, seqlen_q, seqlen_k = config + benchmark_mla_kernel( + batch=batch, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + topk_length=topk_length, + nheads=nheads, + gather_kv=gather_kv, + is_causal=is_causal, + disable_bitmask=disable_bitmask, + compile_cache=compile_cache, + ) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 6c9c20d0b76..42acbeaec86 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -14,7 +14,7 @@ # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py import math -from typing import Tuple, Callable, Optional, Literal +from typing import Tuple, Callable, Optional, Literal, NamedTuple from functools import partial import cuda.bindings.driver as cuda @@ -73,23 +73,43 @@ # Values: # ex2_emu_freq: int — how often to use emulated exp2 (0=all hardware exp2, higher=more emulation). # SM103 has fast native exp2, so set freq=0 there. +# ex2_emu_res: int — (hd256 only) number of fragment-pairs per freq period to emulate. # ex2_emu_start_frg: int — fragment index to start emulation from # num_regs_softmax: int — register count for softmax warps (multiple of 8) # num_regs_correction: int — register count for correction warps (multiple of 8) # num_regs_other is derived: 512 - num_regs_softmax * 2 - num_regs_correction +# (hd256 exception: num_regs_other is fixed at 32, not derived) _TUNING_CONFIG = { - (True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 176, 'num_regs_correction': 88}, - (False, True, 128, False): {'ex2_emu_freq': 16, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 192, 'num_regs_correction': 72}, + (True, False, 128, False): {"ex2_emu_freq": 10, "ex2_emu_start_frg": 1, "num_regs_softmax": 176, "num_regs_correction": 88}, + (False, True, 128, False): {"ex2_emu_freq": 16, "ex2_emu_start_frg": 1, "num_regs_softmax": 192, "num_regs_correction": 72}, (True, False, 192, False): {"ex2_emu_freq": 16, "ex2_emu_start_frg": 0, "num_regs_softmax": 184, "num_regs_correction": 80}, (False, True, 192, False): {"ex2_emu_freq": 32, "ex2_emu_start_frg": 1, "num_regs_softmax": 192, "num_regs_correction": 72}, (True, False, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 80}, (False, True, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, (True, False, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, (False, True, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 72}, + (True, False, 256, False): {"ex2_emu_freq": 14, "ex2_emu_res": 6, "ex2_emu_start_frg": 0, "num_regs_softmax": 256, "num_regs_correction": 160}, + (True, True, 256, False): {"ex2_emu_freq": 14, "ex2_emu_res": 6, "ex2_emu_start_frg": 0, "num_regs_softmax": 256, "num_regs_correction": 160}, +} +_FP8_TUNING_CONFIG = { + (True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 160, 'num_regs_correction': 72}, +} +_FP8_SMALL_HDIM_REGS = { + False: {"num_regs_softmax": 168, "num_regs_correction": 96, "num_regs_other": 80}, + True: {"num_regs_softmax": 152, "num_regs_correction": 96, "num_regs_other": 112}, } # === END TUNING KNOBS === +class DescaleTensors(NamedTuple): + q_descale: Optional[cute.Tensor] = None + k_descale: Optional[cute.Tensor] = None + v_descale: Optional[cute.Tensor] = None + + def __new_from_mlir_values__(self, values): + return DescaleTensors(*((*values, None, None, None)[:3])) + + class FlashAttentionForwardSm100: def __init__( @@ -187,6 +207,7 @@ def __init__( "Paged KV does not support irregular head dim" ) + # ClC does not compose with these other features, so disable even if requested self.use_clc_scheduler = ( use_clc_scheduler and self.use_tma_KV @@ -350,6 +371,7 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, + descale_tensors: Optional[DescaleTensors] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). @@ -405,6 +427,24 @@ def __call__( raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") if const_expr(self.q_dtype != self.v_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + if const_expr(self.q_dtype.width == 8): + paged_kv_non_tma = not self.use_tma_KV + if const_expr(self.head_dim_padded < 96): + fp8_regs = _FP8_SMALL_HDIM_REGS[paged_kv_non_tma] + self.num_regs_softmax = fp8_regs["num_regs_softmax"] + self.num_regs_correction = fp8_regs["num_regs_correction"] + self.num_regs_other = fp8_regs["num_regs_other"] + else: + fp8_tune = _FP8_TUNING_CONFIG.get( + (self.use_2cta_instrs, self.is_causal, self.head_dim_padded, self.is_sm103), {} + ) + if const_expr("ex2_emu_freq" in fp8_tune): + self._tune = {**self._tune, **fp8_tune} + self.enable_ex2_emu = self._tune["ex2_emu_freq"] > 0 + if const_expr(not paged_kv_non_tma and "num_regs_softmax" in fp8_tune): + self.num_regs_softmax = fp8_tune["num_regs_softmax"] + self.num_regs_correction = fp8_tune["num_regs_correction"] + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_correction self._setup_attributes() self.use_tma_O = ( self.arch >= Arch.sm_90 @@ -704,6 +744,7 @@ class SharedStorage: window_size_left, window_size_right, learnable_sink, + descale_tensors, blocksparse_tensors, sQ_layout, sK_layout, @@ -750,6 +791,7 @@ def kernel( window_size_left: Optional[Int32], window_size_right: Optional[Int32], learnable_sink: Optional[cute.Tensor], + descale_tensors: Optional[DescaleTensors], blocksparse_tensors: Optional[BlockSparseTensors], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, @@ -1178,6 +1220,7 @@ def kernel( self.softmax_loop, softmax_scale_log2=softmax_scale_log2, softmax_scale=softmax_scale, + descale_tensors=descale_tensors, thr_mma_qk=thr_mma_qk, sScale=sScale, mLSE=mLSE, @@ -1233,6 +1276,7 @@ def kernel( sm_stats_barrier, pipeline_o_epi, learnable_sink, + descale_tensors, gmem_tiled_copy_O, tma_atom_O, softmax_scale_log2, @@ -1498,6 +1542,7 @@ def mma( qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op qk_mma_idesc, pv_mma_idesc = sm100_desc.mma_op_to_idesc(qk_mma_op), sm100_desc.mma_op_to_idesc(pv_mma_op) + qk_mma_kind = sm100_utils._tcgen05_mma_kind(qk_mma_op) q_smem_base = sm100_desc.smem_desc_base_from_tensor(sQ, sm100_desc.Major.K) k_smem_base = sm100_desc.smem_desc_base_from_tensor(sK, sm100_desc.Major.K) v_smem_base = sm100_desc.smem_desc_base_from_tensor(sV, sm100_desc.Major.MN) @@ -1524,8 +1569,9 @@ def mma( # idesc=qk_mma_idesc, smem_desc_base_b=k_smem_base, tCrB_layout=tSrK[None, None, None, 0].layout, - smem_var_name_prefix=f"fa_fwd_q_smem_desc", - idesc_var_name=f"fa_fwd_qk_mma_idesc", + smem_var_name_prefix="fa_fwd_q_smem_desc", + idesc_var_name="fa_fwd_qk_mma_idesc", + kind=qk_mma_kind, smem_offset=-sQ_stage_stride if stage == 0 else sQ_stage_stride, zero_init=True, cta_group=self.cta_group_size, @@ -1754,12 +1800,39 @@ def mma( # pipeline_o_acc.producer_acquire() inside the loop. # for both softmax0 and softmax1 warp group + @cute.jit + def _kv_head_idx(self, head_idx: Int32) -> Int32: + """Map query-head tile index -> KV-head index (FA3 descale semantics).""" + if cutlass.const_expr(self.pack_gqa): + return head_idx + return head_idx // self.qhead_per_kvhead + + @cute.jit + def _load_effective_descales( + self, + descale_tensors: Optional[DescaleTensors], + batch_idx: Int32, + kv_head_idx: Int32, + ) -> Tuple[Float32, Float32]: + """Load effective QK and V descales, defaulting unspecified tensors to identity.""" + qk_descale = Float32(1.0) + v_descale = Float32(1.0) + if cutlass.const_expr(descale_tensors is not None): + if cutlass.const_expr(descale_tensors.q_descale is not None): + qk_descale = qk_descale * Float32(descale_tensors.q_descale[batch_idx, kv_head_idx]) + if cutlass.const_expr(descale_tensors.k_descale is not None): + qk_descale = qk_descale * Float32(descale_tensors.k_descale[batch_idx, kv_head_idx]) + if cutlass.const_expr(descale_tensors.v_descale is not None): + v_descale = Float32(descale_tensors.v_descale[batch_idx, kv_head_idx]) + return qk_descale, v_descale + @cute.jit def softmax_loop( self, stage: int | Int32, softmax_scale_log2: Float32, - softmax_scale: Float32, + softmax_scale: Float32 | None, + descale_tensors: Optional[DescaleTensors], thr_mma_qk: cute.core.ThrMma, tStS: cute.Tensor, # ((TILE_M, TILE_N), 1, 1, q_stage) sScale: cute.Tensor, @@ -1825,7 +1898,10 @@ def softmax_loop( ) tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 + tcgen05.copy.St32x32bOp( + tcgen05.copy.Repetition(8 if const_expr(self.q_dtype.width == 8) else 16) + ), + Float32, ) thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) tStP_r2t = thr_tmem_store.partition_D(tStP) # (((16,32),1),1,4) @@ -1841,6 +1917,7 @@ def softmax_loop( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + kv_head_idx = self._kv_head_idx(head_idx) seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) @@ -1895,10 +1972,26 @@ def softmax_loop( else: mask_fn_none = None + qk_descale, _ = self._load_effective_descales(descale_tensors, batch_idx, kv_head_idx) + + max_offset = 8 if cutlass.const_expr(self.q_dtype.width == 8) else 0 + if const_expr(self.score_mod is None): + softmax_scale_log2_eff = softmax_scale_log2 * qk_descale + softmax_scale_eff = None + else: + softmax_scale_log2_eff = softmax_scale_log2 + softmax_scale_eff = softmax_scale * qk_descale + + rescale_threshold = ( + 8.0 if const_expr(self.q_dtype.width == 16) else + 4.0 if const_expr(self.q_dtype.width == 8) else + 0.0 + ) softmax = SoftmaxSm100.create( - softmax_scale_log2, - rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, - softmax_scale=softmax_scale, + softmax_scale_log2_eff, + rescale_threshold=rescale_threshold, + softmax_scale=softmax_scale_eff, + max_offset=max_offset, ) softmax.reset() @@ -1943,7 +2036,6 @@ def softmax_loop( ) if const_expr(self.use_block_sparsity) or has_work: - # See block_sparse_utils.py NOTE [SM100 block-sparse empty tiles: mbarrier contract]. pipeline_sm_stats.producer_acquire_w_index_phase(stage, sm_stats_producer_phase) sm_stats_producer_phase ^= 1 @@ -2244,6 +2336,7 @@ def correction_loop( sm_stats_barrier: pipeline.NamedBarrier, pipeline_o_epi: pipeline.PipelineAsync, learnable_sink: Optional[cute.Tensor], + descale_tensors: Optional[DescaleTensors], gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: cute.CopyAtom, softmax_scale_log2: Float32, @@ -2284,6 +2377,17 @@ def correction_loop( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + kv_head_idx = self._kv_head_idx(head_idx) + qk_descale, v_descale = self._load_effective_descales(descale_tensors, batch_idx, kv_head_idx) + if const_expr(self.score_mod is None): + softmax_scale_log2_eff = softmax_scale_log2 * qk_descale + else: + softmax_scale_log2_eff = softmax_scale_log2 + + max_offset = Float32(8.0) if cutlass.const_expr(self.q_dtype.width == 8) else Float32(0.0) + max_offset_scale = ( + Float32(256.0) if cutlass.const_expr(self.q_dtype.width == 8) else Float32(1.0) + ) seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) @@ -2386,15 +2490,16 @@ def correction_loop( if const_expr(not self.is_split_kv) or split_idx == 0: if row_max == -Float32.inf: # It's possible to have an empty row with splitKV. - row_max = sink_val * (LOG2_E / softmax_scale_log2) - row_sum = Float32(1.0) + row_max = sink_val * (LOG2_E / softmax_scale_log2_eff) + row_sum = max_offset_scale else: row_sum += cute.math.exp2( - sink_val * LOG2_E - row_max * softmax_scale_log2, fastmath=True + sink_val * LOG2_E - row_max * softmax_scale_log2_eff + max_offset, fastmath=True ) acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) + scale = scale * v_descale # Wait for the last O to be ready from the MMA warp pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase) if const_expr(not self.use_correction_warps_for_epi): @@ -2458,7 +2563,9 @@ def correction_loop( sm_stats_consumer_phase, o_corr_consumer_phase, corr_epi_producer_phase, - softmax_scale_log2, + softmax_scale_log2_eff, + max_offset, + max_offset_scale, mO_cur, gO, gmem_tiled_copy_O_for_empty_tile, @@ -2485,7 +2592,7 @@ def correction_loop( # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) LN2 = math.log(2.0) lse = ( - (row_max * softmax_scale_log2 + cute.math.log2(row_sum, fastmath=True)) * LN2 + (row_max * softmax_scale_log2_eff + (cute.math.log2(row_sum, fastmath=True) - max_offset)) * LN2 if not acc_O_mn_row_is_zero_or_nan else -Float32.inf ) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 615561b12c3..fe68dc79415 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,24 +1,6 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. -# Supported features: -# - BF16 & FP16 dtype -# - noncausal & causal attention -# - MHA, GQA, MQA -# - hdim 64, 96, 128. -# - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape) -# - varlen -# - sliding window -# - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow) - -# Features not supported yet: -# - split (i.e. FlashDecoding) -# - tuned block sizes -# - paged KV -# - append KV to existing KV cache -# - FP8 -# - bwd pass optimized for Hopper/Blackwell - import os import math from dataclasses import dataclass @@ -52,7 +34,7 @@ ) from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 -from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100, DescaleTensors from flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 @@ -61,6 +43,11 @@ from flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine +from flash_attn.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 + +# SM100 head_dim=256 2CTA kernel imports +from flash_attn.cute.sm100_hd256_2cta_fmha_forward import BlackwellFusedMultiHeadAttentionForward +from flash_attn.cute.sm100_hd256_2cta_fmha_backward import BlackwellFusedMultiHeadAttentionBackward from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, @@ -102,6 +89,8 @@ def _get_device_arch(): def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, alignment: int) -> None: """Validate head dimension constraints based on compute capability.""" is_deepseek_shape = head_dim == 192 and head_dim_v == 128 + is_deepseek_mla_absorbed_shape = head_dim == 64 and head_dim_v == 512 + is_dedicate_kernel_shape = head_dim == 256 and head_dim_v == 256 is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128 is_sm90_range = 8 <= head_dim <= 512 and 8 <= head_dim_v <= 512 @@ -111,9 +100,9 @@ def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, f"head_dim and head_dim_v must be between 8 and 256 and divisible by {alignment}." ) elif compute_capability in [10, 11]: - assert (is_standard_range or is_deepseek_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( + assert (is_standard_range or is_deepseek_shape or is_deepseek_mla_absorbed_shape or is_dedicate_kernel_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. " - f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek." + f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek, or (256, 256) for hd256." ) @@ -256,11 +245,12 @@ def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device): if not is_fake_mode(): assert t.is_cuda, f"{name} must be on CUDA" - torch2cute_dtype_map = { torch.float16: cutlass.Float16, torch.bfloat16: cutlass.BFloat16, torch.float32: cutlass.Float32, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e5m2: cutlass.Float8E5M2, } @@ -296,17 +286,18 @@ def _resolve_causal_local_window(causal, window_size_left, window_size_right, ma local = False return causal, local, window_size_left, window_size_right - def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + qv: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, + min_seqlen_k: Optional[int] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, @@ -328,6 +319,10 @@ def _flash_attn_fwd( out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, aux_tensors: Optional[list[torch.Tensor]] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + gather_kv_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. @@ -343,6 +338,7 @@ def _flash_attn_fwd( aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. """ q, k, v = [maybe_contiguous(t) for t in (q, k, v)] + q_descale, k_descale, v_descale = [maybe_contiguous(t) for t in (q_descale, k_descale, v_descale)] num_head, head_dim = q.shape[-2:] if cu_seqlens_q is None: batch_size, seqlen_q = q.shape[:2] @@ -388,7 +384,9 @@ def _flash_attn_fwd( assert seqused_k is None or seqused_k.shape == (batch_size,), ( "seqused_k must have shape (batch_size,)" ) - assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" + assert q.dtype in [torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2], ( + "inputs must be float16, bfloat16, fp8 e4m3fn, or fp8 e5m2" + ) assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: if t is not None: @@ -409,6 +407,9 @@ def _flash_attn_fwd( q, k, v, + q_descale, + k_descale, + v_descale, cu_seqlens_q, cu_seqlens_k, seqused_q, @@ -424,14 +425,17 @@ def _flash_attn_fwd( if arch // 10 not in [8, 12]: _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment) if softmax_scale is None: - softmax_scale = 1.0 / math.sqrt(head_dim) + softmax_scale = 1.0 / math.sqrt(head_dim) if qv is None else 1.0 / math.sqrt(head_dim + head_dim_v) if softcap == 0.0: softcap = None qhead_per_kvhead = num_head // num_head_kv if pack_gqa is None: pack_gqa = qhead_per_kvhead > 1 - out_torch_dtype = q.dtype + is_fp8 = q.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + if is_fp8 and (q.requires_grad or k.requires_grad or v.requires_grad): + raise NotImplementedError("FA4 CuTe FP8 backward is not supported yet (forward-only).") + out_torch_dtype = torch.bfloat16 if is_fp8 else q.dtype device = q.device q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) @@ -453,7 +457,24 @@ def _flash_attn_fwd( elif lse is not None: _validate_tensor(lse, "lse", lse_shape, torch.float32, device) + if seqlen_k == 0: + out.zero_() + if lse is not None: + lse.fill_(float("-inf")) + return out, lse + + if is_fp8: + for t, name in ((q_descale, "q_descale"), (k_descale, "k_descale"), (v_descale, "v_descale")): + if t is not None: + _validate_tensor(t, name, (batch_size, num_head_kv), torch.float32, device) + else: + assert q_descale is None and k_descale is None and v_descale is None, ( + "q_descale/k_descale/v_descale are only supported for FP8 inputs" + ) + dtype = torch2cute_dtype_map[q.dtype] + if is_fp8: + assert arch // 10 == 10, "FP8 is only supported on SM100 (compute capability 10.x) for FA4 CuTe." use_block_sparsity = block_sparse_tensors is not None causal, local, window_size_left, window_size_right = _resolve_causal_local_window( @@ -495,11 +516,16 @@ def _flash_attn_fwd( # TODO: fix GQA + SplitKV + non-varlen if pack_gqa and num_splits != 1 and cu_seqlens_q is None: pack_gqa = False + + if pack_gqa and qv is not None and 128 % qhead_per_kvhead != 0: + pack_gqa = False if max_seqlen_q is None: max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q if max_seqlen_k is None: max_seqlen_k = seqlen_k + if cu_seqlens_k is None and seqused_k is None: + min_seqlen_k = seqlen_k seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead if arch // 10 == 10: q_stage = 2 if seqlen_q_packgqa > tile_m else 1 @@ -518,7 +544,7 @@ def _flash_attn_fwd( # SplitKV uses float32 partial output, which doubles the O buffer size # in shared memory, causing OOM for diff-headdim (192, 128) if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1: - if num_n_blocks >= 64: + if num_n_blocks >= 64 and head_dim_v != 512: tile_n = 64 num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) @@ -546,13 +572,20 @@ def _flash_attn_fwd( and (tile_m % qhead_per_kvhead == 0 or not pack_gqa) ) - # hash score and mask mods for compile cache - score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False - mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False + # hd=256 2CTA forward uses dedicated kernel (SM100 only) + use_dedicated_hd256_kernel = arch // 10 == 10 and head_dim == 256 and head_dim_v == 256 + use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel if softcap is not None: assert score_mod is None, "softcap and score_mod cannot be used together" score_mod = utils.create_softcap_scoremod(softcap) + elif score_mod is not None: + if arch // 10 == 8: + raise NotImplementedError("Custom user-provided score_mod is not supported on SM8x architectures.") + + # hash score and mask mods for compile cache + score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False + mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False is_varlen = ( cu_seqlens_q is not None @@ -561,6 +594,13 @@ def _flash_attn_fwd( or seqused_k is not None ) + # CLC regressed for varlen MHA and dense noncausal. Imbalanced varlen shapes + # keep more K/V blocks in flight and hurt L2; dense noncausal mostly just + # pays work-stealing overhead. + is_varlen_mha = is_varlen and qhead_per_kvhead == 1 + is_dense_noncausal = not is_varlen and not causal and not local + use_clc_scheduler = requested_use_clc_scheduler and not is_varlen_mha and not is_dense_noncausal + if mask_mod is not None: if is_varlen: raise NotImplementedError( @@ -605,6 +645,44 @@ def _flash_attn_fwd( else: aux_tensor_metadata = None + if qv is not None: + assert arch // 10 in [10, 11], "only support Blackwell arch with qv" + assert qv.shape[:-1] == q.shape[:-1] + assert qv.shape[-1] == head_dim_v + assert head_dim == 64 and head_dim_v == 512, "only support MLA weight absorbed shape with qv" + assert not local, "local not yet supported with qv" + assert page_table is None, "page table not yet supported with qv" + assert q_descale is None and k_descale is None and v_descale is None, ( + "q_descale/k_descale/v_descale are not yet supported with qv" + ) + + assert not is_split_kv, "split kv not supported with qv" + assert learnable_sink is None + assert softcap is None + assert score_mod is None + assert mask_mod is None + + qv = maybe_contiguous(qv) + + gather_kv_length = 2048 + sparse_kv = gather_kv_indices is not None + disable_sparse_kv_bitmask = False + if sparse_kv: + assert gather_kv_indices.shape[:-1] == q.shape[:-2] + gather_kv_length = gather_kv_indices.shape[-1] + assert gather_kv_length % 256 == 0 + if min_seqlen_k is None or causal: + disable_sparse_kv_bitmask = False + else: + # seqlen_k_boundary = min_seqlen_k - max_seqlen_q + 1 if causal else min_seqlen_k + seqlen_k_boundary = min_seqlen_k + disable_sparse_kv_bitmask = seqlen_k_boundary >= gather_kv_length + else: + assert gather_kv_indices is None, "gather_kv_indices is only supported with qv" + gather_kv_length = None + sparse_kv = None + disable_sparse_kv_bitmask = None + compile_key = ( dtype, head_dim, @@ -625,6 +703,9 @@ def _flash_attn_fwd( window_size_left is not None, window_size_right is not None, learnable_sink is not None, + q_descale is not None, + k_descale is not None, + v_descale is not None, tile_m, tile_n, q_stage, @@ -637,9 +718,14 @@ def _flash_attn_fwd( q_subtile_factor, mma_pv_is_rs, intra_wg_overlap, - requested_use_clc_scheduler, + use_clc_scheduler, + qv is not None, + gather_kv_length, + sparse_kv, + disable_sparse_kv_bitmask, fa_logging.get_fa_log_level(), ) + if compile_key not in _flash_attn_fwd.compile_cache: ( cu_seqlens_q_tensor, @@ -668,6 +754,33 @@ def _flash_attn_fwd( else: lse_tensor = None + q_descale_tensor = ( + to_cute_tensor(q_descale, assumed_align=4, leading_dim=1) + if q_descale is not None + else None + ) + k_descale_tensor = ( + to_cute_tensor(k_descale, assumed_align=4, leading_dim=1) + if k_descale is not None + else None + ) + v_descale_tensor = ( + to_cute_tensor(v_descale, assumed_align=4, leading_dim=1) + if v_descale is not None + else None + ) + descale_tensors_tensor = ( + DescaleTensors( + q_descale=q_descale_tensor, + k_descale=k_descale_tensor, + v_descale=v_descale_tensor, + ) + if q_descale_tensor is not None + or k_descale_tensor is not None + or v_descale_tensor is not None + else None + ) + sparse_tensors = None if normalized_block_sparse_tensors is not None: sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) @@ -677,6 +790,9 @@ def _flash_attn_fwd( if aux_tensors is not None: cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] + qv_tensor = to_cute_tensor(qv) if qv is not None else None + gather_kv_indices_tensor = to_cute_tensor(gather_kv_indices) if gather_kv_indices is not None else None + if arch // 10 == 8: assert page_table is None, "paged KV not supported on SM 8.0" assert not is_split_kv, "SplitKV not supported on SM 8.0" @@ -721,31 +837,79 @@ def _flash_attn_fwd( paged_kv_non_tma=page_size not in [None, tile_n], ) elif arch // 10 in [10, 11]: - fa_fwd = FlashAttentionForwardSm100( - head_dim, - head_dim_v, - qhead_per_kvhead=qhead_per_kvhead, - is_causal=causal, - is_local=local, - is_split_kv=is_split_kv, - pack_gqa=pack_gqa, - m_block_size=tile_m, - n_block_size=tile_n, - q_stage=q_stage, - is_persistent=not causal - and not local - and cu_seqlens_q is None - and seqused_q is None - and not is_split_kv, - score_mod=score_mod, - mask_mod=mask_mod, - has_aux_tensors=aux_tensors is not None, - paged_kv_non_tma=page_size not in [None, tile_n], - is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, - q_subtile_factor=q_subtile_factor, - use_2cta_instrs=use_2cta_instrs, - use_clc_scheduler=requested_use_clc_scheduler, - ) + if qv is not None: + fa_fwd = FlashAttentionMLAForwardSm100( + is_causal=causal, + use_cpasync_load_KV=sparse_kv, + topk_length=gather_kv_length, + is_topk_gather=sparse_kv, + pack_gqa=pack_gqa, + qhead_per_kvhead=qhead_per_kvhead, + nheads_kv=num_head_kv, + is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, + disable_bitmask=disable_sparse_kv_bitmask, + ) + else: + if use_dedicated_hd256_kernel: + # hd=256 2CTA forward: check for currently unsupported features + assert softcap is None, "SM100 forward with head_dim=256 does not support softcap" + assert not use_block_sparsity, \ + "SM100 forward with head_dim=256 does not support block sparsity" + assert learnable_sink is None, \ + "SM100 forward with head_dim=256 does not support learnable_sink" + assert seqused_q is None and seqused_k is None, \ + "SM100 forward with head_dim=256 does not support seqused_q/seqused_k" + if page_table is not None: + assert max_seqlen_k % page_size == 0, ( + f"SM100 hd256 2CTA paged KV requires max_seqlen_k divisible by " + f"page_size ({page_size}), got max_seqlen_k={max_seqlen_k}" + ) + assert page_table.shape[1] == max_seqlen_k // page_size, ( + f"SM100 hd256 2CTA paged KV requires page_table.shape[1] == " + f"max_seqlen_k // page_size ({max_seqlen_k} // {page_size} = " + f"{max_seqlen_k // page_size}), got {page_table.shape[1]}; " + f"pass page_table[:, :{max_seqlen_k // page_size}] to slice to " + f"the actual sequence length" + ) + assert page_table.stride(0) == page_table.shape[1], ( + f"SM100 hd256 2CTA paged KV requires a fully contiguous page_table " + f"(stride(0)={page_table.stride(0)} must equal " + f"shape[1]={page_table.shape[1]})" + ) + # pack_gqa is an auto-selected optimization; disable it for hd256 kernel + pack_gqa = False + + flash_fwd_obj_cls = ( + BlackwellFusedMultiHeadAttentionForward + if use_dedicated_hd256_kernel + else FlashAttentionForwardSm100 + ) + + fa_fwd = flash_fwd_obj_cls( + head_dim, + head_dim_v, + qhead_per_kvhead=qhead_per_kvhead, + is_causal=causal, + is_local=local, + is_split_kv=is_split_kv, + pack_gqa=pack_gqa, + m_block_size=tile_m, + n_block_size=tile_n, + q_stage=q_stage, + is_persistent=not causal + and not local + and cu_seqlens_q is None + and seqused_q is None + and not is_split_kv, + score_mod=score_mod, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + paged_kv_non_tma=page_size not in [None, tile_n], + is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, + q_subtile_factor=q_subtile_factor, + use_2cta_instrs=use_2cta_instrs, + use_clc_scheduler=use_clc_scheduler, + ) elif arch // 10 == 12: # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity assert not use_block_sparsity, "Block sparsity not supported on SM 12.0" @@ -773,51 +937,118 @@ def _flash_attn_fwd( f"Unsupported compute capability: {arch}. Supported: 8.x, 9.x, 10.x, 11.x, 12.x" ) # TODO: check @can_implement - _flash_attn_fwd.compile_cache[compile_key] = cute.compile( - fa_fwd, - q_tensor, - k_tensor, - v_tensor, - o_tensor, - lse_tensor, - softmax_scale, - cu_seqlens_q_tensor, - cu_seqlens_k_tensor, - seqused_q_tensor, - seqused_k_tensor, - page_table_tensor, - window_size_left, - window_size_right, - learnable_sink_tensor, - sparse_tensors, - cute_aux_tensors, - current_stream, - options="--enable-tvm-ffi", - ) + if qv is not None: + _flash_attn_fwd.compile_cache[compile_key] = cute.compile( + fa_fwd, + q_tensor, + qv_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + softmax_scale, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + gather_kv_indices_tensor, + page_table_tensor, + window_size_left, + window_size_right, + current_stream, + options="--enable-tvm-ffi", + ) + else: + compile_args = [ + fa_fwd, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + softmax_scale, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + page_table_tensor, + window_size_left, + window_size_right, + learnable_sink_tensor, + sparse_tensors, + cute_aux_tensors, + current_stream, + ] + if arch // 10 in [10, 11]: + compile_args.insert(-3, descale_tensors_tensor) + _flash_attn_fwd.compile_cache[compile_key] = cute.compile(*compile_args, options="--enable-tvm-ffi") - # In "fake mode", we will take torch fake tensors as input and the expected behaviors are: - # - Use those fake metadata to populate compilation cache - # - Return "fake" output tensors, which could be needed in follow-up fake operations - # Thus, we skip the actual kernel invocation here. if not is_fake_mode(): - _flash_attn_fwd.compile_cache[compile_key]( - q.detach(), - k.detach(), - v.detach(), - out.detach() if not is_split_kv else out_partial, - lse_partial if is_split_kv else lse, - softmax_scale, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - page_table, - window_size_left, - window_size_right, - learnable_sink, - normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, - aux_tensors, + q_call, k_call, v_call = q.detach(), k.detach(), v.detach() + qv_call = qv.detach() if qv is not None else None + if is_fp8: + # need uint8 workaround until we pin torch >= 2.11.0 where fp8 export is supported + q_call = q_call.view(torch.uint8) + k_call = k_call.view(torch.uint8) + v_call = v_call.view(torch.uint8) + if qv_call is not None: + qv_call = qv_call.view(torch.uint8) + descale_tensors = ( + DescaleTensors(q_descale=q_descale, k_descale=k_descale, v_descale=v_descale) + if q_descale is not None or k_descale is not None or v_descale is not None + else None ) + if qv is not None: + _flash_attn_fwd.compile_cache[compile_key]( + q_call, + qv_call, + k_call, + v_call, + out.detach(), + lse, + softmax_scale, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + gather_kv_indices, + page_table, + window_size_left, + window_size_right, + ) + else: + call_args = [ + q_call, + k_call, + v_call, + out.detach() if not is_split_kv else out_partial, + lse_partial if is_split_kv else lse, + softmax_scale, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + page_table, + window_size_left, + window_size_right, + learnable_sink, + ] + if arch // 10 in [10, 11]: + call_args.append(descale_tensors) + call_args.extend([ + ( + normalized_block_sparse_tensors.mask_block_cnt, + normalized_block_sparse_tensors.mask_block_idx, + normalized_block_sparse_tensors.full_block_cnt, + normalized_block_sparse_tensors.full_block_idx, + normalized_block_sparse_tensors.dq_write_order, + normalized_block_sparse_tensors.dq_write_order_full, + ) + if normalized_block_sparse_tensors is not None + else None, + aux_tensors, + ]) + _flash_attn_fwd.compile_cache[compile_key](*call_args) if is_split_kv: _flash_attn_fwd_combine( out_partial, @@ -878,7 +1109,8 @@ def make_fake_bwd_tensors(dtype, has_gqa, varlen_q, varlen_k): def _compile_bwd_preprocess( - dtype, head_dim, head_dim_v, m_block_size, has_cuseqlens_q, has_seqused_q, has_dlse, + dtype, head_dim, head_dim_v, m_block_size, has_cuseqlens_q, has_seqused_q, has_dlse, has_dq_accum, + use_padded_offsets, ): """Compile bwd preprocess kernel using cute fake tensors (no real GPU tensors needed).""" mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( @@ -889,7 +1121,10 @@ def _compile_bwd_preprocess( mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None mSequsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None mdLSE = fake_tensor(Float32, mLSE.shape, divisibility=1) if has_dlse else None - fa_bwd_pre = FlashAttentionBackwardPreprocess(dtype, head_dim, head_dim_v, m_block_size) + mdQaccum = mdQaccum if has_dq_accum else None + fa_bwd_pre = FlashAttentionBackwardPreprocess( + dtype, head_dim, head_dim_v, m_block_size, use_padded_offsets=use_padded_offsets + ) return cute.compile( fa_bwd_pre, mO, mdO, mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSequsedQ, mdLSE, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), @@ -901,11 +1136,13 @@ def _bwd_preprocess( out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse, dtype, head_dim, head_dim_v, m_block_size, + use_padded_offsets=True, ): """Backward preprocess: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum.""" is_varlen = cu_seqlens_q is not None compile_key = ( - dtype, head_dim, head_dim_v, m_block_size, is_varlen, seqused_q is not None, dlse is not None, + dtype, head_dim, head_dim_v, m_block_size, is_varlen, seqused_q is not None, dlse is not None, dq_accum is not None, + use_padded_offsets, ) if compile_key not in _bwd_preprocess.compile_cache: _bwd_preprocess.compile_cache[compile_key] = _compile_bwd_preprocess(*compile_key) @@ -1018,6 +1255,7 @@ def _flash_attn_bwd( num_head, head_dim = q.shape[-2:] head_dim_v = v.shape[-1] + window_size = [window_size_left, window_size_right] causal, local, window_size_left, window_size_right = _resolve_causal_local_window( causal, window_size_left, window_size_right ) @@ -1088,10 +1326,14 @@ def _flash_attn_bwd( or score_mod is not None or score_mod_bwd is not None or mask_mod is not None + or block_sparse_tensors is not None ) cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1 use_2cta_instrs = cluster_size==2 + use_dedicated_hd256_kernel = arch // 10 == 10 and head_dim == 256 and head_dim_v == 256 + use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel + q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ maybe_contiguous(t) for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -1172,12 +1414,20 @@ def _flash_attn_bwd( pack_gqa = qhead_per_kvhead > 1 # pack_gqa backward not yet supported in bwd pack_gqa = False - if score_mod is not None: + + if softcap != 0.0: + assert score_mod is None and score_mod_bwd is None, ( + "softcap and score_mod/score_mod_bwd cannot be used together" + ) + score_mod = utils.create_softcap_scoremod(softcap) + score_mod_bwd = utils.create_softcap_scoremod_bwd(softcap) + elif score_mod is not None: assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided" - assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)" assert cu_seqlens_q is None and cu_seqlens_k is None, ( "varlen + score_mod not supported in bwd yet" ) + if arch // 10 == 8: + raise NotImplementedError("Custom user-provided score_mod is not supported on SM8x architectures.") device = q.device out_torch_dtype = q.dtype @@ -1200,12 +1450,16 @@ def _flash_attn_bwd( head_dim_rounded = (head_dim + 32 - 1) // 32 * 32 if cu_seqlens_q is None: - dq_accum = torch.empty( - batch_size, - num_head, - seqlen_q_rounded * head_dim_rounded, - dtype=torch.float32, - device=device, + dq_accum = ( + None + if use_dedicated_hd256_kernel + else torch.empty( + batch_size, + num_head, + seqlen_q_rounded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) ) dpsum = torch.empty( batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device @@ -1217,8 +1471,12 @@ def _flash_attn_bwd( total_q_rounded_padded = ( (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size ) - dq_accum = torch.empty( - num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device + dq_accum = ( + None + if use_dedicated_hd256_kernel + else torch.empty( + num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device + ) ) dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) @@ -1226,7 +1484,8 @@ def _flash_attn_bwd( # GQA (qhead_per_kvhead > 1) needs dK/dV accum+postprocess since multiple Q heads # accumulate into the same dK/dV. SM90 varlen_k with qhead_per_kvhead==1 now uses # ragged TMA tensors for direct store, so no longer needs accum+postprocess. - dKV_postprocess = qhead_per_kvhead > 1 + # hd=256 2CTA backward has its own internal postprocess for dK/dV. + dKV_postprocess = qhead_per_kvhead > 1 and not use_dedicated_hd256_kernel if dKV_postprocess: head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 if cu_seqlens_k is None: @@ -1278,10 +1537,12 @@ def _flash_attn_bwd( dV_semaphore = None # Preprocess kernel: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum. + # For hd=256 dedicated path, dq_accum is None so preprocess only fills dpsum/lse_log2. _bwd_preprocess( out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse, dtype, head_dim, head_dim_v, m_block_size, + use_padded_offsets=use_dedicated_hd256_kernel, ) # num_threads: SM90 derives from BwdConfig.num_wg, SM120 is set to 128 above, # SM100/SM110 uses default from function signature (384). @@ -1312,6 +1573,30 @@ def _flash_attn_bwd( block_size=(m_block_size, n_block_size), subtile_factor=subtile_factor, ) + if deterministic: + if normalized_block_sparse_tensors.dq_write_order is None: + raise ValueError( + "deterministic block-sparse backward requires dq_write_order in block_sparse_tensors" + ) + if ( + normalized_block_sparse_tensors.full_block_cnt is not None + and normalized_block_sparse_tensors.dq_write_order_full is None + ): + raise ValueError( + "deterministic block-sparse backward requires dq_write_order_full when full blocks are present" + ) + if normalized_block_sparse_tensors.spt is None: + raise ValueError( + "deterministic block-sparse backward requires block_sparse_tensors.spt " + "to match dq_write_order direction" + ) + if ( + normalized_block_sparse_tensors is not None + and normalized_block_sparse_tensors.spt is not None + ): + spt = normalized_block_sparse_tensors.spt and deterministic + else: + spt = (causal or local) and deterministic if arch // 10 in [8, 9, 12]: compile_key = ( @@ -1323,7 +1608,6 @@ def _flash_attn_bwd( causal, window_size_left is not None, window_size_right is not None, - softcap != 0.0, m_block_size, n_block_size, num_threads, @@ -1353,6 +1637,9 @@ def _flash_attn_bwd( get_broadcast_dims(k), get_broadcast_dims(v), get_broadcast_dims(dout), + # Prevent TVM stride poisoning when only one block is present. + (seqlen_q_rounded // m_block_size == 1), + (seqlen_k_rounded // n_block_size == 1), ) else: compile_key = ( @@ -1364,7 +1651,6 @@ def _flash_attn_bwd( causal, window_size_left is not None, window_size_right is not None, - softcap != 0.0, m_block_size, n_block_size, num_threads, @@ -1372,6 +1658,7 @@ def _flash_attn_bwd( cluster_size, use_2cta_instrs, deterministic, + spt, score_mod_hash, score_mod_bwd_hash, mask_mod_hash, @@ -1386,14 +1673,17 @@ def _flash_attn_bwd( get_broadcast_dims(k), get_broadcast_dims(v), get_broadcast_dims(dout), + # Prevent TVM stride poisoning when only one block is present. + (seqlen_q_rounded // m_block_size == 1), + (seqlen_k_rounded // n_block_size == 1), ) + if compile_key not in _flash_attn_bwd.compile_cache: q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv) ] - dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ - to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2) - ] + lse_log2_tensor, dpsum_tensor = [to_cute_tensor(t) for t in (lse_log2, dpsum)] + dq_accum_tensor = to_cute_tensor(dq_accum) if dq_accum is not None else None if dKV_postprocess: dk_accum_tensor, dv_accum_tensor = [ to_cute_tensor(t) for t in (dk_accum, dv_accum) @@ -1428,6 +1718,8 @@ def _flash_attn_bwd( AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs=V_in_regs, + score_mod=score_mod, + score_mod_bwd=score_mod_bwd, ) elif arch // 10 == 9: fa_bwd_obj = FlashAttentionBackwardSm90( @@ -1459,28 +1751,62 @@ def _flash_attn_bwd( dQ_single_wg=dQ_single_wg, ) else: - fa_bwd_obj = FlashAttentionBackwardSm100( - head_dim, - head_dim_v, - is_causal=causal, - is_local=local, - qhead_per_kvhead=qhead_per_kvhead, - tile_m=m_block_size, - tile_n=n_block_size, - cluster_size=cluster_size, - use_2cta_instrs=use_2cta_instrs, - deterministic=deterministic, - score_mod=score_mod, - score_mod_bwd=score_mod_bwd, - mask_mod=mask_mod, - has_aux_tensors=aux_tensors is not None, - subtile_factor=subtile_factor, - ) + if use_dedicated_hd256_kernel: + assert softcap == 0.0, "SM100 backward with head_dim=256 does not support softcap" + assert block_sparse_tensors is None, \ + "SM100 backward with head_dim=256 does not support block sparsity" + assert dlse is None, \ + "SM100 backward with head_dim=256 does not support dlse" + assert seqused_q is None and seqused_k is None, \ + "SM100 backward with head_dim=256 does not support seqused_q/seqused_k" + + dq_tile_mn = (128, 128) + dkdv_tile_mn = (128, 64) + fa_bwd_obj = BlackwellFusedMultiHeadAttentionBackward( + head_dim, + head_dim_v, + is_causal=causal, + is_local=local, + qhead_per_kvhead=qhead_per_kvhead, + is_persistent=False, + deterministic=deterministic, + cluster_size=cluster_size, + use_2cta_instrs=use_2cta_instrs, + score_mod=score_mod, + score_mod_bwd=score_mod_bwd, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + subtile_factor=subtile_factor, + tile_m_dq=dq_tile_mn[0], + tile_n_dq=dq_tile_mn[1], + tile_m_dkdv=dkdv_tile_mn[0], + tile_n_dkdv=dkdv_tile_mn[1], + ) + else: + fa_bwd_obj = FlashAttentionBackwardSm100( + head_dim, + head_dim_v, + is_causal=causal, + is_local=local, + qhead_per_kvhead=qhead_per_kvhead, + tile_m=m_block_size, + tile_n=n_block_size, + cluster_size=cluster_size, + use_2cta_instrs=use_2cta_instrs, + deterministic=deterministic, + spt=spt, + score_mod=score_mod, + score_mod_bwd=score_mod_bwd, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + subtile_factor=subtile_factor, + ) # Block sparse tensors for backward use Q-direction indexing (transposed from forward). sparse_tensors_compile = None if normalized_block_sparse_tensors is not None: sparse_tensors_compile = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) + dq_accum_tensor = dq_tensor if use_dedicated_hd256_kernel else dq_accum_tensor # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( @@ -1499,7 +1825,6 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, - None, # softcap - not yet supported in backward window_size_left, window_size_right, dQ_semaphore_tensor, @@ -1511,6 +1836,7 @@ def _flash_attn_bwd( options="--enable-tvm-ffi", ) if not is_fake_mode(): + dq_accum = dq if use_dedicated_hd256_kernel else dq_accum _flash_attn_bwd.compile_cache[compile_key]( q.detach(), k.detach(), @@ -1526,51 +1852,60 @@ def _flash_attn_bwd( cu_seqlens_k, seqused_q, seqused_k, - None, # softcap - not yet supported in backward window_size_left, window_size_right, dQ_semaphore, dK_semaphore, dV_semaphore, aux_tensors, - normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, + ( + normalized_block_sparse_tensors.mask_block_cnt, + normalized_block_sparse_tensors.mask_block_idx, + normalized_block_sparse_tensors.full_block_cnt, + normalized_block_sparse_tensors.full_block_idx, + normalized_block_sparse_tensors.dq_write_order, + normalized_block_sparse_tensors.dq_write_order_full, + ) + if normalized_block_sparse_tensors is not None + else None, ) - - if arch // 10 == 9: - # dQ postprocess: match main kernel's MMA WG count, unless dQ_single_wg - num_threads_post_dQ = 128 if dQ_single_wg else cfg.num_wg * 128 - num_threads_post_dKV = cfg.num_wg * 128 - else: - num_threads_post_dQ = 128 - num_threads_post_dKV = 128 - # Postprocess: convert dq_accum from float32 to dq in bf16/fp16 - _bwd_postprocess_convert( - dq_accum, dq, softmax_scale, - cu_seqlens_q, seqused_q, - arch, dtype, head_dim, m_block_size, num_threads_post_dQ, - AtomLayoutMdQ, dQ_swapAB, - use_2cta_instrs=use_2cta_instrs, cluster_size=1, - ) + # hd=256 2CTA backward has its own internal postprocess, skip here. + if not use_dedicated_hd256_kernel: + if arch // 10 == 9: + # dQ postprocess: match main kernel's MMA WG count, unless dQ_single_wg + num_threads_post_dQ = 128 if dQ_single_wg else cfg.num_wg * 128 + num_threads_post_dKV = cfg.num_wg * 128 + else: + num_threads_post_dQ = 128 + num_threads_post_dKV = 128 - if dKV_postprocess: - # Postprocess: convert dk_accum from float32 to dk in bf16/fp16 _bwd_postprocess_convert( - dk_accum, dk, softmax_scale, - cu_seqlens_k, seqused_k, - arch, dtype, head_dim, n_block_size, num_threads_post_dKV, - AtomLayoutNdKV, dKV_swapAB, - cluster_size=cluster_size, - ) - # Postprocess: convert dv_accum from float32 to dv in bf16/fp16 - _bwd_postprocess_convert( - dv_accum, dv, 1.0, - cu_seqlens_k, seqused_k, - arch, dtype, head_dim_v, n_block_size, num_threads_post_dKV, - AtomLayoutNdKV, dKV_swapAB, - cluster_size=cluster_size, + dq_accum, dq, softmax_scale, + cu_seqlens_q, seqused_q, + arch, dtype, head_dim, m_block_size, num_threads_post_dQ, + AtomLayoutMdQ, dQ_swapAB, + use_2cta_instrs=use_2cta_instrs, cluster_size=1, ) + if dKV_postprocess: + # Postprocess: convert dk_accum from float32 to dk in bf16/fp16 + _bwd_postprocess_convert( + dk_accum, dk, softmax_scale, + cu_seqlens_k, seqused_k, + arch, dtype, head_dim, n_block_size, num_threads_post_dKV, + AtomLayoutNdKV, dKV_swapAB, + cluster_size=cluster_size, + ) + # Postprocess: convert dv_accum from float32 to dv in bf16/fp16 + _bwd_postprocess_convert( + dv_accum, dv, 1.0, + cu_seqlens_k, seqused_k, + arch, dtype, head_dim_v, n_block_size, num_threads_post_dKV, + AtomLayoutNdKV, dKV_swapAB, + cluster_size=cluster_size, + ) + return dq, dk, dv @@ -1584,6 +1919,8 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + qv: Optional[torch.Tensor] = None, + gather_kv_indices: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), @@ -1592,28 +1929,19 @@ def forward( num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, + score_mod: Optional[Callable] = None, + score_mod_bwd: Optional[Callable] = None, mask_mod: Optional[Callable] = None, - full_block_cnt: Optional[torch.Tensor] = None, - full_block_idx: Optional[torch.Tensor] = None, - mask_block_cnt: Optional[torch.Tensor] = None, - mask_block_idx: Optional[torch.Tensor] = None, - block_size: Optional[Tuple[int, int]] = None, + aux_tensors: Optional[list] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, + block_sparse_tensors_bwd: Optional[BlockSparseTensorsTorch] = None, return_lse: bool = False, ): - # Only create block sparse tensors if at least one block sparse parameter is provided - block_sparse_tensors = None - if any(t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]): - block_sparse_tensors = BlockSparseTensorsTorch( - full_block_cnt=full_block_cnt, - full_block_idx=full_block_idx, - mask_block_cnt=mask_block_cnt, - mask_block_idx=mask_block_idx, - block_size=block_size, - ) out, lse = _flash_attn_fwd( q, k, v, + qv=qv, softmax_scale=softmax_scale, causal=causal, window_size_left=window_size[0], @@ -1622,23 +1950,31 @@ def forward( softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, + score_mod=score_mod, mask_mod=mask_mod, + aux_tensors=aux_tensors, block_sparse_tensors=block_sparse_tensors, return_lse=return_lse, + gather_kv_indices=gather_kv_indices, ) - ctx.save_for_backward(q, k, v, out, lse) + ctx.save_for_backward(q, k, v, out, lse, *(aux_tensors or ())) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.deterministic = deterministic ctx.return_lse = return_lse + ctx.score_mod = score_mod + ctx.score_mod_bwd = score_mod_bwd + ctx.mask_mod = mask_mod + ctx.block_sparse_tensors_bwd = block_sparse_tensors_bwd ctx.set_materialize_grads(False) return out, lse @staticmethod def backward(ctx, dout, dlse): - q, k, v, out, lse = ctx.saved_tensors + q, k, v, out, lse, *aux = ctx.saved_tensors + aux_tensors = aux if aux else None if not ctx.return_lse: dlse = None if dout is None: @@ -1656,9 +1992,14 @@ def backward(ctx, dout, dlse): window_size_left=ctx.window_size[0], window_size_right=ctx.window_size[1], deterministic=ctx.deterministic, + score_mod=ctx.score_mod, + score_mod_bwd=ctx.score_mod_bwd, + mask_mod=ctx.mask_mod, + aux_tensors=aux_tensors, + block_sparse_tensors=ctx.block_sparse_tensors_bwd, dlse=dlse, ) - return dq, dk, dv, *((None,) * 20) # Extra Nones is fine + return dq, dk, dv, *((None,) * 30) # Extra Nones is fine class FlashAttnVarlenFunc(torch.autograd.Function): @@ -1668,12 +2009,15 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], + qv: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, + min_seqlen_k: Optional[int] = None, + gather_kv_indices: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, @@ -1684,6 +2028,7 @@ def forward( pack_gqa: Optional[bool] = None, deterministic: bool = False, score_mod: Optional[Callable] = None, + score_mod_bwd: Optional[Callable] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, ): @@ -1691,12 +2036,14 @@ def forward( q, k, v, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, + qv=qv, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=seqused_q, + seqused_k=seqused_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, + min_seqlen_k=min_seqlen_k, page_table=page_table, softmax_scale=softmax_scale, causal=causal, @@ -1709,8 +2056,20 @@ def forward( score_mod=score_mod, aux_tensors=aux_tensors, return_lse=return_lse, + gather_kv_indices=gather_kv_indices, + ) + ctx.save_for_backward( + q, + k, + v, + out, + lse, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + *(aux_tensors or ()), ) - ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size @@ -1719,13 +2078,15 @@ def forward( ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.return_lse = return_lse + ctx.score_mod = score_mod + ctx.score_mod_bwd = score_mod_bwd ctx.set_materialize_grads(False) return out, lse @staticmethod def backward(ctx, dout, dlse): - q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors - assert ctx.softcap == 0.0 + q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, *aux = ctx.saved_tensors + aux_tensors = aux if aux else None if not ctx.return_lse: dlse = None if dout is None: @@ -1749,16 +2110,21 @@ def backward(ctx, dout, dlse): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_k=ctx.max_seqlen_k, deterministic=ctx.deterministic, + score_mod=ctx.score_mod, + score_mod_bwd=ctx.score_mod_bwd, + aux_tensors=aux_tensors, dlse=dlse, ) - return dq, dk, dv, *((None,) * 20) + return dq, dk, dv, *((None,) * 30) def flash_attn_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + qv: Optional[torch.Tensor] = None, + gather_kv_indices: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), @@ -1767,18 +2133,20 @@ def flash_attn_func( num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, + score_mod: Optional[Callable] = None, + score_mod_bwd: Optional[Callable] = None, mask_mod: Optional[Callable] = None, - full_block_cnt: Optional[torch.Tensor] = None, - full_block_idx: Optional[torch.Tensor] = None, - mask_block_cnt: Optional[torch.Tensor] = None, - mask_block_idx: Optional[torch.Tensor] = None, - block_size: Optional[Tuple[int, int]] = None, + aux_tensors: Optional[list] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, + block_sparse_tensors_bwd: Optional[BlockSparseTensorsTorch] = None, return_lse: bool = False, ): return FlashAttnFunc.apply( q, k, v, + qv, + gather_kv_indices, softmax_scale, causal, window_size, @@ -1787,12 +2155,12 @@ def flash_attn_func( num_splits, pack_gqa, deterministic, + score_mod, + score_mod_bwd, mask_mod, - full_block_cnt, - full_block_idx, - mask_block_cnt, - mask_block_idx, - block_size, + aux_tensors, + block_sparse_tensors, + block_sparse_tensors_bwd, return_lse, ) @@ -1801,12 +2169,15 @@ def flash_attn_varlen_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + qv: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, + min_seqlen_k: Optional[int] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + gather_kv_indices: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, @@ -1817,19 +2188,37 @@ def flash_attn_varlen_func( pack_gqa: Optional[bool] = None, deterministic: bool = False, score_mod: Optional[Callable] = None, + score_mod_bwd: Optional[Callable] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, ): + """ + Explanation of some optional arguments: + + qv: we write the MLA weight absorbed formula as + O = softmax(scale * (Q @ K.T + Qv @ V.T)) @ V + where Q = q_pe, Qv = q_nope, K = pe_cache, V = kv_cache. + + gather_kv_indices: a tensor of shape (batch, seqlen_q, gather_kv_length) or + (total_q, gather_kv_length) if there is cu_seqlens_q. + Currently, only used for topk sparsity with MLA absorption kernel. + + min_seqlen_k: for varlen, specifies the minimum kv sequence length for any batch. + Used with gather_kv_indices to determine if we need oob masking. + """ return FlashAttnVarlenFunc.apply( q, k, v, + qv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, + min_seqlen_k, + gather_kv_indices, page_table, softmax_scale, causal, @@ -1840,6 +2229,7 @@ def flash_attn_varlen_func( pack_gqa, deterministic, score_mod, + score_mod_bwd, aux_tensors, return_lse, ) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 6b5ca16c6f5..9c171ba9865 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -1,14 +1,17 @@ # Copyright (c) 2025, Tri Dao. -from typing import Optional, Callable, TypeAlias +from typing import Optional, Callable, TypeAlias, Tuple from dataclasses import dataclass +import enum import cutlass import cutlass.cute as cute from cutlass import Float32, Int32, Uint32, const_expr +from cutlass.cutlass_dsl import min as dsl_min from quack import layout_utils import flash_attn.cute.utils as utils +from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.seqlen_info import SeqlenInfoQK MaskGenFn: TypeAlias = Callable[[int], Uint32] @@ -384,6 +387,8 @@ def apply_mask_sm100( fastdiv_mods=(None, None), head_divmod=None, check_q_boundary: bool = False, + r2p: bool = True, + rBitmask: Optional[cute.Tensor] = None, ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_shape = (self.tile_m, self.tile_n) @@ -397,8 +402,18 @@ def apply_mask_sm100( if n_block < 0: n_block = 0 seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - r2p = True - if const_expr(not mask_causal and not mask_local and mask_mod is None): + + if const_expr(rBitmask is not None): + ncol_packed = const_expr(cute.size(rBitmask.shape[0])) + for i in cutlass.range_constexpr(ncol_packed): + col_start = 32 * i # mask is bit-packed into uint32 + curr_mask_val = rBitmask[i] + for j in cutlass.range_constexpr(32): + curr_col = col_start + j + mask = (curr_mask_val >> j) & 1 + acc_S[curr_col] = acc_S[curr_col] if cutlass.Boolean(mask) else -Float32.inf + + elif const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): if const_expr(not r2p): for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): @@ -548,7 +563,7 @@ def apply_mask_sm100_transposed( """ Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q. - Coordinate conventio: + Coordinate convention: - ROW corresponds to Q (m_block) - COL corresponds to KV (n_block) @@ -709,3 +724,747 @@ def mask_gen_fn(s: int) -> Uint32: mask_gen_fn, rank1=True, ) + + +# ----------------------------------------------------------------------------- +# SM100 FMHA fused-mask policy layer (separate from generic mask primitives). +# ----------------------------------------------------------------------------- + + +class Sm100MaskEnum(enum.Enum): + """Enumeration of mask types for FMHA operations. + + - RESIDUAL_MASK: Residual mask for handling variable sequence lengths + - WINDOW_MASK: Window mask for attention which also includes causal and no mask + - WINDOW_MASK_INFERENCE: Same as the window mask, but has the limitation that the end of q is aligned with the end of k + - WINDOW_MASK_BWD: Window mask for backward pass + - WINDOW_MASK_BWD_INFERENCE: Same as the window mask for backward pass, but has the limitation that the end of q is aligned with the end of k + """ + + NO_MASK = enum.auto() + RESIDUAL_MASK = enum.auto() + CAUSAL_MASK = enum.auto() + WINDOW_MASK = enum.auto() + WINDOW_MASK_INFERENCE = enum.auto() + # Deprecated the following types + WINDOW_MASK_BWD = enum.auto() + WINDOW_MASK_BWD_INFERENCE = enum.auto() + RESIDUAL_MASK_BWD = enum.auto() + + +class Sm100FusedMask: + """A fused mask implementation for FMHA operations. + + This class handles different types of attention masks including no mask, + residual mask for variable sequence lengths, and causal mask for + autoregressive attention patterns. + + The class provides methods to: + - Calculate trip counts for different mask types + - Apply masks to attention scores + - Handle masked and unmasked trip calculations + """ + + def get_trip_count( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Int32: + """ + Calculate the number of trips needed for the current block. + + The trip count depends on the mask type and the block coordinates. + For causal masks, it considers the autoregressive constraint. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Number of trips needed. + :rtype: Int32 + """ + result = 0 + offset = 0 + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE): + offset = seqlen_k - seqlen_q + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE): + offset = seqlen_q - seqlen_k + if cutlass.const_expr(mask_type == Sm100MaskEnum.RESIDUAL_MASK): + result = cute.ceil_div(seqlen_k, tile_shape[1]) + if cutlass.const_expr(mask_type is Sm100MaskEnum.RESIDUAL_MASK_BWD): + result = cute.ceil_div(seqlen_q, tile_shape[0]) + if cutlass.const_expr( + mask_type == Sm100MaskEnum.WINDOW_MASK + or mask_type == Sm100MaskEnum.WINDOW_MASK_INFERENCE + ): + if cutlass.const_expr(window_size_right is None): + result = cute.ceil_div(seqlen_k, tile_shape[1]) + else: + max_idx_q = (blk_coord[0] + 1) * tile_shape[0] + idx_k = max_idx_q + offset + window_size_right + tmp_blocks_k = cute.ceil_div(idx_k, tile_shape[1]) + max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1]) + result = dsl_min(max_blocks_k, tmp_blocks_k) + if cutlass.const_expr( + mask_type == Sm100MaskEnum.WINDOW_MASK_BWD + or mask_type == Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE + ): + if cutlass.const_expr(window_size_left is None): + result = cute.ceil_div(seqlen_q, tile_shape[0]) + else: + max_idx_k = (blk_coord[1] + 1) * tile_shape[1] + idx_k = max_idx_k + offset + window_size_left + tmp_blocks_q = cute.ceil_div(idx_k, tile_shape[0]) + max_blocks_q = cute.ceil_div(seqlen_q, tile_shape[0]) + result = dsl_min(max_blocks_q, tmp_blocks_q) + start_block = Sm100FusedMask.get_trip_start( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + result = result - start_block + return result + + @cute.jit + def get_trip_start_count_via_block_info( + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + is_causal: cutlass.Constexpr[bool] = False, + is_local: cutlass.Constexpr[bool] = False, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Tuple[Int32, Int32]: + block_info = BlockInfo( + tile_m=tile_shape[0], + tile_n=tile_shape[1], + is_causal=is_causal, + is_local=is_local and not is_causal, + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + + seqlen_info = SeqlenInfoQK( + offset_q=Int32(0), + offset_k=Int32(0), + padded_offset_q=Int32(0), + padded_offset_k=Int32(0), + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + has_cu_seqlens_q=False, + has_cu_seqlens_k=False, + has_seqused_q=False, + has_seqused_k=False, + ) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen_info, blk_coord[0]) + return n_block_min, n_block_max - n_block_min + + @cute.jit + def get_trip_mask_bounds_via_block_info( + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + is_causal: cutlass.Constexpr[bool] = False, + is_local: cutlass.Constexpr[bool] = False, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Tuple[Int32, Int32]: + """Return SM100-style mask boundaries for dense iteration. + + Returns: + - n_block_min_causal_local_mask: right-side masked region start + - n_block_min_before_local_mask: start of fully unmasked middle region + """ + block_info = BlockInfo( + tile_m=tile_shape[0], + tile_n=tile_shape[1], + is_causal=is_causal, + is_local=is_local and not is_causal, + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + seqlen_info = SeqlenInfoQK( + offset_q=Int32(0), + offset_k=Int32(0), + padded_offset_q=Int32(0), + padded_offset_k=Int32(0), + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + has_cu_seqlens_q=False, + has_cu_seqlens_k=False, + has_seqused_q=False, + has_seqused_k=False, + ) + n_block_min, _ = block_info.get_n_block_min_max(seqlen_info, blk_coord[0]) + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen_info, blk_coord[0], n_block_min + ) + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen_info, blk_coord[0], n_block_min + ) + return n_block_min_causal_local_mask, n_block_min_before_local_mask + + @cute.jit + def get_trip_start( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Int32: + """ + Get the start of the trip for the current block. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + """ + result = 0 + offset = 0 + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE): + offset = seqlen_k - seqlen_q + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE): + offset = seqlen_q - seqlen_k + if cutlass.const_expr( + mask_type is Sm100MaskEnum.WINDOW_MASK + or mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE + ): + if cutlass.const_expr(window_size_left is not None): + min_idx_q = blk_coord[0] * tile_shape[0] + idx_k = min_idx_q + offset - window_size_left + tmp_blocks_k = idx_k // tile_shape[1] + result = max(tmp_blocks_k, result) + if cutlass.const_expr( + mask_type is Sm100MaskEnum.WINDOW_MASK_BWD + or mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE + ): + if cutlass.const_expr(window_size_right is not None): + min_idx_k = blk_coord[1] * tile_shape[1] + idx_q = min_idx_k + offset - window_size_right + tmp_blocks_q = idx_q // tile_shape[0] + result = max(tmp_blocks_q, result) + return result + + @cute.jit + def get_leading_mask_id( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Tuple[Int32, Int32]: + """ + Get the begin and end tile idx for the leading mask. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Tuple of (begin, end) tile idx for the leading mask. + :rtype: Tuple[Int32, Int32] + """ + offset = 0 + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE): + offset = seqlen_k - seqlen_q + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE): + offset = seqlen_q - seqlen_k + leading_mask_begin = Sm100FusedMask.get_trip_start( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + trip_count = Sm100FusedMask.get_trip_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + + leading_mask_end = leading_mask_begin + if cutlass.const_expr( + mask_type is Sm100MaskEnum.WINDOW_MASK + or mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE + ): + if cutlass.const_expr(window_size_left is not None): + min_idx_q = (blk_coord[0] + 1) * tile_shape[0] + offset - window_size_left + leading_mask_end = dsl_min( + cute.ceil_div(min_idx_q, tile_shape[1]) - 1, + trip_count + leading_mask_begin - 1, + ) + else: + leading_mask_end = leading_mask_begin - 1 + elif cutlass.const_expr( + mask_type is Sm100MaskEnum.WINDOW_MASK_BWD + or mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE + ): + if cutlass.const_expr(window_size_right is not None): + min_idx_k = (blk_coord[1] + 1) * tile_shape[1] + offset - window_size_right + leading_mask_end = cute.ceil_div(min_idx_k, tile_shape[0]) - 1 + else: + leading_mask_end = leading_mask_begin - 1 + return leading_mask_begin, leading_mask_end + + @cute.jit + def get_trailing_mask_id( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Tuple[Optional[Int32], Optional[Int32]]: + """ + Get the begin and end tile idx for the trailing mask. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Tuple of (begin, end) tile idx for the trailing mask. + :rtype: Tuple[Int32, Int32] + """ + offset = 0 + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE): + offset = seqlen_k - seqlen_q + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE): + offset = seqlen_q - seqlen_k + trip_start = Sm100FusedMask.get_trip_start( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + trip_count = Sm100FusedMask.get_trip_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + + trailing_mask_begin, trailing_mask_end = None, None + if cutlass.const_expr( + mask_type is Sm100MaskEnum.WINDOW_MASK + or mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE + ): + if cutlass.const_expr(window_size_right is not None): + min_idx_q = blk_coord[0] * tile_shape[0] + offset + window_size_right + trailing_mask_begin = dsl_min( + min_idx_q // tile_shape[1], trip_count + trip_start - 1 + ) + trailing_mask_end = trip_count + trip_start - 1 + else: + # last tile, we always apply mask on it regardless whether it's a residual tile + trailing_mask_begin = trip_count + trip_start - 1 + trailing_mask_end = trip_count + trip_start - 1 + else: + if cutlass.const_expr(window_size_left is not None): + min_idx_k = blk_coord[1] * tile_shape[1] + offset + window_size_left + 1 + max_idx_k = (blk_coord[1] + 1) * tile_shape[1] + offset + window_size_left + trailing_mask_begin = dsl_min( + cute.ceil_div(min_idx_k, tile_shape[0]) - 1, + trip_count + trip_start - 1, + ) + trailing_mask_end = dsl_min( + cute.ceil_div(max_idx_k, tile_shape[0]) - 1, + trip_count + trip_start - 1, + ) + else: + # last tile, we always apply mask on it regardless whether it's a residual tile + trailing_mask_begin = trip_count + trip_start - 1 + trailing_mask_end = trip_count + trip_start - 1 + + return trailing_mask_begin, trailing_mask_end + + @cute.jit + def get_masked_leading_count( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Int32: + """ + Calculate the number of masked trips for the leading mask. + + This is used for blocks that need special handling due to masking. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Number of masked trips. + :rtype: Int32 + """ + result = 0 + if cutlass.const_expr( + mask_type is not Sm100MaskEnum.RESIDUAL_MASK + and mask_type is not Sm100MaskEnum.RESIDUAL_MASK_BWD + ): + if cutlass.const_expr(window_size_left is not None or window_size_right is not None): + leading_mask_begin, leading_mask_end = Sm100FusedMask.get_leading_mask_id( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + result = max(leading_mask_end - leading_mask_begin + 1, 0) + + return result + + @cute.jit + def get_masked_trailing_count( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + rem_count: Optional[Int32] = 0, + ) -> Int32: + """ + Calculate the number of masked trips for the trailing mask. + + This is used for blocks that need special handling due to masking. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + :param rem_count: Remaining count from previous calculations. + :type rem_count: Int32 + + :return: Number of masked trips. + :rtype: Int32 + """ + result = 0 + + if cutlass.const_expr( + mask_type is not Sm100MaskEnum.RESIDUAL_MASK + and mask_type is not Sm100MaskEnum.RESIDUAL_MASK_BWD + ): + if cutlass.const_expr(window_size_left is not None or window_size_right is not None): + trailing_mask_begin, trailing_mask_end = Sm100FusedMask.get_trailing_mask_id( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + leading_mask_begin, leading_mask_end = Sm100FusedMask.get_leading_mask_id( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + if cutlass.const_expr( + trailing_mask_begin is not None and trailing_mask_end is not None + ): + if trailing_mask_begin <= leading_mask_end: + result = max(trailing_mask_end - leading_mask_end, 0) + else: + result = max(trailing_mask_end - trailing_mask_begin + 1, 0) + else: + if seqlen_k % tile_shape[1] != 0: + result = 1 + else: + result = 0 + + return result + rem_count + + @cute.jit + def get_unmasked_trip_count( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Int32: + """ + Calculate the number of unmasked trips for the current block. + + This represents the number of trips that don't require special + masking treatment. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Number of unmasked trips. + :rtype: Int32 + """ + result = ( + Sm100FusedMask.get_trip_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + - Sm100FusedMask.get_masked_leading_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + - Sm100FusedMask.get_masked_trailing_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + 0, + ) + ) + return result + + @cute.jit + def apply_mask( + mask_type: Sm100MaskEnum, + acc_qk: cute.Tensor, + index_qk: cute.Tensor, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + index_transform: cutlass.Constexpr = lambda index_q, index_k: ( + index_q, + index_k, + ), + ): + """ + Apply the appropriate mask to the attention scores. + + This method modifies the attention scores (acc_qk) based on the mask type + and the positions in the index tensor. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param acc_qk: Accumulated QK attention scores tensor. + :type acc_qk: cute.Tensor + :param index_qk: Index tensor containing position information. + :type index_qk: cute.Tensor + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Optional[int] + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[int] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[int] + """ + offset = 0 + # NOTE: causal masking in this repo aligns the *end* of Q with the *end* of K + # when seqlen_k != seqlen_q (same as the test/reference implementation): + # k_index <= q_index + (seqlen_k - seqlen_q) + window_right + # In our kernels, causal is represented by (window_left is None, window_right is not None). + if cutlass.const_expr(window_size_left is None and window_size_right is not None): + offset = seqlen_k - seqlen_q + elif cutlass.const_expr( + mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE + or mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE + ): + offset = seqlen_k - seqlen_q + for i in cutlass.range_constexpr(cute.size(acc_qk), unroll_full=True): + index_q, index_k = index_transform(*index_qk[i]) + if cutlass.const_expr(window_size_left is not None or window_size_right is not None): + if cutlass.const_expr(window_size_left is None): + if index_q + offset + window_size_right < index_k: + acc_qk[i] = -Float32.inf + if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask + acc_qk[i] = -Float32.inf + elif cutlass.const_expr(window_size_right is None): + if index_q + offset - window_size_left > index_k: + acc_qk[i] = -Float32.inf + if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask + acc_qk[i] = -Float32.inf + else: + max_K_index = dsl_min(index_q + offset + window_size_right, seqlen_k) + min_K_index = max(0, index_q + offset - window_size_left) + if index_k > max_K_index or index_k < min_K_index: + acc_qk[i] = -Float32.inf + if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask + acc_qk[i] = -Float32.inf + + if cutlass.const_expr( + mask_type == Sm100MaskEnum.RESIDUAL_MASK + or mask_type == Sm100MaskEnum.RESIDUAL_MASK_BWD + ): + if index_k >= seqlen_k or index_q >= seqlen_q: + acc_qk[i] = -Float32.inf + + @cute.jit + def apply_mask_via_causal_local( + acc_qk: cute.Tensor, + index_qk: cute.Tensor, + seqlen_q: Int32, + seqlen_k: Int32, + apply_semantic_window: cutlass.Constexpr[bool] = True, + is_causal: cutlass.Constexpr[bool] = False, + is_local: cutlass.Constexpr[bool] = False, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + index_transform: cutlass.Constexpr = lambda index_q, index_k: ( + index_q, + index_k, + ), + ): + """Apply forward mask without mask_type. + + - If apply_semantic_window=True, apply causal/local window constraints. + - Always apply residual OOB masking (index_k>=seqlen_k or index_q>=seqlen_q). + """ + offset = 0 + if cutlass.const_expr(apply_semantic_window): + # Match WINDOW_MASK_INFERENCE semantics: end-align Q/K when lengths differ. + offset = seqlen_k - seqlen_q + for i in cutlass.range_constexpr(cute.size(acc_qk), unroll_full=True): + index_q, index_k = index_transform(*index_qk[i]) + if cutlass.const_expr(apply_semantic_window): + if cutlass.const_expr(is_causal and not is_local): + # Pure causal; tolerate both external forms: + # - (None, None) from interface + # - (None, 0) from fused-mask-style callers + right = 0 if const_expr(window_size_right is None) else window_size_right + if index_q + offset + right < index_k: + acc_qk[i] = -Float32.inf + elif cutlass.const_expr( + is_local or window_size_left is not None or window_size_right is not None + ): + if cutlass.const_expr(window_size_left is None): + if index_q + offset + window_size_right < index_k: + acc_qk[i] = -Float32.inf + elif cutlass.const_expr(window_size_right is None): + if index_q + offset - window_size_left > index_k: + acc_qk[i] = -Float32.inf + else: + max_K_index = dsl_min(index_q + offset + window_size_right, seqlen_k) + min_K_index = max(0, index_q + offset - window_size_left) + if index_k > max_K_index or index_k < min_K_index: + acc_qk[i] = -Float32.inf + # Residual mask is always needed for boundary protection. + if index_k >= seqlen_k or index_q >= seqlen_q: + acc_qk[i] = -Float32.inf diff --git a/flash_attn/cute/mma_sm100_desc.py b/flash_attn/cute/mma_sm100_desc.py index ab8dd098b92..9a2adfca8d7 100644 --- a/flash_attn/cute/mma_sm100_desc.py +++ b/flash_attn/cute/mma_sm100_desc.py @@ -83,9 +83,9 @@ def to_UMMA_format(cutlass_type) -> int: if cutlass_type is cutlass.TFloat32: return F16F32Format.TF32 # Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them - if cutlass_type is cutlass.FloatE4M3FN: + if cutlass_type is cutlass.Float8E4M3FN: return MXF8F6F4Format.E4M3 - if cutlass_type is cutlass.FloatE5M2: + if cutlass_type is cutlass.Float8E5M2: return MXF8F6F4Format.E5M2 raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}") diff --git a/flash_attn/cute/named_barrier.py b/flash_attn/cute/named_barrier.py index dd0d1988960..c4536dabd0f 100644 --- a/flash_attn/cute/named_barrier.py +++ b/flash_attn/cute/named_barrier.py @@ -45,3 +45,12 @@ class NamedBarrierBwdSm100(enum.IntEnum): Compute = enum.auto() dQaccReduce = enum.auto() TmemPtr = enum.auto() + + +class NamedBarrierFwdSm100_MLA2CTA(enum.IntEnum): + Epilogue = enum.auto() + TmemPtr = enum.auto() + Cpasync = enum.auto() + Softmax = enum.auto() + SoftmaxStatsFull = enum.auto() + SoftmaxStatsEmpty = enum.auto() diff --git a/flash_attn/cute/paged_kv.py b/flash_attn/cute/paged_kv.py index 40d69a6de3f..ed0e498dcb0 100644 --- a/flash_attn/cute/paged_kv.py +++ b/flash_attn/cute/paged_kv.py @@ -169,6 +169,42 @@ def compute_X_ptr(self, K_or_V: str): tPrXPtr[i] = utils.elem_pointer(mX, (page_offset, 0, page)).toint() return tPrXPtr + @cute.jit + def _flatten_smem_sm100(self, sX: cute.Tensor, K_or_V: str): + """Flatten SM100 smem ((a,b), cta_split, k) to (a,(b,k)); transpose V to (d,page_size).""" + sX_pi = cute.make_tensor( + sX.iterator, + cute.make_layout( + (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])), + stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])), + ), + ) + if const_expr(K_or_V == "V"): + sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0])) + return sX_pi + + @cute.jit + def _copy_row_async( + self, + tXsX: cute.Tensor, + tXcX: cute.Tensor, + mX_paged_cur_copy: cute.Tensor, + m: Int32, + should_load: cute.Tensor, + ): + """Issue cp.async copies for one row across all k-tiles.""" + for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])): + ki = tXcX[0, 0, k][1] // self.async_copy_elems + mX_paged_cur_copy_ki = mX_paged_cur_copy[None, ki] + tXsX_k = tXsX[None, m, k] + mX_paged_cur_copy_ki = cute.make_tensor(mX_paged_cur_copy_ki.iterator, tXsX_k.layout) + cute.copy( + self.gmem_tiled_copy_KV, + mX_paged_cur_copy_ki, + tXsX_k, + pred=should_load, + ) + @cute.jit def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): assert K_or_V in ("K", "V") @@ -181,18 +217,7 @@ def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): sX_pi = cute.group_modes(sX, 0, 1) # SM90 does NOT transpose V here (it's transposed via utils.transpose_view before MMA) else: - # SM100: Finesse sX layout to be (M, N). - sX_pi = cute.make_tensor( - sX.iterator, - cute.make_layout( - (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])), - stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])), - ), - ) - - if const_expr(K_or_V == "V"): - # Transpose smem V to match transposed gmem layout - sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0])) + sX_pi = self._flatten_smem_sm100(sX, K_or_V) head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded cX = cute.make_identity_tensor((self.n_block_size, head_dim)) @@ -218,17 +243,4 @@ def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): ) mX_paged_cur = cute.make_tensor(x_gmem_ptr, cute.make_layout((head_dim,))) mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,)) - - for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])): - ki = tXcX[0, 0, k][1] // self.async_copy_elems - mX_paged_cur_copy_ki = mX_paged_cur_copy[None, ki] - tXsX_k = tXsX[None, m, k] - mX_paged_cur_copy_ki = cute.make_tensor( - mX_paged_cur_copy_ki.iterator, tXsX_k.layout - ) - cute.copy( - self.gmem_tiled_copy_KV, - mX_paged_cur_copy_ki, - tXsX_k, - pred=should_load, - ) + self._copy_row_async(tXsX, tXcX, mX_paged_cur_copy, m, should_load) diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml index 2b0b60b42f1..cb1c3bb884f 100644 --- a/flash_attn/cute/pyproject.toml +++ b/flash_attn/cute/pyproject.toml @@ -28,13 +28,14 @@ dependencies = [ "typing_extensions", "apache-tvm-ffi>=0.1.5,<0.2", "torch-c-dlpack-ext", - "quack-kernels>=0.3.3", + "quack-kernels>=0.4.0", ] [project.optional-dependencies] cu13 = ["nvidia-cutlass-dsl[cu13]>=4.4.2"] dev = [ "pytest", + "pytest-xdist", "ruff", ] diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py new file mode 100644 index 00000000000..c07e3e94176 --- /dev/null +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py @@ -0,0 +1,291 @@ +# Copyright (c) 2025, Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms, Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin. + + +"""Fused multi-head attention (FMHA) backward for the SM100 architecture using CUTE DSL. + +Constraints: +* Supported head dimensions: 256 only +* mma_tiler_mn must be 64,64 +* Batch size must be the same for Q, K, and V tensors +""" + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Int32 + +from flash_attn.cute.sm100_hd256_2cta_fmha_backward_dqkernel import ( + BlackwellFusedMultiHeadAttentionBackwardDQKernel, +) +from flash_attn.cute.sm100_hd256_2cta_fmha_backward_dkdvkernel import ( + BlackwellFusedMultiHeadAttentionBackwardDKDVKernel, +) + + +def _as_bshkrd_tensor( + tensor: cute.Tensor, + h_k: Int32, + h_r: Int32, + varlen: bool, +) -> cute.Tensor: + """Normalize (B,S,H,D)/(S,H,D) tensors to (B,S,H_k,H_r,D) view.""" + if cutlass.const_expr(cute.rank(tensor.layout) == 5): + return tensor + if cutlass.const_expr(cute.rank(tensor.layout) == 4): + return cute.make_tensor( + tensor.iterator, + cute.make_layout( + (tensor.shape[0], tensor.shape[1], h_k, h_r, tensor.shape[3]), + stride=( + tensor.stride[0], + tensor.stride[1], + tensor.stride[2] * h_r, + tensor.stride[2], + tensor.stride[3], + ), + ), + ) + assert cutlass.const_expr(cute.rank(tensor.layout) == 3), "Expected rank-3 varlen tensor" + assert cutlass.const_expr(varlen), "Rank-3 input is only valid for varlen backward" + return cute.make_tensor( + tensor.iterator, + cute.make_layout( + (1, tensor.shape[0], h_k, h_r, tensor.shape[2]), + stride=( + 0, + tensor.stride[0], + tensor.stride[1] * h_r, + tensor.stride[1], + tensor.stride[2], + ), + ), + ) + + +def _as_shhb_tensor( + tensor: cute.Tensor, + h_k: Int32, + h_r: Int32, + b: Int32, + varlen: bool, +) -> cute.Tensor: + """Normalize (B,H,S)/(H,S) tensors to (S, ((H_r, H_k), B)) view.""" + if cutlass.const_expr(cute.rank(tensor.layout) == 3): + return cute.make_tensor( + tensor.iterator, + cute.make_layout( + (tensor.shape[2], ((h_r, h_k), tensor.shape[0])), + stride=( + tensor.stride[2], + ((tensor.stride[1], tensor.stride[1] * h_r), tensor.stride[0]), + ), + ), + ) + assert cutlass.const_expr(cute.rank(tensor.layout) == 2), "Expected rank-2 varlen tensor" + assert cutlass.const_expr(varlen), "Rank-2 input is only valid for varlen backward" + return cute.make_tensor( + tensor.iterator, + cute.make_layout( + (tensor.shape[1], ((h_r, h_k), b)), + stride=( + tensor.stride[1], + ((tensor.stride[0], tensor.stride[0] * h_r), 0), + ), + ), + ) + + +class BlackwellFusedMultiHeadAttentionBackward: + """FMHA backward class for executing CuTeDSL kernel.""" + + def __init__( + self, + head_dim: int, + head_dim_v: int | None = None, + is_causal: bool = False, + is_local: bool = False, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, + is_persistent: bool = False, + deterministic: bool = False, + cluster_size: int = 1, + use_2cta_instrs: bool = False, + score_mod: cutlass.Constexpr | None = None, + score_mod_bwd: cutlass.Constexpr | None = None, + mask_mod: cutlass.Constexpr | None = None, + has_aux_tensors: cutlass.Constexpr = False, + subtile_factor: cutlass.Constexpr[int] = 1, + tile_m_dq: int = 128, + tile_n_dq: int = 128, + tile_m_dkdv: int = 128, + tile_n_dkdv: int = 64, + window_size_left: int | None = None, + window_size_right: int | None = None, + use_clc_scheduler: bool = False, + ): + """Initialization.""" + head_dim_v = head_dim if head_dim_v is None else head_dim_v + assert head_dim == 256 and head_dim_v == 256, ( + "SM100 dedicated backward kernel only supports (head_dim, head_dim_v) = (256, 256)" + ) + assert not is_local, "SM100 backward with head_dim=256 does not support local attention" + assert tile_m_dq == 128 and tile_n_dq == 128, ( + "SM100 dedicated backward kernel only supports tile_m_dq=128 and tile_n_dq=128" + ) + assert tile_m_dkdv == 128 and tile_n_dkdv == 64, ( + "SM100 dedicated backward kernel only supports tile_m_dkdv=128 and tile_n_dkdv=64" + ) + assert score_mod is None and score_mod_bwd is None and mask_mod is None, ( + "SM100 backward with head_dim=256 does not support score_mod/mask_mod" + ) + assert not deterministic, ( + "SM100 backward with head_dim=256 does not support deterministic mode" + ) + assert not has_aux_tensors, "SM100 backward with head_dim=256 does not support aux_tensors" + assert cluster_size in (1, 2), ( + "SM100 backward with head_dim=256 only supports cluster_size in {1, 2}" + ) + assert use_2cta_instrs, "SM100 backward with head_dim=256 requires use_2cta_instrs=True" + # subtile_factor is accepted for interface parity with FlashAttentionBackwardSm100, + # but this dedicated kernel uses fixed internal behavior. + + self.acc_dtype = cutlass.Float32 + self.is_causal = is_causal + self.window_size_left = ( + None if (window_size_left is None or window_size_left < 0) else window_size_left + ) + self.window_size_right = ( + None if (window_size_right is None or window_size_right < 0) else window_size_right + ) + self.tile_m_dq = tile_m_dq + self.tile_n_dq = tile_n_dq + self.tile_m_dkdv = tile_m_dkdv + self.tile_n_dkdv = tile_n_dkdv + self.use_clc_scheduler = use_clc_scheduler + + self.dq_kernel = BlackwellFusedMultiHeadAttentionBackwardDQKernel( + self.acc_dtype, + (self.tile_m_dq, self.tile_n_dq, 256), + self.is_causal, + self.window_size_left, + self.window_size_right, + False, # is_persistent + False, # split_head + use_clc_scheduler=self.use_clc_scheduler, + ) + self.dkdv_kernel = BlackwellFusedMultiHeadAttentionBackwardDKDVKernel( + self.acc_dtype, + (self.tile_m_dkdv, self.tile_n_dkdv, 256), + self.is_causal, + self.window_size_left, + self.window_size_right, + use_clc_scheduler=self.use_clc_scheduler, + ) + + @cute.jit + def __call__( + self, + Q: cute.Tensor, + K: cute.Tensor, + V: cute.Tensor, + dO: cute.Tensor, + lse_log2: cute.Tensor, + dpsum: cute.Tensor, + dQ_accum: cute.Tensor | None, + dK: cute.Tensor, + dV: cute.Tensor, + scale_softmax: cutlass.Float32, + cumulative_s_q: cute.Tensor | None, + cumulative_s_k: cute.Tensor | None, + seqused_q: cute.Tensor | None = None, + seqused_k: cute.Tensor | None = None, + window_size_left: Int32 | None = None, + window_size_right: Int32 | None = None, + dQ_semaphore: cute.Tensor | None = None, + dK_semaphore: cute.Tensor | None = None, + dV_semaphore: cute.Tensor | None = None, + aux_tensors: tuple[cute.Tensor] | None = None, + block_sparse_tensors: cute.Tensor | None = None, + stream: cuda.CUstream = None, + ): + """Host function to launch CuTeDSL kernel.""" + assert seqused_q is None and seqused_k is None, ( + "SM100 backward with head_dim=256 does not support seqused_q/seqused_k" + ) + assert window_size_left is None and window_size_right is None, ( + "SM100 backward with head_dim=256 uses constructor-provided window sizes" + ) + assert dQ_semaphore is None and dK_semaphore is None and dV_semaphore is None, ( + "SM100 backward with head_dim=256 does not use semaphores" + ) + assert block_sparse_tensors is None, ( + "SM100 backward with head_dim=256 does not support block sparse tensors" + ) + assert aux_tensors is None or len(aux_tensors) == 0, ( + "SM100 backward with head_dim=256 does not support aux_tensors" + ) + assert dQ_accum is not None, ( + "SM100 backward with head_dim=256 expects dQ tensor at dQ_accum slot" + ) + dQ = dQ_accum + varlen = cumulative_s_q is not None or cumulative_s_k is not None + q_rank = cute.rank(Q.layout) + k_rank = cute.rank(K.layout) + if cutlass.const_expr(q_rank == 5): + h_q = Q.shape[2] * Q.shape[3] + elif cutlass.const_expr(q_rank == 4): + h_q = Q.shape[2] + else: + h_q = Q.shape[1] + if cutlass.const_expr(k_rank == 5): + h_k = K.shape[2] + elif cutlass.const_expr(k_rank == 4): + h_k = K.shape[2] + else: + h_k = K.shape[1] + h_r = h_q // h_k + if cutlass.const_expr(cumulative_s_q is not None): + b = cumulative_s_q.shape[0] - 1 + elif cutlass.const_expr(cumulative_s_k is not None): + b = cumulative_s_k.shape[0] - 1 + else: + b = Q.shape[0] + + Q = _as_bshkrd_tensor(Q, h_k, h_r, varlen) + K = _as_bshkrd_tensor(K, h_k, 1, varlen) + V = _as_bshkrd_tensor(V, h_k, 1, varlen) + dQ = _as_bshkrd_tensor(dQ, h_k, h_r, varlen) + dK = _as_bshkrd_tensor(dK, h_k, 1, varlen) + dV = _as_bshkrd_tensor(dV, h_k, 1, varlen) + dO = _as_bshkrd_tensor(dO, h_k, h_r, varlen) + scaled_LSE = _as_shhb_tensor(lse_log2, h_k, h_r, b, varlen) + sum_OdO = _as_shhb_tensor(dpsum, h_k, h_r, b, varlen) + + # Keep original order: dQ first, then dKdV. + self.dq_kernel( + Q, + K, + V, + dQ, + dO, + scaled_LSE, + sum_OdO, + cumulative_s_q, + cumulative_s_k, + scale_softmax, + stream, + ) + self.dkdv_kernel( + Q, + K, + V, + dK, + dV, + dO, + scaled_LSE, + sum_OdO, + cumulative_s_q, + cumulative_s_k, + scale_softmax, + stream, + ) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py new file mode 100644 index 00000000000..7a8cdeede6a --- /dev/null +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py @@ -0,0 +1,3360 @@ +# Copyright (c) 2025, Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms, Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin. + +"""Fused multi-head attention (FMHA) backward for the SM100 architecture using CUTE DSL. + +Constraints: +* Supported head dimensions: 256 only +* cta_tiler_mn must be 64,128 +* Batch size must be the same for Q, K, and V tensors +""" + +import enum +import math + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.typing import Int32 +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +from cutlass.utils import ClcDynamicPersistentTileScheduler +from flash_attn.cute.tile_scheduler import ( + ClcState, + SM100_TMEM_CAPACITY_COLUMNS, + make_sm100_thread_cooperative_group as make_thread_cooperative_group, + Sm100FmhaClcDynamicTileSchedulerParams as FmhaClcDynamicTileSchedulerParams, + Sm100FmhaClcDynamicTileScheduler as FmhaClcDynamicTileScheduler, + Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams, +) + + +LAYOUT_RANK_CONSTANT = 3 + + +@cute.jit +def split_wg( + t: cute.Tensor, + num_warp_groups: Int32, + wg_idx: Int32, +) -> cute.Tensor: + """Split warp group.""" + # dishengbin, TODO:need to double check if more efficient to split in other dimensions + ret = None + if cutlass.const_expr(cute.rank(t.layout) == LAYOUT_RANK_CONSTANT): + p = cute.composition( + t, + cute.make_layout( + ( + t.shape[0], + t.shape[1], + (num_warp_groups, cute.size(t, mode=[2]) // num_warp_groups), + ) + ), + ) + ret = p[None, None, (wg_idx, None)] + else: + p = cute.composition( + t, + cute.make_layout( + ( + t.shape[0], + t.shape[1], + t.shape[2], + (num_warp_groups, cute.size(t, mode=[3]) // num_warp_groups), + ) + ), + ) + ret = p[None, None, None, (wg_idx, None)] + return ret + + +class MaskType(enum.Enum): + """Mask type used in FMHA backward.""" + + NO_MASK = enum.auto() + RESIDUAL_MASK_FOR_BACKWARD = enum.auto() + CAUSAL_MASK_FOR_BACKWARD = enum.auto() + + +def Tmemory_offset(lane, col): + """Tensor memory offset.""" + return (lane << 16) + col + + +permute_order = (0, 1, 2, 3, 4) + + +class BlackwellFusedMultiHeadAttentionBackwardDKDVKernel: + """FMHA backward class for executing CuTeDSL kernel.""" + + def __init__( + self, + acc_dtype: type[cutlass.Numeric], + cta_tiler: tuple[int, int, int], + is_causal: bool, + window_size_left: int | None, + window_size_right: int | None, + use_clc_scheduler: bool = False, + ): + """Initialization.""" + self.acc_dtype = acc_dtype + self.cta_tiler = cta_tiler + self.use_clc_scheduler = use_clc_scheduler + self.sched_warp_id = 10 if use_clc_scheduler else None + # TODO: need check, not sure whether need to *2 if 2cta + self.tile_shape_Q = cta_tiler[0] + self.tile_shape_K = cta_tiler[1] + self.tile_shape_dQ_K = cta_tiler[2] + self.tile_shape_dV_dO = cta_tiler[2] + # For S + self.KQ_mma_tiler = ( + cta_tiler[1] * 2, + cta_tiler[0], + cta_tiler[2], + ) + # For dP + self.VdO_mma_tiler = ( + cta_tiler[1] * 2, + cta_tiler[0], + cta_tiler[2], + ) + # For dV + self.PdO_mma_tiler = ( + cta_tiler[1] * 2, + cta_tiler[2], + cta_tiler[0], + ) + # For dK + self.dSQ_mma_tiler = ( + cta_tiler[1] * 2, + cta_tiler[2], + cta_tiler[0], + ) + # For dQ, dishengbin, need to remove + self.dSK_mma_tiler = ( + cta_tiler[0] * 2, + cta_tiler[2], + cta_tiler[1], + ) + self.cluster_shape_mn = (2, 1) + self.is_causal = is_causal + self.window_size_left: int = -1 if window_size_left is None else window_size_left + self.window_size_right: int = -1 if window_size_right is None else window_size_right + self.has_sliding_window = False + if self.window_size_left > 0 or self.window_size_right > 0: + self.has_sliding_window = True + if self.is_causal: + self.window_size_right = 0 + + self.compute_warp_id = (0, 1, 2, 3, 4, 5, 6, 7) + self.mma_warp_id = 8 + self.load_warp_id = 9 + self.empty_warp_id = 10 + + self.num_compute_warps = 8 + + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * (self.num_compute_warps + 4) + + self.cta_sync_bar_id = 0 + self.tmem_alloc_sync_bar_id = 1 + self.compute_sync_bar_id = 2 + self.epilogue_sync_bar_id = 3 + self.reduce_sync_bar_id = 4 + + self.tmem_dK_offset = 0 + self.tmem_dV_offset = Tmemory_offset(0, cta_tiler[2] // 2) + self.tmem_dP_offset = Tmemory_offset(0, cta_tiler[2] + cta_tiler[0] // 2) + self.tmem_S_offset = Tmemory_offset(0, cta_tiler[2]) + + self.num_regs_reduce = 152 + self.num_regs_compute = 128 + self.num_regs_mma = 128 + self.num_regs_empty = 96 + self.num_regs_load = 96 + + self.buffer_align_bytes = 128 + + def _setup_attributes(self): + """Settings for pipeline stage.""" + self.load_mma_Q_stage = 1 + self.load_mma_K_stage = 1 + self.load_mma_V_stage = 1 + self.load_mma_QT_stage = 1 + self.load_mma_dO_stage = 1 + self.load_compute_LSE_stage = 1 + self.load_compute_sum_OdO_stage = 1 + self.mma_compute_S_stage = 1 + self.mma_compute_dP_stage = 1 + self.compute_mma_P_stage = 1 + self.compute_mma_dS_stage = 1 + self.mma_compute_dKdV_stage = 2 + + if cutlass.const_expr(self.use_clc_scheduler): + self.num_clc_stage = 1 + self.num_clc_response_bytes = 16 + + @cute.jit + def __call__( + self, + Q: cute.Tensor, + K: cute.Tensor, + V: cute.Tensor, + dK: cute.Tensor, + dV: cute.Tensor, + dO: cute.Tensor, + scaled_LSE: cute.Tensor, + sum_OdO: cute.Tensor, + cumulative_s_q: cute.Tensor | None, + cumulative_s_k: cute.Tensor | None, + scale_softmax: cutlass.Float32, + stream: cuda.CUstream, + ): + """Host function to launch CuTeDSL kernel.""" + varlen = cumulative_s_q is not None or cumulative_s_k is not None + # Infer shape metadata from normalized 5D tensors (B, S, H_k, H_r, D). + h_r = Q.shape[3] + h_k = Q.shape[2] + if cutlass.const_expr(cumulative_s_q is not None): + b = cumulative_s_q.shape[0] - 1 + elif cutlass.const_expr(cumulative_s_k is not None): + b = cumulative_s_k.shape[0] - 1 + else: + b = Q.shape[0] + problem_shape = ( + Q.shape[1], + K.shape[1], + Q.shape[4], + ((h_r, h_k), b), + ) + hb = ((h_r, h_k), b) + # (b, s, h_k, h_r, d) -> (s, d, ((h_r, h_k), b)) + Q = cute.make_tensor( + Q.iterator, + cute.make_layout( + (Q.shape[1], Q.shape[4], hb), + stride=( + cute.assume(Q.stride[1], divby=64), + Q.stride[4], + ( + (Q.shape[4], Q.shape[4] * Q.shape[3]), + ( + 0 + if varlen + else cute.assume(Q.shape[1] * Q.shape[4] * h_r * h_k, divby=64) + ), + ), + ), + ), + ) + # (b, s, h_k, 1, d) -> (s, d, ((1, h_k), b)) + K = cute.make_tensor( + K.iterator, + cute.make_layout( + (K.shape[1], K.shape[4], hb), + stride=( + cute.assume(K.stride[1], divby=64), + K.stride[4], + ( + (0, K.shape[4]), + (0 if varlen else cute.assume(K.shape[1] * K.shape[4] * 1 * h_k, divby=64)), + ), + ), + ), + ) + # (b, s, h_k, 1, d) -> (s, d, ((1, h_k), b)) + V = cute.make_tensor( + V.iterator, + cute.make_layout( + (V.shape[1], V.shape[4], hb), + stride=( + cute.assume(V.stride[1], divby=64), + V.stride[4], + ( + (0, V.shape[4]), + (0 if varlen else cute.assume(V.shape[1] * V.shape[4] * 1 * h_k, divby=64)), + ), + ), + ), + ) + # (s, d, ((h_r, h_k), b)) -> (d, s, ((h_r, h_k), b)) + QT = cute.make_tensor( + Q.iterator, + cute.make_layout( + (Q.shape[1], Q.shape[0], Q.shape[2]), + stride=( + Q.stride[1], + Q.stride[0], + Q.stride[2], + ), + ), + ) + dK = cute.make_tensor(dK.iterator, K.layout) + dV = cute.make_tensor(dV.iterator, V.layout) + # (s, d, ((h_r, h_k), b)) + dO = cute.make_tensor(dO.iterator, Q.layout) + + # (s, d, ((h_r, h_k), b)) -> (d, s, ((h_r, h_k), b)) + dOT = cute.make_tensor( + dO.iterator, + cute.make_layout( + (dO.shape[1], dO.shape[0], dO.shape[2]), + stride=( + dO.stride[1], + dO.stride[0], + dO.stride[2], + ), + ), + ) + + self.Q_major_mode = utils.LayoutEnum.from_tensor(Q).mma_major_mode() + self.K_major_mode = utils.LayoutEnum.from_tensor(K).mma_major_mode() + self.dK_major_mode = utils.LayoutEnum.from_tensor(dK).mma_major_mode() + self.V_major_mode = utils.LayoutEnum.from_tensor(V).mma_major_mode() + self.dV_major_mode = utils.LayoutEnum.from_tensor(dV).mma_major_mode() + self.dO_major_mode = utils.LayoutEnum.from_tensor(dO).mma_major_mode() + + if cutlass.const_expr(self.Q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError(f"The layout of q is not supported: {self.Q_major_mode}") + if cutlass.const_expr(self.K_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of k is not supported") + if cutlass.const_expr(self.dK_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of dk is not supported") + if cutlass.const_expr(self.V_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of v is not supported") + if cutlass.const_expr(self.dV_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of dv is not supported") + + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.TWO + PT_source = tcgen05.OperandSource.SMEM + + # compute S + KQ_tiled_mma = sm100_utils.make_trivial_tiled_mma( + K.element_type, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + cta_group, + self.KQ_mma_tiler[:2], + ) + # compute dP + VdO_tiled_mma = sm100_utils.make_trivial_tiled_mma( + V.element_type, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + cta_group, + self.VdO_mma_tiler[:2], + ) + # compute dV + PdO_tiled_mma = sm100_utils.make_trivial_tiled_mma( + dO.element_type, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.MN, + self.acc_dtype, + cta_group, + self.PdO_mma_tiler[:2], + PT_source, + ) + # compute dK + dSQ_tiled_mma = sm100_utils.make_trivial_tiled_mma( + Q.element_type, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.MN, + self.acc_dtype, + cta_group, + self.dSQ_mma_tiler[:2], + ) + # compute dQ + # dishengbin, need to remove, but used in dS_mem_layout_staged + dSK_tiled_mma = sm100_utils.make_trivial_tiled_mma( + K.element_type, + tcgen05.OperandMajorMode.MN, + tcgen05.OperandMajorMode.MN, + self.acc_dtype, + cta_group, + self.dSK_mma_tiler[:2], + ) + + atom_thr_size = cute.size(KQ_tiled_mma.thr_id.shape) + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) # type: ignore[assignment] + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (atom_thr_size,), + ) + + K_smem_layout_staged = sm100_utils.make_smem_layout_a( + KQ_tiled_mma, + self.KQ_mma_tiler, + K.element_type, + 1, + ) + Q_smem_layout_staged = sm100_utils.make_smem_layout_b( + KQ_tiled_mma, + self.KQ_mma_tiler, + Q.element_type, + self.load_mma_Q_stage, + ) + V_smem_layout_staged = sm100_utils.make_smem_layout_a( + VdO_tiled_mma, + self.VdO_mma_tiler, + V.element_type, + 1, + ) + dO_smem_layout_staged = sm100_utils.make_smem_layout_b( + VdO_tiled_mma, + self.VdO_mma_tiler, + dO.element_type, + self.load_mma_dO_stage, + ) + # dishengbin, need to remove, but used for sPT, need to double check + dS_smem_layout_staged = sm100_utils.make_smem_layout_a( + dSK_tiled_mma, + self.dSK_mma_tiler, + Q.element_type, + self.compute_mma_dS_stage, + ) + dST_smem_layout_staged = sm100_utils.make_smem_layout_a( + dSQ_tiled_mma, + self.dSQ_mma_tiler, + Q.element_type, + self.compute_mma_dS_stage, + ) + tiled_mma = dSQ_tiled_mma + is_k_major = tiled_mma.op.a_major_mode == tcgen05.OperandMajorMode.K + a_major_mode = tcgen05.OperandMajorMode.K if is_k_major else tcgen05.OperandMajorMode.MN + tmp = cute.dice(self.dSQ_mma_tiler, (1, None, 1)) + a_smem_shape = tiled_mma.partition_shape_A( + cute.dice(self.dSQ_mma_tiler, (1, None, 1)), + ) + a_smem_shape_mn_k = ( + cute.size(a_smem_shape[0][0]) * a_smem_shape[1], + cute.size(a_smem_shape[0][1]) * a_smem_shape[2], + ) + smem_layout_atom_kind = sm100_utils.get_smem_layout_atom_ab( + a_major_mode, + K.element_type, + a_smem_shape_mn_k, + ) + a_smem_layout_atom = sm100_utils.make_smem_layout_atom( + smem_layout_atom_kind, + K.element_type, + ) + + a_smem_shape = cute.append( + a_smem_shape, + self.compute_mma_dS_stage, + ) + order = (2, 1, 3) if not is_k_major else (1, 2, 3) + dST_smem_layout_staged_tmp = sm100_utils.tile_to_mma_shape( + a_smem_layout_atom, + a_smem_shape, + order=order, + ) + QT_smem_layout_staged = sm100_utils.make_smem_layout_b( + dSQ_tiled_mma, + self.dSQ_mma_tiler, + Q.element_type, + self.load_mma_QT_stage, + ) + P_smem_layout_staged = sm100_utils.make_smem_layout_a( + PdO_tiled_mma, + self.PdO_mma_tiler, + Q.element_type, + self.compute_mma_P_stage, + ) + dOT_smem_layout_staged = sm100_utils.make_smem_layout_b( + PdO_tiled_mma, + self.PdO_mma_tiler, + dO.element_type, + self.load_mma_dO_stage, + ) + LSE_smem_layout = cute.make_layout((self.cta_tiler[0], self.load_compute_LSE_stage)) + sum_OdO_smem_layout = cute.make_layout((self.cta_tiler[0], self.load_compute_sum_OdO_stage)) + + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_reduce_op = cpasync.CopyReduceBulkTensorTileS2GOp() + + K_smem_layout = cute.select(K_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + K, + K_smem_layout, + self.KQ_mma_tiler, + KQ_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + V_smem_layout = cute.select(V_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + V, + V_smem_layout, + self.VdO_mma_tiler, + VdO_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + Q_smem_layout = cute.select(Q_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + Q, + Q_smem_layout, + self.KQ_mma_tiler, + KQ_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + QT_smem_layout = cute.select(QT_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_QT, tma_tensor_QT = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + QT, + QT_smem_layout, + self.dSQ_mma_tiler, + dSQ_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + dO_smem_layout = cute.select(dO_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + dO, + dO_smem_layout, + self.VdO_mma_tiler, + VdO_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + dOT_smem_layout = cute.select(dOT_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_dOT, tma_tensor_dOT = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + dOT, + dOT_smem_layout, + self.PdO_mma_tiler, + PdO_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # for 2cta, tma_copy_QT_bytes is same as the tma_copy_Q_bytes + self.tma_copy_Q_bytes = cute.size_in_bytes(Q.element_type, Q_smem_layout) * atom_thr_size + self.tma_copy_K_bytes = cute.size_in_bytes(K.element_type, K_smem_layout) * atom_thr_size + self.tma_copy_V_bytes = cute.size_in_bytes(V.element_type, V_smem_layout) * atom_thr_size + self.tma_copy_dO_bytes = cute.size_in_bytes(dO.element_type, dO_smem_layout) * atom_thr_size + + # Variant 3a epilogue: TMA store atoms (S2G) for dK / dV. + # Each compute warp group owns half the hd_v output via split_wg, so + # the per-WG epi tile is (cta_tiler[1], cta_tiler[2] / num_compute_wgs). + # SMEM staging will alias onto sP+sdST in subsequent commits; for now + # the atoms and layouts are built and threaded through but unused. + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + num_compute_wgs = self.num_compute_warps // 4 + # Variant 3a Path 2: CTA-shared epilogue SMEM. epi_tile is (M, gcd(128B, ...)) + # = (M, 64) for bf16 hd=256. Total stages = num_compute_wgs * num_epi_stages + # = 4 stages of (64, 64), virtually a per-CTA (64, 256) buffer aliased onto + # sP+sdST. Both warp-groups cooperatively populate this buffer; TMA fires + # one (64, 64) box per stage to the corresponding (64, 64) GMEM slice. + epi_cols_dKV = math.gcd( + 128 // (dK.element_type.width // 8), self.cta_tiler[2] // num_compute_wgs + ) + num_epi_stages_dKV = (self.cta_tiler[2] // num_compute_wgs) // epi_cols_dKV + epi_tile_dKV = (self.cta_tiler[1], epi_cols_dKV) + total_epi_stages = num_compute_wgs * num_epi_stages_dKV + dK_layout_enum = utils.LayoutEnum.from_tensor(dK) + dV_layout_enum = utils.LayoutEnum.from_tensor(dV) + sdK_epi_layout = sm100_utils.make_smem_layout_epi( + dK.element_type, + dK_layout_enum, + epi_tile_dKV, + total_epi_stages, + ) + sdV_epi_layout = sm100_utils.make_smem_layout_epi( + dV.element_type, + dV_layout_enum, + epi_tile_dKV, + total_epi_stages, + ) + tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom( + tma_store_op, + dK, + cute.select(sdK_epi_layout, mode=[0, 1]), + epi_tile_dKV, + ) + tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom( + tma_store_op, + dV, + cute.select(sdV_epi_layout, mode=[0, 1]), + epi_tile_dKV, + ) + + @cute.struct + class SharedStorage: + # Pipeline barriers + load_mma_Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_Q_stage * 2] + load_mma_K_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_K_stage * 2] + load_mma_V_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_V_stage * 2] + load_mma_QT_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_QT_stage * 2] + load_mma_dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_dO_stage * 2] + load_mma_dOT_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_dO_stage * 2] + load_compute_lse_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_compute_LSE_stage * 2 + ] + load_compute_sum_OdO_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_compute_sum_OdO_stage * 2 + ] + mma_compute_S_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.mma_compute_S_stage * 2 + ] + mma_compute_dP_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.mma_compute_dP_stage * 2 + ] + compute_mma_P_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.compute_mma_P_stage * 2 + ] + compute_mma_dS_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.compute_mma_dS_stage * 2 + ] + mma_compute_dKdV_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.mma_compute_dKdV_stage * 2 + ] + tmem_holding_buf: cutlass.Int32 + tmem_dealloc_mbar_ptr: cutlass.Int64 + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + clc_response: cute.struct.MemRange[Int32, 4] + # Smem tensors + sK: cute.struct.Align[ + cute.struct.MemRange[K.element_type, cute.cosize(K_smem_layout_staged)], + self.buffer_align_bytes, + ] + # only used in 2cta + sV: cute.struct.Align[ + cute.struct.MemRange[V.element_type, cute.cosize(V_smem_layout_staged)], + self.buffer_align_bytes, + ] + sQ: cute.struct.Align[ + cute.struct.MemRange[Q.element_type, cute.cosize(Q_smem_layout_staged)], + self.buffer_align_bytes, + ] + sQT: cute.struct.Align[ + cute.struct.MemRange[Q.element_type, cute.cosize(QT_smem_layout_staged)], + self.buffer_align_bytes, + ] + sdO: cute.struct.Align[ + cute.struct.MemRange[dO.element_type, cute.cosize(dO_smem_layout_staged)], + self.buffer_align_bytes, + ] + sdOT: cute.struct.Align[ + cute.struct.MemRange[dO.element_type, cute.cosize(dOT_smem_layout_staged)], + self.buffer_align_bytes, + ] + # only used in 2cta + # dishengbin checked whether we need sP + sP: cute.struct.Align[ + cute.struct.MemRange[Q.element_type, cute.cosize(P_smem_layout_staged)], + self.buffer_align_bytes, + ] + sdST: cute.struct.Align[ + cute.struct.MemRange[Q.element_type, cute.cosize(dST_smem_layout_staged)], + self.buffer_align_bytes, + ] + + sLSE: cute.struct.Align[ + cute.struct.MemRange[self.acc_dtype, cute.cosize(LSE_smem_layout)], + self.buffer_align_bytes, + ] + sSum_OdO: cute.struct.Align[ + cute.struct.MemRange[self.acc_dtype, cute.cosize(sum_OdO_smem_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # =============================== bwd =============================== + K_val = problem_shape[1] + _, H_K = problem_shape[3][0] + B = problem_shape[3][1] + problem_shape_mbh = ( + cute.ceil_div(K_val, self.cta_tiler[1]), + cute.size(B), + cute.size(H_K), + ) + if cutlass.const_expr(self.use_clc_scheduler): + self.tile_sched_params = FmhaClcDynamicTileSchedulerParams( + problem_shape_mbh, + (*self.cluster_shape_mn, 1), + ) + bwd_grid = FmhaClcDynamicTileScheduler.get_grid_shape(self.tile_sched_params) + else: + self.tile_sched_params = FmhaStaticTileSchedulerParams( + is_persistent=False, + problem_shape_mbh=problem_shape_mbh, + ) + bwd_grid = self._compute_bwd_grid(problem_shape, self.cta_tiler[1]) + bwd_grid = cute.round_up(bwd_grid, self.cluster_shape_mnk) + + self.dkdv_bwd( + KQ_tiled_mma, + VdO_tiled_mma, + PdO_tiled_mma, + dSQ_tiled_mma, + tma_atom_K, + tma_tensor_K, + K, + tma_atom_V, + tma_tensor_V, + tma_atom_Q, + tma_tensor_Q, + Q, + tma_atom_QT, + tma_tensor_QT, + tma_atom_dO, + tma_tensor_dO, + tma_atom_dOT, + tma_tensor_dOT, + dK, + dV, + tma_atom_dK, + tma_tensor_dK, + tma_atom_dV, + tma_tensor_dV, + scaled_LSE, + scale_softmax, + sum_OdO, + problem_shape, + cumulative_s_q, + cumulative_s_k, + self.cluster_layout_vmnk, + K_smem_layout_staged, + Q_smem_layout_staged, + V_smem_layout_staged, + dO_smem_layout_staged, + dS_smem_layout_staged, + dST_smem_layout_staged, + QT_smem_layout_staged, + dOT_smem_layout_staged, + P_smem_layout_staged, + LSE_smem_layout, + sum_OdO_smem_layout, + sdK_epi_layout, + sdV_epi_layout, + self.tile_sched_params, + ).launch( + grid=bwd_grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=self.shared_storage.size_in_bytes(), # type: ignore [attr-defined] + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def dkdv_bwd( + self, + KQ_tiled_mma: cute.TiledMma, + VdO_tiled_mma: cute.TiledMma, + PdO_tiled_mma: cute.TiledMma, + dSQ_tiled_mma: cute.TiledMma, + tma_atom_K: cute.CopyAtom, + K_in: cute.Tensor, + K_ref: cute.Tensor, + tma_atom_V: cute.CopyAtom, + V_in: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + Q_in: cute.Tensor, + Q_ref: cute.Tensor, + tma_atom_QT: cute.CopyAtom, + QT_in: cute.Tensor, + tma_atom_dO: cute.CopyAtom, + dO_in: cute.Tensor, + tma_atom_dOT: cute.CopyAtom, + dOT_in: cute.Tensor, + dK: cute.Tensor, + dV: cute.Tensor, + tma_atom_dK: cute.CopyAtom, + dK_tma: cute.Tensor, + tma_atom_dV: cute.CopyAtom, + dV_tma: cute.Tensor, + LSE: cute.Tensor, + scale_softmax: cutlass.Float32, + sum_OdO: cute.Tensor, + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + cumulative_s_q: cute.Tensor | None, + cumulative_s_k: cute.Tensor | None, + cluster_layout_vmnk: cute.Layout, + K_smem_layout_staged: cute.ComposedLayout, + Q_smem_layout_staged: cute.ComposedLayout, + V_smem_layout_staged: cute.ComposedLayout, + dO_smem_layout_staged: cute.ComposedLayout, + dS_smem_layout_staged: cute.ComposedLayout, + dST_smem_layout_staged: cute.ComposedLayout, + QT_smem_layout_staged: cute.ComposedLayout, + dOT_smem_layout_staged: cute.ComposedLayout, + P_smem_layout_staged: cute.ComposedLayout, + LSE_smem_layout: cute.Layout, + sum_OdO_smem_layout: cute.Layout, + sdK_epi_layout: cute.ComposedLayout, + sdV_epi_layout: cute.ComposedLayout, + tile_sched_params: FmhaStaticTileSchedulerParams | FmhaClcDynamicTileSchedulerParams, + ): + """Core CuTeDSL backward kernel.""" + bidx, bidy, bidz = cute.arch.block_idx() + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + varlen = cumulative_s_q is not None or cumulative_s_k is not None + + mma_tile_coord_v = bidx % cute.size(KQ_tiled_mma.thr_id.shape) + + if warp_idx == self.load_warp_id: + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_QT) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_dO) + cpasync.prefetch_descriptor(tma_atom_dOT) + + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + load_mma_Q_producer, load_mma_Q_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.load_mma_Q_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_Q_bytes, + barrier_storage=storage.load_mma_Q_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_mma_K_producer, load_mma_K_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.load_mma_K_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_K_bytes, + barrier_storage=storage.load_mma_K_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_mma_V_producer, load_mma_V_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.load_mma_V_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_V_bytes, + barrier_storage=storage.load_mma_V_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_mma_QT_producer, load_mma_QT_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.load_mma_QT_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_Q_bytes, + barrier_storage=storage.load_mma_QT_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_mma_dO_producer, load_mma_dO_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.load_mma_dO_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_dO_bytes, + barrier_storage=storage.load_mma_dO_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_mma_dOT_producer, load_mma_dOT_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.load_mma_dO_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_dO_bytes, + barrier_storage=storage.load_mma_dOT_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_compute_LSE_producer, load_compute_LSE_consumer = pipeline.PipelineCpAsync.create( + num_stages=self.load_compute_LSE_stage, + producer_group=make_thread_cooperative_group(self.threads_per_warp), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * self.num_compute_warps + ), + barrier_storage=storage.load_compute_lse_mbar_ptr.data_ptr(), + ).make_participants() + load_compute_sum_OdO_producer, load_compute_sum_OdO_consumer = ( + pipeline.PipelineCpAsync.create( + num_stages=self.load_compute_sum_OdO_stage, + producer_group=make_thread_cooperative_group(self.threads_per_warp), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * self.num_compute_warps + ), + barrier_storage=storage.load_compute_sum_OdO_mbar_ptr.data_ptr(), + ).make_participants() + ) + mma_compute_S_producer, mma_compute_S_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_compute_S_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0] + ), + barrier_storage=storage.mma_compute_S_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + mma_compute_dP_producer, mma_compute_dP_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_compute_dP_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0] + ), + barrier_storage=storage.mma_compute_dP_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + compute_mma_P_producer, compute_mma_P_consumer = pipeline.PipelineAsyncUmma.create( + num_stages=self.compute_mma_P_stage, + producer_group=make_thread_cooperative_group( + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0] + ), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.compute_mma_P_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + compute_mma_dS_producer, compute_mma_dS_consumer = pipeline.PipelineAsyncUmma.create( + num_stages=self.compute_mma_dS_stage, + producer_group=make_thread_cooperative_group( + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0] + ), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.compute_mma_dS_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + mma_compute_dKdV_producer, mma_compute_dKdV_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_compute_dKdV_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0] + ), + barrier_storage=storage.mma_compute_dKdV_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + + cute.arch.barrier(barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta) + + # setup mma + sQ = storage.sQ.get_tensor(Q_smem_layout_staged.outer, swizzle=Q_smem_layout_staged.inner) + sK = storage.sK.get_tensor(K_smem_layout_staged.outer, swizzle=K_smem_layout_staged.inner) + sV = storage.sV.get_tensor(V_smem_layout_staged.outer, swizzle=V_smem_layout_staged.inner) + sdO = storage.sdO.get_tensor( + dO_smem_layout_staged.outer, swizzle=dO_smem_layout_staged.inner + ) + sLSE = storage.sLSE.get_tensor(LSE_smem_layout) + sSum_OdO = storage.sSum_OdO.get_tensor(sum_OdO_smem_layout) + tmem_holding_buf = storage.tmem_holding_buf + # for 2cta, QT use different mem from Q + + sQT = storage.sQT.get_tensor( + QT_smem_layout_staged.outer, swizzle=QT_smem_layout_staged.inner + ) + sdST = storage.sdST.get_tensor( + dST_smem_layout_staged.outer, swizzle=dST_smem_layout_staged.inner + ) + tP_fake_ptr = cute.make_ptr(sQ.element_type, 0, cute.AddressSpace.tmem) + tP = cute.make_tensor(tP_fake_ptr, P_smem_layout_staged.outer) + + sP = storage.sP.get_tensor(P_smem_layout_staged.outer, swizzle=P_smem_layout_staged.inner) + + sdOT = storage.sdOT.get_tensor( + dOT_smem_layout_staged.outer, swizzle=dOT_smem_layout_staged.inner + ) + + # tSTrK shape : (MMA, MMA_M, MMA_K, STAGE) + tSTrK = KQ_tiled_mma.make_fragment_A(sK) + # tSTrQ shape : (MMA, MMA_N, MMA_K, STAGE) + tSTrQ = KQ_tiled_mma.make_fragment_B(sQ) + + # tdPTrV shape : (MMA, MMA_M, MMA_K, STAGE) + tdPTrV = VdO_tiled_mma.make_fragment_A(sV) + # tdPTrdO shape : (MMA, MMA_N, MMA_K, STAGE) + tdPTrdO = VdO_tiled_mma.make_fragment_B(sdO) + + # tdKrdST shape: (MMA, MMA_M, MMA_K, STAGE) + tdKrdST = dSQ_tiled_mma.make_fragment_A(sdST) + # tdKrQT shape : (MMA, MMA_N, MMA_K, STAGE) + tdKrQT = dSQ_tiled_mma.make_fragment_B(sQT) + + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=self.threads_per_cta, + ) + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.load_warp_id, + is_two_cta=True, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + tmem.allocate(self.tmem_alloc_cols) + + # wait for tmem allocation and retrieve the pointer + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + # Cluster arrive after barrier init + # is_relaxed=False has memory consistency guarantee + pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=False) + + if cutlass.const_expr(self.use_clc_scheduler): + clc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + cluster_size = cute.size(self.cluster_shape_mnk) + num_clc_consumer_threads = self.threads_per_warp * ( + 1 # sched_warp (CTA 0 only) + + cluster_size + * ( + len(self.compute_warp_id) + + 1 # mma_warp + + 1 # load_warp + ) + ) + clc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_clc_consumer_threads + ) + clc_response_ptr = storage.clc_response.data_ptr() + clc = ClcState.create( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_sched_params.clc_hw_params(), + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + ), + pipeline=pipeline.PipelineClcFetchAsync.create( + barrier_storage=storage.clc_mbar_ptr.data_ptr(), + num_stages=self.num_clc_stage, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=self.num_clc_response_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ), + consumer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_clc_stage + ), + producer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_clc_stage + ), + ) + tile_sched = FmhaClcDynamicTileScheduler.create( + tile_sched_params, + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + clc, + ) + work_tile = tile_sched.initial_work_tile_info() + else: + clc = None + clc_response_ptr = None + + tSTtST_shape = KQ_tiled_mma.partition_shape_C(cute.select(self.KQ_mma_tiler, mode=[0, 1])) + tSTtST = KQ_tiled_mma.make_fragment_C(tSTtST_shape) + # tSTtST shape : (MMA, MMA_M, MMA_N) + tSTtST = cute.make_tensor(tmem_ptr + self.tmem_S_offset, tSTtST.layout) + + # tdVrP shape : (MMA, MMA_M, MMA_K, STAGE) + tdVrP = PdO_tiled_mma.make_fragment_A(sP) + # tdVrdOT shape : (MMA, MMA_N, MMA_K, STAGE) + tdVrdOT = PdO_tiled_mma.make_fragment_B(sdOT) + + tdPTtdPT_shape = VdO_tiled_mma.partition_shape_C( + cute.select(self.VdO_mma_tiler, mode=[0, 1]) + ) + tdPTtdPT = VdO_tiled_mma.make_fragment_C(tdPTtdPT_shape) + # tdPTtdPT shape : (MMA, MMA_M, MMA_N) + tdPTtdPT = cute.make_tensor(tmem_ptr + self.tmem_dP_offset, tdPTtdPT.layout) + + tdKtdK_shape = dSQ_tiled_mma.partition_shape_C(cute.select(self.dSQ_mma_tiler, mode=[0, 1])) + tdKtdK = dSQ_tiled_mma.make_fragment_C(tdKtdK_shape) + # tdKtdK shape : (MMA, MMA_M, MMA_N) + tdKtdK = cute.make_tensor(tmem_ptr + self.tmem_dK_offset, tdKtdK.layout) + + tdVtdV_shape = PdO_tiled_mma.partition_shape_C(cute.select(self.PdO_mma_tiler, mode=[0, 1])) + tdVtdV = PdO_tiled_mma.make_fragment_C(tdVtdV_shape) + # tdVtdV shape : (MMA, MMA_M, MMA_N) + tdVtdV = cute.make_tensor(tmem_ptr + self.tmem_dV_offset, tdVtdV.layout) + + # get the current batch problem shape + + if cutlass.const_expr(self.use_clc_scheduler): + # ===== CLC PERSISTENT PATH: per-warp while loops ===== + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_first_cta_in_cluster = cta_rank_in_cluster == 0 + is_sched_warp = warp_idx == self.sched_warp_id and is_first_cta_in_cluster + + # Register allocation ONCE (before any loop) + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_load) + elif warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_alloc(self.num_regs_mma) + elif warp_idx >= self.compute_warp_id[0] and warp_idx <= self.compute_warp_id[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_compute) + else: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + # Cluster wait + pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk) + + # ===== SCHEDULER WARP ===== + if is_sched_warp: + while work_tile.is_valid_tile: + tile_sched.prefetch_next_work() + work_tile = tile_sched.advance_to_next_work() + tile_sched.producer_tail() + + # ===== LOAD WARP ===== + elif warp_idx == self.load_warp_id: + while work_tile.is_valid_tile: + # Decode coordinates from CLC work tile + blk_coord_k = work_tile.tile_idx[0] + blk_coord_b = work_tile.tile_idx[2][0] + blk_coord_h_k = work_tile.tile_idx[2][1] + blk_coord = ( + Int32(0), + blk_coord_k, + Int32(0), + ((Int32(0), blk_coord_h_k), blk_coord_b), + ) + seqlen_q_cur_batch = Q_ref.shape[0] + seqlen_k_cur_batch = K_ref.shape[0] + blk_offset = (Int32(0), Int32(0), Int32(0), ((Int32(0), Int32(0)), Int32(0))) + if cutlass.const_expr(varlen): + assert isinstance(cumulative_s_q, cute.Tensor) + assert isinstance(cumulative_s_k, cute.Tensor) + seqlen_q_cur_batch = ( + cumulative_s_q[blk_coord_b + 1] - cumulative_s_q[blk_coord_b] + ) + seqlen_k_cur_batch = ( + cumulative_s_k[blk_coord_b + 1] - cumulative_s_k[blk_coord_b] + ) + blk_offset = ( + cumulative_s_q[blk_coord_b], + cumulative_s_k[blk_coord_b], + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + iter_start, iter_end = self.get_Q_block_min_max( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + blk_coord_k, + is_2cta=True, + ) + iter_count = (iter_end - iter_start) * problem_shape[3][0][0] + if iter_count <= 0: + if blk_coord_k * self.tile_shape_K < seqlen_k_cur_batch: + problem_shape_cur_batch = ( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + problem_shape[2], + problem_shape[3], + ) + self.epilogue_clear( + blk_coord, blk_offset, problem_shape_cur_batch, dK, dV + ) + else: + problem_shape_cur_batch = ( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + problem_shape[2], + problem_shape[3], + ) + ( + load_mma_Q_producer, + load_mma_K_producer, + load_mma_V_producer, + load_compute_LSE_producer, + load_mma_dO_producer, + load_mma_dOT_producer, + load_compute_sum_OdO_producer, + load_mma_QT_producer, + ) = self.load( + K_in, + V_in, + Q_in, + QT_in, + dO_in, + dOT_in, + LSE, + sum_OdO, + sK, + sQ, + sQT, + sV, + sdO, + sdOT, + sLSE, + sSum_OdO, + KQ_tiled_mma, + VdO_tiled_mma, + PdO_tiled_mma, + dSQ_tiled_mma, + tma_atom_K, + tma_atom_Q, + tma_atom_QT, + tma_atom_V, + tma_atom_dO, + tma_atom_dOT, + blk_offset, + problem_shape_cur_batch, + varlen, + iter_count, + iter_start, + iter_end, + load_mma_Q_producer, + load_mma_Q_consumer, + load_mma_K_producer, + load_mma_K_consumer, + load_mma_V_producer, + load_mma_V_consumer, + load_compute_LSE_producer, + load_compute_LSE_consumer, + load_mma_dO_producer, + load_mma_dO_consumer, + load_mma_dOT_producer, + load_mma_dOT_consumer, + load_compute_sum_OdO_producer, + load_compute_sum_OdO_consumer, + load_mma_QT_producer, + load_mma_QT_consumer, + blk_coord_k, + blk_coord_h_k, + blk_coord_b, + ) + # CLC advance + work_tile = tile_sched.advance_to_next_work() + # producer_tail after loop + load_mma_K_producer.tail() + load_mma_V_producer.tail() + load_mma_Q_producer.tail() + load_compute_LSE_producer.tail() + load_mma_dO_producer.tail() + load_mma_dOT_producer.tail() + load_compute_sum_OdO_producer.tail() + load_mma_QT_producer.tail() + + # ===== MMA WARP ===== + elif warp_idx == self.mma_warp_id: + while work_tile.is_valid_tile: + blk_coord_k = work_tile.tile_idx[0] + blk_coord_b = work_tile.tile_idx[2][0] + blk_coord_h_k = work_tile.tile_idx[2][1] + blk_coord = ( + Int32(0), + blk_coord_k, + Int32(0), + ((Int32(0), blk_coord_h_k), blk_coord_b), + ) + seqlen_q_cur_batch = Q_ref.shape[0] + seqlen_k_cur_batch = K_ref.shape[0] + blk_offset = (Int32(0), Int32(0), Int32(0), ((Int32(0), Int32(0)), Int32(0))) + if cutlass.const_expr(varlen): + assert isinstance(cumulative_s_q, cute.Tensor) + assert isinstance(cumulative_s_k, cute.Tensor) + seqlen_q_cur_batch = ( + cumulative_s_q[blk_coord_b + 1] - cumulative_s_q[blk_coord_b] + ) + seqlen_k_cur_batch = ( + cumulative_s_k[blk_coord_b + 1] - cumulative_s_k[blk_coord_b] + ) + blk_offset = ( + cumulative_s_q[blk_coord_b], + cumulative_s_k[blk_coord_b], + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + iter_start, iter_end = self.get_Q_block_min_max( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + blk_coord_k, + is_2cta=True, + ) + iter_count = (iter_end - iter_start) * problem_shape[3][0][0] + if iter_count <= 0: + if blk_coord_k * self.tile_shape_K < seqlen_k_cur_batch: + problem_shape_cur_batch = ( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + problem_shape[2], + problem_shape[3], + ) + self.epilogue_clear( + blk_coord, blk_offset, problem_shape_cur_batch, dK, dV + ) + else: + ( + mma_compute_S_producer, + mma_compute_dP_producer, + mma_compute_dKdV_producer, + load_mma_Q_consumer, + load_mma_K_consumer, + load_mma_V_consumer, + load_mma_dO_consumer, + load_mma_dOT_consumer, + compute_mma_P_consumer, + compute_mma_dS_consumer, + load_mma_QT_consumer, + ) = self.mma_2cta( + KQ_tiled_mma, + VdO_tiled_mma, + PdO_tiled_mma, + dSQ_tiled_mma, + tSTtST, + tSTrQ, + tSTrK, + tdPTtdPT, + tdPTrV, + tdPTrdO, + tdVtdV, + tdVrP, + tdVrdOT, + tdKrdST, + tdKtdK, + tdKrQT, + iter_count, + load_mma_Q_consumer, + load_mma_K_consumer, + load_mma_V_consumer, + mma_compute_S_producer, + load_mma_dO_consumer, + mma_compute_dP_producer, + load_mma_dOT_consumer, + compute_mma_P_consumer, + compute_mma_dS_consumer, + load_mma_QT_consumer, + mma_compute_dKdV_producer, + ) + # CLC advance + work_tile = tile_sched.advance_to_next_work() + # producer_tail after loop + mma_compute_S_producer.tail() + mma_compute_dP_producer.tail() + mma_compute_dKdV_producer.tail() + + # ===== COMPUTE WARPS ===== + elif warp_idx >= self.compute_warp_id[0] and warp_idx <= self.compute_warp_id[-1]: + while work_tile.is_valid_tile: + blk_coord_k = work_tile.tile_idx[0] + blk_coord_b = work_tile.tile_idx[2][0] + blk_coord_h_k = work_tile.tile_idx[2][1] + blk_coord = ( + Int32(0), + blk_coord_k, + Int32(0), + ((Int32(0), blk_coord_h_k), blk_coord_b), + ) + seqlen_q_cur_batch = Q_ref.shape[0] + seqlen_k_cur_batch = K_ref.shape[0] + blk_offset = (Int32(0), Int32(0), Int32(0), ((Int32(0), Int32(0)), Int32(0))) + if cutlass.const_expr(varlen): + assert isinstance(cumulative_s_q, cute.Tensor) + assert isinstance(cumulative_s_k, cute.Tensor) + seqlen_q_cur_batch = ( + cumulative_s_q[blk_coord_b + 1] - cumulative_s_q[blk_coord_b] + ) + seqlen_k_cur_batch = ( + cumulative_s_k[blk_coord_b + 1] - cumulative_s_k[blk_coord_b] + ) + blk_offset = ( + cumulative_s_q[blk_coord_b], + cumulative_s_k[blk_coord_b], + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + iter_start, iter_end = self.get_Q_block_min_max( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + blk_coord_k, + is_2cta=True, + ) + iter_count = (iter_end - iter_start) * problem_shape[3][0][0] + if iter_count <= 0: + if blk_coord_k * self.tile_shape_K < seqlen_k_cur_batch: + problem_shape_cur_batch = ( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + problem_shape[2], + problem_shape[3], + ) + self.epilogue_clear( + blk_coord, blk_offset, problem_shape_cur_batch, dK, dV + ) + else: + problem_shape_cur_batch = ( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + problem_shape[2], + problem_shape[3], + ) + ( + compute_mma_P_producer, + compute_mma_dS_producer, + mma_compute_S_consumer, + compute_mma_P_consumer, + load_compute_LSE_consumer, + load_compute_sum_OdO_consumer, + mma_compute_dP_consumer, + compute_mma_dS_consumer, + mma_compute_dKdV_consumer, + ) = self.compute( + tSTtST, + tdPTtdPT, + tdVrP, + sP, + sLSE, + sdST, + sdOT, + sSum_OdO, + dK, + dV, + tdKtdK, + tdVtdV, + PdO_tiled_mma, + dSQ_tiled_mma, + blk_coord, + blk_offset, + problem_shape_cur_batch, + iter_count, + iter_start, + iter_end, + scale_softmax, + mma_compute_S_producer, + mma_compute_S_consumer, + compute_mma_P_producer, + compute_mma_P_consumer, + load_compute_LSE_producer, + load_compute_LSE_consumer, + load_compute_sum_OdO_producer, + load_compute_sum_OdO_consumer, + mma_compute_dP_producer, + mma_compute_dP_consumer, + compute_mma_dS_producer, + compute_mma_dS_consumer, + mma_compute_dKdV_producer, + mma_compute_dKdV_consumer, + varlen, + sK, + seqlen_k_cur_batch, + tma_atom_dK, + dK_tma, + tma_atom_dV, + dV_tma, + sdK_epi_layout, + sdV_epi_layout, + ) + cute.arch.barrier( + barrier_id=self.epilogue_sync_bar_id, + number_of_threads=self.num_compute_warps * self.threads_per_warp, + ) + # CLC advance + work_tile = tile_sched.advance_to_next_work() + # producer_tail after loop + compute_mma_P_producer.tail() + compute_mma_dS_producer.tail() + + else: + # ===== STATIC PATH: original non-persistent code ===== + blk_coord = (Int32(0), bidx, Int32(0), ((Int32(0), bidy), bidz)) + seqlen_q_cur_batch = Q_ref.shape[0] + seqlen_k_cur_batch = K_ref.shape[0] + blk_offset = (Int32(0), Int32(0), Int32(0), ((Int32(0), Int32(0)), Int32(0))) + if cutlass.const_expr(varlen): + assert isinstance(cumulative_s_q, cute.Tensor) + assert isinstance(cumulative_s_k, cute.Tensor) + seqlen_q_cur_batch = cumulative_s_q[bidz + 1] - cumulative_s_q[bidz] + seqlen_k_cur_batch = cumulative_s_k[bidz + 1] - cumulative_s_k[bidz] + blk_offset = ( + cumulative_s_q[bidz], + cumulative_s_k[bidz], + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + + iter_start, iter_end = self.get_Q_block_min_max( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + blk_coord[1], + is_2cta=True, + ) + + # Cluster wait + pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk) + + iter_count = (iter_end - iter_start) * problem_shape[3][0][0] + problem_shape_cur_batch = ( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + problem_shape[2], + problem_shape[3], + ) + if iter_count <= 0: + if bidx * self.tile_shape_K < seqlen_k_cur_batch: + self.epilogue_clear( + blk_coord, + blk_offset, + problem_shape_cur_batch, + dK, + dV, + ) + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + elif warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_load) + + self.load( + K_in, + V_in, + Q_in, + QT_in, + dO_in, + dOT_in, + LSE, + sum_OdO, + sK, + sQ, + sQT, + sV, + sdO, + sdOT, + sLSE, + sSum_OdO, + KQ_tiled_mma, + VdO_tiled_mma, + PdO_tiled_mma, + dSQ_tiled_mma, + tma_atom_K, + tma_atom_Q, + tma_atom_QT, + tma_atom_V, + tma_atom_dO, + tma_atom_dOT, + blk_offset, + problem_shape_cur_batch, + varlen, + iter_count, + iter_start, + iter_end, + load_mma_Q_producer, + load_mma_Q_consumer, + load_mma_K_producer, + load_mma_K_consumer, + load_mma_V_producer, + load_mma_V_consumer, + load_compute_LSE_producer, + load_compute_LSE_consumer, + load_mma_dO_producer, + load_mma_dO_consumer, + load_mma_dOT_producer, + load_mma_dOT_consumer, + load_compute_sum_OdO_producer, + load_compute_sum_OdO_consumer, + load_mma_QT_producer, + load_mma_QT_consumer, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + elif warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_alloc(self.num_regs_mma) + + self.mma_2cta( + KQ_tiled_mma, + VdO_tiled_mma, + PdO_tiled_mma, + dSQ_tiled_mma, + tSTtST, + tSTrQ, + tSTrK, + tdPTtdPT, + tdPTrV, + tdPTrdO, + tdVtdV, + tdVrP, + tdVrdOT, + tdKrdST, + tdKtdK, + tdKrQT, + iter_count, + load_mma_Q_consumer, + load_mma_K_consumer, + load_mma_V_consumer, + mma_compute_S_producer, + load_mma_dO_consumer, + mma_compute_dP_producer, + load_mma_dOT_consumer, + compute_mma_P_consumer, + compute_mma_dS_consumer, + load_mma_QT_consumer, + mma_compute_dKdV_producer, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Compute + # /////////////////////////////////////////////////////////////////////////////// + elif warp_idx >= self.compute_warp_id[0] and warp_idx <= self.compute_warp_id[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_compute) + + self.compute( + tSTtST, + tdPTtdPT, + tdVrP, + sP, + sLSE, + sdST, + sdOT, + sSum_OdO, + dK, + dV, + tdKtdK, + tdVtdV, + PdO_tiled_mma, + dSQ_tiled_mma, + blk_coord, + blk_offset, + problem_shape_cur_batch, + iter_count, + iter_start, + iter_end, + scale_softmax, + mma_compute_S_producer, + mma_compute_S_consumer, + compute_mma_P_producer, + compute_mma_P_consumer, + load_compute_LSE_producer, + load_compute_LSE_consumer, + load_compute_sum_OdO_producer, + load_compute_sum_OdO_consumer, + mma_compute_dP_producer, + mma_compute_dP_consumer, + compute_mma_dS_producer, + compute_mma_dS_consumer, + mma_compute_dKdV_producer, + mma_compute_dKdV_consumer, + varlen, + sK, + seqlen_k_cur_batch, + tma_atom_dK, + dK_tma, + tma_atom_dV, + dV_tma, + sdK_epi_layout, + sdV_epi_layout, + ) + + cute.arch.barrier( + barrier_id=self.epilogue_sync_bar_id, + number_of_threads=self.num_compute_warps * self.threads_per_warp, + ) + + else: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + cute.arch.cluster_arrive() + cute.arch.cluster_wait() + # dishengbin Deallocate tmem for early exit + # Dealloc the tensor memory + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + @cute.jit + def get_Q_block_min_max( + self, + seq_Q: Int32, + seq_K: Int32, + blk_coord_k: Int32, + is_2cta: bool, + ): + """Get Q tiles range.""" + Q_block_max = cute.ceil_div(seq_Q, self.tile_shape_Q) + Q_block_min = cutlass.Int32(0) + if cutlass.const_expr(self.has_sliding_window): + # For 2cta, use the last K block in the cluster so both CTAs get the same Q_block_max + blk_coord_k_for_max = (blk_coord_k // 2) * 2 + 1 + Q_block_max_tmp = cute.ceil_div( + (blk_coord_k_for_max + 1) * self.tile_shape_K + + seq_Q + - seq_K + + self.window_size_left, + self.tile_shape_Q, + ) + Q_block_max = min(Q_block_max, Q_block_max_tmp) + if cutlass.const_expr(self.is_causal or self.has_sliding_window): + # For 2cta, use the first K block in the cluster so both CTAs get the same Q_block_min. + # This ensures both CTAs in a cluster run the same number of pipeline iterations, + # avoiding hang from mismatched producer_commit / consumer_wait counts. + blk_coord_k_for_min = (blk_coord_k // 2) * 2 + Q_block_min_tmp = ( + blk_coord_k_for_min * self.tile_shape_K + seq_Q - seq_K - self.window_size_right + ) // self.tile_shape_Q + # Consider the case of 2cta, we need to ensure the K block is aligned to 2 + Q_block_min_tmp = Q_block_min_tmp - Q_block_min_tmp % 2 + Q_block_min = max(Q_block_min_tmp, Q_block_min) + return Q_block_min, Q_block_max + + @cute.jit + def load( + self, + K_in: cute.Tensor, + V_in: cute.Tensor, + Q_in: cute.Tensor, + QT_in: cute.Tensor, + dO_in: cute.Tensor, + dOT_in: cute.Tensor, + LSE_in: cute.Tensor, + sum_OdO_in: cute.Tensor, + sK: cute.Tensor, + sQ: cute.Tensor, + sQT: cute.Tensor, + sV: cute.Tensor, + sdO: cute.Tensor, + sdOT: cute.Tensor, + sLSE: cute.Tensor, + sSum_OdO: cute.Tensor, + KQ_tiled_mma: cute.TiledMma, + VdO_tiled_mma: cute.TiledMma, + PdO_tiled_mma: cute.TiledMma, + dSQ_tiled_mma: cute.TiledMma, + tma_atom_K: cute.CopyAtom, + tma_atom_Q: cute.CopyAtom, + tma_atom_QT: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + tma_atom_dOT: cute.CopyAtom, + blk_offset: cute.Shape, + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + varlen: bool, + iter_count: Int32, + iter_start: Int32, + iter_end: Int32, + load_mma_Q_producer, + load_mma_Q_consumer, + load_mma_K_producer, + load_mma_K_consumer, + load_mma_V_producer, + load_mma_V_consumer, + load_compute_LSE_producer, + load_compute_LSE_consumer, + load_mma_dO_producer, + load_mma_dO_consumer, + load_mma_dOT_producer, + load_mma_dOT_consumer, + load_compute_sum_OdO_producer, + load_compute_sum_OdO_consumer, + load_mma_QT_producer, + load_mma_QT_consumer, + blk_coord_k_override: Int32 = Int32(-1), + blk_coord_h_k_override: Int32 = Int32(-1), + blk_coord_b_override: Int32 = Int32(-1), + ): + """TMA load.""" + tidx, _, _ = cute.arch.thread_idx() + if cutlass.const_expr(self.use_clc_scheduler): + blk_coord_k = blk_coord_k_override + blk_coord_h_k = blk_coord_h_k_override + blk_coord_b = blk_coord_b_override + else: + blk_coord_k, blk_coord_h_k, blk_coord_b = cute.arch.block_idx() + blk_coord_h_r = Int32(0) + blk_coord_h = (blk_coord_h_r, blk_coord_h_k) + iter_index = iter_start + mma_tile_coord_v = blk_coord_k % cute.size(KQ_tiled_mma.thr_id.shape) + mma_tile_coord_m = blk_coord_k // cute.size(KQ_tiled_mma.thr_id.shape) + + K = cute.domain_offset(cute.select(blk_offset, mode=[1, 2, 3]), K_in) + V = cute.domain_offset(cute.select(blk_offset, mode=[1, 2, 3]), V_in) + Q = cute.domain_offset(cute.select(blk_offset, mode=[0, 2, 3]), Q_in) + QT = cute.domain_offset(cute.select(blk_offset, mode=[2, 0, 3]), QT_in) + dO = cute.domain_offset(cute.select(blk_offset, mode=[0, 2, 3]), dO_in) + dOT = cute.domain_offset(cute.select(blk_offset, mode=[2, 0, 3]), dOT_in) + blk_offset_stats = blk_offset + if cutlass.const_expr(varlen): + cuseqlen_q_stats = cute.assume( + (blk_offset[0] + blk_coord_b * self.tile_shape_Q) + // self.tile_shape_Q + * self.tile_shape_Q, + divby=self.tile_shape_Q, + ) + blk_offset_stats = ( + cuseqlen_q_stats, + blk_offset[1], + blk_offset[2], + blk_offset[3], + ) + LSE = cute.domain_offset(cute.select(blk_offset_stats, mode=[0, 3]), LSE_in) + sum_OdO = cute.domain_offset(cute.select(blk_offset_stats, mode=[0, 3]), sum_OdO_in) + + gK = cute.local_tile(K, cute.select(self.KQ_mma_tiler, mode=[0, 2]), (None, None, None)) + gQ = cute.local_tile(Q, cute.select(self.KQ_mma_tiler, mode=[1, 2]), (None, None, None)) + gQT = cute.local_tile(QT, cute.select(self.dSQ_mma_tiler, mode=[1, 2]), (None, None, None)) + gV = cute.local_tile(V, cute.select(self.VdO_mma_tiler, mode=[0, 2]), (None, None, None)) + gdO = cute.local_tile(dO, cute.select(self.VdO_mma_tiler, mode=[1, 2]), (None, None, None)) + gdOT = cute.local_tile( + dOT, cute.select(self.PdO_mma_tiler, mode=[1, 2]), (None, None, None) + ) + + KQ_thr_mma = KQ_tiled_mma.get_slice(mma_tile_coord_v) + VdO_thr_mma = VdO_tiled_mma.get_slice(mma_tile_coord_v) + PdO_thr_mma = PdO_tiled_mma.get_slice(mma_tile_coord_v) + dSQ_thr_mma = dSQ_tiled_mma.get_slice(mma_tile_coord_v) + + tSTgK = KQ_thr_mma.partition_A(gK) + tSTgQ = KQ_thr_mma.partition_B(gQ) + tdKgQT = dSQ_thr_mma.partition_B(gQT) + tdPTgV = VdO_thr_mma.partition_A(gV) + tdPTgdO = VdO_thr_mma.partition_B(gdO) + tdVgdOT = PdO_thr_mma.partition_B(gdOT) + + cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk) + cta_layout_vmnk = cute.tiled_divide(cta_layout_mnk, (KQ_tiled_mma.thr_id,)) + cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cute.arch.block_idx_in_cluster()) + + tKsK, tKgK_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_K, + cta_in_cluster_coord_vmnk[2], + cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSTgK, 0, 3), + ) + tQsQ, tQgQ_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_Q, + cta_in_cluster_coord_vmnk[1], + cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])), + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSTgQ, 0, 3), + ) + tQTsQT, tQTgQT_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_QT, + cta_in_cluster_coord_vmnk[1], + cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])), + cute.group_modes(sQT, 0, 3), + cute.group_modes(tdKgQT, 0, 3), + ) + tVsV, tVgV_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_V, + cta_in_cluster_coord_vmnk[2], + cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])), + cute.group_modes(sV, 0, 3), + cute.group_modes(tdPTgV, 0, 3), + ) + tdOsdO, tdOgdO_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_dO, + cta_in_cluster_coord_vmnk[1], + cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])), + cute.group_modes(sdO, 0, 3), + cute.group_modes(tdPTgdO, 0, 3), + ) + tdOTsdOT, tdOTgdOT_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_dOT, + cta_in_cluster_coord_vmnk[1], + cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])), + cute.group_modes(sdOT, 0, 3), + cute.group_modes(tdVgdOT, 0, 3), + ) + + k_handle = load_mma_K_producer.acquire_and_advance() + cute.copy( + tma_atom_K, + tKgK_mkl[(None, mma_tile_coord_m, 0, (blk_coord_h, blk_coord_b))], + tKsK[None, 0], + tma_bar_ptr=k_handle.barrier, + ) + + q_handle = load_mma_Q_producer.acquire_and_advance() + cute.copy( + tma_atom_Q, + tQgQ_mkl[(None, iter_index, 0, (blk_coord_h, blk_coord_b))], + tQsQ[None, q_handle.index], + tma_bar_ptr=q_handle.barrier, + ) + + lse_handle = load_compute_LSE_producer.acquire_and_advance() + thread_idx = tidx % self.threads_per_warp + async_copy_num_elts = sLSE.shape[0] // self.threads_per_warp + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + self.acc_dtype, + num_bits_per_copy=self.acc_dtype.width, + ) + sLSE_for_copy = cute.flat_divide(sLSE, (1,)) + LSE_for_copy = cute.flat_divide(LSE, (1,)) + # Warp-coalesced: at each i, lane T accesses index `T + i*W` (stride-1 + # across the warp) instead of `T*N + i` (stride-N across the warp). + for i in cutlass.range_constexpr(async_copy_num_elts): + LSE_idx = self.tile_shape_Q * iter_index + thread_idx + i * self.threads_per_warp + sLSE_idx = thread_idx + i * self.threads_per_warp + if cute.elem_less(LSE_idx, problem_shape[0]): + cute.copy( + atom_async_copy, + LSE_for_copy[None, LSE_idx, (blk_coord_h, blk_coord_b)], + sLSE_for_copy[None, sLSE_idx, lse_handle.index], + ) + else: + sLSE_for_copy[None, sLSE_idx, lse_handle.index].fill(0.0) + lse_handle.commit() + + v_handle = load_mma_V_producer.acquire_and_advance() + cute.copy( + tma_atom_V, + tVgV_mkl[(None, mma_tile_coord_m, 0, (blk_coord_h, blk_coord_b))], + tVsV[(None, 0)], + tma_bar_ptr=v_handle.barrier, + ) + + do_handle = load_mma_dO_producer.acquire_and_advance() + cute.copy( + tma_atom_dO, + tdOgdO_mkl[(None, iter_index, 0, (blk_coord_h, blk_coord_b))], + tdOsdO[(None, do_handle.index)], + tma_bar_ptr=do_handle.barrier, + ) + + sum_odo_handle = load_compute_sum_OdO_producer.acquire_and_advance() + sSum_OdO_for_copy = cute.flat_divide(sSum_OdO, (1,)) + sum_OdO_for_copy = cute.flat_divide(sum_OdO, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + sum_OdO_idx = self.tile_shape_Q * iter_index + thread_idx + i * self.threads_per_warp + sSum_OdO_idx = thread_idx + i * self.threads_per_warp + if cute.elem_less(sum_OdO_idx, problem_shape[0]): + cute.copy( + atom_async_copy, + sum_OdO_for_copy[None, sum_OdO_idx, (blk_coord_h, blk_coord_b)], + sSum_OdO_for_copy[None, sSum_OdO_idx, sum_odo_handle.index], + ) + else: + sSum_OdO_for_copy[None, sSum_OdO_idx, sum_odo_handle.index].fill(0.0) + sum_odo_handle.commit() + + dot_handle = load_mma_dOT_producer.acquire_and_advance() + cute.copy( + tma_atom_dOT, + tdOTgdOT_mkl[(None, 0, iter_index, (blk_coord_h, blk_coord_b))], + tdOTsdOT[None, dot_handle.index], + tma_bar_ptr=dot_handle.barrier, + ) + + qt_handle = load_mma_QT_producer.acquire_and_advance() + cute.copy( + tma_atom_QT, + tQTgQT_mkl[(None, 0, iter_index, (blk_coord_h, blk_coord_b))], + tQTsQT[None, qt_handle.index], + tma_bar_ptr=qt_handle.barrier, + ) + + iter_count -= 1 + iter_index += 1 + + while iter_count > 0: + if iter_index == iter_end: + iter_index = iter_start + blk_coord_h_r += 1 + blk_coord_h = (blk_coord_h_r, blk_coord_h_k) + + q_handle = load_mma_Q_producer.acquire_and_advance() + cute.copy( + tma_atom_Q, + tQgQ_mkl[(None, iter_index, 0, (blk_coord_h, blk_coord_b))], + tQsQ[None, q_handle.index], + tma_bar_ptr=q_handle.barrier, + ) + + lse_handle = load_compute_LSE_producer.acquire_and_advance() + sLSE_for_copy = cute.flat_divide(sLSE, (1,)) + LSE_for_copy = cute.flat_divide(LSE, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + LSE_idx = self.tile_shape_Q * iter_index + thread_idx + i * self.threads_per_warp + sLSE_idx = thread_idx + i * self.threads_per_warp + if cute.elem_less(LSE_idx, problem_shape[0]): + cute.copy( + atom_async_copy, + LSE_for_copy[None, LSE_idx, (blk_coord_h, blk_coord_b)], + sLSE_for_copy[None, sLSE_idx, lse_handle.index], + ) + else: + sLSE_for_copy[None, sLSE_idx, lse_handle.index].fill(0.0) + lse_handle.commit() + + do_handle = load_mma_dO_producer.acquire_and_advance() + cute.copy( + tma_atom_dO, + tdOgdO_mkl[(None, iter_index, 0, (blk_coord_h, blk_coord_b))], + tdOsdO[None, do_handle.index], + tma_bar_ptr=do_handle.barrier, + ) + + sum_odo_handle = load_compute_sum_OdO_producer.acquire_and_advance() + sSum_OdO_for_copy = cute.flat_divide(sSum_OdO, (1,)) + sum_OdO_for_copy = cute.flat_divide(sum_OdO, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + sum_OdO_idx = ( + self.tile_shape_Q * iter_index + thread_idx + i * self.threads_per_warp + ) + sSum_OdO_idx = thread_idx + i * self.threads_per_warp + if cute.elem_less(sum_OdO_idx, problem_shape[0]): + cute.copy( + atom_async_copy, + sum_OdO_for_copy[None, sum_OdO_idx, (blk_coord_h, blk_coord_b)], + sSum_OdO_for_copy[None, sSum_OdO_idx, sum_odo_handle.index], + ) + else: + sSum_OdO_for_copy[None, sSum_OdO_idx, sum_odo_handle.index].fill(0.0) + sum_odo_handle.commit() + + dot_handle = load_mma_dOT_producer.acquire_and_advance() + cute.copy( + tma_atom_dOT, + tdOTgdOT_mkl[(None, 0, iter_index, (blk_coord_h, blk_coord_b))], + tdOTsdOT[None, dot_handle.index], + tma_bar_ptr=dot_handle.barrier, + ) + + qt_handle = load_mma_QT_producer.acquire_and_advance() + cute.copy( + tma_atom_QT, + tQTgQT_mkl[(None, 0, iter_index, (blk_coord_h, blk_coord_b))], + tQTsQT[None, qt_handle.index], + tma_bar_ptr=qt_handle.barrier, + ) + + iter_count -= 1 + iter_index += 1 + + if not cutlass.const_expr(self.use_clc_scheduler): + load_mma_K_producer.tail() + load_mma_V_producer.tail() + load_mma_Q_producer.tail() + load_compute_LSE_producer.tail() + load_mma_dO_producer.tail() + load_mma_dOT_producer.tail() + load_compute_sum_OdO_producer.tail() + load_mma_QT_producer.tail() + + return ( + load_mma_Q_producer, + load_mma_K_producer, + load_mma_V_producer, + load_compute_LSE_producer, + load_mma_dO_producer, + load_mma_dOT_producer, + load_compute_sum_OdO_producer, + load_mma_QT_producer, + ) + + @cute.jit + def mma_2cta( + self, + KQ_tiled_mma: cute.TiledMma, + VdO_tiled_mma: cute.TiledMma, + PdO_tiled_mma: cute.TiledMma, + dSQ_tiled_mma: cute.TiledMma, + tSTtST: cute.Tensor, + tSTrQ: cute.Tensor, + tSTrK: cute.Tensor, + tdPTtdPT: cute.Tensor, + tdPTrV: cute.Tensor, + tdPTrdO: cute.Tensor, + tdVtdV: cute.Tensor, + tdVrP: cute.Tensor, + tdVrdOT: cute.Tensor, + tdKrdST: cute.Tensor, + tdKtdK: cute.Tensor, + tdKrQT: cute.Tensor, + iter_count: Int32, + load_mma_Q_consumer, + load_mma_K_consumer, + load_mma_V_consumer, + mma_compute_S_producer, + load_mma_dO_consumer, + mma_compute_dP_producer, + load_mma_dOT_consumer, + compute_mma_P_consumer, + compute_mma_dS_consumer, + load_mma_QT_consumer, + mma_compute_dKdV_producer, + ): + """CuTeDSL kernel for mma pipeline.""" + load_mma_Q_releaser = load_mma_Q_consumer.clone() + load_mma_K_releaser = load_mma_K_consumer.clone() + load_mma_V_releaser = load_mma_V_consumer.clone() + + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + + if is_leader_cta: + s_handle = mma_compute_S_producer.acquire_and_advance() + k_handle = load_mma_K_consumer.wait_and_advance() + q_handle = load_mma_Q_consumer.wait_and_advance() + + # Compute S = K * Q + for k_block in cutlass.range(0, cute.size(tSTrQ, mode=[2]), unroll_full=True): + KQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, k_block != 0) + cute.gemm( + KQ_tiled_mma, + tSTtST, + tSTrK[None, None, k_block, 0], + tSTrQ[None, None, k_block, q_handle.index], + tSTtST, + ) + q_handle.release() + + cute.arch.fence_view_async_tmem_store() + s_handle.commit() + + do_handle = load_mma_dO_consumer.wait_and_advance() + v_handle = load_mma_V_consumer.wait_and_advance() + + dp_handle = mma_compute_dP_producer.acquire_and_advance() + + # Compute dP = V * dO + VdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_block in cutlass.range(0, cute.size(tdPTrV, mode=[2]), unroll_full=True): + cute.gemm( + VdO_tiled_mma, + tdPTtdPT, + tdPTrV[None, None, k_block, 0], + tdPTrdO[None, None, k_block, do_handle.index], + tdPTtdPT, + ) + VdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + dp_handle.commit() + do_handle.release() + # V only produced once by load(); hold v_handle until end, release there via releaser + + p_handle = compute_mma_P_consumer.wait_and_advance() + dot_handle = load_mma_dOT_consumer.wait_and_advance() + + # Compute dV = P * dO (First iteration) + PdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_block in cutlass.range(0, cute.size(tdVrP, mode=[2]), unroll_full=True): + cute.gemm( + PdO_tiled_mma, + tdVtdV, + tdVrP[None, None, k_block, 0], + tdVrdOT[None, None, k_block, dot_handle.index], + tdVtdV, + ) + PdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + dot_handle.release() + p_handle.release() + + iter_count -= 1 + + dSQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + while iter_count > 0: + if is_leader_cta: + q_handle = load_mma_Q_consumer.wait_and_advance() + s_handle = mma_compute_S_producer.acquire_and_advance() + + # Compute S = K * Q + for k_block in cutlass.range(0, cute.size(tSTrQ, mode=[2]), unroll_full=True): + KQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, k_block != 0) + cute.gemm( + KQ_tiled_mma, + tSTtST, + tSTrK[None, None, k_block, 0], + tSTrQ[None, None, k_block, q_handle.index], + tSTtST, + ) + q_handle.release() + s_handle.commit() + + if is_leader_cta: + qt_handle = load_mma_QT_consumer.wait_and_advance() + ds_handle = compute_mma_dS_consumer.wait_and_advance() + + # Compute dK = dS * QT + for k_block in cutlass.range(0, cute.size(tdKrdST, mode=[2]), unroll_full=True): + cute.gemm( + dSQ_tiled_mma, + tdKtdK, + tdKrdST[ + None, + None, + k_block, + ds_handle.index, + ], + tdKrQT[None, None, k_block, qt_handle.index], + tdKtdK, + ) + dSQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + qt_handle.release() + ds_handle.release() + + if is_leader_cta: + dp_handle = mma_compute_dP_producer.acquire_and_advance() + do_handle = load_mma_dO_consumer.wait_and_advance() + # V only produced once by load(); reuse same V (index 0) for all loop iterations + VdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_block in cutlass.range(0, cute.size(tdPTrV, mode=[2]), unroll_full=True): + cute.gemm( + VdO_tiled_mma, + tdPTtdPT, + tdPTrV[None, None, k_block, 0], + tdPTrdO[None, None, k_block, do_handle.index], + tdPTtdPT, + ) + VdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + dp_handle.commit() + do_handle.release() + + if is_leader_cta: + p_handle = compute_mma_P_consumer.wait_and_advance() + dot_handle = load_mma_dOT_consumer.wait_and_advance() + + # Compute dV = P * dO (Loop iterations) + for k_block in cutlass.range(0, cute.size(tdVrP, mode=[2]), unroll_full=True): + cute.gemm( + PdO_tiled_mma, + tdVtdV, + tdVrP[None, None, k_block, 0], + tdVrdOT[None, None, k_block, dot_handle.index], + tdVtdV, + ) + PdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + p_handle.release() + dot_handle.release() + + iter_count -= 1 + + if is_leader_cta: + dkdv_handle = mma_compute_dKdV_producer.acquire_and_advance() + dkdv_handle.commit() + + load_mma_K_releaser.release() + load_mma_K_releaser.advance() + load_mma_V_releaser.release() + load_mma_V_releaser.advance() + + if is_leader_cta: + dkdv_handle = mma_compute_dKdV_producer.acquire_and_advance() + + ds_handle = compute_mma_dS_consumer.wait_and_advance() + qt_handle = load_mma_QT_consumer.wait_and_advance() + + # Compute dK = dS * Q + for k_block in cutlass.range(0, cute.size(tdKrdST, mode=[2]), unroll_full=True): + cute.gemm( + dSQ_tiled_mma, + tdKtdK, + tdKrdST[None, None, k_block, ds_handle.index], + tdKrQT[None, None, k_block, qt_handle.index], + tdKtdK, + ) + dSQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + dkdv_handle.commit() + qt_handle.release() + ds_handle.release() + + if not cutlass.const_expr(self.use_clc_scheduler): + mma_compute_S_producer.tail() + mma_compute_dP_producer.tail() + mma_compute_dKdV_producer.tail() + + return ( + mma_compute_S_producer, + mma_compute_dP_producer, + mma_compute_dKdV_producer, + load_mma_Q_consumer, + load_mma_K_consumer, + load_mma_V_consumer, + load_mma_dO_consumer, + load_mma_dOT_consumer, + compute_mma_P_consumer, + compute_mma_dS_consumer, + load_mma_QT_consumer, + ) + + @cute.jit + def reg_to_smem_mma64x64( + self, + regs: cute.Tensor, + smem: cute.Tensor, + index: Int32, + tiler_mn: tuple[Int32, Int32], + dp_idx: Int32, + wg_idx: Int32, + ): + smem_slice = smem[None, None, None, index] + # TODO: double check the layout of the data in reg. + # TODO: this may introduce additional smem transpose. + thread_layout = cute.make_ordered_layout( + tiler_mn, + (0, 1), # TODO: (0,1) or (1,0) ??? + ) + smem_slice_tmp = cute.composition(smem_slice, thread_layout) + + # TODO: temporary code for tile 64 x 64. + tmp_shape = ((8, 2, 4), (2, 4, 2, 2, 2)) + tmp_stride = ((64, 512, 1024), (1, 2, 8, 16, 32)) + smem_copy = cute.composition(smem_slice_tmp, cute.make_layout(tmp_shape, stride=tmp_stride)) + + # TODO: the following code is only for tile 64 x 64. + # TODO: need to modify the code for other tile sizes. + lane_idx = dp_idx % 32 + reg_shape = regs.shape + atom_loops = reg_shape[0][0][2] + block_loops = reg_shape[2] + # | 00 ~ 07 | 08 ~ 15 | 16 ~ 23 | 24 ~ 31 | 32 ~ 39 | 40 ~ 47 | 48 ~ 55 | 56 ~ 63 | + # |---- atom size ----|---- atom size ----|---- atom size ----|---- atom size ----| + # |---- wg0 ----|---- wg1 ----|---- wg0 ----|---- wg1 ----| + for ia in cutlass.range(atom_loops): + for ib in cutlass.range(block_loops): + # the lower 8 lines + regs_copy = regs[((None, 0, ia), 0), 0, ib] # two elements + smem_copy_slice = smem_copy[ + (lane_idx // 4, 0, dp_idx // 32), + (None, lane_idx % 4, ia, wg_idx, ib), + ] + cute.autovec_copy(regs_copy, smem_copy_slice) + # the upper 8 lines + regs_copy = regs[((None, 1, ia), 0), 0, ib] + smem_copy_slice = smem_copy[ + (lane_idx // 4, 1, dp_idx // 32), + (None, lane_idx % 4, ia, wg_idx, ib), + ] + cute.autovec_copy(regs_copy, smem_copy_slice) + + @cute.jit + def reg_to_smem_mma128x128_2cta( + self, + regs: cute.Tensor, + smem: cute.Tensor, + index: Int32, + tiler_mn: tuple[Int32, Int32], + dp_idx: Int32, + wg_idx: Int32, + smem_RowMajor: bool = True, + ): + smem_slice = smem[None, None, None, index] + # K>> smem_slice: tensor, S<3,4,3>> o ((64,16),1,(4,2)):((64,1),0,(16,4096))> + thread_layout = cute.make_ordered_layout( + # (tileN, tileM) + tiler_mn, + (0, 1), + ) + # K>> thread_layout: (64,128):(128,1) + smem_slice_tmp = cute.composition(smem_slice, thread_layout) + + # NOTE: hardcode for tcgen05.ld.32x32b.x8 & mma128x64+2cta + # tmp_shape = ((32, 2), (8, 2, 2, 2)) # for 64x64 tile + # tmp_stride = ((64, 32*64), (1, 8, 16, 32)) + # NOTE: hardcode for tcgen05.ld.32x32b.x16 & mma128x64+2cta + tmp_shape = ((32, 2), (16, 2, 2, 2)) # for 128x64 tile + tmp_stride = ((64, 32 * 64), (1, 16, 32, 64 * 64)) + # smem_copy = cute.composition(smem_slice_tmp, cute.make_layout(tmp_shape, stride=tmp_stride)) + smem_copy = cute.make_tensor( + smem_slice_tmp.iterator, cute.make_layout(tmp_shape, stride=tmp_stride) + ) + + warp_idx = dp_idx // 32 + warp_row_idx = warp_idx % 2 + warp_col_idx = warp_idx // 2 # corresponding to the second 64 cols in smem + lane_idx = dp_idx % 32 + reg_shape = ( + regs.shape + ) # ((8,1),1,2):((1,0),0,8) for 64x64, ((16,1),1,2):((1,0),0,16) for 128x64 + block_loops = reg_shape[2] + + # TODO: maybe can use cp.async for optimization + for ib in cutlass.range(block_loops): + regs_copy = regs[(None, 0), 0, ib] + smem_copy_slice = smem_copy[(lane_idx, warp_row_idx), (None, wg_idx, ib, warp_col_idx)] + cute.autovec_copy(regs_copy, smem_copy_slice) + + @cute.jit + def reg_to_smem_mma128x64_2cta( + self, + regs: cute.Tensor, + smem: cute.Tensor, + index: Int32, + tiler_mn: tuple[Int32, Int32], + dp_idx: Int32, + wg_idx: Int32, + smem_RowMajor: bool = True, + ): + smem_slice = smem[None, None, None, index] + thread_layout = cute.make_ordered_layout( + # (tileN, tileM) + tiler_mn, + (1, 0) if smem_RowMajor else (0, 1), + ) + smem_slice_tmp = cute.composition(smem_slice, thread_layout) + # NOTE: hardcode for tcgen05.ld.32x32b.x8 & mma128x64+2cta + tmp_shape = ((32, 2), (8, 2, 2, 2)) + tmp_stride = ((64, 32 * 64), (1, 8, 16, 32)) + smem_copy = cute.composition(smem_slice_tmp, cute.make_layout(tmp_shape, stride=tmp_stride)) + + warp_idx = dp_idx // 32 + warp_row_idx = warp_idx % 2 + warp_col_idx = warp_idx // 2 + lane_idx = dp_idx % 32 + reg_shape = regs.shape # ((8,1),1,2):((1,0),0,8) + block_loops = reg_shape[2] + + # TODO: maybe can use cp.async for optimization + for ib in cutlass.range(block_loops): + regs_copy = regs[(None, 0), 0, ib] + smem_copy_slice = smem_copy[(lane_idx, warp_row_idx), (None, wg_idx, ib, warp_col_idx)] + cute.autovec_copy(regs_copy, smem_copy_slice) + + @cute.jit + def compute( + self, + tSTtST: cute.Tensor, + tdPTtdPT: cute.Tensor, + tdVrP: cute.Tensor, + sP: cute.Tensor, + sLSE: cute.Tensor, + # sdS: cute.Tensor, + sdST: cute.Tensor, + sdOT: cute.Tensor, + sSum_OdO: cute.Tensor, + dK: cute.Tensor, + dV: cute.Tensor, + tdKtdK: cute.Tensor, + tdVtdV: cute.Tensor, + PdO_tiled_mma: cute.TiledMma, + dSQ_tiled_mma: cute.TiledMma, + blk_coord: cute.Coord, + blk_offset: cute.Shape, + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + iter_count: Int32, + iter_start: Int32, + iter_end: Int32, + scale_softmax: cutlass.Float32, + mma_compute_S_producer, + mma_compute_S_consumer, + compute_mma_P_producer, + compute_mma_P_consumer, + load_compute_LSE_producer, + load_compute_LSE_consumer, + load_compute_sum_OdO_producer, + load_compute_sum_OdO_consumer, + mma_compute_dP_producer, + mma_compute_dP_consumer, + compute_mma_dS_producer, + compute_mma_dS_consumer, + mma_compute_dKdV_producer, + mma_compute_dKdV_consumer, + varlen: bool, + sK: cute.Tensor, + problem_shape_k_cur_batch: Int32, + tma_atom_dK: cute.CopyAtom, + dK_tma: cute.Tensor, + tma_atom_dV: cute.CopyAtom, + dV_tma: cute.Tensor, + sdK_epi_layout: cute.ComposedLayout, + sdV_epi_layout: cute.ComposedLayout, + ): + """CuTeDSL kernel for recomputing softmax and producing dk and dv.""" + tidx, _, _ = cute.arch.thread_idx() + Q, K, _, _ = problem_shape + _, blk_coord_k, _, _ = blk_coord + + iter_index = iter_start + + # adi: TMEM_ST, TMEM_DPT + tmem_load_op = tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)) + tmem_load_atom = cute.make_copy_atom( + tmem_load_op, + self.acc_dtype, + ) + + tSTtST = tSTtST[(None, None), 0, 0] + tdPTtdPT = tdPTtdPT[(None, None), 0, 0] + + cST = cute.make_identity_tensor(cute.select(self.cta_tiler, mode=[1, 0])) + cdPT = cute.make_identity_tensor(cute.select(self.cta_tiler, mode=[1, 0])) + + num_warp_groups = self.num_compute_warps // 4 + dp_idx = tidx % 128 + wg_idx = (tidx % (self.num_compute_warps * self.threads_per_warp)) // 128 + tiled_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tSTtST) + thr_t2r = tiled_t2r.get_slice(dp_idx) + + tTR_cST = thr_t2r.partition_D(cST) + tTR_cST = split_wg(tTR_cST, num_warp_groups, wg_idx) + tTR_rST = cute.make_rmem_tensor(tTR_cST.shape, self.acc_dtype) + + tTR_tST = thr_t2r.partition_S(tSTtST) + tTR_tST = split_wg(tTR_tST, num_warp_groups, wg_idx) + + tTR_cdPT_p = thr_t2r.partition_D(cdPT) + tTR_cdPT = split_wg(tTR_cdPT_p, num_warp_groups, wg_idx) + tTR_rdPT = cute.make_rmem_tensor(tTR_cdPT.shape, self.acc_dtype) + + tTR_tdPT = thr_t2r.partition_S(tdPTtdPT) + tTR_tdPT = split_wg(tTR_tdPT, num_warp_groups, wg_idx) + + tdVcST = PdO_tiled_mma.get_slice(0).partition_A(cST) + + is_residual_k = blk_coord_k * self.tile_shape_K + self.tile_shape_K > K + last_iter = iter_end - 1 + + while iter_count > 0: + s_handle = mma_compute_S_consumer.wait_and_advance() + p_handle = compute_mma_P_producer.acquire_and_advance() + lse_handle = load_compute_LSE_consumer.wait_and_advance() + + leading_causal_masking = cutlass.Boolean(False) + if cutlass.const_expr(self.is_causal): + # TODO: could be optimized by specify an exact iter_index + # + # NOTE (causal + 2CTA correctness): + # `iter_start` can be rounded down to an even index for 2CTA. When the true + # causal boundary Q tile is odd, it becomes (iter_start + 1). In that case, + # we must treat both (iter_start, iter_start + 1) as "masked tiles" so the + # per-element causal mask is applied on the boundary tile too. + leading_causal_masking = iter_index == iter_start + q_block_min_unaligned = ( + blk_coord_k * self.tile_shape_K + Q - K - self.window_size_right + ) // self.tile_shape_Q + boundary_in_second = (q_block_min_unaligned % 2) == 1 + leading_causal_masking = leading_causal_masking or ( + boundary_in_second and (iter_index == iter_start + 1) + ) + offset_partial_tile = (K - Q) % self.tile_shape_K + need_additional_mask = offset_partial_tile and iter_index == iter_start + 1 + leading_causal_masking = leading_causal_masking or need_additional_mask + leading_causal_masking = cute.arch.shuffle_sync(leading_causal_masking, 0) + + trailing_residual_masking = cutlass.Boolean(False) + trailing_residual_masking = iter_index == last_iter or is_residual_k + trailing_residual_masking = cute.arch.shuffle_sync(trailing_residual_masking, 0) + + # For causal, every tile may contain (q,k) with k > q; we must apply per-element mask for all Q tiles. + is_masked_tile = ( + leading_causal_masking + or trailing_residual_masking + or self.has_sliding_window + or cutlass.const_expr(self.is_causal) + ) + + # Compute P = softmax(S, LSE) + cute.copy(tiled_t2r, tTR_tST, tTR_rST) + + if is_masked_tile: + for i in cutlass.range(cute.size(tTR_rST), unroll_full=True): + c_transpose = tTR_cST[i] + pos = ( + cute.get(c_transpose, mode=[1]) + iter_index * self.tile_shape_Q, + cute.get(c_transpose, mode=[0]) + blk_coord_k * self.tile_shape_K, + ) + if cutlass.const_expr(self.has_sliding_window): + if cutlass.const_expr(self.window_size_left < 0): + tTR_rST[i] = ( + -cutlass.Float32.inf + if pos[1] > pos[0] + K - Q + self.window_size_right + else tTR_rST[i] + ) + else: + max_K_index = min(pos[0] + K - Q + self.window_size_right, K) + min_K_index = max(0, pos[0] + K - Q - self.window_size_left) + tTR_rST[i] = ( + -cutlass.Float32.inf + if pos[1] > max_K_index or pos[1] < min_K_index + else tTR_rST[i] + ) + if cutlass.const_expr(self.is_causal) and ( + pos[0] + K - Q < pos[1] or not cute.elem_less(pos, (Q, K)) + ): + tTR_rST[i] = -cutlass.Float32.inf + if not cute.elem_less(pos, (Q, K)): + tTR_rST[i] = -cutlass.Float32.inf + + log2_e = cutlass.Float32(math.log2(math.e)) + softmax_scale_log2_e = scale_softmax * log2_e + + for i in cutlass.range(0, cute.size(tTR_rST), 2, unroll_full=True): + lse = ( + -sLSE[ + cute.get(tTR_cST[i], mode=[1]), + lse_handle.index, + ], + -sLSE[ + cute.get(tTR_cST[i + 1], mode=[1]), + lse_handle.index, + ], + ) + tTR_rST[i], tTR_rST[i + 1] = cute.arch.fma_packed_f32x2( + (tTR_rST[i], tTR_rST[i + 1]), + (softmax_scale_log2_e, softmax_scale_log2_e), + lse, + ) + tTR_rST[i] = cute.math.exp2(tTR_rST[i], fastmath=True) + tTR_rST[i + 1] = cute.math.exp2(tTR_rST[i + 1], fastmath=True) + + # convert fp32 P to fp16 P which will be used in the PdO + tTR_rPT = self.quantize(tTR_rST, dV.element_type) # tTR_rST is ST in fp32 in RF. + self.reg_to_smem_mma128x128_2cta( + tTR_rPT, + sP, + p_handle.index, + (self.tile_shape_K, self.tile_shape_Q), + dp_idx, + wg_idx, + ) + cute.arch.fence_view_async_shared() + cute.arch.barrier( + barrier_id=self.compute_sync_bar_id, + number_of_threads=self.num_compute_warps * self.threads_per_warp, + ) + + p_handle.commit() + + s_handle.release() + lse_handle.release() + + sum_odo_handle = load_compute_sum_OdO_consumer.wait_and_advance() + dp_handle = mma_compute_dP_consumer.wait_and_advance() + ds_handle = compute_mma_dS_producer.acquire_and_advance() + + # Compute dS = dsoftmax(P, dP, sum_OdO) + cute.copy(tiled_t2r, tTR_tdPT, tTR_rdPT) + + for i in cutlass.range(0, cute.size(tTR_rdPT), 2, unroll_full=True): + dpsum_0 = -sSum_OdO[ + cute.get(tTR_cdPT[i], mode=[1]), + sum_odo_handle.index, + ] + dpsum_1 = -sSum_OdO[ + cute.get(tTR_cdPT[i + 1], mode=[1]), + sum_odo_handle.index, + ] + if cutlass.const_expr(varlen): + if not cute.elem_less(cute.get(tTR_cdPT[i], mode=[1]), Q): + dpsum_0 = 0.0 + if not cute.elem_less(cute.get(tTR_cdPT[i + 1], mode=[1]), Q): + dpsum_1 = 0.0 + tTR_rdPT[i], tTR_rdPT[i + 1] = cute.arch.add_packed_f32x2( + (tTR_rdPT[i], tTR_rdPT[i + 1]), + (dpsum_0, dpsum_1), + ) + tTR_rdPT[i], tTR_rdPT[i + 1] = cute.arch.mul_packed_f32x2( + (tTR_rdPT[i], tTR_rdPT[i + 1]), (tTR_rST[i], tTR_rST[i + 1]) + ) + # For causal, force dS to zero at masked (q,k) so dK/dV accumulation is correct + if cutlass.const_expr(self.is_causal): + for i in cutlass.range(cute.size(tTR_rdPT), unroll_full=True): + c_transpose = tTR_cdPT[i] + pos = ( + cute.get(c_transpose, mode=[1]) + iter_index * self.tile_shape_Q, + cute.get(c_transpose, mode=[0]) + blk_coord_k * self.tile_shape_K, + ) + if pos[0] + K - Q < pos[1] or not cute.elem_less(pos, (Q, K)): + tTR_rdPT[i] = cutlass.Float32(0.0) + # convert fp32 dS to fp16 dS which will be used in the computation of dK and DQ + tTR_rdST = self.quantize(tTR_rdPT, dV.element_type) + + cute.arch.fence_view_async_tmem_load() + dp_handle.release() + + self.reg_to_smem_mma128x128_2cta( + tTR_rdST, + sdST, + ds_handle.index, + (self.tile_shape_K, self.tile_shape_Q), + dp_idx, + wg_idx, + ) + cute.arch.fence_view_async_shared() + cute.arch.barrier( + barrier_id=self.compute_sync_bar_id, + number_of_threads=self.num_compute_warps * self.threads_per_warp, + ) + + ds_handle.commit() + sum_odo_handle.release() + + iter_count -= 1 + iter_index += 1 + if iter_index == iter_end: + iter_index = iter_start + + # Epilogue + mma_compute_dKdV_consumer = self.epilogue( + blk_coord, + blk_offset, + problem_shape, + dK, + dV, + tdKtdK, + tdVtdV, + scale_softmax, + mma_compute_dKdV_producer, + mma_compute_dKdV_consumer, + problem_shape_k_cur_batch, + tma_atom_dK, + dK_tma, + tma_atom_dV, + dV_tma, + sdK_epi_layout, + sdV_epi_layout, + varlen, + sdOT, + sP, + ) + + if not cutlass.const_expr(self.use_clc_scheduler): + compute_mma_P_producer.tail() + compute_mma_dS_producer.tail() + + return ( + compute_mma_P_producer, + compute_mma_dS_producer, + mma_compute_S_consumer, + compute_mma_P_consumer, + load_compute_LSE_consumer, + load_compute_sum_OdO_consumer, + mma_compute_dP_consumer, + compute_mma_dS_consumer, + mma_compute_dKdV_consumer, + ) + + @cute.jit + def quantize( + self, + input_t: cute.Tensor, + element_dtype: type[cutlass.Numeric], + ) -> cute.Tensor: + """Convert Float32 to element dtype.""" + output = cute.make_rmem_tensor(input_t.shape, element_dtype) + output.store(input_t.load().to(element_dtype)) + return output + + @cute.jit + def store( + self, + gmem: cute.Tensor, + regs: cute.Tensor, + coord: cute.Tensor, + tensor_shape: cute.Shape, + ): + for i in cutlass.range(cute.size(coord, mode=[2]), unroll_full=True): + coord_i = coord[None, 0, i] + gmem_i = gmem[None, 0, i] + regs_i = regs[None, 0, i] + if cute.elem_less(coord_i[0], tensor_shape): + gmem_i.store(regs_i.load().to(gmem.element_type)) + + @cute.jit + def epilogue_clear( + self, + blk_coord: cute.Coord, + blk_offset: cute.Shape, + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + dK: cute.Tensor, + dV: cute.Tensor, + ): + """Early stopping needs to clear dK and dV.""" + tidx, _, _ = cute.arch.thread_idx() + block_dim_x, _, _ = cute.arch.block_dim() + _, K, _, HB = problem_shape + _, blk_coord_k, _, blk_coord_batch = blk_coord + + mdK_offset = cute.assume(blk_offset[1] * dK.stride[0], divby=64) + mdK = cute.make_tensor( + dK.iterator + mdK_offset, + cute.make_layout((K, self.tile_shape_dQ_K, HB), stride=dK.stride), + ) + gdK = cute.local_tile( + mdK, (self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1]), (None, None, None) + ) + gdK = gdK[None, None, blk_coord_k, 0, blk_coord_batch] + cdK = cute.domain_offset( + (blk_coord_k * self.tile_shape_K, 0), + cute.make_identity_tensor((self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1])), + ) + + mdV_offset = cute.assume(blk_offset[1] * dV.stride[0], divby=64) + mdV = cute.make_tensor( + dV.iterator + mdV_offset, + cute.make_layout((K, self.tile_shape_dV_dO, HB), stride=dV.stride), + ) + gdV = cute.local_tile( + mdV, (self.PdO_mma_tiler[0], self.PdO_mma_tiler[1]), (None, None, None) + ) + gdV = gdV[None, None, blk_coord_k, 0, blk_coord_batch] + cdV = cute.domain_offset( + (blk_coord_k * self.tile_shape_K, 0), + cute.make_identity_tensor((self.PdO_mma_tiler[0], self.PdO_mma_tiler[1])), + ) + + for i in cutlass.range(tidx * 8, cute.size(gdK), block_dim_x * 8): + if cute.elem_less(cdK[i], cute.select(problem_shape, mode=[1, 2])): + gdK_i = cute.make_tensor(gdK.iterator + cute.assume(i, divby=8), (8)) + gdK_i.fill(0) + + for i in cutlass.range(tidx * 8, cute.size(gdV), block_dim_x * 8): + if cute.elem_less(cdV[i], cute.select(problem_shape, mode=[1, 2])): + gdV_i = cute.make_tensor(gdV.iterator + cute.assume(i, divby=8), (8)) + gdV_i.fill(0) + + @cute.jit + def epilogue( + self, + blk_coord: cute.Coord, + blk_offset: cute.Shape, + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + dK: cute.Tensor, + dV: cute.Tensor, + tdKtdK: cute.Tensor, + tdVtdV: cute.Tensor, + scale_softmax: cutlass.Float32, + mma_compute_dKdV_producer, + mma_compute_dKdV_consumer, + problem_shape_k_cur_batch: Int32, + tma_atom_dK: cute.CopyAtom, + dK_tma: cute.Tensor, + tma_atom_dV: cute.CopyAtom, + dV_tma: cute.Tensor, + sdK_epi_layout: cute.ComposedLayout, + sdV_epi_layout: cute.ComposedLayout, + varlen: bool, + sdOT: cute.Tensor, + sP: cute.Tensor, + ): + """Variant 3a (5/5) Path 2: CTA-shared SMEM with cooperative WG writes + TMA bulk store. + + Both warp-groups cooperatively populate a per-CTA (64, 256) virtual SMEM + buffer (4 stages of (64, 64) aliased onto sP+sdST). Per-thread t2r N + coverage is interleaved across the full hd=256, so per-WG TMA is not + viable — instead we treat SMEM as one shared per-CTA buffer and let + each thread's `self.store`-equivalent write into it via a (64, 256) + virtual tensor whose N axis maps (n%64, n//64) → (N_within, stage). + After an inter-WG barrier (256 threads), the leader warp fires 4 TMA + bulk stores, one per stage, to the corresponding (64, 64) GMEM slice. + Varlen falls back to per-thread self.store as in flash_bwd_sm100.py. + """ + tidx, _, _ = cute.arch.thread_idx() + _, K, D, HB = problem_shape + _, blk_coord_k, _, blk_coord_batch = blk_coord + + # adi: TMEM_DK, TMEM_DV + tmem_copy_op = tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)) + load_op = cute.make_copy_atom( + tmem_copy_op, + self.acc_dtype, + ) + + tdKtdK = tdKtdK[(None, None), 0, 0] + mdK_offset = cute.assume(blk_offset[1] * dK.stride[0], divby=64) + mdK = cute.make_tensor( + dK.iterator + mdK_offset, + cute.make_layout((K, self.tile_shape_dQ_K, HB), stride=dK.stride), + ) + gdK = cute.local_tile(mdK, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) + gdK = gdK[None, None, blk_coord_k, 0, blk_coord_batch] + cdK = cute.domain_offset( + (blk_coord_k * self.tile_shape_K, 0), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), + ) + + num_warp_groups = self.num_compute_warps // 4 + dp_idx = tidx % 128 + wg_idx = (tidx % (self.num_compute_warps * self.threads_per_warp)) // 128 + leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0 + + # Path 2 SMEM staging. dV stages through sdOT (already-consumed by the + # dV MMA before the dV epilogue begins). dK stages through sP+sdST + # (dead after dK MMA completes, before dK epilogue runs). + s_epi_dK = cute.make_tensor( + cute.recast_ptr(sP.iterator, sdK_epi_layout.inner, dK.element_type), + sdK_epi_layout.outer, + ) + s_epi_dV = cute.make_tensor( + cute.recast_ptr(sdOT.iterator, sdV_epi_layout.inner, dV.element_type), + sdV_epi_layout.outer, + ) + + # Compile-time: stage tile shape and number of stages. + epi_cols_dKV = math.gcd( + 128 // (dV.element_type.width // 8), self.cta_tiler[2] // num_warp_groups + ) + num_epi_stages_dKV = (self.cta_tiler[2] // num_warp_groups) // epi_cols_dKV + total_epi_stages = num_warp_groups * num_epi_stages_dKV + epi_tile_dKV = (self.cta_tiler[1], epi_cols_dKV) + + # Local (M, N) coord tensor for SMEM indexing (no global domain offset + # — cdK/cdV are domain-offset by blk_coord_k * tile_shape_K to match + # the GMEM destination, but the SMEM indexing must be per-CTA-local). + cdV_local = cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])) + cdK_local = cdV_local + + tiled_t2r_dK = tcgen05.make_tmem_copy(load_op, tdKtdK) + thread_t2r_dK = tiled_t2r_dK.get_slice(dp_idx) + + tTR_cdK = thread_t2r_dK.partition_D(cdK) + tTR_cdK = split_wg(tTR_cdK, num_warp_groups, wg_idx) + tTR_cdK_local = thread_t2r_dK.partition_D(cdK_local) + tTR_cdK_local = split_wg(tTR_cdK_local, num_warp_groups, wg_idx) + tTR_gdK = thread_t2r_dK.partition_D(gdK) + tTR_gdK = split_wg(tTR_gdK, num_warp_groups, wg_idx) + tTR_rdK = cute.make_rmem_tensor(tTR_cdK.shape, self.acc_dtype) + tTR_tdK = thread_t2r_dK.partition_S(tdKtdK) + tTR_tdK = split_wg(tTR_tdK, num_warp_groups, wg_idx) + + mdV_in = cute.make_tensor( + dV.iterator, cute.make_layout((K, self.cta_tiler[2], HB), stride=dV.stride) + ) + offset_mdV = cute.assume(blk_offset[1] * mdV_in.stride[0], divby=64) + mdV = cute.make_tensor(mdV_in.iterator + offset_mdV, mdV_in.layout) + gdV = cute.local_tile(mdV, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) + gdV = gdV[None, None, blk_coord_k, 0, blk_coord_batch] + + cdV = cute.domain_offset( + (blk_coord_k * self.cta_tiler[1], 0), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), + ) + + tdVtdV = tdVtdV[(None, None), 0, 0] + + tiled_t2r_dV = tcgen05.make_tmem_copy(load_op, tdVtdV) + thread_t2r_dV = tiled_t2r_dV.get_slice(dp_idx) + + tTR_cdV = thread_t2r_dV.partition_D(cdV) + tTR_cdV = split_wg(tTR_cdV, num_warp_groups, wg_idx) + tTR_cdV_local = thread_t2r_dV.partition_D(cdV_local) + tTR_cdV_local = split_wg(tTR_cdV_local, num_warp_groups, wg_idx) + tTR_gdV = thread_t2r_dV.partition_D(gdV) + tTR_gdV = split_wg(tTR_gdV, num_warp_groups, wg_idx) + tTR_rdV = cute.make_rmem_tensor(tTR_cdV.shape, self.acc_dtype) + tTR_tdV = thread_t2r_dV.partition_S(tdVtdV) + tTR_tdV = split_wg(tTR_tdV, num_warp_groups, wg_idx) + + # GMEM destinations for the multi-stage TMA path (gated on not-varlen). + if cutlass.const_expr(not varlen): + mdV_tma_3d = cute.make_tensor( + dV_tma.iterator, + cute.make_layout((K, self.cta_tiler[2], HB), stride=dV_tma.stride), + ) + mdV_tma_cur = mdV_tma_3d[None, None, blk_coord_batch] + gdV_tma = cute.local_tile( + mdV_tma_cur, (self.cta_tiler[1], self.cta_tiler[2]), (blk_coord_k, 0) + ) + gdV_tma_epi = cute.local_tile(gdV_tma, epi_tile_dKV, (0, None)) + + mdK_tma_3d = cute.make_tensor( + dK_tma.iterator, + cute.make_layout((K, self.cta_tiler[2], HB), stride=dK_tma.stride), + ) + mdK_tma_cur = mdK_tma_3d[None, None, blk_coord_batch] + gdK_tma = cute.local_tile( + mdK_tma_cur, (self.cta_tiler[1], self.cta_tiler[2]), (blk_coord_k, 0) + ) + gdK_tma_epi = cute.local_tile(gdK_tma, epi_tile_dKV, (0, None)) + + cta_threads = self.num_compute_warps * self.threads_per_warp + + dkdv_handle = mma_compute_dKdV_consumer.wait_and_advance() + + if blk_coord_k * self.tile_shape_K < problem_shape_k_cur_batch: + cute.copy(tiled_t2r_dV, tTR_tdV, tTR_rdV) + tTR_rdV_cast = cute.make_rmem_tensor(tTR_rdV.shape, dV.element_type) + tTR_rdV_cast.store(tTR_rdV.load().to(dV.element_type)) + + if cutlass.const_expr(not varlen): + # reg -> SMEM via per-element indexed stores using tTR_cdV's + # per-thread (M, N) coords. (M, N) is per-CTA cdV space (M=0..63, + # N=0..255). We map N=(n%epi_cols, n//epi_cols) → (N_within, stage) + # of the 3D s_epi_dV tensor. + for _i in cutlass.range_constexpr(cute.size(tTR_cdV_local, mode=[2])): + for _j in cutlass.range_constexpr(cute.size(tTR_cdV_local[None, 0, _i])): + c = tTR_cdV_local[None, 0, _i][_j] + m_pos = c[0] + n_pos = c[1] + stage_pos = n_pos // epi_cols_dKV + n_within_pos = n_pos % epi_cols_dKV + v = tTR_rdV_cast[None, 0, _i][_j] + s_epi_dV[m_pos, n_within_pos, stage_pos] = v + cute.arch.fence_view_async_shared() + # Inter-WG barrier — both warp-groups must finish their writes + # before the leader warp reads SMEM via TMA. + cute.arch.barrier(barrier_id=5, number_of_threads=cta_threads) + # TMA bulk store, one (64, 64) box per stage. + if leader_warp and wg_idx == 0: + for _stage in cutlass.range_constexpr(total_epi_stages): + sdV_stage = s_epi_dV[None, None, _stage] + gdV_stage = gdV_tma_epi[None, None, _stage] + td_sdV, td_gdV = cpasync.tma_partition( + tma_atom_dV, + 0, + cute.make_layout(1), + cute.group_modes(sdV_stage, 0, 2), + cute.group_modes(gdV_stage, 0, 2), + ) + cute.copy(tma_atom_dV, td_sdV, td_gdV) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + else: + self.store(tTR_gdV, tTR_rdV, tTR_cdV, (K, D)) + + cute.arch.fence_view_async_tmem_load() + dkdv_handle.release() + + dkdv_handle = mma_compute_dKdV_consumer.wait_and_advance() + + if blk_coord_k * self.tile_shape_K < problem_shape_k_cur_batch: + cute.copy(tiled_t2r_dK, tTR_tdK, tTR_rdK) + + for i in cutlass.range(cute.size(tTR_rdK), unroll_full=True): + tTR_rdK[i] = scale_softmax * tTR_rdK[i] + + tTR_rdK_cast = cute.make_rmem_tensor(tTR_rdK.shape, dK.element_type) + tTR_rdK_cast.store(tTR_rdK.load().to(dK.element_type)) + + if cutlass.const_expr(not varlen): + for _i in cutlass.range_constexpr(cute.size(tTR_cdK_local, mode=[2])): + for _j in cutlass.range_constexpr(cute.size(tTR_cdK_local[None, 0, _i])): + c = tTR_cdK_local[None, 0, _i][_j] + m_pos = c[0] + n_pos = c[1] + stage_pos = n_pos // epi_cols_dKV + n_within_pos = n_pos % epi_cols_dKV + v = tTR_rdK_cast[None, 0, _i][_j] + s_epi_dK[m_pos, n_within_pos, stage_pos] = v + cute.arch.fence_view_async_shared() + cute.arch.barrier(barrier_id=6, number_of_threads=cta_threads) + if leader_warp and wg_idx == 0: + for _stage in cutlass.range_constexpr(total_epi_stages): + sdK_stage = s_epi_dK[None, None, _stage] + gdK_stage = gdK_tma_epi[None, None, _stage] + td_sdK, td_gdK = cpasync.tma_partition( + tma_atom_dK, + 0, + cute.make_layout(1), + cute.group_modes(sdK_stage, 0, 2), + cute.group_modes(gdK_stage, 0, 2), + ) + cute.copy(tma_atom_dK, td_sdK, td_gdK) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + else: + self.store(tTR_gdK, tTR_rdK, tTR_cdK, (K, D)) + + cute.arch.fence_view_async_tmem_load() + dkdv_handle.release() + + return mma_compute_dKdV_consumer + + def get_workspace_tensor( + self, + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + workspace: cute.Tensor, + acc_dtype: type[cutlass.Numeric], + varlen: bool, + ) -> tuple[cute.Tensor, cute.Tensor, cute.Tensor]: + """Get workspace tensor.""" + D = problem_shape[2] + H, B = cute.size(problem_shape[3][0]), cute.size(problem_shape[3][1]) + H_r, H_k = problem_shape[3][0] + D = cute.round_up(D, 8) + + # b = 1 for varlen, else batch_size + b = workspace.shape[0] + # s_q_sum for varlen, else s_q_max, already rounded to 8 + S_Q = workspace.shape[1] + + acc_bytes = acc_dtype.width // 8 + sum_OdO_bytes = cute.assume(b * H * S_Q * acc_bytes, divby=acc_bytes) + scaled_lse_bytes = cute.assume(b * H * S_Q * acc_bytes, divby=acc_bytes) + + sum_OdO_iter = workspace.iterator + scaled_lse_iter = sum_OdO_iter + sum_OdO_bytes + + sum_OdO_iter = cute.recast_ptr(sum_OdO_iter, dtype=self.acc_dtype) + scaled_lse_iter = cute.recast_ptr(scaled_lse_iter, dtype=self.acc_dtype) + + sum_OdO = cute.make_tensor( + sum_OdO_iter, + cute.make_layout( + (S_Q, ((H_r, H_k), B)), + stride=(1, ((S_Q, S_Q * H_r), 0 if varlen else S_Q * H)), + ), + ) + scaled_lse = cute.make_tensor( + scaled_lse_iter, + cute.make_layout( + (S_Q, ((H_r, H_k), B)), + stride=(1, ((S_Q, S_Q * H_r), 0 if varlen else S_Q * H)), + ), + ) + + return sum_OdO, scaled_lse + + @staticmethod + def _compute_sum_OdO_grid( + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + block_q: int, + ) -> tuple[Int32, Int32, Int32]: + """Compute grid shape for sum_OdO kernel.""" + return ( + cute.ceil_div(cute.size(problem_shape[0]), block_q), + cute.size(problem_shape[3][0]), # H + cute.size(problem_shape[3][1]), # B + ) + + @staticmethod + def _compute_bwd_grid( + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + block_k: int, + ) -> tuple[Int32, Int32, Int32]: + """Compute grid shape for bwd kernel.""" + K = problem_shape[1] + _, H_K = problem_shape[3][0] + B = problem_shape[3][1] + return (cute.ceil_div(K, block_k), cute.size(H_K), cute.size(B)) + + @staticmethod + def get_workspace_size(s_q: int, d: int, h: int, b: int, acc_dtype: type[cutlass.Numeric]): + """Get workspace size.""" + d = (d + 7) // 8 * 8 # round up to 8 + s_q = (s_q + 7) // 8 * 8 # round up to 8 + workspace_bytes = 0 + # OdO vector + workspace_bytes += acc_dtype.width // 8 + # scaled LSE vector + workspace_bytes += acc_dtype.width // 8 + # FP32 versions of outputs that are churned (start off with Q only) + workspace_bytes += d * acc_dtype.width // 8 + return (b, s_q, h, workspace_bytes) + + def make_and_init_load_mma_K_pipeline(self, load_mma_K_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for load mma Q.""" + load_mma_K_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_mma_K_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_mma_K_mbar_ptr, + num_stages=self.load_mma_K_stage, + producer_group=load_mma_K_producer_group, + consumer_group=load_mma_K_consumer_group, + tx_count=self.tma_copy_K_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_load_mma_V_pipeline(self, load_mma_V_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for load mma Q.""" + load_mma_V_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_mma_V_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_mma_V_mbar_ptr, + num_stages=self.load_mma_V_stage, + producer_group=load_mma_V_producer_group, + consumer_group=load_mma_V_consumer_group, + tx_count=self.tma_copy_V_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_load_mma_Q_pipeline(self, load_mma_Q_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for load mma Q.""" + load_mma_Q_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_mma_Q_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_mma_Q_mbar_ptr, + num_stages=self.load_mma_Q_stage, + producer_group=load_mma_Q_producer_group, + consumer_group=load_mma_Q_consumer_group, + tx_count=self.tma_copy_Q_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_load_mma_QT_pipeline(self, load_mma_QT_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for load mma QT.""" + load_mma_QT_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_mma_QT_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_mma_QT_mbar_ptr, + num_stages=self.load_mma_QT_stage, + producer_group=load_mma_QT_producer_group, + consumer_group=load_mma_QT_consumer_group, + tx_count=self.tma_copy_Q_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_load_mma_dO_pipeline(self, load_mma_dO_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for load mma dO.""" + load_mma_dO_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_mma_dO_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_mma_dO_mbar_ptr, + num_stages=self.load_mma_dO_stage, + producer_group=load_mma_dO_producer_group, + consumer_group=load_mma_dO_consumer_group, + tx_count=self.tma_copy_dO_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_load_compute_LSE_pipeline(self, load_compute_lse_mbar_ptr): + """Create and initialize barrier for load compute lse.""" + load_compute_lse_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp, + # self.threads_per_warp, + ) + load_compute_lse_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * self.num_compute_warps, + # self.threads_per_warp * self.num_compute_warps, + ) + return pipeline.PipelineCpAsync.create( + barrier_storage=load_compute_lse_mbar_ptr, + num_stages=self.load_compute_LSE_stage, + producer_group=load_compute_lse_producer_group, + consumer_group=load_compute_lse_consumer_group, + ) + + def make_and_init_load_compute_sum_OdO_pipeline(self, load_compute_sum_OdO_mbar_ptr): + """Create and initialize barrier for load sum OdO.""" + load_compute_sum_OdO_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp, + # self.threads_per_warp, + ) + load_compute_sum_OdO_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * self.num_compute_warps, + # self.threads_per_warp * self.num_compute_warps, + ) + return pipeline.PipelineCpAsync.create( + barrier_storage=load_compute_sum_OdO_mbar_ptr, + num_stages=self.load_compute_sum_OdO_stage, + producer_group=load_compute_sum_OdO_producer_group, + consumer_group=load_compute_sum_OdO_consumer_group, + ) + + def make_and_init_mma_compute_S_pipeline(self, mma_compute_S_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for mma S.""" + mma_compute_S_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + mma_compute_S_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0], + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_compute_S_mbar_ptr, + num_stages=self.mma_compute_S_stage, + producer_group=mma_compute_S_producer_group, + consumer_group=mma_compute_S_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Barrier to between dP = v * dO and consume of dP in compute() + def make_and_init_mma_compute_dP_pipeline(self, mma_compute_dP_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for mma Q.""" + mma_compute_dP_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + mma_compute_dP_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0], + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_compute_dP_mbar_ptr, + num_stages=self.mma_compute_dP_stage, + producer_group=mma_compute_dP_producer_group, + consumer_group=mma_compute_dP_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_compute_mma_P_pipeline(self, compute_mma_P_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for mma P.""" + compute_mma_P_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0], + # self.num_compute_warps * self.threads_per_warp, + ) + compute_mma_P_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + return pipeline.PipelineAsyncUmma.create( + barrier_storage=compute_mma_P_mbar_ptr, + num_stages=self.compute_mma_P_stage, + producer_group=compute_mma_P_producer_group, + consumer_group=compute_mma_P_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_compute_mma_dS_pipeline(self, compute_mma_dS_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for mma dS.""" + + compute_mma_dS_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0], + # self.num_compute_warps * self.threads_per_warp, + ) + compute_mma_dS_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + + return pipeline.PipelineAsyncUmma.create( + barrier_storage=compute_mma_dS_mbar_ptr, + num_stages=self.compute_mma_dS_stage, + producer_group=compute_mma_dS_producer_group, + consumer_group=compute_mma_dS_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_mma_compute_dKdV_pipeline( + self, mma_compute_dKdV_mbar_ptr, cluster_layout_vmnk + ): + """Create and initialize barrier for mma dKdV.""" + mma_compute_dKdV_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + mma_compute_dKdV_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0], + # self.num_compute_warps * self.threads_per_warp, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_compute_dKdV_mbar_ptr, + num_stages=self.mma_compute_dKdV_stage, + producer_group=mma_compute_dKdV_producer_group, + consumer_group=mma_compute_dKdV_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py new file mode 100644 index 00000000000..b25ca48f007 --- /dev/null +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py @@ -0,0 +1,2237 @@ +# Copyright (c) 2025, Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms, Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin. + +from typing import Type, Tuple, Optional + +import cuda.bindings.driver as cuda + +import math +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +from cutlass.cute.nvgpu import cpasync +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.typing import Int32, Int64, Float32 + +from cutlass.utils import ClcDynamicPersistentTileScheduler +from flash_attn.cute.tile_scheduler import ( + ClcState, + compute_sm100_fmha_grid as compute_grid, + compute_sm100_fmha_grid_clc as compute_grid_clc, + make_sm100_thread_cooperative_group as make_thread_cooperative_group, + Sm100FmhaStaticTileScheduler as FmhaStaticTileScheduler, + Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams, + Sm100FmhaClcDynamicTileScheduler as FmhaClcDynamicTileScheduler, + Sm100FmhaClcDynamicTileSchedulerParams as FmhaClcDynamicTileSchedulerParams, +) +from flash_attn.cute.mask import ( + Sm100FusedMask as FusedMask, +) +from flash_attn.cute.tile_scheduler import SM100_TMEM_CAPACITY_COLUMNS + + +class BlackwellFusedMultiHeadAttentionBackwardDQKernel: + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + mma_tiler: Tuple[int, int, int], + is_causal: bool, + window_size_left: int | None, + window_size_right: int | None, + is_persistent: bool, + split_head: bool, + use_clc_scheduler: bool = False, + ): + self.acc_dtype = acc_dtype + self.mma_tiler = mma_tiler + self.is_causal = is_causal + self.window_size_left = window_size_left + # Keep original behavior (known-good in this repo) + window_size_left = ( + None + if (window_size_left is None or window_size_left < 0) + else cutlass.Int32(window_size_left) + ) + window_size_right = ( + None + if (window_size_right is None or window_size_right < 0) + else cutlass.Int32(window_size_right) + ) + self.window_size_left = None if self.is_causal else window_size_left + self.window_size_right = cutlass.Int32(0) if self.is_causal else window_size_right + self.is_local = (not self.is_causal) and ( + self.window_size_left is not None or self.window_size_right is not None + ) + assert mma_tiler[0] == 128 and mma_tiler[1] == 128, "Only 128x128 tile impl is supported" + assert mma_tiler[2] == 256, "Only 256 is supported for 128x128 tile impl" + self.cta_tiler = ( + mma_tiler[0], + mma_tiler[1], + mma_tiler[2], + ) + self.qk_mma_tiler = ( + 2 * mma_tiler[0], + mma_tiler[1], + min(self.cta_tiler[2], 128) if split_head else self.cta_tiler[2], + ) + self.dov_mma_tiler = self.qk_mma_tiler + self.dsk_mma_tiler = ( + 2 * mma_tiler[0], + min(self.cta_tiler[2], 128) if split_head else self.cta_tiler[2], + mma_tiler[1], + ) + + self.dsk_block_tiler = ( + self.dsk_mma_tiler[0] // 2, + self.dsk_mma_tiler[1], + self.dsk_mma_tiler[2], + ) + self.iterations_qk = self.cta_tiler[2] // self.qk_mma_tiler[2] + self.iterations_dov = self.cta_tiler[2] // self.dov_mma_tiler[2] + self.iterations_dsk = self.cta_tiler[2] // self.dsk_mma_tiler[1] + self.cluster_shape_mn = (2, 1) + self.tmem_warp_shape_mn = (4, 1) + self.is_persistent = is_persistent + self.use_clc_scheduler = use_clc_scheduler + self.use_semantic_trip_range = self.is_causal or self.is_local + + self.compute_warp_ids = (0, 1, 2, 3) + self.epilogue_warp_ids = (4, 5, 6, 7) + self.mma_warp_id = 8 + self.load_warp_id = 9 + self.empty_warp_id = (10, 11) + self.sched_warp_id = self.empty_warp_id[0] if use_clc_scheduler else None + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + self.num_compute_warps = len(self.compute_warp_ids) + + self.cta_sync_bar_id = 0 + self.tmem_alloc_sync_bar_id = 1 + self.compute_sync_bar_id = 2 + + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.compute_warp_ids, # this is to get a round num threads + *self.epilogue_warp_ids, + self.mma_warp_id, + self.load_warp_id, + *self.empty_warp_id, + ) + ) + + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + + self.tmem_s_offset = 0 + self.tmem_dp_offset = 128 + self.tmem_dq_offset = 256 + + self.num_regs_compute = 256 + self.num_regs_epilogue = 160 + self.num_regs_other = 32 + + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + self.q_stage = self.iterations_qk + self.k_stage = self.iterations_qk + self.do_stage = self.iterations_dov + self.v_stage = self.iterations_dov + self.kt_stage = 1 + self.qk_acc_stage = 1 + self.dov_acc_stage = 1 + self.dsk_acc_stage = 1 + self.epi_stage = 1 + self.load_compute_LSE_stage = 1 + self.load_compute_sum_OdO_stage = 1 + if cutlass.const_expr(self.use_clc_scheduler): + self.num_clc_stage = 1 + self.num_clc_response_bytes = 16 + + @cute.jit + def __call__( + self, + q_tensor: cute.Tensor, + k_tensor: cute.Tensor, + v_tensor: cute.Tensor, + dq_tensor: cute.Tensor, + do_tensor: cute.Tensor, + lse_tensor: cute.Tensor, + sum_odo_tensor: cute.Tensor, + cum_seqlen_q: Optional[cute.Tensor], + cum_seqlen_k: Optional[cute.Tensor], + scale_softmax: cutlass.Float32, + stream: cuda.CUstream, + ): + varlen = cum_seqlen_q is not None or cum_seqlen_k is not None + # Infer shape metadata from normalized 5D tensors (B, S, H_k, H_r, D), + # similar to the dedicated hd256 forward path. + s_q = q_tensor.shape[1] + s_k = k_tensor.shape[1] + d = q_tensor.shape[4] + h_k = q_tensor.shape[2] + h_r = q_tensor.shape[3] + if cutlass.const_expr(cum_seqlen_q is not None): + b = cum_seqlen_q.shape[0] - 1 + elif cutlass.const_expr(cum_seqlen_k is not None): + b = cum_seqlen_k.shape[0] - 1 + else: + b = q_tensor.shape[0] + # `lse_tensor` / `sum_odo_tensor` are preallocated FP32 buffers (lse_log2 / dpsum) + # whose sequence dimension is padded (rounded up) by the caller. + # Use their leading dimension as the LSE length to ensure correct batch strides. + s_lse = lse_tensor.shape[0] + s_q64 = Int64(s_q) + s_k64 = Int64(s_k) + s_lse64 = Int64(s_lse) + d64 = cute.assume(Int64(d), divby=128) + h_r64 = Int64(h_r) + h_k64 = Int64(h_k) + b64 = Int64(b) + # Packed-varlen representation uses batch-dim = 1 and sequence-dim = total_{q,k}. + # Keep the *physical* sequence extent in the tensor layouts so that applying + # `cuseqlen_*` offsets stays within the tensor domain. + s_q_total = q_tensor.shape[1] if cum_seqlen_q is not None else s_q64 + s_k_total = k_tensor.shape[1] if cum_seqlen_k is not None else s_k64 + stride_b_qo = h_r64 * h_k64 * s_q64 * d64 if cum_seqlen_q is None else 0 + stride_b_kv = h_k64 * s_k64 * d64 if cum_seqlen_k is None else 0 + b_lse = b64 if cum_seqlen_q is None else 1 + stride_b_lse = h_r64 * h_k64 * s_lse64 if cum_seqlen_q is None else 0 + + # (s, d, ((h_r, h_k), b)) + q_layout = cute.make_layout( + (s_q_total, d, ((h_r, h_k), b)), + stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + ) + q = cute.make_tensor(q_tensor.iterator, q_layout) + # (s, d, ((h_r, h_k), b)) + do_layout = cute.make_layout( + (s_q_total, d, ((h_r, h_k), b)), + stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + ) + do = cute.make_tensor(do_tensor.iterator, do_layout) + # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast + k_layout = cute.make_layout( + (s_k_total, d, ((h_r, h_k), b)), + stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)), + ) + k = cute.make_tensor(k_tensor.iterator, k_layout) + # (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast + kt_layout = cute.make_layout( + (d, s_k_total, ((h_r, h_k), b)), + stride=(1, d64 * h_k64, ((0, d64), stride_b_kv)), + ) + kt = cute.make_tensor(k_tensor.iterator, kt_layout) + # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast + v_layout = cute.make_layout( + (s_k_total, d, ((h_r, h_k), b)), + stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)), + ) + v = cute.make_tensor(v_tensor.iterator, v_layout) + # (s, ((h_r, h_k), b)) + lse = cute.make_tensor(lse_tensor.iterator, lse_tensor.layout) + # (s, ((h_r, h_k), b)) + sum_odo_layout = cute.make_layout( + (s_lse64, ((h_r, h_k), b_lse)), + stride=(1, ((s_lse64, h_r64 * s_lse64), stride_b_lse)), + ) + sum_odo = cute.make_tensor(sum_odo_tensor.iterator, sum_odo_layout) + # (s, d, ((h_r, h_k), b)) + dq_layout = cute.make_layout( + (s_q_total, d, ((h_r, h_k), b)), + stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + ) + dq = cute.make_tensor(dq_tensor.iterator, dq_layout) + + # setup static attributes before smem/grid/tma computation + self.q_dtype = q.element_type + self.k_dtype = k.element_type + self.v_dtype = v.element_type + self.do_dtype = do.element_type + self.dq_dtype = dq.element_type + self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.q_dtype.width + + if cutlass.const_expr(self.use_clc_scheduler): + self.tile_sched_params, grid = compute_grid_clc( + (s_q, dq.shape[1], dq.shape[2]) if cum_seqlen_q is not None else dq.shape, + self.cta_tiler, + (*self.cluster_shape_mn, 1), + ) + else: + self.tile_sched_params, grid = compute_grid( + (s_q, dq.shape[1], dq.shape[2]) if cum_seqlen_q is not None else dq.shape, + self.cta_tiler, + self.is_persistent, + ) + + self.q_major_mode = utils.LayoutEnum.from_tensor(q).mma_major_mode() + self.do_major_mode = utils.LayoutEnum.from_tensor(do).mma_major_mode() + self.k_major_mode = utils.LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major_mode = utils.LayoutEnum.from_tensor(v).mma_major_mode() + self.dq_layout = utils.LayoutEnum.from_tensor(dq) + + if cutlass.const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of q is not supported") + if cutlass.const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of k is not supported") + if cutlass.const_expr(self.v_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of v is not supported") + if cutlass.const_expr(self.do_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of v is not supported") + + # check type consistency + if cutlass.const_expr(self.q_dtype != self.k_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") + if cutlass.const_expr(self.q_dtype != self.v_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + if cutlass.const_expr(self.q_dtype != self.do_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.do_dtype}") + + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.TWO + # the intermediate tensor p is from tmem & k-major + ds_source = tcgen05.OperandSource.TMEM + ds_major_mode = tcgen05.OperandMajorMode.K + k_trans_major_mode = tcgen05.OperandMajorMode.MN + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.acc_dtype, + cta_group, + self.qk_mma_tiler[:2], + ) + dov_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.do_dtype, + self.do_major_mode, + self.v_major_mode, + self.acc_dtype, + cta_group, + self.dov_mma_tiler[:2], + ) + dsk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + ds_major_mode, + k_trans_major_mode, + self.acc_dtype, + cta_group, + self.dsk_mma_tiler[:2], + ds_source, + ) + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + + self.epi_tile = self.dsk_block_tiler[:2] + + q_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.qk_mma_tiler, + self.q_dtype, + self.q_stage, + ) + k_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + self.qk_mma_tiler, + self.k_dtype, + self.k_stage, + ) + do_smem_layout_staged = sm100_utils.make_smem_layout_a( + dov_tiled_mma, + self.dov_mma_tiler, + self.do_dtype, + self.do_stage, + ) + v_smem_layout_staged = sm100_utils.make_smem_layout_b( + dov_tiled_mma, + self.dov_mma_tiler, + self.v_dtype, + self.v_stage, + ) + ds_tmem_layout_staged = sm100_utils.make_smem_layout_a( + dsk_tiled_mma, + self.dsk_mma_tiler, + self.q_dtype, + self.qk_acc_stage, + ) + ds_tmem_layout = cute.select(ds_tmem_layout_staged, mode=[0, 1, 2]) + kt_smem_layout_staged = sm100_utils.make_smem_layout_b( + dsk_tiled_mma, + self.dsk_mma_tiler, + self.k_dtype, + self.dsk_acc_stage, + ) + + # TMA load for Q + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) + + q_smem_layout = cute.select(q_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q, tma_tensor_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q, + q_smem_layout, + self.qk_mma_tiler, + qk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + # TMA load for K + k_smem_layout = cute.select(k_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_k, tma_tensor_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k, + k_smem_layout, + self.qk_mma_tiler, + qk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + # TMA load for dO + do_smem_layout = cute.select(do_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_do, tma_tensor_do = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + do, + do_smem_layout, + self.dov_mma_tiler, + dov_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + # TMA load for V + v_smem_layout = cute.select(v_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_v, tma_tensor_v = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + v, + v_smem_layout, + self.dov_mma_tiler, + dov_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + # TMA load for KT + kt_smem_layout = cute.select(kt_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_kt, tma_tensor_kt = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + kt, + kt_smem_layout, + self.dsk_mma_tiler, + dsk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + lse_smem_layout = cute.make_layout((self.cta_tiler[0], self.load_compute_LSE_stage)) + sum_odo_smem_layout = cute.make_layout((self.cta_tiler[0], self.load_compute_sum_OdO_stage)) + + q_copy_size = cute.size_in_bytes(self.q_dtype, q_smem_layout) + k_copy_size = cute.size_in_bytes(self.k_dtype, k_smem_layout) + do_copy_size = cute.size_in_bytes(self.do_dtype, do_smem_layout) + v_copy_size = cute.size_in_bytes(self.v_dtype, v_smem_layout) + kt_copy_size = cute.size_in_bytes(self.k_dtype, kt_smem_layout) + self.tma_copy_q_bytes = q_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + self.tma_copy_k_bytes = k_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + self.tma_copy_do_bytes = do_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + self.tma_copy_v_bytes = v_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + self.tma_copy_kt_bytes = kt_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + + # dQ epilogue TMA store. Use a single (M, 64) SMEM stage and rotate it + # across the hd=256 N dimension to stay within the SMEM budget. + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + epi_cols_dQ = math.gcd(128 // (dq.element_type.width // 8), self.epi_tile[1]) + epi_tile_dQ = (self.epi_tile[0], epi_cols_dQ) + sdQ_epi_layout = sm100_utils.make_smem_layout_epi( + dq.element_type, + self.dq_layout, + epi_tile_dQ, + 1, + ) + tma_atom_dQ, tma_tensor_dQ = cpasync.make_tiled_tma_atom( + tma_store_op, + dq, + cute.select(sdQ_epi_layout, mode=[0, 1]), + epi_tile_dQ, + ) + + @cute.struct + class SharedStorage: + # TMA G2S load barriers: LOAD warp (producer) -> MMA warp (consumer) + load_q_mbar_ptr: cute.struct.MemRange[ + Int64, self.q_stage * 2 + ] # load_q_{producer,consumer} + load_do_mbar_ptr: cute.struct.MemRange[ + Int64, self.do_stage * 2 + ] # load_do_{producer,consumer} + load_k_mbar_ptr: cute.struct.MemRange[ + Int64, self.k_stage * 2 + ] # load_k_{producer,consumer} + load_kt_mbar_ptr: cute.struct.MemRange[ + Int64, self.kt_stage * 2 + ] # load_kt_{producer,consumer} + load_v_mbar_ptr: cute.struct.MemRange[ + Int64, self.v_stage * 2 + ] # load_v_{producer,consumer} + mma_s_mbar_ptr: cute.struct.MemRange[Int64, self.qk_acc_stage * 2] + mma_dp_mbar_ptr: cute.struct.MemRange[Int64, self.dov_acc_stage * 2] + mma_dq_mbar_ptr: cute.struct.MemRange[Int64, self.epi_stage * 2] + ds_mma_mbar_ptr: cute.struct.MemRange[Int64, self.dsk_acc_stage * 2] + lse_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_compute_LSE_stage * 2] + sum_odo_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_compute_sum_OdO_stage * 2 + ] + # A CTA-wide "TMEM lifetime" barrier used to safely deallocate TMEM after all users finish. + tmem_dealloc_mbar_ptr: Int64 + # Tmem holding buffer + tmem_holding_buf: Int32 + # CLC pipeline barriers and response buffer + clc_mbar_ptr: cute.struct.MemRange[Int64, 2] + clc_response: cute.struct.MemRange[Int32, 4] + + self.shared_storage = SharedStorage + + grid = cute.round_up(grid, self.cluster_shape_mnk) + # Launch the kernel synchronously + self.kernel( + qk_tiled_mma, + dov_tiled_mma, + dsk_tiled_mma, + tma_atom_q, + tma_tensor_q, + tma_atom_k, + tma_tensor_k, + tma_atom_v, + tma_tensor_v, + tma_atom_do, + tma_tensor_do, + tma_atom_kt, + tma_tensor_kt, + tma_atom_dQ, + tma_tensor_dQ, + lse, + sum_odo, + dq, + cum_seqlen_q, + cum_seqlen_k, + scale_softmax, + self.window_size_left, + self.window_size_right, + self.cluster_layout_vmnk, + q_smem_layout_staged, + k_smem_layout_staged, + v_smem_layout_staged, + do_smem_layout_staged, + kt_smem_layout_staged, + ds_tmem_layout, + sdQ_epi_layout, + lse_smem_layout, + sum_odo_smem_layout, + self.tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + min_blocks_per_mp=1, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + qk_tiled_mma: cute.TiledMma, + dov_tiled_mma: cute.TiledMma, + dsk_tiled_mma: cute.TiledMma, + tma_atom_q: cute.CopyAtom, + mQ_qdl: cute.Tensor, + tma_atom_k: cute.CopyAtom, + mK_kdl: cute.Tensor, + tma_atom_v: cute.CopyAtom, + mV_dkl: cute.Tensor, + tma_atom_do: cute.CopyAtom, + mdO_qdl: cute.Tensor, + tma_atom_kt: cute.CopyAtom, + mK_dkl: cute.Tensor, + tma_atom_dQ: cute.CopyAtom, + mdQ_tma: cute.Tensor, + mLSE: cute.Tensor, + mSum_OdO: cute.Tensor, + mdQ_qdl: cute.Tensor, + cum_seqlen_q: Optional[cute.Tensor], + cum_seqlen_k: Optional[cute.Tensor], + scale_softmax: Float32, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + cluster_layout_vmnk: cute.Layout, + q_smem_layout_staged: cute.ComposedLayout, + k_smem_layout_staged: cute.ComposedLayout, + v_smem_layout_staged: cute.ComposedLayout, + do_smem_layout_staged: cute.ComposedLayout, + kt_smem_layout_staged: cute.ComposedLayout, + ds_tmem_layout_staged: cute.ComposedLayout, + sdQ_epi_layout: cute.ComposedLayout, + lse_smem_layout: cute.Layout, + sum_odo_smem_layout: cute.Layout, + tile_sched_params: FmhaStaticTileSchedulerParams | FmhaClcDynamicTileSchedulerParams, + ): + # llvm.inline_asm( + # None, + # [], + # '.pragma "global knob CommonIntoMultiBlockLoop=1";', + # "", + # has_side_effects=True, + # is_align_stack=False, + # asm_dialect=llvm.AsmDialect.AD_ATT, + # ) + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # + # Prefetch tma desc + # + if warp_idx == self.load_warp_id: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_q) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_k) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_v) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_do) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_kt) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_dQ) + + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + varlen = cum_seqlen_q is not None or cum_seqlen_k is not None + mma_tile_coord_v = bidx % cute.size(qk_tiled_mma.thr_id.shape) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + + # Alloc + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + load_q_producer, load_q_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.q_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_q_bytes, + barrier_storage=storage.load_q_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_k_producer, load_k_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.k_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_k_bytes, + barrier_storage=storage.load_k_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_v_producer, load_v_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.v_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_v_bytes, + barrier_storage=storage.load_v_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_do_producer, load_do_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.do_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_do_bytes, + barrier_storage=storage.load_do_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_kt_producer, load_kt_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.kt_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_kt_bytes, + barrier_storage=storage.load_kt_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + mma_s_producer, mma_s_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.qk_acc_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + len(self.compute_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + barrier_storage=storage.mma_s_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + mma_dp_producer, mma_dp_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.dov_acc_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + len(self.compute_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + barrier_storage=storage.mma_dp_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + ds_mma_producer, ds_mma_consumer = pipeline.PipelineAsyncUmma.create( + num_stages=self.dsk_acc_stage, + producer_group=make_thread_cooperative_group( + len(self.compute_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.ds_mma_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + mma_dq_producer, mma_dq_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.epi_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + len(self.epilogue_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + barrier_storage=storage.mma_dq_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + + load_lse_producer, load_lse_consumer = pipeline.PipelineCpAsync.create( + num_stages=self.load_compute_LSE_stage, + producer_group=make_thread_cooperative_group(self.threads_per_warp), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * self.num_compute_warps + ), + barrier_storage=storage.lse_mbar_ptr.data_ptr(), + ).make_participants() + load_sum_odo_producer, load_sum_odo_consumer = pipeline.PipelineCpAsync.create( + num_stages=self.load_compute_sum_OdO_stage, + producer_group=make_thread_cooperative_group(self.threads_per_warp), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * self.num_compute_warps + ), + barrier_storage=storage.sum_odo_mbar_ptr.data_ptr(), + ).make_participants() + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilogue_warp_ids[0], + is_two_cta=True, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + tmem.allocate(self.tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + # Initialize CLC state if using dynamic scheduler + if cutlass.const_expr(self.use_clc_scheduler): + clc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + cluster_size = cute.size(self.cluster_shape_mnk) + num_clc_consumer_threads = self.threads_per_warp * ( + 1 # sched_warp (CTA 0 only) + + cluster_size + * ( + len(self.compute_warp_ids) + + len(self.epilogue_warp_ids) + + 1 # mma_warp + + 1 # load_warp + ) + ) + clc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_clc_consumer_threads + ) + clc_response_ptr = storage.clc_response.data_ptr() + clc = ClcState.create( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_sched_params.clc_hw_params(), + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + ), + pipeline=pipeline.PipelineClcFetchAsync.create( + barrier_storage=storage.clc_mbar_ptr.data_ptr(), + num_stages=self.num_clc_stage, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=self.num_clc_response_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ), + consumer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_clc_stage + ), + producer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_clc_stage + ), + ) + else: + clc = None + clc_response_ptr = None + + # Cluster arrive after barrier init + pipeline.pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True) + + sQ = smem.allocate_tensor( + element_type=self.q_dtype, + layout=q_smem_layout_staged.outer, + swizzle=q_smem_layout_staged.inner, + byte_alignment=128, + ) + sK = smem.allocate_tensor( + element_type=self.k_dtype, + layout=k_smem_layout_staged.outer, + swizzle=k_smem_layout_staged.inner, + byte_alignment=128, + ) + # K and V now use separate memory since we removed the transform stage + sV = smem.allocate_tensor( + element_type=self.v_dtype, + layout=v_smem_layout_staged.outer, + swizzle=v_smem_layout_staged.inner, + byte_alignment=128, + ) + sdO = smem.allocate_tensor( + element_type=self.do_dtype, + layout=do_smem_layout_staged.outer, + swizzle=do_smem_layout_staged.inner, + byte_alignment=128, + ) + sKT = smem.allocate_tensor( + element_type=self.k_dtype, + layout=kt_smem_layout_staged.outer, + swizzle=kt_smem_layout_staged.inner, + byte_alignment=128, + ) + sLSE = smem.allocate_tensor( + element_type=self.acc_dtype, + layout=lse_smem_layout, + byte_alignment=128, + ) + sSum_OdO = smem.allocate_tensor( + element_type=self.acc_dtype, + layout=sum_odo_smem_layout, + byte_alignment=128, + ) + # Alias the dQ TMA epilogue staging buffer onto sdO. A standalone + # (128, 64) bf16 allocation exceeds B200's per-block SMEM limit for + # this kernel, while sdO is no longer needed by the epilogue warp once + # the dQ accumulator is ready to store. + s_epi_dQ = cute.make_tensor( + cute.recast_ptr(sdO.iterator, sdQ_epi_layout.inner, self.dq_dtype), + sdQ_epi_layout.outer, + ) + qk_thr_mma = qk_tiled_mma.get_slice(mma_tile_coord_v) # default 1sm + dov_thr_mma = dov_tiled_mma.get_slice(mma_tile_coord_v) # default 1sm + dsk_thr_mma = dsk_tiled_mma.get_slice(mma_tile_coord_v) # default 1sm + tSrQ = qk_thr_mma.make_fragment_A(sQ) + tSrK = qk_thr_mma.make_fragment_B(sK) + tdPrdO = dov_thr_mma.make_fragment_A(sdO) + tdPrV = dov_thr_mma.make_fragment_B(sV) + tdQrKT = dsk_thr_mma.make_fragment_B(sKT) + qk_acc_shape = qk_thr_mma.partition_shape_C((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tStS = qk_thr_mma.make_fragment_C(cute.append(qk_acc_shape, self.qk_acc_stage)) + dov_acc_shape = dov_thr_mma.partition_shape_C( + (self.dov_mma_tiler[0], self.dov_mma_tiler[1]) + ) + tdPtdP = dov_thr_mma.make_fragment_C(cute.append(dov_acc_shape, self.dov_acc_stage)) + dsk_acc_shape = dsk_thr_mma.partition_shape_C( + (self.dsk_mma_tiler[0], self.dsk_mma_tiler[1]) + ) + tdQtdQ = dsk_thr_mma.make_fragment_C(dsk_acc_shape) + tdQtdQ_layout = cute.append( + tdQtdQ.layout, + cute.make_layout( + self.iterations_dsk, + stride=self.dsk_mma_tiler[1] // self.tmem_warp_shape_mn[1], + ), + ) + tStS = cute.make_tensor(tStS.iterator + self.tmem_s_offset, tStS.layout) + tdPtdP = cute.make_tensor(tdPtdP.iterator + self.tmem_dp_offset, tdPtdP.layout) + tdQtdQ_staged = cute.make_tensor(tdQtdQ.iterator + self.tmem_dq_offset, tdQtdQ_layout) + + # /////////////////////////////////////////////////////////////////////////////// + # EMPTY + # /////////////////////////////////////////////////////////////////////////////// + for _i in cutlass.range_constexpr(len(self.empty_warp_id)): + if warp_idx == self.empty_warp_id[_i]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + if cutlass.const_expr(self.use_clc_scheduler): + tile_sched = FmhaClcDynamicTileScheduler.create( + tile_sched_params, + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + clc, + ) + else: + blk_idx = cute.arch.block_idx() + tile_sched = FmhaStaticTileScheduler( + tile_sched_params, blk_idx[0], blk_idx, cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + # Cluster wait + pipeline.pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk) + + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + continue_cond = False + batch_coord = curr_block_coord[2][1] + seqlen_q = mQ_qdl.shape[0] + seqlen_k = mK_kdl.shape[0] + cuseqlen_q = Int32(0) + cuseqlen_k = Int32(0) + block_offset = ( + Int32(0), + Int32(0), + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + block_offset = ( + cuseqlen_q, + cuseqlen_k, + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if not continue_cond: + mQ_qdl_ = cute.domain_offset(cute.select(block_offset, mode=[0, 2, 3]), mQ_qdl) + mK_kdl_ = cute.domain_offset(cute.select(block_offset, mode=[1, 2, 3]), mK_kdl) + mdO_qdl_ = cute.domain_offset( + cute.select(block_offset, mode=[0, 2, 3]), mdO_qdl + ) + mV_dkl_ = cute.domain_offset(cute.select(block_offset, mode=[1, 2, 3]), mV_dkl) + mK_dkl_ = cute.domain_offset(cute.select(block_offset, mode=[2, 1, 3]), mK_dkl) + block_offset_stats = block_offset + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q_stats = cute.assume( + (cuseqlen_q + batch_coord * self.cta_tiler[0]) + // self.cta_tiler[0] + * self.cta_tiler[0], + divby=self.cta_tiler[0], + ) + block_offset_stats = ( + cuseqlen_q_stats, + block_offset[1], + block_offset[2], + block_offset[3], + ) + LSE = cute.domain_offset(cute.select(block_offset_stats, mode=[0, 3]), mLSE) + sum_OdO = cute.domain_offset( + cute.select(block_offset_stats, mode=[0, 3]), mSum_OdO + ) + + # Local tile partition global tensors + q_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # (bM, bK, loopM, loopK, loopL) + gQ_qdl = cute.flat_divide(mQ_qdl_, cute.select(self.qk_mma_tiler, mode=[0, 2])) + tSgQ_qdl = qk_thr_mma.partition_A(gQ_qdl) + tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_q, + block_in_cluster_coord_vmnk[2], + q_cta_layout, + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ_qdl, 0, 3), + ) + k_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + gK_kdl = cute.flat_divide(mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2])) + tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_k, + block_in_cluster_coord_vmnk[1], + k_cta_layout, + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_kdl, 0, 3), + ) + do_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # (bM, bK, loopM, loopK, loopL) + gdO_qdl = cute.flat_divide( + mdO_qdl_, cute.select(self.dov_mma_tiler, mode=[0, 2]) + ) + tdPgdO_qdl = dov_thr_mma.partition_A(gdO_qdl) + tdOsdO, tdOgdO_qdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_do, + block_in_cluster_coord_vmnk[2], + do_cta_layout, + cute.group_modes(sdO, 0, 3), + cute.group_modes(tdPgdO_qdl, 0, 3), + ) + v_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + gV_dkl = cute.flat_divide(mV_dkl_, cute.select(self.dov_mma_tiler, mode=[1, 2])) + tSgV_dkl = dov_thr_mma.partition_B(gV_dkl) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_v, + block_in_cluster_coord_vmnk[1], + v_cta_layout, + cute.group_modes(sV, 0, 3), + cute.group_modes(tSgV_dkl, 0, 3), + ) + # kt layout + kt_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + gK_dkl = cute.flat_divide(mK_dkl_, cute.select(self.dsk_mma_tiler, mode=[1, 2])) + tdQgK_dkl = dsk_thr_mma.partition_B(gK_dkl) + tKTsKT, tKgK_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_kt, + block_in_cluster_coord_vmnk[1], + kt_cta_layout, + cute.group_modes(sKT, 0, 3), + cute.group_modes(tdQgK_dkl, 0, 3), + ) + # ((atom_v, rest_v), RestK) + tQgQ = tQgQ_qdl[None, mma_block_coord[0], None, mma_block_coord[2]] + # ((atom_v, rest_v), RestK) + tdOgdO = tdOgdO_qdl[None, mma_block_coord[0], None, mma_block_coord[2]] + # ((atom_v, rest_v), RestN, RestK) + tKgK = tKgK_kdl[None, None, None, mma_block_coord[2]] + # ((atom_v, rest_v), RestN, RestK) + tVgV = tVgV_dkl[None, None, None, mma_block_coord[2]] + # ((atom_v, rest_v), RestN, RestK) + tKTgKT = tKgK_dkl[None, None, None, mma_block_coord[2]] + + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + # LSE + lse_handle = load_lse_producer.acquire_and_advance() + # 32 threads loading 128 values of 32b each + # so 4*32b = 128b + thread_idx = tidx % self.threads_per_warp + async_copy_num_elts = sLSE.shape[0] // self.threads_per_warp + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + self.acc_dtype, + num_bits_per_copy=self.acc_dtype.width, + ) + sLSE_for_copy = cute.flat_divide(sLSE, (1,)) + LSE_for_copy = cute.flat_divide(LSE, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + LSE_idx = ( + self.cta_tiler[0] * curr_block_coord[0] + + thread_idx * async_copy_num_elts + ) + if cute.elem_less(LSE_idx + i, seqlen_q): + cute.copy( + atom_async_copy, + LSE_for_copy[None, LSE_idx + i, curr_block_coord[2]], + sLSE_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + lse_handle.index, + ], + ) + else: + sLSE_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + lse_handle.index, + ].fill(0.0) + lse_handle.commit() + + sum_odo_handle = load_sum_odo_producer.acquire_and_advance() + sSum_OdO_for_copy = cute.flat_divide(sSum_OdO, (1,)) + sum_OdO_for_copy = cute.flat_divide(sum_OdO, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + sum_OdO_idx = ( + self.cta_tiler[0] * curr_block_coord[0] + + thread_idx * async_copy_num_elts + ) + if cute.elem_less(sum_OdO_idx + i, seqlen_q): + cute.copy( + atom_async_copy, + sum_OdO_for_copy[None, sum_OdO_idx + i, curr_block_coord[2]], + sSum_OdO_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + sum_odo_handle.index, + ], + ) + else: + sSum_OdO_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + sum_odo_handle.index, + ].fill(0.0) + sum_odo_handle.commit() + + # Q + for iter in cutlass.range(self.iterations_qk, unroll=1): + q_handle = load_q_producer.acquire_and_advance() + cute.copy( + tma_atom_q, + tQgQ[None, iter], + tQsQ[None, q_handle.index], + tma_bar_ptr=q_handle.barrier, + ) + # dO + for iter in cutlass.range(self.iterations_dov, unroll=1): + do_handle = load_do_producer.acquire_and_advance() + cute.copy( + tma_atom_do, + tdOgdO[None, iter], + tdOsdO[None, do_handle.index], + tma_bar_ptr=do_handle.barrier, + ) + + kv_coord = seqlen_kv_loop_start + for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # Ki + for iter in cutlass.range(self.iterations_qk, unroll=1): + k_handle = load_k_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[None, kv_coord, iter], + tKsK[None, k_handle.index], + tma_bar_ptr=k_handle.barrier, + ) + # Vi + for iter in cutlass.range(self.iterations_dov, unroll=1): + v_handle = load_v_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[None, kv_coord, iter], + tVsV[None, v_handle.index], + tma_bar_ptr=v_handle.barrier, + ) + # KTi + for iter in cutlass.range(self.iterations_dsk, unroll=1): + kt_handle = load_kt_producer.acquire_and_advance() + cute.copy( + tma_atom_kt, + tKTgKT[None, iter, kv_coord], + tKTsKT[None, kt_handle.index], + tma_bar_ptr=kt_handle.barrier, + ) + kv_coord += 1 + + work_tile = tile_sched.advance_to_next_work() + # End of persistent scheduler loop + load_k_producer.tail() + load_v_producer.tail() + load_kt_producer.tail() + load_q_producer.tail() + load_do_producer.tail() + load_lse_producer.tail() + load_sum_odo_producer.tail() + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + continue_cond = False + seqlen_q = mQ_qdl.shape[0] + seqlen_k = mK_kdl.shape[0] + batch_coord = curr_block_coord[2][1] + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + # dq_handle = mma_dq_producer.acquire_and_advance() + load_q_releaser = load_q_consumer.clone() + load_do_releaser = load_do_consumer.clone() + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + num_innerloop = 8 + + if is_leader_cta: + dq_handle = mma_dq_producer.acquire_and_advance() + if seqlen_kv_loop_steps > 1: + # QK0 + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + load_q_consumer.wait_and_advance() + tSrQ_slice = tSrQ[None, None, None, iter] + + k_handle = load_k_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + cute.arch.fence_view_async_tmem_store() + s_handle.commit() + + # dOV0 + dp_handle = mma_dp_producer.acquire_and_advance() + tdPtdP_slice = tdPtdP[None, None, None, dp_handle.index] + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_dov, unroll=1): + load_do_consumer.wait_and_advance() + tdPrdO_slice = tdPrdO[None, None, None, iter] + v_handle = load_v_consumer.wait_and_advance() + tdPrV_trans_slice = tdPrV[None, None, None, v_handle.index] + num_kphases = cute.size(tdPrdO_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + cute.arch.fence_view_async_tmem_store() + dp_handle.commit() + + for i in cutlass.range(1, seqlen_kv_loop_steps - 1, 1, unroll=1): + # QKi + s_handle = mma_s_producer.acquire_and_advance() + + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_k_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range( + num_kphases, unroll_full=True + ): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + s_handle.commit() + + # dSKTi + ds_handle = ds_mma_consumer.wait_and_advance() + dsk_whether_acc = dsk_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_dsk, unroll=1): + kt_handle = load_kt_consumer.wait_and_advance() + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, dsk_whether_acc) + tdQtdQ_slice = tdQtdQ_staged[None, None, None, iter] + tdStdS_slice = tdPtdP[None, None, None, ds_handle.index] + tdS = cute.make_tensor( + tdStdS_slice.iterator, ds_tmem_layout_staged.outer + ) + tdQrdS = dsk_thr_mma.make_fragment_A(tdS) + tdQrdS_slice = cute.make_tensor( + cute.recast_ptr(tdStdS_slice.iterator, dtype=self.q_dtype), + tdQrdS.layout, + ) + + tdQrKT_slice = tdQrKT[None, None, None, kt_handle.index] + num_kphases = cute.size(tdQrKT_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range( + num_kphases, unroll_full=True + ): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + kt_handle.release() + ds_handle.release() + + # dOVi + dp_handle = mma_dp_producer.acquire_and_advance() + tdPtdP_slice = tdPtdP[None, None, None, dp_handle.index] + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_dov, unroll=1): + tdPrdO_slice = tdPrdO[None, None, None, iter] + v_handle = load_v_consumer.wait_and_advance() + tdPrV_trans_slice = tdPrV[None, None, None, v_handle.index] + num_kphases = cute.size(tdPrdO_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range( + num_kphases, unroll_full=True + ): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + dp_handle.commit() + + # QKend + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_k_consumer.wait_and_advance() + + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + + num_kphases = cute.size(tSrQ_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + load_q_releaser.release() + load_q_releaser.advance() + s_handle.commit() + + # dSKTend - 1 + ds_handle = ds_mma_consumer.wait_and_advance() + dsk_whether_acc = dsk_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_dsk, unroll=1): + kt_handle = load_kt_consumer.wait_and_advance() + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, dsk_whether_acc) + tdQtdQ_slice = tdQtdQ_staged[None, None, None, iter] + tdStdS_slice = tdPtdP[None, None, None, ds_handle.index] + tdS = cute.make_tensor( + tdStdS_slice.iterator, ds_tmem_layout_staged.outer + ) + tdQrdS = dsk_thr_mma.make_fragment_A(tdS) + tdQrdS_slice = cute.make_tensor( + cute.recast_ptr(tdStdS_slice.iterator, dtype=self.q_dtype), + tdQrdS.layout, + ) + + tdQrKT_slice = tdQrKT[None, None, None, kt_handle.index] + num_kphases = cute.size(tdQrKT_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + kt_handle.release() + ds_handle.release() + + # dOVend + dp_handle = mma_dp_producer.acquire_and_advance() + tdPtdP_slice = tdPtdP[None, None, None, dp_handle.index] + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_dov, unroll=1): + tdPrdO_slice = tdPrdO[None, None, None, iter] + v_handle = load_v_consumer.wait_and_advance() + tdPrV_trans_slice = tdPrV[None, None, None, v_handle.index] + num_kphases = cute.size(tdPrdO_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + load_do_releaser.release() + load_do_releaser.advance() + dp_handle.commit() + # dSKTend + ds_handle = ds_mma_consumer.wait_and_advance() + dsk_whether_acc = dsk_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_dsk, unroll=1): + kt_handle = load_kt_consumer.wait_and_advance() + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, dsk_whether_acc) + tdQtdQ_slice = tdQtdQ_staged[None, None, None, iter] + tdStdS_slice = tdPtdP[None, None, None, ds_handle.index] + tdS = cute.make_tensor( + tdStdS_slice.iterator, ds_tmem_layout_staged.outer + ) + tdQrdS = dsk_thr_mma.make_fragment_A(tdS) + tdQrdS_slice = cute.make_tensor( + cute.recast_ptr(tdStdS_slice.iterator, dtype=self.q_dtype), + tdQrdS.layout, + ) + + tdQrKT_slice = tdQrKT[None, None, None, kt_handle.index] + num_kphases = cute.size(tdQrKT_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + kt_handle.release() + ds_handle.release() + else: + # QK0 + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + for iter in cutlass.range(self.iterations_qk, unroll=1): + load_q_consumer.wait_and_advance() + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_k_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + load_q_releaser.release() + load_q_releaser.advance() + s_handle.commit() + + # dOV0 + dp_handle = mma_dp_producer.acquire_and_advance() + tdPtdP_slice = tdPtdP[None, None, None, dp_handle.index] + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_dov, unroll=1): + load_do_consumer.wait_and_advance() + tdPrdO_slice = tdPrdO[None, None, None, iter] + v_handle = load_v_consumer.wait_and_advance() + tdPrV_trans_slice = tdPrV[None, None, None, v_handle.index] + num_kphases = cute.size(tdPrdO_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + load_do_releaser.release() + load_do_releaser.advance() + dp_handle.commit() + + # dSKT0 + ds_handle = ds_mma_consumer.wait_and_advance() + dsk_whether_acc = dsk_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_dsk, unroll=1): + kt_handle = load_kt_consumer.wait_and_advance() + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, dsk_whether_acc) + tdQtdQ_slice = tdQtdQ_staged[None, None, None, iter] + tdStdS_slice = tdPtdP[None, None, None, ds_handle.index] + tdS = cute.make_tensor( + tdStdS_slice.iterator, ds_tmem_layout_staged.outer + ) + tdQrdS = dsk_thr_mma.make_fragment_A(tdS) + tdQrdS_slice = cute.make_tensor( + cute.recast_ptr(tdStdS_slice.iterator, dtype=self.q_dtype), + tdQrdS.layout, + ) + + tdQrKT_slice = tdQrKT[None, None, None, kt_handle.index] + num_kphases = cute.size(tdQrKT_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + kt_handle.release() + ds_handle.release() + dq_handle.commit() + work_tile = tile_sched.advance_to_next_work() + # End of persistent scheduler loop + mma_s_producer.tail() + mma_dp_producer.tail() + mma_dq_producer.tail() + + # Softmax and dSoftmax warp + if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]: + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.num_regs_compute) + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + batch_coord = curr_block_coord[2][1] + continue_cond = False + seqlen_q = mQ_qdl.shape[0] + seqlen_k = mK_kdl.shape[0] + cuseqlen_q = Int32(0) + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + end_count = start_count + trip_count + if cutlass.const_expr(self.use_semantic_trip_range): + n_block_min_causal_local_mask, n_block_min_before_local_mask = ( + FusedMask.get_trip_mask_bounds_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + + cS_base = cute.make_identity_tensor( + (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) + ) + cS = cute.domain_offset((mma_block_coord[0] * self.qk_mma_tiler[0], 0), cS_base) + tScS = qk_thr_mma.partition_C(cS) + + cdP_base = cute.make_identity_tensor( + (self.dov_mma_tiler[0], self.dov_mma_tiler[1]) + ) + cdP = cute.domain_offset( + (mma_block_coord[0] * self.dov_mma_tiler[0], 0), cdP_base + ) + tdPcdP = dov_thr_mma.partition_C(cdP) + + lse_handle = load_lse_consumer.wait_and_advance() + sum_odo_handle = load_sum_odo_consumer.wait_and_advance() + for step in cutlass.range(start_count, end_count, 1, unroll=1): + cS_iter = cute.domain_offset((0, step * self.qk_mma_tiler[1]), cS) + tScS_iter = qk_thr_mma.partition_C(cS_iter) + + cdP_iter = cute.domain_offset((0, step * self.dov_mma_tiler[1]), cdP) + + tdPcdP_iter = dov_thr_mma.partition_C(cdP_iter) + + # Si, dPi -> dSi + if cutlass.const_expr(self.use_semantic_trip_range): + need_apply_mask = ( + step >= n_block_min_causal_local_mask + or step < n_block_min_before_local_mask + ) + else: + need_apply_mask = step == end_count - 1 + mma_s_consumer, mma_dp_consumer, ds_mma_producer = self.compute_step( + (need_apply_mask, window_size_left, window_size_right), + ( + seqlen_q, + seqlen_k, + scale_softmax, + batch_coord, + curr_block_coord[0], + varlen, + ), + (tStS, tScS_iter, tdPtdP, tdPcdP_iter, sLSE, sSum_OdO), + ( + mma_s_consumer, + mma_dp_consumer, + ds_mma_producer, + lse_handle, + sum_odo_handle, + ), + step, + ) + lse_handle.release() + sum_odo_handle.release() + work_tile = tile_sched.advance_to_next_work() + ds_mma_producer.tail() + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_epilogue) + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + batch_coord = curr_block_coord[2][1] + # cute.printf("batch_coord={}", batch_coord) + seqlen_q = mQ_qdl.shape[0] + seqlen_k = mK_kdl.shape[0] + continue_cond = False + cuseqlen_q = Int32(0) + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + mdQ_qdl_eff = mdQ_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + block_offset_dQ = ( + cuseqlen_q, + Int32(0), + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + mdQ_qdl_eff = cute.domain_offset( + cute.select(block_offset_dQ, mode=[0, 2, 3]), mdQ_qdl + ) + + # (bM, bN, loopM, loopN, loopL) + gdQ_qdl = cute.flat_divide( + mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + cdQ_qdl = cute.flat_divide( + cute.make_identity_tensor(mdQ_qdl_eff.shape), + cute.select(self.dsk_block_tiler, mode=[0, 1]), + ) + + gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + gdQ_tma_staged = gdQ_staged + if cutlass.const_expr(not varlen): + gdQ_tma_qdl = cute.flat_divide( + mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + gdQ_tma_staged = gdQ_tma_qdl[ + None, None, curr_block_coord[0], None, curr_block_coord[2] + ] + + # dQ TMEM to GMEM + mma_dq_consumer = self.dQ_epilogue( + (seqlen_q, cuseqlen_q, mQ_qdl.shape[0], batch_coord), + (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged), + self.epi_tile, + (tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen), + ) + work_tile = tile_sched.advance_to_next_work() + # NOTE: tmem.free() moved to kernel end to enable cluster-wide sync + + # /////////////////////////////////////////////////////////////////////////////// + # Scheduler Warp (only for CLC dynamic scheduler) + # /////////////////////////////////////////////////////////////////////////////// + if cutlass.const_expr(self.use_clc_scheduler): + is_first_cta_in_cluster = cta_rank_in_cluster == 0 + + if warp_idx == self.sched_warp_id and is_first_cta_in_cluster: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + while work_tile.is_valid_tile: + tile_sched.prefetch_next_work() + work_tile = tile_sched.advance_to_next_work() + tile_sched.producer_tail() + + # /////////////////////////////////////////////////////////////////////////////// + # Empty warps reg dealloc + # /////////////////////////////////////////////////////////////////////////////// + if cutlass.const_expr(self.use_clc_scheduler): + if warp_idx > self.load_warp_id: + if not (warp_idx == self.sched_warp_id and is_first_cta_in_cluster): + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + else: + if warp_idx > self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + # /////////////////////////////////////////////////////////////////////////////// + # Cooperative TMEM Deallocation (2CTA) + # /////////////////////////////////////////////////////////////////////////////// + # All warps (including scheduler) have finished by this point. + # Cluster-wide sync ensures both CTAs reach here before dealloc. + cute.arch.cluster_arrive() + cute.arch.cluster_wait() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + return + + @cute.jit + def compute_step( + self, + mask_args: Tuple, + value_args: Tuple, + tensor_args: Tuple, + pipeline_args: Tuple, + step: Int32, + ) -> Tuple[Float32, Float32, pipeline.PipelineConsumer, pipeline.PipelineProducer]: + need_apply_mask, window_size_left, window_size_right = mask_args + seqlen_q, seqlen_k, scale_softmax, batch_coord, block_m_idx, varlen = value_args + tStS, tScS, tdPtdP, tdPcdP, sLSE, sSum_OdO = tensor_args + mma_s_consumer, mma_dp_consumer, ds_mma_producer, lse_handle, sum_odo_handle = pipeline_args + + bidx = block_m_idx + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.compute_warp_ids)) + s_handle = mma_s_consumer.wait_and_advance() + tStS_slice = tStS[(None, None), 0, 0, s_handle.index] + tScS_slice = tScS[(None, None), 0, 0] + tmem_load_atom = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition(16)), self.acc_dtype + ) + tmem_tiled_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS_slice) + thr_load = tmem_tiled_load.get_slice(thread_idx) + tTMEM_LOADtS = thr_load.partition_S(tStS_slice) + tTMEM_LOADcS = thr_load.partition_D(tScS_slice) + tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.acc_dtype) + cute.copy(tmem_tiled_load, tTMEM_LOADtS, tTMEM_LOADrS) + cute.arch.fence_view_async_tmem_load() + s_handle.release() + if need_apply_mask: + FusedMask.apply_mask_via_causal_local( + tTMEM_LOADrS, + tTMEM_LOADcS, + seqlen_q, + seqlen_k, + self.use_semantic_trip_range, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + + log2_e = cutlass.Float32(math.log2(math.e)) + softmax_scale_log2_e = scale_softmax * log2_e + tTMEM_STORErP = cute.make_rmem_tensor(tTMEM_LOADrS.shape, self.q_dtype) + for k in cutlass.range(0, cute.size(tTMEM_LOADrS), 2, unroll_full=True): + lse = ( + -sLSE[ + cute.get(tTMEM_LOADcS[k], mode=[0]) - bidx * self.cta_tiler[0], + lse_handle.index, + ], + -sLSE[ + cute.get(tTMEM_LOADcS[k + 1], mode=[0]) - bidx * self.cta_tiler[0], + lse_handle.index, + ], + ) + tTMEM_LOADrS[k], tTMEM_LOADrS[k + 1] = cute.arch.fma_packed_f32x2( + (tTMEM_LOADrS[k], tTMEM_LOADrS[k + 1]), + (softmax_scale_log2_e, softmax_scale_log2_e), + lse, + ) + tTMEM_LOADrS[k] = cute.math.exp2(tTMEM_LOADrS[k], fastmath=True) + tTMEM_LOADrS[k + 1] = cute.math.exp2(tTMEM_LOADrS[k + 1], fastmath=True) + + dp_handle = mma_dp_consumer.wait_and_advance() + tdPtdP_slice = tdPtdP[(None, None), 0, 0, dp_handle.index] + tdPcdP_slice = tdPcdP[(None, None), 0, 0] + thr_load = tmem_tiled_load.get_slice(thread_idx) + tTMEM_LOADtdP = thr_load.partition_S(tdPtdP_slice) + tTMEM_LOADcdP = thr_load.partition_D(tdPcdP_slice) + tTMEM_LOADrdP = cute.make_rmem_tensor(tTMEM_LOADcdP.shape, self.acc_dtype) + cute.copy(tmem_tiled_load, tTMEM_LOADtdP, tTMEM_LOADrdP) + cute.arch.fence_view_async_tmem_load() + dp_handle.release() + tTMEM_STORErdP = cute.make_rmem_tensor(tTMEM_LOADrdP.shape, self.q_dtype) + + for k in cutlass.range(0, cute.size(tTMEM_LOADrdP), 2, unroll_full=True): + dpsum_0 = -sSum_OdO[ + cute.get(tTMEM_LOADcdP[k], mode=[0]) - bidx * self.cta_tiler[0], + sum_odo_handle.index, + ] + dpsum_1 = -sSum_OdO[ + cute.get(tTMEM_LOADcdP[k + 1], mode=[0]) - bidx * self.cta_tiler[0], + sum_odo_handle.index, + ] + if cutlass.const_expr(varlen): + if not cute.elem_less(cute.get(tTMEM_LOADcdP[k], mode=[0]), seqlen_q): + dpsum_0 = 0.0 + if not cute.elem_less(cute.get(tTMEM_LOADcdP[k + 1], mode=[0]), seqlen_q): + dpsum_1 = 0.0 + tTMEM_LOADrdP[k], tTMEM_LOADrdP[k + 1] = cute.arch.add_packed_f32x2( + (tTMEM_LOADrdP[k], tTMEM_LOADrdP[k + 1]), + (dpsum_0, dpsum_1), + ) + tTMEM_LOADrdP[k], tTMEM_LOADrdP[k + 1] = cute.arch.mul_packed_f32x2( + (tTMEM_LOADrdP[k], tTMEM_LOADrdP[k + 1]), (tTMEM_LOADrS[k], tTMEM_LOADrS[k + 1]) + ) + tTMEM_LOADrdP[k], tTMEM_LOADrdP[k + 1] = cute.arch.mul_packed_f32x2( + (tTMEM_LOADrdP[k], tTMEM_LOADrdP[k + 1]), (scale_softmax, scale_softmax) + ) + dp_vec = tTMEM_LOADrdP.load() + tTMEM_STORErdP.store(dp_vec.to(self.q_dtype)) + + ds_handle = ds_mma_producer.acquire_and_advance() + tmem_store_atom = cute.make_copy_atom( + tcgen05.St32x32bOp(tcgen05.Repetition(32)), self.acc_dtype + ) + tilePlikeFP32 = tdPtdP_slice.shape[1] // Float32.width * self.q_dtype.width + tdPtdP_dS_layout = cute.composition( + tdPtdP_slice.layout, cute.make_layout((tdPtdP_slice.shape[0], tilePlikeFP32)) + ) + tdPtdP_dS = cute.make_tensor(tdPtdP_slice.iterator, tdPtdP_dS_layout) + tdPcdP_dS_layout = cute.composition( + tdPcdP_slice.layout, cute.make_layout((tdPcdP_slice.shape[0], tilePlikeFP32)) + ) + tdPcdP_dS = cute.make_tensor(tdPcdP_slice.iterator, tdPcdP_dS_layout) + tmem_tiled_store = tcgen05.make_tmem_copy(tmem_store_atom, tdPtdP_dS) + + thr_store = tmem_tiled_store.get_slice(thread_idx) + tTMEM_STOREtdS = thr_store.partition_D(tdPtdP_dS) + tTMEM_STOREcdP = thr_store.partition_S(tdPcdP_dS) + tTMEM_STORErdS_ = cute.make_tensor( + cute.recast_ptr(tTMEM_STORErdP.iterator, dtype=self.acc_dtype), + tTMEM_STOREcdP.shape, + ) + cute.copy(tmem_tiled_store, tTMEM_STORErdS_, tTMEM_STOREtdS) + cute.arch.fence_view_async_tmem_store() + ds_handle.commit() + return mma_s_consumer, mma_dp_consumer, ds_mma_producer + + @cute.jit + def dQ_epilogue( + self, + value_args: Tuple, + dq_args: Tuple, + epi_tile: cute.Tile, + tma_args: Tuple, + ) -> Tuple[pipeline.PipelineConsumer, pipeline.PipelineProducer]: + seqlen_q, cuseqlen_q, total_q, batch_coord = value_args + (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged) = dq_args + tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen = tma_args + dq_handle = mma_dq_consumer.wait_and_advance() + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + cute.arch.fence_view_async_shared() + + epi_cols_dQ = math.gcd(128 // (self.dq_dtype.width // 8), epi_tile[1]) + num_epi_stages_dQ = epi_tile[1] // epi_cols_dQ + epi_tile_dQ = (epi_tile[0], epi_cols_dQ) + leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0 + + for iter in cutlass.range(self.iterations_dsk): + gdQ = gdQ_staged[None, None, iter] + cdQ = cdQ_staged[None, None, iter] + tdQtdQ = tdQtdQ_staged[(None, None), 0, 0, iter] + tdQtdQ_epi = cute.zipped_divide(tdQtdQ, epi_tile) + cdQ_epi = cute.zipped_divide(cdQ, epi_tile) + gdQ_epi = cute.zipped_divide(gdQ, epi_tile) + cdQ_local = cute.make_identity_tensor(epi_tile) + cdQ_local_epi = cute.zipped_divide(cdQ_local, epi_tile) + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.epilogue_warp_ids)) + tmem_copy_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tdQtdQ_epi) + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + tTMEM_LOADtdQ = thr_tmem_load.partition_S(tdQtdQ_epi) + tTMEM_LOADgdQ = thr_tmem_load.partition_D(gdQ_epi) + tTMEM_LOADcdQ = thr_tmem_load.partition_D(cdQ_epi) + tTMEM_LOADcdQ_local = thr_tmem_load.partition_D(cdQ_local_epi) + + if cutlass.const_expr(not varlen): + gdQ_tma = gdQ_tma_staged[None, None, iter] + gdQ_tma_epi = cute.local_tile(gdQ_tma, epi_tile_dQ, (0, None)) + sdQ_stage = s_epi_dQ[None, None, 0] + + for stage_k in cutlass.range_constexpr(num_epi_stages_dQ): + for i in cutlass.range(cute.size(tTMEM_LOADtdQ, mode=[1]), unroll_full=True): + tTMEM_LOADtdQ_i = tTMEM_LOADtdQ[None, i, 0] + tTMEM_LOADcdQ_i_local = tTMEM_LOADcdQ_local[None, i, 0] + tTMrdQ = cute.make_rmem_tensor(tTMEM_LOADcdQ_i_local.shape, self.acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtdQ_i, tTMrdQ) + tSMrdQ = cute.make_rmem_tensor(tTMrdQ.shape, self.q_dtype) + dq_vec = tTMrdQ.load() + tSMrdQ.store(dq_vec.to(self.q_dtype)) + for j in cutlass.range_constexpr(cute.size(tTMEM_LOADcdQ_i_local)): + c = tTMEM_LOADcdQ_i_local[j] + m_pos = c[0] + n_pos = c[1] + if n_pos // epi_cols_dQ == stage_k: + s_epi_dQ[m_pos, n_pos % epi_cols_dQ, 0] = tSMrdQ[j] + + cute.arch.fence_view_async_shared() + cute.arch.barrier(barrier_id=3, number_of_threads=128) + + if leader_warp: + gdQ_stage = gdQ_tma_epi[None, None, stage_k] + td_sdQ, td_gdQ = cpasync.tma_partition( + tma_atom_dQ, + 0, + cute.make_layout(1), + cute.group_modes(sdQ_stage, 0, 2), + cute.group_modes(gdQ_stage, 0, 2), + ) + cute.copy(tma_atom_dQ, td_sdQ, td_gdQ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + # Non-issuing threads must not rotate the SMEM buffer + # until the leader warp's TMA read has completed. + cute.arch.barrier(barrier_id=3, number_of_threads=128) + else: + for i in cutlass.range(cute.size(tTMEM_LOADtdQ, mode=[1]), unroll_full=True): + tTMEM_LOADtdQ_i = tTMEM_LOADtdQ[None, i, 0] + tTMEM_LOADgdQ_i = tTMEM_LOADgdQ[None, i, 0] + tTMEM_LOADcdQ_i = tTMEM_LOADcdQ[None, i, 0] + tTMrdQ = cute.make_rmem_tensor(tTMEM_LOADcdQ[None, 0, i].shape, self.acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtdQ_i, tTMrdQ) + tSMrdQ = cute.make_rmem_tensor(tTMrdQ.shape, self.q_dtype) + dq_vec = tTMrdQ.load() + tSMrdQ.store(dq_vec.to(self.q_dtype)) + if cute.elem_less(tTMEM_LOADcdQ_i[0][0], seqlen_q): + cute.autovec_copy(tSMrdQ, tTMEM_LOADgdQ_i) + dq_handle.release() + return mma_dq_consumer diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py new file mode 100644 index 00000000000..28087125f47 --- /dev/null +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py @@ -0,0 +1,1872 @@ +# Copyright (c) 2025, Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms, Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin. + +import math +from typing import Tuple, Optional + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.typing import Int32, Int64, Float32 + +from cutlass.utils import ClcDynamicPersistentTileScheduler +from flash_attn.cute.tile_scheduler import ( + ClcState, + compute_sm100_fmha_grid as compute_grid, + compute_sm100_fmha_grid_clc as compute_grid_clc, + make_sm100_thread_cooperative_group as make_thread_cooperative_group, + Sm100FmhaStaticTileScheduler as FmhaStaticTileScheduler, + Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams, + Sm100FmhaClcDynamicTileScheduler as FmhaClcDynamicTileScheduler, + Sm100FmhaClcDynamicTileSchedulerParams as FmhaClcDynamicTileSchedulerParams, +) +from flash_attn.cute.mask import ( + Sm100FusedMask as FusedMask, +) +from flash_attn.cute.tile_scheduler import SM100_TMEM_CAPACITY_COLUMNS +from flash_attn.cute.flash_fwd_sm100 import DescaleTensors, _TUNING_CONFIG +from flash_attn.cute.utils import ex2_emulation_2 + + +class BlackwellFusedMultiHeadAttentionForward: + def __init__( + self, + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: int = 1, + is_causal: bool = False, + is_local: bool = False, + is_split_kv: bool = False, + pack_gqa: bool = False, + q_subtile_factor: int | None = None, + m_block_size: int = 128, + n_block_size: int = 128, + q_stage: int = 2, + is_persistent: bool = True, + score_mod=None, + mask_mod=None, + has_aux_tensors: bool = False, + paged_kv_non_tma: bool = False, + is_varlen_q: bool = False, + use_2cta_instrs: bool = False, + use_clc_scheduler: bool = False, + ): + head_dim_v = head_dim if head_dim_v is None else head_dim_v + assert head_dim == 256 and head_dim_v == 256, ( + "SM100 dedicated kernel only supports (head_dim, head_dim_v) = (256, 256)" + ) + assert score_mod is None, "SM100 forward with head_dim=256 does not support score_mod" + assert mask_mod is None, "SM100 forward with head_dim=256 does not support mask_mod" + assert not has_aux_tensors, "SM100 forward with head_dim=256 does not support aux tensors" + assert not paged_kv_non_tma, ( + "SM100 hd256 2CTA supports TMA paged KV only (page_size must equal tile_n=128)" + ) + assert not pack_gqa, "SM100 forward with head_dim=256 does not support pack_gqa" + assert not is_split_kv, "SM100 forward with head_dim=256 does not support SplitKV" + assert q_subtile_factor is None, ( + "SM100 forward with head_dim=256 does not support q_subtile_factor" + ) + assert m_block_size == 128 and n_block_size == 128, ( + "SM100 dedicated kernel only supports tile_m=128 and tile_n=128" + ) + # q_stage / persistence / scheduler knobs are accepted for interface parity, + # but this dedicated kernel uses fixed internal settings. + + qk_acc_dtype = cutlass.Float32 + pv_acc_dtype = cutlass.Float32 + mma_tiler = (128, 128, head_dim) + self.qk_acc_dtype = qk_acc_dtype + self.pv_acc_dtype = pv_acc_dtype + self.qhead_per_kvhead = qhead_per_kvhead + self.mma_tiler = mma_tiler + assert mma_tiler[0] == 128 and mma_tiler[1] == 128, "Only 128x128 tile impl is supported" + assert mma_tiler[2] == 256, "Only 256 is supported for 128x128 tile impl" + self.cta_tiler = ( + mma_tiler[0], + mma_tiler[1], + mma_tiler[2], + ) + self.qk_mma_tiler = ( + 2 * mma_tiler[0], + mma_tiler[1], + min(self.cta_tiler[2], 128), + ) + self.pv_mma_tiler = self.qk_mma_tiler + self.pv_block_tiler = ( + self.pv_mma_tiler[0] // 2, + self.pv_mma_tiler[1], + self.pv_mma_tiler[2], + ) + self.iterations_qk = self.cta_tiler[2] // self.qk_mma_tiler[2] + self.iterations_pv = self.cta_tiler[2] // self.pv_mma_tiler[1] + self.cluster_shape_mn = (2, 1) + self.tmem_warp_shape_mn = (4, 1) + # Dedicated hd256 kernel uses fixed scheduling policy. + self.is_persistent = False + self.is_causal = is_causal + self.is_local = is_local + self.use_semantic_trip_range = is_causal or is_local + self.use_clc_scheduler = False + + self.softmax_warp_ids = (0, 1, 2, 3) + self.correction_warp_ids = (4, 5, 6, 7) + self.mma_warp_id = 8 + self.load_warp_id = 9 + self.empty_warp_id = (10, 11) + self.sched_warp_id = self.empty_warp_id[0] if use_clc_scheduler else None + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.softmax_warp_ids, # this is to get a round num threads + *self.correction_warp_ids, + self.mma_warp_id, + self.load_warp_id, + *self.empty_warp_id, + ) + ) + + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + + self.tmem_s_offset = 0 + self.tmem_o_offset = 256 + self.tmem_p_offset = self.tmem_s_offset + + _tune_key = (True, is_causal, 256, False) # hd256: always 2cta, no sm103 variant + _tune = _TUNING_CONFIG.get(_tune_key, {}) + self.num_regs_softmax = _tune.get("num_regs_softmax", 256) + self.num_regs_correction = _tune.get("num_regs_correction", 160) + self.num_regs_other = 32 # fixed for hd256; not derived from 512 budget like other kernels + self.ex2_emu_freq = _tune.get("ex2_emu_freq", 4) + self.ex2_emu_res = _tune.get("ex2_emu_res", 3) + self.ex2_emu_start_frg = _tune.get("ex2_emu_start_frg", 0) + + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + self.q_stage = self.iterations_qk + self.kv_stage = 4 + self.qk_acc_stage = 2 + self.mma_corr_stage = 1 + if cutlass.const_expr(self.use_clc_scheduler): + self.num_clc_stage = 1 + self.num_clc_response_bytes = 16 + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + learnable_sink: Optional[cute.Tensor] = None, + descale_tensors: Optional[DescaleTensors] = None, + blocksparse_tensors: Optional[cute.Tensor] = None, + aux_tensors: Optional[list] = None, + stream: cuda.CUstream = None, + ): + # Keep parity with FlashAttentionForwardSm100.__call__ interface. + # (TODO@wangsiyu) Implement these features. + assert mSeqUsedQ is None and mSeqUsedK is None, ( + "SM100 forward with head_dim=256 does not support seqused_q/seqused_k" + ) + assert learnable_sink is None, ( + "SM100 forward with head_dim=256 does not support learnable_sink" + ) + assert blocksparse_tensors is None, ( + "SM100 forward with head_dim=256 does not support block sparsity" + ) + assert aux_tensors is None, "SM100 forward with head_dim=256 does not support aux_tensors" + assert not self.is_local, ( + "SM100 forward with head_dim=256 does not support local attention yet" + ) + assert window_size_left is None and window_size_right is None, ( + "SM100 forward with head_dim=256 does not support runtime window_size overrides" + ) + assert descale_tensors is None, ( + "SM100 forward with head_dim=256 does not support descale_tensors" + ) + + q_tensor, k_tensor, v_tensor, o_tensor = mQ, mK, mV, mO + lse_tensor = mLSE + cum_seqlen_q = mCuSeqlensQ + cum_seqlen_k = mCuSeqlensK + + q_rank = len(mQ.shape) + k_rank = len(mK.shape) + if cutlass.const_expr(cum_seqlen_q is not None): + # Varlen path accepts either legacy 5D tensors or standard 3D tensors. + if cutlass.const_expr(q_rank == 5): + s_q = mQ.shape[1] + h_q = mQ.shape[2] * mQ.shape[3] + d = mQ.shape[4] + elif cutlass.const_expr(q_rank == 3): + s_q = mQ.shape[0] + h_q = mQ.shape[1] + d = mQ.shape[2] + else: + raise RuntimeError(f"hd256 forward varlen expects q rank 3 or 5, got rank {q_rank}") + else: + # Non-varlen path accepts either legacy 5D tensors or standard 4D tensors. + if cutlass.const_expr(q_rank == 5): + s_q = mQ.shape[1] + h_q = mQ.shape[2] * mQ.shape[3] + d = mQ.shape[4] + elif cutlass.const_expr(q_rank == 4): + s_q = mQ.shape[1] + h_q = mQ.shape[2] + d = mQ.shape[3] + else: + raise RuntimeError( + f"hd256 forward non-varlen expects q rank 4 or 5, got rank {q_rank}" + ) + + if cutlass.const_expr(cum_seqlen_k is not None): + if cutlass.const_expr(k_rank == 5): + s_k = mK.shape[1] + h_k = mK.shape[2] + elif cutlass.const_expr(k_rank == 3): + s_k = mK.shape[0] + h_k = mK.shape[1] + else: + raise RuntimeError(f"hd256 forward varlen expects k rank 3 or 5, got rank {k_rank}") + else: + if cutlass.const_expr(k_rank == 5): + s_k = mK.shape[1] + h_k = mK.shape[2] + elif cutlass.const_expr(k_rank == 4): + s_k = mK.shape[1] + h_k = mK.shape[2] + else: + raise RuntimeError( + f"hd256 forward non-varlen expects k rank 4 or 5, got rank {k_rank}" + ) + if cutlass.const_expr(cum_seqlen_q is not None): + b = mCuSeqlensQ.shape[0] - 1 + elif cutlass.const_expr(cum_seqlen_k is not None): + b = mCuSeqlensK.shape[0] - 1 + else: + b = mQ.shape[0] + + scale_softmax = softmax_scale + scale_softmax_log2 = softmax_scale * math.log2(math.exp(1.0)) + scale_output = 1.0 + s_lse = s_q + h_r = h_q // h_k + s_q64 = Int64(s_q) + s_k64 = Int64(s_k) + s_lse64 = Int64(s_lse) + d64 = cute.assume(Int64(d), divby=128) + h_r64 = Int64(h_r) + h_k64 = Int64(h_k) + b64 = Int64(b) + s_q_total = ( + q_tensor.shape[1] + if cum_seqlen_q is not None and q_rank == 5 + else (q_tensor.shape[0] if cum_seqlen_q is not None else s_q64) + ) + s_k_total = ( + k_tensor.shape[1] + if cum_seqlen_k is not None and k_rank == 5 + else (k_tensor.shape[0] if cum_seqlen_k is not None else s_k64) + ) + stride_b_qo = h_r64 * h_k64 * s_q64 * d64 if cum_seqlen_q is None else 0 + stride_b_kv = h_k64 * s_k64 * d64 if cum_seqlen_k is None else 0 + b_lse = b64 if cum_seqlen_q is None else 1 + stride_b_lse = h_r64 * h_k64 * s_lse64 if cum_seqlen_q is None else 0 + + # (s, d, ((h_r, h_k), b)) + q_layout = cute.make_layout( + (s_q_total, d, ((h_r, h_k), b)), + stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + ) + q = cute.make_tensor(q_tensor.iterator, q_layout) + if cutlass.const_expr(mPageTable is not None): + # Paged: K layout (num_pages, page_size, h_k, d); page_table maps kv_coord→physical page. + num_pages = k_tensor.shape[0] + page_size = k_tensor.shape[1] + page_size64 = Int64(page_size) + max_seqlen_k_paged = Int32(mPageTable.shape[1] * page_size) + k_paged_layout = cute.make_layout( + (page_size, d, h_k, num_pages), + stride=(d64 * h_k64, 1, d64, page_size64 * d64 * h_k64), + ) + k = cute.make_tensor(k_tensor.iterator, k_paged_layout) + v_paged_layout = cute.make_layout( + (d, page_size, h_k, num_pages), + stride=(1, d64 * h_k64, d64, page_size64 * d64 * h_k64), + ) + v = cute.make_tensor(v_tensor.iterator, v_paged_layout) + page_table_layout = cute.make_layout( + (b, mPageTable.shape[1]), + stride=(Int64(mPageTable.shape[1]), 1), + ) + page_table = cute.make_tensor(mPageTable.iterator, page_table_layout) + else: + # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast + k_layout = cute.make_layout( + (s_k_total, d, ((h_r, h_k), b)), + stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)), + ) + k = cute.make_tensor(k_tensor.iterator, k_layout) + # (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast + v_layout = cute.make_layout( + (d, s_k_total, ((h_r, h_k), b)), + stride=(1, d64 * h_k64, ((0, d64), stride_b_kv)), + ) + v = cute.make_tensor(v_tensor.iterator, v_layout) + page_table = None + max_seqlen_k_paged = None + # (s, d, ((h_r, h_k), b)) + o_layout = cute.make_layout( + (s_q_total, d, ((h_r, h_k), b)), + stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + ) + o = cute.make_tensor(o_tensor.iterator, o_layout) + if cutlass.const_expr(lse_tensor is not None): + # (s, ((h_r, h_k), b)) + lse_layout = cute.make_layout( + (s_lse64, ((h_r, h_k), b_lse)), + stride=(1, ((s_lse64, h_r64 * s_lse64), stride_b_lse)), + ) + lse = cute.make_tensor(lse_tensor.iterator, lse_layout) + else: + lse = None + + # setup static attributes before smem/grid/tma computation + self.q_dtype = q.element_type + self.k_dtype = k.element_type + self.v_dtype = v.element_type + self.o_dtype = o.element_type + self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.q_dtype.width + + if cutlass.const_expr(self.use_clc_scheduler): + self.tile_sched_params, grid = compute_grid_clc( + (s_q, o.shape[1], o.shape[2]) if cum_seqlen_q is not None else o.shape, + self.cta_tiler, + (*self.cluster_shape_mn, 1), + ) + else: + self.tile_sched_params, grid = compute_grid( + (s_q, o.shape[1], o.shape[2]) if cum_seqlen_q is not None else o.shape, + self.cta_tiler, + self.is_persistent, + ) + + self.q_major_mode = utils.LayoutEnum.from_tensor(q).mma_major_mode() + self.k_major_mode = utils.LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major_mode = utils.LayoutEnum.from_tensor(v).mma_major_mode() + self.o_layout = utils.LayoutEnum.from_tensor(o) + + if cutlass.const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of q is not supported") + if cutlass.const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of k is not supported") + if cutlass.const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN): + raise RuntimeError("The layout of v is not supported") + + # check type consistency + if cutlass.const_expr(self.q_dtype != self.k_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") + if cutlass.const_expr(self.q_dtype != self.v_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.TWO + # the intermediate tensor p is from tmem & k-major + p_source = tcgen05.OperandSource.TMEM + p_major_mode = tcgen05.OperandMajorMode.K + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.qk_acc_dtype, + cta_group, + self.qk_mma_tiler[:2], + ) + pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.pv_acc_dtype, + cta_group, + self.pv_mma_tiler[:2], + p_source, + ) + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + + self.epi_tile = self.pv_block_tiler[:2] + + q_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.qk_mma_tiler, + self.q_dtype, + self.q_stage, + ) + k_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + self.qk_mma_tiler, + self.k_dtype, + self.kv_stage, + ) + p_tmem_layout_staged = sm100_utils.make_smem_layout_a( + pv_tiled_mma, + self.pv_mma_tiler, + self.q_dtype, + self.qk_acc_stage, + ) + p_tmem_layout = cute.select(p_tmem_layout_staged, mode=[0, 1, 2]) + v_smem_layout_staged = sm100_utils.make_smem_layout_b( + pv_tiled_mma, + self.pv_mma_tiler, + self.v_dtype, + self.kv_stage, + ) + # TMA load for Q + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) + + q_smem_layout = cute.select(q_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q, tma_tensor_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q, + q_smem_layout, + self.qk_mma_tiler, + qk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # TMA load for K + k_smem_layout = cute.select(k_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_k, tma_tensor_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k, + k_smem_layout, + self.qk_mma_tiler, + qk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + # TMA load for V + v_smem_layout = cute.select(v_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_v, tma_tensor_v = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + v, + v_smem_layout, + self.pv_mma_tiler, + pv_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + q_copy_size = cute.size_in_bytes(self.q_dtype, q_smem_layout) + k_copy_size = cute.size_in_bytes(self.k_dtype, k_smem_layout) + self.tma_copy_q_bytes = q_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + self.tma_copy_kv_bytes = k_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + + @cute.struct + class SharedStorage: + # TMA G2S load barriers: LOAD warp (producer) -> MMA warp (consumer) + load_q_mbar_ptr: cute.struct.MemRange[ + Int64, self.q_stage * 2 + ] # load_q_{producer,consumer} + load_kv_mbar_ptr: cute.struct.MemRange[ + Int64, self.kv_stage * 2 + ] # load_kv_{producer,consumer} + mma_s_mbar_ptr: cute.struct.MemRange[Int64, self.qk_acc_stage * 2] + p_mma_mbar_ptr: cute.struct.MemRange[Int64, self.qk_acc_stage * 2] + # Softmax -> Correction signaling barriers (row_max/row_sum vec ready) + s_corr_mbar_ptr: cute.struct.MemRange[ + Int64, self.qk_acc_stage * 2 + ] # s_corr_{producer,consumer} + sum_mbar_ptr: cute.struct.MemRange[Int64, 2] + # MMA -> Correction ownership barriers for O_partial tokens (online rescale/finalize) + mma_corr_mbar_ptr: cute.struct.MemRange[ + Int64, self.mma_corr_stage * 2 + ] # mma_corr_{producer,consumer} + # A CTA-wide "TMEM lifetime" barrier used to safely deallocate TMEM after all users finish. + tmem_dealloc_mbar_ptr: Int64 + # Tmem holding buffer + tmem_holding_buf: Int32 + # CLC pipeline barriers and response buffer + clc_mbar_ptr: cute.struct.MemRange[Int64, 2] + clc_response: cute.struct.MemRange[Int32, 4] + + self.shared_storage = SharedStorage + + grid = cute.round_up(grid, self.cluster_shape_mnk) + # Launch the kernel synchronously + self.kernel( + qk_tiled_mma, + pv_tiled_mma, + tma_atom_q, + tma_tensor_q, + tma_atom_k, + tma_tensor_k, + tma_atom_v, + tma_tensor_v, + o, + cum_seqlen_q, + cum_seqlen_k, + lse, + scale_softmax_log2, + scale_softmax, + scale_output, + page_table, + max_seqlen_k_paged, + window_size_left, + window_size_right, + self.cluster_layout_vmnk, + q_smem_layout_staged, + k_smem_layout_staged, + p_tmem_layout, + v_smem_layout_staged, + self.tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + min_blocks_per_mp=1, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + qk_tiled_mma: cute.TiledMma, + pv_tiled_mma: cute.TiledMma, + tma_atom_q: cute.CopyAtom, + mQ_qdl: cute.Tensor, + tma_atom_k: cute.CopyAtom, + mK_kdl: cute.Tensor, + tma_atom_v: cute.CopyAtom, + mV_dkl: cute.Tensor, + mO_qdl: cute.Tensor, + cum_seqlen_q: Optional[cute.Tensor], + cum_seqlen_k: Optional[cute.Tensor], + mLSE: Optional[cute.Tensor], + scale_softmax_log2: Float32, + scale_softmax: Float32, + scale_output: Float32, + mPageTable: Optional[cute.Tensor], + max_seqlen_k: Optional[Int32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + cluster_layout_vmnk: cute.Layout, + q_smem_layout_staged: cute.ComposedLayout, + k_smem_layout_staged: cute.ComposedLayout, + p_tmem_layout_staged: cute.ComposedLayout, + v_smem_layout_staged: cute.ComposedLayout, + tile_sched_params: FmhaStaticTileSchedulerParams | FmhaClcDynamicTileSchedulerParams, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # + # Prefetch tma desc + # + if warp_idx == self.load_warp_id: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_q) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_k) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_v) + + bidx, _, _ = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(qk_tiled_mma.thr_id.shape) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + + # Alloc + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + load_q_producer, load_q_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.q_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_q_bytes, + barrier_storage=storage.load_q_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_kv_producer, load_kv_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.kv_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_kv_bytes, + barrier_storage=storage.load_kv_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + mma_s_producer, mma_s_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.qk_acc_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + len(self.softmax_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + barrier_storage=storage.mma_s_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + p_mma_producer, p_mma_consumer = pipeline.PipelineAsyncUmma.create( + num_stages=self.qk_acc_stage, + producer_group=make_thread_cooperative_group( + len(self.softmax_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.p_mma_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + s_corr_producer, s_corr_consumer = pipeline.PipelineAsync.create( + num_stages=self.qk_acc_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.s_corr_mbar_ptr.data_ptr(), + defer_sync=True, + ).make_participants() + sum_producer, sum_consumer = pipeline.PipelineAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.sum_mbar_ptr.data_ptr(), + defer_sync=True, + ).make_participants() + mma_corr_producer, mma_corr_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_corr_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + len(self.correction_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + barrier_storage=storage.mma_corr_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.correction_warp_ids[0], + is_two_cta=True, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + tmem.allocate(self.tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + # Initialize CLC state if using dynamic scheduler + if cutlass.const_expr(self.use_clc_scheduler): + clc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + cluster_size = cute.size(self.cluster_shape_mnk) + num_clc_consumer_threads = self.threads_per_warp * ( + 1 # sched_warp (CTA 0 only) + + cluster_size + * ( + len(self.softmax_warp_ids) + + len(self.correction_warp_ids) + + 1 # mma_warp + + 1 # load_warp + ) + ) + clc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_clc_consumer_threads + ) + clc_response_ptr = storage.clc_response.data_ptr() + clc = ClcState.create( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_sched_params.clc_hw_params(), + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + ), + pipeline=pipeline.PipelineClcFetchAsync.create( + barrier_storage=storage.clc_mbar_ptr.data_ptr(), + num_stages=self.num_clc_stage, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=self.num_clc_response_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ), + consumer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_clc_stage + ), + producer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_clc_stage + ), + ) + else: + clc = None + clc_response_ptr = None + + # Cluster arrive after barrier init + pipeline.pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True) + + sQ = smem.allocate_tensor( + element_type=self.q_dtype, + layout=q_smem_layout_staged.outer, + swizzle=q_smem_layout_staged.inner, + byte_alignment=128, + ) + sK = smem.allocate_tensor( + element_type=self.k_dtype, + layout=k_smem_layout_staged.outer, + swizzle=k_smem_layout_staged.inner, + byte_alignment=128, + ) + # K and V now use separate memory since we removed the transform stage + sV = smem.allocate_tensor( + element_type=self.v_dtype, + layout=v_smem_layout_staged.outer, + swizzle=v_smem_layout_staged.inner, + byte_alignment=128, + ) + + sSum = smem.allocate_tensor( + element_type=self.qk_acc_dtype, + layout=cute.make_layout(len(self.softmax_warp_ids) * self.threads_per_warp), + byte_alignment=128, + ) + qk_thr_mma = qk_tiled_mma.get_slice(mma_tile_coord_v) # default 1sm + pv_thr_mma = pv_tiled_mma.get_slice(mma_tile_coord_v) # default 1sm + tSrQ = qk_thr_mma.make_fragment_A(sQ) + tSrK = qk_thr_mma.make_fragment_B(sK) + tOrV = pv_thr_mma.make_fragment_B(sV) + qk_acc_shape = qk_thr_mma.partition_shape_C((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tStS = qk_thr_mma.make_fragment_C(cute.append(qk_acc_shape, self.qk_acc_stage)) + pv_acc_shape = pv_thr_mma.partition_shape_C((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOtO = pv_thr_mma.make_fragment_C(pv_acc_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + self.iterations_pv, + stride=self.pv_mma_tiler[1] // self.tmem_warp_shape_mn[1], + ), + ) + tStS = cute.make_tensor(tStS.iterator + self.tmem_s_offset, tStS.layout) + tOtO_staged = cute.make_tensor(tOtO.iterator + self.tmem_o_offset, tOtO_layout) + + # /////////////////////////////////////////////////////////////////////////////// + # EMPTY + # /////////////////////////////////////////////////////////////////////////////// + for _i in cutlass.range_constexpr(len(self.empty_warp_id)): + if warp_idx == self.empty_warp_id[_i]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + if cutlass.const_expr(self.use_clc_scheduler): + tile_sched = FmhaClcDynamicTileScheduler.create( + tile_sched_params, + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + clc, + ) + else: + blk_idx = cute.arch.block_idx() + tile_sched = FmhaStaticTileScheduler( + tile_sched_params, blk_idx[0], blk_idx, cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + # Cluster wait + pipeline.pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk) + + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx # (q_tile_idx, 0, (head_idx, batch_idx)) + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + continue_cond = False + batch_coord = curr_block_coord[2][1] + seqlen_q = mQ_qdl.shape[0] + seqlen_k = ( + mK_kdl.shape[0] if cutlass.const_expr(mPageTable is None) else max_seqlen_k + ) + cuseqlen_q = Int32(0) + cuseqlen_k = Int32(0) + block_offset = ( + Int32(0), + Int32(0), + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + block_offset = ( + cuseqlen_q, + cuseqlen_k, + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if not continue_cond: + mQ_qdl_ = cute.domain_offset(cute.select(block_offset, mode=[0, 2, 3]), mQ_qdl) + # Local tile partition global tensors + q_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # (bM, bK, loopM, loopK, loopL) + gQ_qdl = cute.flat_divide(mQ_qdl_, cute.select(self.qk_mma_tiler, mode=[0, 2])) + tSgQ_qdl = qk_thr_mma.partition_A(gQ_qdl) + tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_q, + block_in_cluster_coord_vmnk[2], + q_cta_layout, + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ_qdl, 0, 3), + ) + kv_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + if cutlass.const_expr(mPageTable is None): + # Dense path: domain_offset K/V by batch block, select batch via mma_block_coord[2]. + mK_kdl_ = cute.domain_offset( + cute.select(block_offset, mode=[1, 2, 3]), mK_kdl + ) + mV_dkl_ = cute.domain_offset( + cute.select(block_offset, mode=[2, 1, 3]), mV_dkl + ) + gK_kdl = cute.flat_divide( + mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2]) + ) + tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_k, + block_in_cluster_coord_vmnk[1], + kv_cta_layout, + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_kdl, 0, 3), + ) + gV_dkl = cute.flat_divide( + mV_dkl_, cute.select(self.pv_mma_tiler, mode=[1, 2]) + ) + tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_v, + block_in_cluster_coord_vmnk[1], + kv_cta_layout, + cute.group_modes(sV, 0, 3), + cute.group_modes(tSgV_dkl, 0, 3), + ) + # ((atom_v, rest_v), RestN, RestK) + tKgK = tKgK_kdl[None, None, None, mma_block_coord[2]] + tVgV = tVgV_dkl[None, None, None, mma_block_coord[2]] + else: + # Paged path: slice K/V by KV head, keep num_pages dim for page_idx-based TMA. + head_kv_coord = curr_block_coord[2][0] // self.qhead_per_kvhead + mK_kdl_ = mK_kdl[None, None, head_kv_coord, None] + mV_dkl_ = mV_dkl[None, None, head_kv_coord, None] + gK_kdl = cute.flat_divide( + mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2]) + ) + tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_k, + block_in_cluster_coord_vmnk[1], + kv_cta_layout, + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_kdl, 0, 3), + ) + gV_dkl = cute.flat_divide( + mV_dkl_, cute.select(self.pv_mma_tiler, mode=[1, 2]) + ) + tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_v, + block_in_cluster_coord_vmnk[1], + kv_cta_layout, + cute.group_modes(sV, 0, 3), + cute.group_modes(tSgV_dkl, 0, 3), + ) + tKgK = tKgK_kdl + tVgV = tVgV_dkl + # ((atom_v, rest_v), RestK) + tQgQ = tQgQ_qdl[None, mma_block_coord[0], None, mma_block_coord[2]] + + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + seqlen_kv_loop_end = seqlen_kv_loop_start + seqlen_kv_loop_steps + # Q + for iter in cutlass.range(self.iterations_qk, unroll=1): + q_handle = load_q_producer.acquire_and_advance() + cute.copy( + tma_atom_q, + tQgQ[None, iter], + tQsQ[None, q_handle.index], + tma_bar_ptr=q_handle.barrier, + ) + + # K0 + kv_coord = seqlen_kv_loop_start + k_page_idx = ( + mPageTable[batch_coord, kv_coord] + if cutlass.const_expr(mPageTable is not None) + else None + ) + for iter in cutlass.range(self.iterations_qk, unroll=1): + k_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[None, kv_coord, iter] + if cutlass.const_expr(mPageTable is None) + else tKgK[None, 0, iter, k_page_idx], + tKsK[None, k_handle.index], + tma_bar_ptr=k_handle.barrier, + ) + kv_coord += 1 + # v_page_idx_prev carries K[i-1]'s page index for use as V[i-1]'s page + # (K and V for the same KV block share the same physical page). + # Also serves as the Vend page index when seqlen_kv_loop_steps == 1. + v_page_idx_prev = ( + k_page_idx if cutlass.const_expr(mPageTable is not None) else None + ) + # Prefetch K1 page after K0 TMA dispatch to hide L2 latency. + if cutlass.const_expr(mPageTable is not None): + if seqlen_kv_loop_steps > 1: + k_page_idx = mPageTable[batch_coord, kv_coord] + + for i in cutlass.range(1, seqlen_kv_loop_steps, 1, unroll=1): + # Ki: k_page_idx was prefetched at end of previous iteration + # (or in the prologue for i==1); L2 latency already hidden. + for iter in cutlass.range(self.iterations_qk, unroll=1): + k_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[None, kv_coord, iter] + if cutlass.const_expr(mPageTable is None) + else tKgK[None, 0, iter, k_page_idx], + tKsK[None, k_handle.index], + tma_bar_ptr=k_handle.barrier, + ) + # Vi-1: reuse v_page_idx_prev (= K[i-1]'s page), no extra GMEM read. + for iter in cutlass.range(self.iterations_pv, unroll=1): + v_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[None, iter, kv_coord - 1] + if cutlass.const_expr(mPageTable is None) + else tVgV[None, iter, 0, v_page_idx_prev], + tVsV[None, v_handle.index], + tma_bar_ptr=v_handle.barrier, + ) + v_page_idx_prev = ( + k_page_idx if cutlass.const_expr(mPageTable is not None) else None + ) + kv_coord += 1 + # Prefetch next K page while V TMA is in flight. + if cutlass.const_expr(mPageTable is not None): + if kv_coord < seqlen_kv_loop_end: + k_page_idx = mPageTable[batch_coord, kv_coord] + # Vend: reuse v_page_idx_prev (= K[end-1]'s page), no extra GMEM read. + for iter in cutlass.range(self.iterations_pv, unroll=1): + v_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[None, iter, seqlen_kv_loop_end - 1] + if cutlass.const_expr(mPageTable is None) + else tVgV[None, iter, 0, v_page_idx_prev], + tVsV[None, v_handle.index], + tma_bar_ptr=v_handle.barrier, + ) + + work_tile = tile_sched.advance_to_next_work() + # End of persistent scheduler loop + load_kv_producer.tail() + load_q_producer.tail() + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + continue_cond = False + seqlen_q = mQ_qdl.shape[0] + seqlen_k = ( + mK_kdl.shape[0] if cutlass.const_expr(mPageTable is None) else max_seqlen_k + ) + batch_coord = curr_block_coord[2][1] + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + seqlen_kv_loop_end = seqlen_kv_loop_start + seqlen_kv_loop_steps + + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + load_q_releaser = load_q_consumer.clone() + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + if seqlen_kv_loop_steps > 1: + # QK0 + if is_leader_cta: + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + load_q_consumer.wait_and_advance() + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_kv_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + s_handle.commit() + for i in cutlass.range(1, seqlen_kv_loop_steps - 1, 1, unroll=1): + # QKi + if is_leader_cta: + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_kv_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + s_handle.commit() + + # PVi-1 + p_handle = p_mma_consumer.wait_and_advance() + o_handle = mma_corr_producer.acquire_and_advance() + pv_whether_acc = pv_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_pv, unroll=1): + v_handle = load_kv_consumer.wait_and_advance() + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, pv_whether_acc) + tOtO_slice = tOtO_staged[None, None, None, iter] + tStS_slice = tStS[None, None, None, p_handle.index] + tP = cute.make_tensor( + tStS_slice.iterator, p_tmem_layout_staged.outer + ) + tOrP = pv_thr_mma.make_fragment_A(tP) + tOrP_slice = cute.make_tensor( + cute.recast_ptr(tStS_slice.iterator, dtype=self.q_dtype), + tOrP.layout, + ) + tOrV_slice = tOrV[None, None, None, v_handle.index] + num_kphases = cute.size(tOrV_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + pv_tiled_mma, + tOtO_slice, + tOrP_slice[kphase_coord], + tOrV_slice[kphase_coord], + tOtO_slice, + ) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + o_handle.commit() + p_handle.release() + if is_leader_cta: + # QKend + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_kv_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + load_q_releaser.release() + load_q_releaser.advance() + s_handle.commit() + + # PVend-1 + p_handle = p_mma_consumer.wait_and_advance() + o_handle = mma_corr_producer.acquire_and_advance() + pv_whether_acc = pv_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_pv, unroll=1): + v_handle = load_kv_consumer.wait_and_advance() + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, pv_whether_acc) + tOtO_slice = tOtO_staged[None, None, None, iter] + tStS_slice = tStS[None, None, None, p_handle.index] + tP = cute.make_tensor( + tStS_slice.iterator, p_tmem_layout_staged.outer + ) + tOrP = pv_thr_mma.make_fragment_A(tP) + tOrP_slice = cute.make_tensor( + cute.recast_ptr(tStS_slice.iterator, dtype=self.q_dtype), + tOrP.layout, + ) + tOrV_slice = tOrV[None, None, None, v_handle.index] + num_kphases = cute.size(tOrV_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + pv_tiled_mma, + tOtO_slice, + tOrP_slice[kphase_coord], + tOrV_slice[kphase_coord], + tOtO_slice, + ) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + o_handle.commit() + p_handle.release() + else: + if is_leader_cta: + # QK0 + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + load_q_consumer.wait_and_advance() + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_kv_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + load_q_releaser.release() + load_q_releaser.advance() + s_handle.commit() + + if is_leader_cta: + # PVend + p_handle = p_mma_consumer.wait_and_advance() + o_handle = mma_corr_producer.acquire_and_advance() + pv_whether_acc = pv_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_pv, unroll=1): + v_handle = load_kv_consumer.wait_and_advance() + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, pv_whether_acc) + tOtO_slice = tOtO_staged[None, None, None, iter] + tStS_slice = tStS[None, None, None, p_handle.index] + tP = cute.make_tensor(tStS_slice.iterator, p_tmem_layout_staged.outer) + tOrP = pv_thr_mma.make_fragment_A(tP) + tOrP_slice = cute.make_tensor( + cute.recast_ptr(tStS_slice.iterator, dtype=self.q_dtype), + tOrP.layout, + ) + tOrV_slice = tOrV[None, None, None, v_handle.index] + num_kphases = cute.size(tOrV_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + pv_tiled_mma, + tOtO_slice, + tOrP_slice[kphase_coord], + tOrV_slice[kphase_coord], + tOtO_slice, + ) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + o_handle.commit() + p_handle.release() + work_tile = tile_sched.advance_to_next_work() + # End of persistent scheduler loop + mma_s_producer.tail() + mma_corr_producer.tail() + + if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax_warp_ids[0]: + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + batch_coord = curr_block_coord[2][1] + continue_cond = False + seqlen_q = mQ_qdl.shape[0] + seqlen_k = ( + mK_kdl.shape[0] if cutlass.const_expr(mPageTable is None) else max_seqlen_k + ) + cuseqlen_q = Int32(0) + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + row_max = -Float32.inf + row_max_prev = -Float32.inf + row_sum = 0.0 + + start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + end_count = start_count + trip_count + if cutlass.const_expr(self.use_semantic_trip_range): + n_block_min_causal_local_mask, n_block_min_before_local_mask = ( + FusedMask.get_trip_mask_bounds_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + cS_base = cute.make_identity_tensor( + (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) + ) + cS = cute.domain_offset((mma_block_coord[0] * self.qk_mma_tiler[0], 0), cS_base) + tScS = qk_thr_mma.partition_C(cS) + + for step in cutlass.range(start_count, end_count, 1, unroll=1): + cS_iter = cute.domain_offset((0, step * self.qk_mma_tiler[1]), cS) + tScS_iter = qk_thr_mma.partition_C(cS_iter) + if cutlass.const_expr(self.use_semantic_trip_range): + need_apply_mask = ( + step >= n_block_min_causal_local_mask + or step < n_block_min_before_local_mask + ) + else: + # Residual path only needs seqlen masking on the last K tile. + need_apply_mask = step == end_count - 1 + # Si -> Pi + ( + row_max, + row_sum, + mma_s_consumer, + p_mma_producer, + s_corr_producer, + ) = self.softmax_step( + (need_apply_mask, window_size_left, window_size_right), + ( + row_max_prev, + row_sum, + seqlen_q, + seqlen_k, + scale_softmax_log2, + ), + (tStS, tScS_iter), + (mma_s_consumer, p_mma_producer, s_corr_producer), + ) + row_max_prev = row_max + sum_producer = self.store_sum_max( + row_max, + mLSE, + row_sum, + sSum, + sum_producer, + curr_block_coord, + seqlen_q, + cum_seqlen_q, + cuseqlen_q, + scale_softmax, + ) + work_tile = tile_sched.advance_to_next_work() + p_mma_producer.tail() + s_corr_producer.tail() + + # /////////////////////////////////////////////////////////////////////////////// + # Correction + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_correction) + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + batch_coord = curr_block_coord[2][1] + seqlen_q = mQ_qdl.shape[0] + seqlen_k = ( + mK_kdl.shape[0] if cutlass.const_expr(mPageTable is None) else max_seqlen_k + ) + continue_cond = False + cuseqlen_q = Int32(0) + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + mO_qdl_eff = mO_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + block_offset_o = ( + cuseqlen_q, + Int32(0), + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + mO_qdl_eff = cute.domain_offset( + cute.select(block_offset_o, mode=[0, 2, 3]), mO_qdl + ) + + # (bM, bN, loopM, loopN, loopL) + gO_qdl = cute.flat_divide( + mO_qdl_eff, cute.select(self.pv_block_tiler, mode=[0, 1]) + ) + cO_qdl = cute.flat_divide( + cute.make_identity_tensor(mO_qdl_eff.shape), + cute.select(self.pv_block_tiler, mode=[0, 1]), + ) + + _, seqlen_kv_loop_steps = FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + gO_staged = gO_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + cO_staged = cO_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr_mma.partition_C(cS) + + # Empty step as the first step is no need for correction + stats_handle = s_corr_consumer.wait_and_advance() + stats_handle.release() + for step in cutlass.range(1, seqlen_kv_loop_steps, 1, unroll=1): + # Oi-1 -> Oi + mma_corr_consumer, s_corr_consumer = self.correction_rescale( + scale_softmax_log2, + (s_corr_consumer, tStS, tScS), + (mma_corr_consumer, tOtO_staged, cO_staged), + self.epi_tile, + ) + # O_partial -> O_final + mma_corr_consumer, sum_consumer = self.correction_epilog( + (seqlen_q, scale_output), + (sum_consumer, sSum), + (mma_corr_consumer, gO_staged, cO_staged, tOtO_staged), + self.epi_tile, + ) + work_tile = tile_sched.advance_to_next_work() + # NOTE: tmem.free() moved to kernel end to enable cluster-wide sync + + # /////////////////////////////////////////////////////////////////////////////// + # Scheduler Warp (only for CLC dynamic scheduler) + # /////////////////////////////////////////////////////////////////////////////// + if cutlass.const_expr(self.use_clc_scheduler): + is_first_cta_in_cluster = cta_rank_in_cluster == 0 + + if warp_idx == self.sched_warp_id and is_first_cta_in_cluster: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + while work_tile.is_valid_tile: + tile_sched.prefetch_next_work() + work_tile = tile_sched.advance_to_next_work() + tile_sched.producer_tail() + + # /////////////////////////////////////////////////////////////////////////////// + # Empty warps reg dealloc + # /////////////////////////////////////////////////////////////////////////////// + if cutlass.const_expr(self.use_clc_scheduler): + if warp_idx > self.load_warp_id: + if not (warp_idx == self.sched_warp_id and is_first_cta_in_cluster): + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + else: + if warp_idx > self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + # /////////////////////////////////////////////////////////////////////////////// + # Cooperative TMEM Deallocation (2CTA) + # /////////////////////////////////////////////////////////////////////////////// + # All warps (including scheduler) have finished by this point. + # Cluster-wide sync ensures both CTAs reach here before dealloc. + cute.arch.cluster_arrive() + cute.arch.cluster_wait() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + return + + @cute.jit + def softmax_step( + self, + mask_args: Tuple, + value_args: Tuple, + tensor_args: Tuple, + pipeline_args: Tuple, + ) -> Tuple[Float32, Float32, pipeline.PipelineConsumer, pipeline.PipelineProducer]: + need_apply_mask, window_size_left, window_size_right = mask_args + row_max, row_sum, seqlen_q, seqlen_k, scale_softmax_log2 = value_args + tStS, tScS = tensor_args + mma_s_consumer, p_mma_producer, s_corr_producer = pipeline_args + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.softmax_warp_ids)) + s_handle = mma_s_consumer.wait_and_advance() + tStS_slice = tStS[(None, None), 0, 0, s_handle.index] + tScS_slice = tScS[(None, None), 0, 0] + tmem_load_atom = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition(32)), self.qk_acc_dtype + ) + tmem_tiled_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS_slice) + thr_load = tmem_tiled_load.get_slice(thread_idx) + tTMEM_LOADtS = thr_load.partition_S(tStS_slice) + tTMEM_LOADcS = thr_load.partition_D(tScS_slice) + tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype) + cute.copy(tmem_tiled_load, tTMEM_LOADtS, tTMEM_LOADrS) + + cute.arch.fence_view_async_tmem_load() + s_handle.release() + if need_apply_mask: + FusedMask.apply_mask_via_causal_local( + tTMEM_LOADrS, + tTMEM_LOADcS, + seqlen_q, + seqlen_k, + self.use_semantic_trip_range, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + old_row_max = row_max + row_max = tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0) + row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + row_max_safe = 0.0 + + stats_handle = s_corr_producer.acquire_and_advance() + stats_layout = cute.composition( + tStS_slice.layout, cute.make_layout((tStS_slice.shape[0], 2)) + ) + stats_c_layout = cute.composition( + tScS_slice.layout, cute.make_layout((tScS_slice.shape[0], 2)) + ) + tOtStats = cute.make_tensor(tStS_slice.iterator + self.tilePlikeFP32, stats_layout) + tOcStats = cute.make_tensor(tScS_slice.iterator, stats_c_layout) + tmem_store_stats_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + tiled_tmem_store_stats = tcgen05.make_tmem_copy(tmem_store_stats_atom, tOtStats) + thr_tmem_store_stats = tiled_tmem_store_stats.get_slice(thread_idx) + tTMEM_STOREcStats = thr_tmem_store_stats.partition_S(tOcStats) + tTMEM_STORErStats = cute.make_rmem_tensor(tTMEM_STOREcStats.shape, self.qk_acc_dtype) + tTMEM_STORErStats[0] = old_row_max + tTMEM_STORErStats[1] = row_max_safe + tTMEM_STOREtStats = thr_tmem_store_stats.partition_D(tOtStats) + cute.copy(tiled_tmem_store_stats, tTMEM_STORErStats, tTMEM_STOREtStats) + cute.arch.fence_view_async_tmem_store() + stats_handle.commit() + + scale = scale_softmax_log2 + minus_row_max_scale = (0.0 - row_max_safe) * scale + # Acquire P write slot early — overlaps any pipeline stall with exp2 compute + p_handle = p_mma_producer.acquire_and_advance() + # Fragment-based FMA + exp2 + bf16 conversion + # Trades SFU for FMA via polynomial emulation on a fraction of elements + ex2_frg_tile = 32 + ex2_frg_cnt = cute.size(tTMEM_LOADrS) // ex2_frg_tile + tTMEM_LOADrS_ex2 = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(ex2_frg_tile)) + tTMEM_STORErP = cute.make_rmem_tensor(tTMEM_LOADrS.shape, self.q_dtype) + tTMEM_STORErP_ex2 = cute.logical_divide(tTMEM_STORErP, cute.make_layout(ex2_frg_tile)) + for j in cutlass.range_constexpr(ex2_frg_cnt): + for k in cutlass.range_constexpr(0, ex2_frg_tile, 2): + tTMEM_LOADrS_ex2[k, j], tTMEM_LOADrS_ex2[k + 1, j] = cute.arch.fma_packed_f32x2( + (tTMEM_LOADrS_ex2[k, j], tTMEM_LOADrS_ex2[k + 1, j]), + (scale, scale), + (minus_row_max_scale, minus_row_max_scale), + ) + if cutlass.const_expr(self.ex2_emu_freq == 0): + tTMEM_LOADrS_ex2[k, j] = cute.math.exp2(tTMEM_LOADrS_ex2[k, j], fastmath=True) + tTMEM_LOADrS_ex2[k + 1, j] = cute.math.exp2( + tTMEM_LOADrS_ex2[k + 1, j], fastmath=True + ) + else: + if cutlass.const_expr( + k % self.ex2_emu_freq < self.ex2_emu_freq - self.ex2_emu_res + or j >= ex2_frg_cnt - 1 + or j < self.ex2_emu_start_frg + ): + tTMEM_LOADrS_ex2[k, j] = cute.math.exp2( + tTMEM_LOADrS_ex2[k, j], fastmath=True + ) + tTMEM_LOADrS_ex2[k + 1, j] = cute.math.exp2( + tTMEM_LOADrS_ex2[k + 1, j], fastmath=True + ) + else: + tTMEM_LOADrS_ex2[k, j], tTMEM_LOADrS_ex2[k + 1, j] = ex2_emulation_2( + tTMEM_LOADrS_ex2[k, j], tTMEM_LOADrS_ex2[k + 1, j] + ) + tTMEM_STORErP_ex2[None, j].store(tTMEM_LOADrS_ex2[None, j].load().to(self.q_dtype)) + tmem_store_atom = cute.make_copy_atom( + tcgen05.St32x32bOp(tcgen05.Repetition(32)), self.qk_acc_dtype + ) + tilePlikeFP32 = tStS_slice.shape[1] // Float32.width * self.q_dtype.width + tStS_P_layout = cute.composition( + tStS_slice.layout, cute.make_layout((tStS_slice.shape[0], tilePlikeFP32)) + ) + tStS_P = cute.make_tensor(tStS_slice.iterator, tStS_P_layout) + tScS_P_layout = cute.composition( + tScS_slice.layout, cute.make_layout((tScS_slice.shape[0], tilePlikeFP32)) + ) + tScS_P = cute.make_tensor(tScS_slice.iterator, tScS_P_layout) + tmem_tiled_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P) + thr_store = tmem_tiled_store.get_slice(thread_idx) + tTMEM_STOREtP = thr_store.partition_D(tStS_P) + tTMEM_STOREcS = thr_store.partition_S(tScS_P) + tTMEM_STORErP_ = cute.make_tensor( + cute.recast_ptr(tTMEM_STORErP.iterator, dtype=self.qk_acc_dtype), + tTMEM_STOREcS.shape, + ) + cute.copy(tmem_tiled_store, tTMEM_STORErP_, tTMEM_STOREtP) + cute.arch.fence_view_async_tmem_store() + + p_handle.commit() + acc_scale_ = scale * (old_row_max - row_max_safe) + acc_scale = cute.math.exp2(acc_scale_, fastmath=True) * 0.5 + # TODO: calc row sum with TensorSSA + row_sum *= acc_scale + local_row_sum_0 = (row_sum, row_sum) + local_row_sum_1 = (0.0, 0.0) + local_row_sum_2 = (0.0, 0.0) + local_row_sum_3 = (0.0, 0.0) + reduction_unroll = 4 + frg_tile = cute.size(tTMEM_LOADrS) // reduction_unroll + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + for j in cutlass.range_constexpr(0, cute.size(tTMEM_LOADrS_frg, mode=[0]), 2): + local_row_sum_0 = cute.arch.add_packed_f32x2( + local_row_sum_0, (tTMEM_LOADrS_frg[j, 0], tTMEM_LOADrS_frg[j + 1, 0]) + ) + local_row_sum_1 = cute.arch.add_packed_f32x2( + local_row_sum_1, (tTMEM_LOADrS_frg[j, 1], tTMEM_LOADrS_frg[j + 1, 1]) + ) + local_row_sum_2 = cute.arch.add_packed_f32x2( + local_row_sum_2, (tTMEM_LOADrS_frg[j, 2], tTMEM_LOADrS_frg[j + 1, 2]) + ) + local_row_sum_3 = cute.arch.add_packed_f32x2( + local_row_sum_3, (tTMEM_LOADrS_frg[j, 3], tTMEM_LOADrS_frg[j + 1, 3]) + ) + local_row_sum_0 = cute.arch.add_packed_f32x2(local_row_sum_0, local_row_sum_1) + local_row_sum_2 = cute.arch.add_packed_f32x2(local_row_sum_2, local_row_sum_3) + local_row_sum_0 = cute.arch.add_packed_f32x2(local_row_sum_0, local_row_sum_2) + row_sum = local_row_sum_0[0] + local_row_sum_0[1] + return row_max, row_sum, mma_s_consumer, p_mma_producer, s_corr_producer + + @cute.jit + def correction_rescale( + self, + scale_softmax_log2: Float32, + stats_args: tuple, + o_args: tuple, + epi_tile: cute.Tile, + ) -> pipeline.PipelineConsumer: + (s_corr_consumer, tStS, tScS) = stats_args + (mma_o_consumer, tOtO_staged, cO_staged) = o_args + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.softmax_warp_ids)) + + stats_handle = s_corr_consumer.wait_and_advance() + tStS_slice = tStS[(None, None), 0, 0, stats_handle.index] + tScS_slice = tScS[(None, None), 0, 0] + stats_layout = cute.composition( + tStS_slice.layout, cute.make_layout((tStS_slice.shape[0], 2)) + ) + stats_c_layout = cute.composition( + tScS_slice.layout, cute.make_layout((tScS_slice.shape[0], 2)) + ) + tOtStats = cute.make_tensor(tStS_slice.iterator + self.tilePlikeFP32, stats_layout) + tOcStats = cute.make_tensor(tScS_slice.iterator, stats_c_layout) + tmem_load_stats_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + tiled_tmem_load_stats = tcgen05.make_tmem_copy(tmem_load_stats_atom, tOtStats) + thr_tmem_load_stats = tiled_tmem_load_stats.get_slice(thread_idx) + tTMEM_LOADtStats = thr_tmem_load_stats.partition_S(tOtStats) + tTMEM_LOADcStats = thr_tmem_load_stats.partition_D(tOcStats) + tTMEM_LOADrStats = cute.make_rmem_tensor(tTMEM_LOADcStats.shape, self.qk_acc_dtype) + cute.copy(tiled_tmem_load_stats, tTMEM_LOADtStats, tTMEM_LOADrStats) + + scale = scale_softmax_log2 * (tTMEM_LOADrStats[0] - tTMEM_LOADrStats[1]) + scale = cute.math.exp2(scale, fastmath=True) + stats_handle.release() + o_handle = mma_o_consumer.wait_and_advance() + for iter in cutlass.range(self.iterations_pv, unroll_full=True): + tOtO = tOtO_staged[(None, None), 0, 0, iter] + cO = cO_staged[None, None, iter] + tOtO_epi = cute.zipped_divide(tOtO, epi_tile) + cO_epi = cute.zipped_divide(cO, epi_tile) + tmem_load_atom = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition(16)), + self.pv_acc_dtype, + ) + tmem_tiled_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_epi) + thr_load = tmem_tiled_load.get_slice(thread_idx) + tmem_store_atom = cute.make_copy_atom( + tcgen05.St32x32bOp(tcgen05.Repetition(16)), + self.pv_acc_dtype, + ) + tmem_store_atom = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_epi) + thr_store = tmem_store_atom.get_slice(thread_idx) + tTMEM_LOADtO = thr_load.partition_S(tOtO_epi) + tTMEM_LOADcO = thr_load.partition_D(cO_epi) + tTMEM_STOREtO = thr_store.partition_D(tOtO_epi) + tTMrO = cute.make_rmem_tensor_like( + cute.append( + cute.make_layout(tTMEM_LOADcO[None, 0, 0].shape), + cute.make_layout(2, stride=cute.size(tTMEM_LOADcO[None, 0, 0].shape)), + ), + self.pv_acc_dtype, + ) + tTMEM_LOADtO_0 = tTMEM_LOADtO[None, 0, 0] + cute.copy(tmem_tiled_load, tTMEM_LOADtO_0, tTMrO[None, 0]) + iter_num = cute.size(tTMEM_LOADtO, mode=[1]) + for i in cutlass.range(1, iter_num, unroll_full=True): + tTMEM_LOADtO_i = tTMEM_LOADtO[None, i, 0] + cute.copy(tmem_tiled_load, tTMEM_LOADtO_i, tTMrO[None, i % 2]) + for j in cutlass.range(0, cute.size(tTMrO, mode=[0]), 2, unroll_full=True): + tTMrO[j, (i - 1) % 2], tTMrO[j + 1, (i - 1) % 2] = cute.arch.mul_packed_f32x2( + (tTMrO[j, (i - 1) % 2], tTMrO[j + 1, (i - 1) % 2]), + (scale, scale), + ) + tTMEM_STOREtO_prev_i = tTMEM_STOREtO[None, i - 1, 0] + cute.copy(tmem_store_atom, tTMrO[None, (i - 1) % 2], tTMEM_STOREtO_prev_i) + + for j in cutlass.range(0, cute.size(tTMrO, mode=[0]), 2, unroll_full=True): + tTMrO[j, (iter_num - 1) % 2], tTMrO[j + 1, (iter_num - 1) % 2] = ( + cute.arch.mul_packed_f32x2( + ( + tTMrO[j, (iter_num - 1) % 2], + tTMrO[j + 1, (iter_num - 1) % 2], + ), + (scale, scale), + ) + ) + cute.copy( + tmem_store_atom, + tTMrO[None, (iter_num - 1) % 2], + tTMEM_STOREtO[None, iter_num - 1, 0], + ) + cute.arch.fence_view_async_tmem_store() + o_handle.release() + return mma_o_consumer, s_corr_consumer + + @cute.jit + def correction_epilog( + self, + value_args: Tuple, + sum_args: Tuple, + o_args: Tuple, + epi_tile: cute.Tile, + ) -> Tuple[pipeline.PipelineConsumer, pipeline.PipelineProducer]: + (seqlen_q, scale_output) = value_args + (sum_consumer, sSum) = sum_args + (mma_o_consumer, gO_staged, cO_staged, tOtO_staged) = o_args + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.softmax_warp_ids)) + sum_handle = sum_consumer.wait_and_advance() + row_sum = sSum[thread_idx] + cute.arch.fence_view_async_shared() + sum_handle.release() + scale = scale_output / row_sum + o_handle = mma_o_consumer.wait_and_advance() + for iter in cutlass.range(self.iterations_pv): + gO = gO_staged[None, None, iter] + cO = cO_staged[None, None, iter] + tOtO = tOtO_staged[(None, None), 0, 0, iter] + tOtO_epi = cute.zipped_divide(tOtO, epi_tile) + cO_epi = cute.zipped_divide(cO, epi_tile) + gO_epi = cute.zipped_divide(gO, epi_tile) + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.softmax_warp_ids)) + tmem_copy_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.pv_acc_dtype + ) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_epi) + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_epi) + tTMEM_LOADgO = thr_tmem_load.partition_D(gO_epi) + tTMEM_LOADcO = thr_tmem_load.partition_D(cO_epi) + for i in cutlass.range(cute.size(tTMEM_LOADtO, mode=[1]), unroll_full=True): + tTMEM_LOADtO_i = tTMEM_LOADtO[None, i, 0] + tTMEM_LOADgO_i = tTMEM_LOADgO[None, i, 0] + tTMEM_LOADcO_i = tTMEM_LOADcO[None, i, 0] + tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO[None, 0, i].shape, self.pv_acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO) + for j in cutlass.range(0, cute.size(tTMrO), 2, unroll_full=True): + tTMrO[j], tTMrO[j + 1] = cute.arch.mul_packed_f32x2( + (tTMrO[j], tTMrO[j + 1]), + (scale, scale), + ) + tSMrO = cute.make_rmem_tensor(tTMrO.shape, self.o_dtype) + o_vec = tTMrO.load() + tSMrO.store(o_vec.to(self.o_dtype)) + if cute.elem_less(tTMEM_LOADcO_i[0][0], seqlen_q): + cute.autovec_copy(tSMrO, tTMEM_LOADgO_i) + o_handle.release() + return mma_o_consumer, sum_consumer + + @cute.jit + def store_sum_max( + self, + row_max, + mLSE, + row_sum, + sSum, + sum_producer, + current_block_coord, + seqlen_q, + cum_seqlen_q, + cuseqlen_q, + scale_softmax, + ): + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.softmax_warp_ids)) + sum_handle = sum_producer.acquire_and_advance() + sSum[thread_idx] = row_sum + cute.arch.fence_view_async_shared() + sum_handle.commit() + + if cutlass.const_expr(mLSE is not None): + q_idx = current_block_coord[0] * self.cta_tiler[0] + tidx + hb_idx = ( + (current_block_coord[2][0], Int32(0)) + if cutlass.const_expr(cum_seqlen_q is not None) + else current_block_coord[2] + ) + lse_value = scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + if cute.elem_less(q_idx, seqlen_q): + global_q_idx = ( + q_idx + cuseqlen_q if cutlass.const_expr(cum_seqlen_q is not None) else q_idx + ) + mLSE[global_q_idx, hb_idx] = lse_value + return sum_producer diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index bbbc038e483..4e8bb7271f4 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -7,7 +7,7 @@ import cutlass import cutlass.cute as cute -from cutlass import Float32 +from cutlass import Float32, Boolean from quack import layout_utils import flash_attn.cute.utils as utils @@ -175,12 +175,14 @@ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: @dataclass class SoftmaxSm100(Softmax): rescale_threshold: cutlass.Constexpr[float] = 0.0 + max_offset: cutlass.Constexpr[int] = 0 @staticmethod def create( scale_log2: Float32, rescale_threshold: cutlass.Constexpr[float] = 0.0, softmax_scale: Float32 | None = None, + max_offset: cutlass.Constexpr[int] = 0, ): num_rows = 1 arch = 100 @@ -194,8 +196,40 @@ def create( arch, softmax_scale, rescale_threshold=rescale_threshold, + max_offset=max_offset, ) + @cute.jit + def compute_row_max_local(self, acc_S_row: cute.TensorSSA, is_first: Boolean) -> Float32: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + return row_max_new + + @cute.jit + def update_row_max_from_local( + self, + row_max_new: Float32, + is_first: Boolean, + ) -> Tuple[Float32, Float32]: + if cutlass.const_expr(is_first): + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + else: + row_max_old = self.row_max[0] + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + acc_scale = cute.math.exp2(acc_scale_) + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale = 1.0 + self.row_max[0] = row_max_new + return row_max_safe, acc_scale + @cute.jit def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: if cutlass.const_expr(is_first): @@ -233,11 +267,13 @@ def scale_subtract_rowmax( ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" row_max_scaled = row_max * self.scale_log2 + max_offset = Float32(self.max_offset) + bias = max_offset - row_max_scaled for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( (acc_S_row[i], acc_S_row[i + 1]), (self.scale_log2, self.scale_log2), - (-row_max_scaled, -row_max_scaled), + (bias, bias), ) @cute.jit @@ -404,7 +440,7 @@ def apply_score_mod_inner( q_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) # For Pack-GQA with non-constant q_idx, we need per-element head indices - # since a thread my process multiple query head indices + # since a thread may process multiple query head indices if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): head_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) diff --git a/flash_attn/cute/testing.py b/flash_attn/cute/testing.py index 6e3c40eb451..6e4bfed1335 100644 --- a/flash_attn/cute/testing.py +++ b/flash_attn/cute/testing.py @@ -91,20 +91,23 @@ def pad_input(hidden_states, indices, batch, seqlen): return rearrange(output, "(b s) ... -> b s ...", b=batch) -def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): +def generate_random_padding_mask( + max_seqlen, batch_size, device, mode="random", zero_lengths=False, min_seqlen=None +): assert mode in ["full", "random", "third"] + min_seqlen = min_seqlen if min_seqlen is not None else 0 if zero_lengths else 1 if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": lengths = torch.randint( - max(0 if zero_lengths else 1, max_seqlen - 20), + max(min_seqlen, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device, ) else: lengths = torch.randint( - max(0 if zero_lengths else 1, max_seqlen // 3), + max(min_seqlen, max_seqlen // 3), max_seqlen + 1, (batch_size, 1), device=device, @@ -343,6 +346,8 @@ def attention_ref( upcast=True, reorder_ops=False, intermediate_dtype=None, + return_lse=False, + gather_kv_indices=None, ): if causal: window_size = (window_size[0], 0) @@ -399,10 +404,21 @@ def attention_ref( local_mask = ( torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask ) + if gather_kv_indices is not None: + batch = q.shape[0] + topk_len = gather_kv_indices.shape[2] + if topk_len < seqlen_k: + topk_index_mask = torch.full( + (batch, seqlen_q, seqlen_k), False, device="cuda" + ).scatter_(-1, gather_kv_indices, True) + scores.masked_fill_(rearrange(~topk_index_mask, "b t s -> b 1 t s"), float("-inf")) if local_mask is not None: scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias + # After all masks are applied, before softmax: + # scores shape: [b, h, t, s] + lse = torch.logsumexp(scores, dim=-1) # [b, h, t] if learnable_sink is None: attention = torch.softmax(scores, dim=-1).to(v.dtype) else: @@ -414,6 +430,8 @@ def attention_ref( normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp( learnable_sink - logits_or_sinks_max ) + # LSE with sink: log(Z) = log(normalizer) + max + lse = (torch.log(normalizer.squeeze(-1)) + logits_or_sinks_max.squeeze(-1)).to(dtype_og) attention = (unnormalized_scores / normalizer).to(v.dtype) if query_padding_mask is not None: attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) @@ -431,6 +449,8 @@ def attention_ref( output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + if return_lse: + return output.to(dtype_og), attention.to(dtype_og), lse.to(dtype_og) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 3ee4bc8bab1..ff820e59626 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Tri Dao. +# Copyright (c) 2025, Tri Dao, Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms, Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin. from enum import IntEnum, auto from typing import Optional, Tuple, Protocol, runtime_checkable @@ -16,6 +16,13 @@ from cutlass import Int32, const_expr from cutlass.cute import FastDivmodDivisor from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams +from cutlass.cute.typing import Boolean +from cutlass.cutlass_dsl import ( + min as dsl_min, + extract_mlir_values, + new_from_mlir_values, +) +from cutlass.utils.hardware_info import HardwareInfo from quack.cute_dsl_utils import ParamsBase @@ -402,6 +409,7 @@ class Params(ParamsBase): cluster_shape_m: cutlass.Constexpr[int] = 1 scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC lpt: cutlass.Constexpr[bool] = True + use_cluster_idx: cutlass.Constexpr[bool] = True @staticmethod @cute.jit @@ -445,6 +453,7 @@ def create( cluster_shape_m=args.cluster_shape_mn[0], scheduling_mode=scheduling_mode, lpt=args.lpt, + use_cluster_idx=args.use_cluster_idx, ) def __init__( @@ -532,12 +541,19 @@ def clc_work_to_coords(self, work) -> WorkTileInfo: block_idx = block_idx // self.params.cluster_shape_m if const_expr(self.params.lpt): # Longest-processing-time-first: reverse block order - block_idx = self.params.num_block - 1 - block_idx + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + num_block = self.params.num_block // self.params.cluster_shape_m + else: + num_block = self.params.num_block + block_idx = num_block - 1 - block_idx split_idx = Int32(0) if const_expr(self.params.is_split_kv): batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod) else: batch_idx = work.tile_idx[2] + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block_idx = block_idx * self.params.cluster_shape_m + bidx_in_cluster[0] return WorkTileInfo( (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)), work.is_valid_tile, @@ -1085,3 +1101,536 @@ def __new_from_mlir_values__(self, values): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return self.__class__(*obj_list, loc=self._loc) + + +# ----------------------------------------------------------------------------- +# SM100 FMHA-specific schedulers (kept separate from generic schedulers). +# ----------------------------------------------------------------------------- + + +class Sm100FmhaStaticTileSchedulerParams: + """A class to represent parameters for the FMHA (Fused Multi-Head Attention) static tile scheduler. + + This class holds the configuration parameters needed to initialize and configure + the tile scheduler for FMHA operations. + + :ivar is_persistent: Whether to use persistent kernel mode. + :type is_persistent: bool + :ivar problem_shape_mbh: Problem shape in (M, B, H) format. + :type problem_shape_mbh: cute.Shape + """ + + def __init__( + self, + is_persistent: bool, + problem_shape_mbh: cute.Shape, + *, + loc=None, + ip=None, + ): + """ + Initializes the Sm100FmhaStaticTileSchedulerParams with the given parameters. + + :param is_persistent: Whether to use persistent kernel mode. + :type is_persistent: bool + :param problem_shape_mbh: Problem shape in (M, B, H) format. + :type problem_shape_mbh: cute.Shape + """ + self.is_persistent = is_persistent + self.problem_shape_mbh = problem_shape_mbh + self._loc = loc + self._ip = ip + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.problem_shape_mbh]: + obj_values = extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.problem_shape_mbh], self._values_pos): + obj_list.append(new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return Sm100FmhaStaticTileSchedulerParams( + self.is_persistent, *(tuple(obj_list)), loc=self._loc + ) + + +class Sm100FmhaStaticTileScheduler: + """A static tile scheduler for FMHA (Fused Multi-Head Attention) operations. + + This class manages the scheduling of work tiles for FMHA kernels, supporting + both persistent and non-persistent kernel modes. It tracks the current work + position and advances through the problem space efficiently. + + :ivar _params: Scheduler parameters. + :type _params: Sm100FmhaStaticTileSchedulerParams + :ivar _blk_coord: Block coordinates. + :type _blk_coord: cute.Coord + :ivar _grid_shape: Grid shape for the kernel. + :type _grid_shape: cute.Shape + :ivar _is_persistent: Whether to use persistent kernel mode. + :type _is_persistent: bool + :ivar _current_work_linear_idx: Current linear work index. + :type _current_work_linear_idx: Int32 + :ivar _problem_shape_mbh: Problem shape in (M, B, H) format. + :type _problem_shape_mbh: cute.Layout + :ivar _num_blocks: Number of blocks in the problem. + :type _num_blocks: Int32 + :ivar _is_first_block: Whether this is the first block. + :type _is_first_block: bool + :ivar num_persistent_sm: Number of persistent SMs. + :type num_persistent_sm: Int32 + """ + + def __init__( + self, + params: Sm100FmhaStaticTileSchedulerParams, + current_work_linear_idx: Int32, + blk_coord: cute.Coord, + grid_shape: cute.Shape, + *, + loc=None, + ip=None, + ): + """ + Initializes the Sm100FmhaStaticTileScheduler with the given parameters. + + :param params: Scheduler parameters. + :type params: Sm100FmhaStaticTileSchedulerParams + :param current_work_linear_idx: Current linear work index. + :type current_work_linear_idx: Int32 + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param grid_shape: Grid shape for the kernel. + :type grid_shape: cute.Shape + """ + self._params = params + self._blk_coord = blk_coord + self._grid_shape = grid_shape + self._is_persistent = params.is_persistent + self._current_work_linear_idx = current_work_linear_idx + self._problem_shape_mbh = cute.make_layout(params.problem_shape_mbh, loc=loc, ip=ip) + self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip) + self._is_first_block = True + self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) + self._loc = loc + self._ip = ip + + # called by host + @staticmethod + def get_grid_shape( + params: Sm100FmhaStaticTileSchedulerParams, + *, + loc=None, + ip=None, + ) -> cute.Shape: + """ + Determine the grid shape for the FMHA kernel. + + For persistent kernels, the grid shape is limited by the number of SMs + (Streaming Multiprocessors) available on the device. For non-persistent + kernels, the grid shape matches the problem shape. + + :param params: Scheduler parameters. + :type params: Sm100FmhaStaticTileSchedulerParams + + :return: Grid shape as (M, B, H) tuple. + :rtype: cute.Shape + """ + if params.is_persistent: + hardware_info = HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + return ( + dsl_min(sm_count, cute.size(params.problem_shape_mbh, loc=loc, ip=ip)), + 1, + 1, + ) + else: + return params.problem_shape_mbh + + @staticmethod + def check_valid_work_for_seqlen_q( + q_tiler: int, + current_idx: Int32, + seqlen_q: Int32, + ) -> Boolean: + """ + Check if the current work index is valid for the given query sequence length. + + This method verifies that the current work tile index multiplied by the + query tiler size is within the bounds of the query sequence length. + + :param q_tiler: Query tiler size. + :type q_tiler: int + :param current_idx: Current work index. + :type current_idx: Int32 + :param seqlen_q: Query sequence length. + :type seqlen_q: Int32 + + :return: True if the work is valid, False otherwise. + :rtype: Boolean + """ + return current_idx * q_tiler < seqlen_q + + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + """ + Get information about the current work tile. + + Determines if the current work is valid and computes the tile coordinates + based on whether the kernel is persistent or non-persistent. + + :return: WorkTileInfo containing tile coordinates and validity flag. + :rtype: WorkTileInfo + """ + is_valid = ( + self._current_work_linear_idx < self._num_blocks + if self._is_persistent + else self._is_first_block + ) + + blk_coord = (0, 0, 0) + if self._is_persistent: + blk_coord = self._problem_shape_mbh.get_hier_coord( + self._current_work_linear_idx, loc=loc, ip=ip + ) + else: + blk_coord = self._blk_coord + + # cur_tile_coord is (mid, 0, (bid, hid)) + cur_tile_coord = ( + blk_coord[0], + 0, + (blk_coord[1], blk_coord[2]), + ) + + return cutlass.utils.WorkTileInfo(cur_tile_coord, is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + """ + Get the initial work tile information. + + :return: Initial WorkTileInfo. + :rtype: WorkTileInfo + """ + return self.get_current_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): + """ + Advance to the next work tile and return it. + + For persistent kernels, advances by the number of persistent SMs. + For non-persistent kernels, marks that the first block has been processed. + """ + if self._is_persistent: + self._current_work_linear_idx += advance_count * self.num_persistent_sm + self._is_first_block = False + return self.get_current_work() + + def prefetch_next_work(self, *, loc=None, ip=None): + """No-op for static scheduler.""" + pass + + def producer_tail(self, *, loc=None, ip=None): + """No-op for static scheduler.""" + pass + + def __extract_mlir_values__(self): + values = extract_mlir_values(self._params) + values.extend(extract_mlir_values(self._current_work_linear_idx)) + values.extend(extract_mlir_values(self._blk_coord)) + values.extend(extract_mlir_values(self._grid_shape)) + return values + + def __new_from_mlir_values__(self, values): + assert len(values) == 10 + new_params = new_from_mlir_values(self._params, values[0:3]) + new_current_work_linear_idx = new_from_mlir_values( + self._current_work_linear_idx, [values[3]] + ) + new_blk_coord = new_from_mlir_values(self._blk_coord, values[4:7]) + new_grid_shape = new_from_mlir_values(self._grid_shape, values[7:]) + return Sm100FmhaStaticTileScheduler( + new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape + ) + + +def compute_sm100_fmha_grid( + o_shape: cute.Shape, + cta_tiler: Tuple[int, int, int], + is_persistent: bool, +) -> Tuple[Sm100FmhaStaticTileSchedulerParams, Tuple[int, int, int]]: + """Compute grid parameters for FMHA (static scheduler). + + The output tensor o has shape (s, d, ((h_r, h_k), b)). + """ + tile_sched_params = Sm100FmhaStaticTileSchedulerParams( + is_persistent, + ( + cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), + cute.size(o_shape[2][0]), + cute.size(o_shape[2][1]), + ), + ) + grid = Sm100FmhaStaticTileScheduler.get_grid_shape(tile_sched_params) + return tile_sched_params, grid + + +############################################################################## +# Fmha CLC dynamic tile scheduler +############################################################################## + + +class Sm100FmhaClcDynamicTileSchedulerParams: + """Parameters for FMHA CLC dynamic persistent tile scheduler. + + This class manages the layout of tiles for CLC (Cluster Launch Control) + based dynamic scheduling, adapted for FMHA's (M, B, H) problem shape. + + :ivar problem_shape_mbh: Problem shape in (M, B, H) format. + :type problem_shape_mbh: cute.Shape + :ivar cluster_shape_mnk: Cluster shape in (M, N, K) format. + :type cluster_shape_mnk: cute.Shape + """ + + def __init__( + self, + problem_shape_mbh: cute.Shape, + cluster_shape_mnk: cute.Shape, + *, + loc=None, + ip=None, + ): + self.problem_shape_mbh = problem_shape_mbh + self._cluster_shape_mnk = cluster_shape_mnk + self.cluster_shape_mn = cluster_shape_mnk[:2] + self._loc = loc + self._ip = ip + + # FMHA uses linear indexing over (M, B, H), convert to (M, N, L) style + # For FMHA: M dim is tile count along sequence, N=1, L=(B*H) + self.problem_shape_ntile_mnl = ( + problem_shape_mbh[0], # M tiles + 1, # N tiles (always 1 for FMHA) + problem_shape_mbh[1] * problem_shape_mbh[2], # L = B * H + ) + + # Create layout for cluster-to-tile mapping + self.problem_layout_ncluster_mnl = cute.make_layout( + cute.ceil_div(self.problem_shape_ntile_mnl, cluster_shape_mnk[:2]), + loc=loc, + ip=ip, + ) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [ + self.problem_shape_mbh, + self._cluster_shape_mnk, + ]: + obj_values = extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + values_copy = list(values) + for obj, n_items in zip( + [self.problem_shape_mbh, self._cluster_shape_mnk], + self._values_pos, + ): + obj_list.append(new_from_mlir_values(obj, values_copy[:n_items])) + values_copy = values_copy[n_items:] + return Sm100FmhaClcDynamicTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + + def get_grid_shape(self, *, loc=None, ip=None) -> Tuple[int, int, int]: + """Compute grid shape aligned with cluster shape.""" + return cute.round_up(self.problem_shape_ntile_mnl, self._cluster_shape_mnk) + + def clc_hw_params(self) -> ClcDynamicPersistentTileSchedulerParams: + """Return params for the upstream CLC hardware scheduler.""" + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=self.problem_shape_ntile_mnl, + cluster_shape_mnk=self._cluster_shape_mnk, + ) + + +class Sm100FmhaClcDynamicTileScheduler: + """CLC dynamic persistent tile scheduler for FMHA. + + This scheduler uses Blackwell's Cluster Launch Control hardware mechanism + for dynamic tile distribution, providing automatic load balancing. + Adapted for FMHA's (M, B, H) problem shape. + """ + + def __init__( + self, + params: Sm100FmhaClcDynamicTileSchedulerParams, + cta_id_in_cluster: cute.Coord, + num_tiles_executed: Int32, + clc_response_ptr: cute.Pointer, + block_idx: Tuple, + clc: ClcState = None, + *, + loc=None, + ip=None, + ): + self.params = params + self.cta_id_in_cluster = cta_id_in_cluster + self._num_tiles_executed = num_tiles_executed + self._clc_response_ptr = clc_response_ptr + self._block_idx = block_idx + self.clc = clc + self._loc = loc + self._ip = ip + + def __extract_mlir_values__(self): + values = extract_mlir_values(self.cta_id_in_cluster) + values.extend(extract_mlir_values(self._num_tiles_executed)) + values.extend(extract_mlir_values(self._clc_response_ptr)) + values.extend(extract_mlir_values(self._block_idx)) + if self.clc is not None: + values.extend(extract_mlir_values(self.clc)) + return values + + def __new_from_mlir_values__(self, values): + new_cta_id_in_cluster = new_from_mlir_values(self.cta_id_in_cluster, values[0:3]) + new_num_tiles_executed = new_from_mlir_values(self._num_tiles_executed, [values[3]]) + new_clc_response_ptr = new_from_mlir_values(self._clc_response_ptr, [values[4]]) + new_block_idx = new_from_mlir_values(self._block_idx, values[5:8]) + new_clc = None + if self.clc is not None: + new_clc = new_from_mlir_values(self.clc, values[8:]) + return Sm100FmhaClcDynamicTileScheduler( + self.params, + new_cta_id_in_cluster, + new_num_tiles_executed, + new_clc_response_ptr, + new_block_idx, + new_clc, + ) + + @staticmethod + def create( + params: Sm100FmhaClcDynamicTileSchedulerParams, + block_idx: Tuple, + grid_dim: Tuple, + clc_response_ptr: cute.Pointer, + clc: ClcState = None, + *, + loc=None, + ip=None, + ): + """Create a CLC dynamic tile scheduler instance.""" + bidx, bidy, bidz = block_idx + + # CTA id in cluster + cta_id_in_cluster = ( + Int32(bidx % params.cluster_shape_mn[0]), + Int32(bidy % params.cluster_shape_mn[1]), + Int32(0), + ) + + num_tiles_executed = Int32(0) + + return Sm100FmhaClcDynamicTileScheduler( + params, + cta_id_in_cluster, + num_tiles_executed, + clc_response_ptr, + block_idx, + clc, + ) + + @staticmethod + def get_grid_shape( + params: Sm100FmhaClcDynamicTileSchedulerParams, + *, + loc=None, + ip=None, + ) -> Tuple[int, int, int]: + """Get grid shape for kernel launch.""" + return params.get_grid_shape(loc=loc, ip=ip) + + def work_tile_info_from_clc_response(self, result_addr: cute.Pointer, *, loc=None, ip=None): + """Parse CLC response and convert to FMHA tile coordinates.""" + m_idx, n_idx, l_idx, vld = cute.arch.clc_response(result_addr, loc=loc, ip=ip) + cute.arch.fence_proxy("async.shared", space="cta") + + # CLC returns first CTA coordinates: m_idx=x, l_idx=z + # l_idx is the L (batch) dimension; decode to (bid, hid) + hid = l_idx % self.params.problem_shape_mbh[2] + bid = l_idx // self.params.problem_shape_mbh[2] + + cta_idx_in_cluster, cta_idy_in_cluster, _ = self.cta_id_in_cluster + cur_tile_coord = ( + m_idx + cta_idx_in_cluster, # M dimension + 0, # N always 0 for FMHA + (bid, hid), # (B, H) packed + ) + + return cutlass.utils.WorkTileInfo(cur_tile_coord, vld) + + def get_current_work(self, *, loc=None, ip=None): + """Get current work tile from CLC response.""" + return self.work_tile_info_from_clc_response(self._clc_response_ptr, loc=loc, ip=ip) + + def initial_work_tile_info(self, *, loc=None, ip=None): + """Get initial work tile based on block index.""" + bidx, bidy, bidz = self._block_idx + # bidz is the L (batch) dimension; decode to (bid, hid) + hid = bidz % self.params.problem_shape_mbh[2] + bid = bidz // self.params.problem_shape_mbh[2] + return cutlass.utils.WorkTileInfo((bidx, 0, (bid, hid)), True) + + def advance_to_next_work(self, *, loc=None, ip=None): + """Consumer-side advance: wait for next tile, read coordinates, release.""" + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work(loc=loc, ip=ip) + self.clc.consumer_release(loc=loc, ip=ip) + self._num_tiles_executed += Int32(1) + return work + + def prefetch_next_work(self, *, loc=None, ip=None): + """Producer-side: issue CLC query for next tile.""" + self.clc.prefetch_next_work(loc=loc, ip=ip) + + def producer_tail(self, *, loc=None, ip=None): + """Producer-side cleanup after last tile.""" + self.clc.producer_tail(loc=loc, ip=ip) + + @property + def num_tiles_executed(self) -> Int32: + return self._num_tiles_executed + + +def compute_sm100_fmha_grid_clc( + o_shape: cute.Shape, + cta_tiler: Tuple[int, int, int], + cluster_shape_mnk: Tuple[int, int, int], +) -> Tuple[Sm100FmhaClcDynamicTileSchedulerParams, Tuple[int, int, int]]: + """Compute grid parameters for FMHA with CLC dynamic scheduling.""" + problem_shape_mbh = ( + cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), + cute.size(o_shape[2][0]), + cute.size(o_shape[2][1]), + ) + tile_sched_params = Sm100FmhaClcDynamicTileSchedulerParams(problem_shape_mbh, cluster_shape_mnk) + grid = Sm100FmhaClcDynamicTileScheduler.get_grid_shape(tile_sched_params) + return tile_sched_params, grid + + +############################################################################## +# Fused Mask +############################################################################## + + +def make_sm100_thread_cooperative_group(size: int): + return cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, size) + + +SM100_TMEM_CAPACITY_COLUMNS = 512 diff --git a/flash_attn/cute/topk_gather_kv.py b/flash_attn/cute/topk_gather_kv.py new file mode 100644 index 00000000000..67169fb5900 --- /dev/null +++ b/flash_attn/cute/topk_gather_kv.py @@ -0,0 +1,274 @@ +from typing import Type, Optional +from dataclasses import dataclass +import operator + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync +from cutlass import Int32, Uint32, const_expr, Boolean + +from flash_attn.cute import utils +from flash_attn.cute.utils import warp_reduce +from quack.cute_dsl_utils import ParamsBase + +import math + + +@dataclass +class CpasyncGatherKVManager(ParamsBase): + mIndexTopk: cute.Tensor + sBitmask: cute.Tensor + + cta_rank_in_cluster: Int32 + thread_idx: Int32 + warp_idx: Int32 + + topk_length: Int32 + seqlen_k_limit: Int32 + tile_n: Int32 + num_threads: cutlass.Constexpr[Int32] + hdim: cutlass.Constexpr[Int32] + hdim_v: cutlass.Constexpr[Int32] + num_hdimv_splits: cutlass.Constexpr[Int32] + cta_group_size: cutlass.Constexpr[Int32] + + gmem_threads_per_row: cutlass.Constexpr[Int32] + topk_indices_per_thread: Int32 + async_copy_elems: Int32 + + gmem_tiled_copy_KV: cute.TiledCopy + gmem_thr_copy_KV: cute.TiledCopy + + rTopk: cute.Tensor + rTopkHalf: cute.Tensor + # for bitmask + rTopk_NonInterleaved: cute.Tensor + + pipeline_bitmask: Optional[pipeline.PipelineAsync] + cpasync_barrier: pipeline.NamedBarrier + + disable_bitmask: cutlass.Constexpr[Boolean] + + @staticmethod + def create( + mIndexTopk: cute.Tensor, + sBitmask: cute.Tensor, + cta_rank_in_cluster: Int32, + thread_idx: Int32, + warp_idx: Int32, + topk_length: Int32, + seqlen_k_limit: Int32, + tile_n: cutlass.Constexpr[Int32], + hdim: cutlass.Constexpr[Int32], + hdim_v: cutlass.Constexpr[Int32], + num_hdimv_splits: cutlass.Constexpr[Int32], + num_threads: cutlass.Constexpr[Int32], + dtype: Type[cutlass.Numeric], + cta_group_size: cutlass.Constexpr[Int32], + pipeline_bitmask: Optional[pipeline.PipelineAsync], + num_stages_bitmask: cutlass.Constexpr[Int32], + cpasync_barrier: pipeline.NamedBarrier, + disable_bitmask: cutlass.Constexpr[Boolean], + ): + assert tile_n % num_threads == 0 + assert num_threads == 128 + assert hdim % 64 == 0 + assert (hdim_v // num_hdimv_splits // cta_group_size) % 64 == 0 + assert num_threads % cute.arch.WARP_SIZE == 0 + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // dtype.width + dtype_bytes = dtype.width // 8 + # assumes hdim is never part of transposed operand + gmem_k_block_size = math.gcd( + hdim, + hdim_v // num_hdimv_splits // cta_group_size, + 128 // dtype_bytes, + ) + assert gmem_k_block_size % async_copy_elems == 0 + gmem_threads_per_row = gmem_k_block_size // async_copy_elems + assert cute.arch.WARP_SIZE % gmem_threads_per_row == 0 + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + dtype, + num_bits_per_copy=universal_copy_bits, + ) + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, async_copy_elems)) + gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout) + gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx) + topk_indices_per_thread = tile_n // num_threads + + rTopk = cute.make_rmem_tensor((topk_indices_per_thread,), Int32) + rTopkHalf = cute.make_rmem_tensor((topk_indices_per_thread,), Int32) + rTopk_NonInterleaved = cute.make_rmem_tensor((topk_indices_per_thread,), Int32) + + return CpasyncGatherKVManager( + mIndexTopk, + sBitmask, + cta_rank_in_cluster, + thread_idx, + warp_idx, + topk_length, + seqlen_k_limit, + tile_n, + num_threads, + hdim, + hdim_v, + num_hdimv_splits, + cta_group_size, + gmem_threads_per_row, + topk_indices_per_thread, + async_copy_elems, + gmem_tiled_copy_KV, + gmem_thr_copy_KV, + rTopk, + rTopkHalf, + rTopk_NonInterleaved, + pipeline_bitmask, + cpasync_barrier, + disable_bitmask, + ) + + @cute.jit + def load_index_topk( + self, + n_block: Int32, + transpose: bool, + ): + entries_per_thread = self.topk_indices_per_thread + rTopk = self.rTopk if const_expr(transpose) else self.rTopkHalf + + for i in cutlass.range_constexpr(entries_per_thread): + row = ( + i * self.num_threads + + (self.thread_idx % self.gmem_threads_per_row) + * (self.num_threads // self.gmem_threads_per_row) + + (self.thread_idx // self.gmem_threads_per_row) + ) + # need this if not offset in load_X + # if const_expr(not transpose): + # row += self.cta_rank_in_cluster * (self.tile_n//self.cta_group_size) + # row = row % self.tile_n + row_idx = n_block * self.tile_n + row + rTopk[i] = self.mIndexTopk[row_idx] + + if const_expr(not transpose and not self.disable_bitmask): + row_non_interleaved = i * self.num_threads + self.thread_idx + row_idx_non_interleaved = n_block * self.tile_n + row_non_interleaved + self.rTopk_NonInterleaved[0] = self.mIndexTopk[row_idx_non_interleaved] + + @cute.jit + def compute_bitmask( + self, + producer_state_bitmask, + ): + lane_idx = cute.arch.lane_idx() + assert cute.size(self.rTopk_NonInterleaved) == 1 + bitmask = Uint32(0) + + # Step 1. Construct per-thread bitmask + topk_idx = self.rTopk_NonInterleaved[0] + is_valid = topk_idx >= 0 and topk_idx < self.seqlen_k_limit + if is_valid: + bitmask = Uint32(1 << lane_idx) + + # Step 2. Warp shuffle bitwise OR = add since indices are exclusive. + bitmask = warp_reduce(bitmask, operator.add) + + self.pipeline_bitmask.producer_acquire(producer_state_bitmask) + # store to smem and sync threads + if lane_idx == 0: + self.sBitmask[self.warp_idx, producer_state_bitmask.index] = bitmask + self.cpasync_barrier.arrive_and_wait() + + self.pipeline_bitmask.producer_commit(producer_state_bitmask) + producer_state_bitmask.advance() + return producer_state_bitmask + + @cute.jit + def compute_X_ptr( + self, + mX: cute.Tensor, + transpose: bool, + ): + entries_per_thread = self.topk_indices_per_thread + tPrXPtr = cute.make_rmem_tensor((entries_per_thread,), cutlass.Int64) + tPrRowValid = cute.make_rmem_tensor((entries_per_thread,), cutlass.Int32) + rTopk = self.rTopk if const_expr(transpose) else self.rTopkHalf + + for i in cutlass.range_constexpr(entries_per_thread): + topk_idx = rTopk[i] + if const_expr(not self.disable_bitmask): + row_valid = topk_idx >= 0 and topk_idx < self.seqlen_k_limit + tPrRowValid[i] = row_valid + if const_expr(not transpose): + tPrXPtr[i] = utils.elem_pointer(mX, (topk_idx, 0)).toint() + else: + tPrXPtr[i] = utils.elem_pointer(mX, (0, topk_idx)).toint() + + return tPrXPtr, tPrRowValid + + @cute.jit + def load_X( + self, + mX: cute.Tensor, + sX: cute.Tensor, + transpose: bool, + K_or_V: str, + ): + assert K_or_V in ("K", "V") + cta_tile_n = self.tile_n if const_expr(transpose) else self.tile_n // self.cta_group_size + head_dim = self.hdim if const_expr(K_or_V == "K") else self.hdim_v // self.num_hdimv_splits + if const_expr(transpose): + head_dim = head_dim // self.cta_group_size + order = (1, 0) if const_expr(transpose) else (0, 1) + + sX_nd_layout = cute.make_ordered_layout((cta_tile_n, head_dim), order=order) + sX_nd = cute.composition(sX, sX_nd_layout) + + cX = cute.make_identity_tensor((cta_tile_n, head_dim)) + tXsX = self.gmem_thr_copy_KV.partition_D(sX_nd) + tXcX = self.gmem_thr_copy_KV.partition_S(cX) + + tPrXPtr, tPrRowValid = self.compute_X_ptr(mX, transpose) + + if const_expr(not transpose): + offset = self.cta_rank_in_cluster * (self.gmem_threads_per_row // self.cta_group_size) + else: + offset = 0 + + for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])): + if const_expr(not self.disable_bitmask): + row_valid = utils.shuffle_sync( + tPrRowValid[m // self.gmem_threads_per_row], + (m + offset) % self.gmem_threads_per_row, + width=self.gmem_threads_per_row, + ) + should_load = cute.make_fragment_like(tXsX[(0, None), m, 0], Boolean) + should_load.fill(Boolean(row_valid)) + x_ptr_i64 = utils.shuffle_sync( + tPrXPtr[m // self.gmem_threads_per_row], + (m + offset) % self.gmem_threads_per_row, + width=self.gmem_threads_per_row, + ) + x_gmem_ptr = cute.make_ptr( + mX.element_type, x_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + mX_cur = cute.make_tensor(x_gmem_ptr, cute.make_layout((head_dim,))) + mX_cur_copy = cute.tiled_divide(mX_cur, (self.async_copy_elems,)) + + for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])): + ki = tXcX[0, 0, k][1] // self.async_copy_elems + mX_cur_copy_ki = mX_cur_copy[None, ki] + tXsX_k = tXsX[None, m, k] + mX_cur_copy_ki = cute.make_tensor(mX_cur_copy_ki.iterator, tXsX_k.layout) + cute.copy( + self.gmem_tiled_copy_KV, + mX_cur_copy_ki, + tXsX_k, + pred=should_load if const_expr(not self.disable_bitmask) else None, + ) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 31186618569..b9dc4f5c112 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -4,6 +4,7 @@ import hashlib import inspect import os +from functools import partial from typing import Type, Callable, Optional, Tuple, overload import cutlass @@ -60,12 +61,33 @@ _fa_disable_2cta_enabled: bool = os.environ.get("FA_DISABLE_2CTA", "0") == "1" +def _is_cuda_12() -> bool: + """Check if the CUDA toolkit version is 12.x. + + 2CTA forward non-causal has a codegen regression on CUDA 12 that causes + ~18% slowdown compared to 1CTA. This is fixed in CUDA 13.x. + """ + try: + import torch + + cuda_version = torch.version.cuda + if cuda_version is not None: + major = cuda_version.split(".")[0] + return int(major) == 12 + except Exception: + pass + return False + + +_fa_disable_2cta_cuda12: bool = _is_cuda_12() + + def _get_use_clc_scheduler_default() -> bool: return _fa_clc_enabled def _get_disable_2cta_default() -> bool: - return _fa_disable_2cta_enabled + return _fa_disable_2cta_enabled or _fa_disable_2cta_cuda12 def _compute_base_hash(func: Callable) -> str: @@ -126,16 +148,28 @@ def hash_callable( def create_softcap_scoremod(softcap_val): - inv_softcap = 1.0 / softcap_val - @cute.jit - def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors): - scores = acc_S_SSA * inv_softcap - return scores * cute.math.tanh(scores, fastmath=True) + def scoremod_premask_fn( + acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, seqlen_info, aux_tensors + ): + scores = acc_S_SSA / softcap_val + return softcap_val * cute.math.tanh(scores, fastmath=True) return scoremod_premask_fn +def create_softcap_scoremod_bwd(softcap_val): + @cute.jit + def scoremod_bwd_fn( + grad_out_SSA, score_SSA, batch_idx, head_idx, q_idx, kv_idx, seqlen_info, aux_tensors + ): + scores = score_SSA / softcap_val + tanh_scores = cute.math.tanh(scores, fastmath=True) + return grad_out_SSA * (1.0 - tanh_scores * tanh_scores) + + return scoremod_bwd_fn + + LOG2_E = math.log2(math.e) @@ -180,6 +214,34 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te ) +def convert_from_dlpack_compact_dynamic( + x, + *, + dynamic_modes: tuple[int, ...], + alignment: int = 16, + stride_order=None, + divisibility: int = 1, + enable_tvm_ffi: bool = False, +) -> cute.Tensor: + """Convert via DLPack and mark selected compact dimensions as dynamic.""" + if isinstance(dynamic_modes, int): + dynamic_modes = (dynamic_modes,) + if stride_order is None: + stride_order = x.dim_order() + t = ( + from_dlpack(x, assumed_align=alignment, enable_tvm_ffi=True) + if enable_tvm_ffi + else from_dlpack(x, assumed_align=alignment) + ) + for mode in dynamic_modes: + t = t.mark_compact_shape_dynamic( + mode=mode, + stride_order=stride_order, + divisibility=divisibility, + ) + return t + + def convert_from_dlpack_leading_static( x, leading_dim, alignment=16, static_modes=None, stride_order=None ) -> cute.Tensor: @@ -787,6 +849,92 @@ def domain_offset_aligned( return cute.make_tensor(new_ptr, tensor.layout) +@dsl_user_op +def warp_reduction( + val: cute.Numeric, op: Callable, *, threads_in_group: int = 32, loc=None, ip=None +) -> cute.Numeric: + """Warp-wide reduction helper for a custom binary op.""" + offset = threads_in_group // 2 + while offset > 0: + val = op( + val, + cute.arch.shuffle_sync_bfly( + val, offset=offset, mask=-1, mask_and_clamp=31, loc=loc, ip=ip + ), + ) + offset //= 2 + return val + + +warp_reduction_max = partial( + warp_reduction, op=lambda x, y: fmax(x, y) if isinstance(x, Float32) else max(x, y) +) +warp_reduction_sum = partial(warp_reduction, op=lambda x, y: x + y) # noqa: FURB118 + + +@dsl_user_op +def make_cotiled_copy( + atom: cute.CopyAtom, atom_layout_tv: cute.Layout, data_layout: cute.Layout, *, loc=None, ip=None +) -> cute.TiledCopy: + """Compatibility wrapper for deprecated CuTeDSL `make_cotiled_copy`.""" + assert cute.is_static(atom_layout_tv.type), "atom_layout_tv must be static" + assert cute.is_static(data_layout.type), "data_layout must be static" + + inv_layout_ = cute.left_inverse(data_layout, loc=loc, ip=ip) + inv_data_layout = cute.make_layout( + (inv_layout_.shape, (1)), stride=(inv_layout_.stride, (0)), loc=loc, ip=ip + ) + layout_tv_data = cute.composition(inv_data_layout, atom_layout_tv, loc=loc, ip=ip) + + atom_layout_v_to_check = cute.coalesce( + cute.make_layout(atom_layout_tv.shape[1], stride=atom_layout_tv.stride[1], loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + data_layout_v_to_check = cute.coalesce( + cute.composition( + data_layout, + cute.make_layout( + layout_tv_data.shape[1], stride=layout_tv_data.stride[1], loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + assert data_layout_v_to_check == atom_layout_v_to_check, ( + "the memory pointed to by atom_layout_tv does not exist in the data_layout." + ) + + flat_data_shape = cute.product_each(data_layout.shape, loc=loc, ip=ip) + tiler = tuple( + cute.filter( + cute.composition( + cute.make_layout( + flat_data_shape, + stride=tuple(0 if j != i else 1 for j in range(cute.rank(flat_data_shape))), + loc=loc, + ip=ip, + ), + layout_tv_data, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + for i in range(cute.rank(flat_data_shape)) + ) + tile2data = cute.composition( + cute.make_layout(flat_data_shape, loc=loc, ip=ip), tiler, loc=loc, ip=ip + ) + layout_tv = cute.composition( + cute.left_inverse(tile2data, loc=loc, ip=ip), layout_tv_data, loc=loc, ip=ip + ) + return cute.make_tiled_copy(atom, layout_tv, tiler, loc=loc, ip=ip) + + @cute.jit def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index a53b4a3108a..2cbb3480bda 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -163,6 +163,7 @@ def _flash_attn_varlen_forward( leftpad_k: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, zero_tensors: bool = False, + num_splits: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( @@ -187,6 +188,7 @@ def _flash_attn_varlen_forward( softcap, return_softmax, None, + num_splits, ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() @@ -214,6 +216,7 @@ def _flash_attn_varlen_forward_fake( leftpad_k: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, zero_tensors: bool = False, + num_splits: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] paged_kv = block_table is not None diff --git a/flash_attn/utils/distributed.py b/flash_attn/utils/distributed.py index 74c55279645..6fe93790154 100644 --- a/flash_attn/utils/distributed.py +++ b/flash_attn/utils/distributed.py @@ -9,9 +9,11 @@ # version of PyTorch. The following 4 lines are for backward compatibility with # older PyTorch. if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + if hasattr(torch.distributed, "_all_gather_base"): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base if "reduce_scatter_tensor" not in dir(torch.distributed): - torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base + if hasattr(torch.distributed, "_reduce_scatter_base"): + torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base # Raw operation, does not support autograd, but does support async diff --git a/tests/cute/test_cache_utils.py b/tests/cute/test_cache_utils.py new file mode 100644 index 00000000000..4b1e690920b --- /dev/null +++ b/tests/cute/test_cache_utils.py @@ -0,0 +1,30 @@ +import logging +from types import SimpleNamespace + +import flash_attn.cute.cache_utils as cache_utils +from flash_attn.cute import fa_logging + + +def test_persistent_cache_hit_logs_at_host_level_only(tmp_path, monkeypatch, caplog): + caplog.set_level(logging.INFO, logger="flash_attn") + original_level = fa_logging.get_fa_log_level() + key = ("test-key",) + cache = cache_utils.JITPersistentCache(tmp_path) + obj_path = tmp_path / f"{cache._key_to_hash(key)}.o" + obj_path.write_bytes(b"cache-hit") + monkeypatch.setattr( + cache_utils.cute.runtime, + "load_module", + lambda *_args, **_kwargs: SimpleNamespace(func=object()), + ) + try: + monkeypatch.setattr(fa_logging, "_fa_log_level", 0) + assert cache_utils.JITPersistentCache(tmp_path)._try_load_from_storage(key) + assert "Loading compiled function from disk" not in caplog.text + + caplog.clear() + monkeypatch.setattr(fa_logging, "_fa_log_level", 1) + assert cache_utils.JITPersistentCache(tmp_path)._try_load_from_storage(key) + assert "Loading compiled function from disk" in caplog.text + finally: + monkeypatch.setattr(fa_logging, "_fa_log_level", original_level) diff --git a/tests/cute/test_clc_fuzz.py b/tests/cute/test_clc_fuzz.py index 022276d3281..c988681da3b 100644 --- a/tests/cute/test_clc_fuzz.py +++ b/tests/cute/test_clc_fuzz.py @@ -17,7 +17,12 @@ from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func from flash_attn.cute.testing import attention_ref -from flash_attn.cute.tile_scheduler import SchedulingMode, SingleTileLPTScheduler, SingleTileVarlenScheduler +from flash_attn.cute.tile_scheduler import ( + SchedulingMode, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + StaticPersistentTileScheduler, +) if torch.cuda.is_available(): @@ -60,10 +65,19 @@ def check_output(q, k, v, *, causal=False, window_size=(None, None), num_splits= torch.cuda.synchronize() if assert_clc and _captured_schedulers: sched_cls, sched_mode, use_2cta = _captured_schedulers[-1] - assert sched_cls is SingleTileLPTScheduler, f"Expected SingleTileLPTScheduler, got {sched_cls.__name__}" - assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" - if assert_2cta: - assert use_2cta, "Expected use_2cta_instrs=True but got False" + is_local = window_size != (None, None) + if causal or is_local: + assert sched_cls is SingleTileLPTScheduler, f"Expected SingleTileLPTScheduler, got {sched_cls.__name__}" + assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + if assert_2cta: + assert use_2cta, "Expected use_2cta_instrs=True but got False" + else: + assert sched_cls is StaticPersistentTileScheduler, ( + f"Expected StaticPersistentTileScheduler for dense noncausal, got {sched_cls.__name__}" + ) + assert sched_mode == SchedulingMode.STATIC, ( + f"Expected STATIC scheduling mode for dense noncausal, got {sched_mode!r}" + ) out_ref, _ = attention_ref(q, k, v, causal=causal, window_size=window_size) out_pt, _ = attention_ref(q, k, v, causal=causal, window_size=window_size, upcast=False, reorder_ops=True) fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() @@ -249,7 +263,7 @@ def test_overlap_sO_sQ_fallback(self): class TestCLCFallback: - def test_varlen_uses_clc(self): + def test_varlen_mha_uses_static(self): _captured_schedulers.clear() batch, seqlen, heads, d = 4, 256, 4, 128 lens = torch.tensor([64, 128, 32, 32], dtype=torch.int32) @@ -271,7 +285,7 @@ def test_varlen_uses_clc(self): assert sched_cls is SingleTileVarlenScheduler, ( f"Expected SingleTileVarlenScheduler for varlen, got {sched_cls.__name__}" ) - assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + assert sched_mode == SchedulingMode.STATIC, f"Expected STATIC scheduling mode, got {sched_mode!r}" @pytest.mark.parametrize("sq,sk,wl,wr", [ (512, 512, 128, 128), @@ -311,7 +325,10 @@ def check_varlen_output(seqlens, heads, d, *, causal=False, kv_heads=None, num_s if _captured_schedulers: sched_cls, sched_mode, *_ = _captured_schedulers[-1] assert sched_cls is SingleTileVarlenScheduler, f"Expected SingleTileVarlenScheduler, got {sched_cls.__name__}" - assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + expected_sched_mode = SchedulingMode.CLC if heads != kv_heads else SchedulingMode.STATIC + assert sched_mode == expected_sched_mode, ( + f"Expected {expected_sched_mode.name} scheduling mode, got {sched_mode!r}" + ) for i in range(len(seqlens)): s = slice(cu_seqlens[i], cu_seqlens[i + 1]) @@ -355,7 +372,10 @@ def check_varlen_output_seqused(seqlens, heads, d, *, causal=False, kv_heads=Non if _captured_schedulers: sched_cls, sched_mode, *_ = _captured_schedulers[-1] assert sched_cls is SingleTileVarlenScheduler, f"Expected SingleTileVarlenScheduler, got {sched_cls.__name__}" - assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + expected_sched_mode = SchedulingMode.CLC if heads != kv_heads else SchedulingMode.STATIC + assert sched_mode == expected_sched_mode, ( + f"Expected {expected_sched_mode.name} scheduling mode, got {sched_mode!r}" + ) out_ref, _ = attention_ref(q, k, v, q_mask, k_mask, causal=causal) out_pt, _ = attention_ref(q, k, v, q_mask, k_mask, causal=causal, upcast=False, reorder_ops=True) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 7b1b510d08a..57c98134bca 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -1,10 +1,12 @@ -# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao, Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms, Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin. import math import itertools import os import random import re +import gc +from functools import wraps import pytest import torch @@ -28,8 +30,28 @@ from flash_attn.cute.interface import ( flash_attn_func, flash_attn_varlen_func, + _flash_attn_fwd, + _flash_attn_bwd, ) +def retry_on_oom(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except torch.OutOfMemoryError as e: + if "out of memory" in str(e).lower(): + if hasattr(_flash_attn_fwd, "compile_cache"): + _flash_attn_fwd.compile_cache.clear() + if hasattr(_flash_attn_bwd, "compile_cache"): + _flash_attn_bwd.compile_cache.clear() + gc.collect() + torch.cuda.empty_cache() + return func(*args, **kwargs) + else: + raise + return wrapper + # torch FakeTensorMode would enable fast cutedsl kernel compilation without allocating the actual GPU memory or running the kernel # When operating fake tensors, we cannot perform data-dependent operations (e.g., `tensor.max()`). USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 @@ -49,8 +71,8 @@ @pytest.mark.parametrize("has_qv", [False]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("softcap", [0.0, 15.0]) -@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 15.0]) +# @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) # @pytest.mark.parametrize("local_enum", [0]) @pytest.mark.parametrize("causal", [False, True]) @@ -72,9 +94,11 @@ (3, 3), (64, 32), (64, 128), + (64, 1), # SM100 hd256 2CTA test case (128, 128), (128, 192), (256, 256), + (255, 256), # SM100 hd256 2CTA test case (239, 1), (799, 3), (113, 203), @@ -95,6 +119,7 @@ ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +@retry_on_oom @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_output( seqlen_q, @@ -112,6 +137,25 @@ def test_flash_attn_output( local = local_enum > 0 if local and causal: pytest.skip() + if has_qv and d != 64: + pytest.skip() + if has_qv and local: + pytest.xfail("has_qv: local not supported yet") + if has_qv and has_learnable_sink: + pytest.xfail("has_qv: learnable sink not supported yet") + # TODO(wangsiyu): SM100 head_dim=256 2CTA kernel currently does not support the following features. + # Remove these skips when support is added. + if d == 256 and IS_SM100: + if has_learnable_sink: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support learnable_sink yet") + if local: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support local attention yet") + if softcap > 0.0: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support softcap yet") + if deterministic: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support deterministic mode yet") + if causal and seqlen_q > seqlen_k: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support causal attention with seqlen_q > seqlen_k yet") device = "cuda" # set seed seed = 0 @@ -121,14 +165,20 @@ def test_flash_attn_output( torch.cuda.synchronize() batch_size = 9 if seqlen_k <= 2048 else 2 # batch_size = 2 - nheads = 6 + nheads = 6 if not has_qv else 128 # nheads = 1 - nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + if not has_qv: + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + else: + nheads_kv = nheads if mha_type == "mha" else (8 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] + if has_qv: + assert d == 64 + dv_vals = [512] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): @@ -244,19 +294,30 @@ def test_flash_attn_output( print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # num_splits_vals = [1, 3] - pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False] + pack_gqa_vals = [True] if has_qv else [False, True, None] if not TEST_BWD_ONLY else [False] # SplitKV is not supported for hdim >= 192 # pack_gqa_vals = [False] - num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY and not has_qv else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - if IS_SM100 and (d >= 192 and dv >= 192): # hdim 192 and 256 not support on SM100 + # SplitKV not supported on SM90 - skip this iteration + if IS_SM90 and num_splits > 1: + continue + if IS_SM100 and (d >= 192 and dv >= 192) and not (d == 256 and dv == 256): continue + # TODO(wangsiyu): SM100 head_dim=256 2CTA kernel does not support pack_gqa yet. + # pack_gqa=None means auto-enable for GQA/MQA (qhead_per_kvhead > 1) + # Remove this when support is added. + if d == 256 and IS_SM100: + if pack_gqa is True: + continue + if pack_gqa is None and mha_type != "mha": + continue out, lse = flash_attn_func( q, k, v, + qv=qv, causal=causal, - # qv=qv, # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, # attention_chunk=attention_chunk, @@ -288,7 +349,11 @@ def test_flash_attn_output( and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 - and ((dv == d and d <= 128) or (d == 192 and dv == 128)) + and ( + (dv == d and d <= 128) + or (d == 192 and dv == 128) + or (IS_SM100 and d == 256 and dv == 256) + ) and learnable_sink is None # and False and not ((causal or local) and seqlen_k < seqlen_q) @@ -384,8 +449,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("has_qv", [False]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("softcap", [0.0, 15.0]) -@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 15.0]) +# @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) # @pytest.mark.parametrize("local_enum", [0]) @pytest.mark.parametrize("causal", [False, True]) @@ -399,7 +464,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) # @pytest.mark.parametrize("d", [128, 192]) -@pytest.mark.parametrize("d", [64, 128, 192]) +@pytest.mark.parametrize("d", [64, 128, 192, 256]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -423,6 +488,11 @@ def test_flash_attn_output( (1023, 1024), (1024, 1023), (2048, 2048), + # SM100 hd256 2CTA test cases + (64, 1), + (255, 256), + (4096, 4096), + (4224, 4224), ], ) @pytest.mark.parametrize("varlen_mode", ["random", "third", "full"]) @@ -445,6 +515,7 @@ def test_flash_attn_output( (False, True), ], ) +@retry_on_oom @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_varlen_output( seqlen_q, @@ -468,6 +539,23 @@ def test_flash_attn_varlen_output( local = local_enum > 0 if local and causal: pytest.skip() + # TODO(wangsiyu): SM100 head_dim=256 2CTA kernel currently does not support the following features. + # Remove these skips when support is added. + if d == 256 and IS_SM100: + if has_learnable_sink: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support learnable_sink yet") + if local: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support local attention yet") + if softcap > 0.0: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support softcap yet") + if deterministic: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support deterministic mode yet") + if causal and seqlen_q > seqlen_k: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support causal attention with seqlen_q > seqlen_k yet") + if zero_lengths_q or zero_lengths_k: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support zero-length sequences yet") + if not unpad_q or not unpad_kv: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support seqused_q/seqused_k mode yet (requires unpad_q=True and unpad_kv=True)") if ( causal or local ): # Right now reference only supports causal attention with seqlen_k == seqlen_q @@ -484,6 +572,8 @@ def test_flash_attn_varlen_output( dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + if d == 256: + dv_vals = [256] # SM100 hd=256 2CTA kernel only supports dv=256 if dtype == torch.float8_e4m3fn: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] @@ -678,6 +768,17 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + # SplitKV not supported on SM90 - skip this iteration + if IS_SM90 and num_splits > 1: + continue + # TODO(wangsiyu): SM100 head_dim=256 2CTA kernel does not support pack_gqa yet. + # pack_gqa=None means auto-enable for GQA/MQA (qhead_per_kvhead > 1) + # Remove this when support is added. + if d == 256 and IS_SM100: + if pack_gqa is True: + continue + if pack_gqa is None and mha_type != "mha": + continue out_unpad, lse = flash_attn_varlen_func( q_unpad if unpad_q else q, k_unpad if unpad_kv else k, @@ -734,7 +835,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): and not has_qv and not dv > 256 and not attention_chunk != 0 - and ((dv == d and d <= 128) or (d == 192 and dv == 128)) + and ( + (dv == d and d <= 128) + or (d == 192 and dv == 128) + or (IS_SM100 and d == 256 and dv == 256) + ) and not has_learnable_sink # and False ): @@ -920,6 +1025,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@retry_on_oom @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_kvcache( seqlen_q, @@ -1680,6 +1786,207 @@ def test_flash_attn_paged_deepseek(seqlen_q, page_size): assert torch.equal(out, out_ref) +@pytest.mark.parametrize("seqlen_q", [128, 512, 2048]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_paged_hd256_sm100_tma(seqlen_q): + """TMA paged KV in the SM100 hd256 2CTA forward kernel. + + Verifies paged KV (page_table + TMA) matches the non-paged varlen reference + and is deterministic across runs. page_size must equal tile_n=128. + """ + if not IS_SM100: + pytest.skip("SM100-specific paged hd256 test") + device = "cuda" + dtype = torch.bfloat16 + d = 256 + batch_size = 2 + nheads = 16 + nheads_kv = 16 + page_size = 128 + assert seqlen_q % page_size == 0 + + torch.random.manual_seed(0) + q = torch.randn(batch_size * seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size * seqlen_q, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size * seqlen_q, nheads_kv, d, device=device, dtype=dtype) + cu_seqlens_q = torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) * seqlen_q + cu_seqlens_k = cu_seqlens_q.clone() + + # Non-paged reference (varlen). + out_ref, _ = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + ) + + # Repack into paged layout: (total_pages, page_size, nheads_kv, d). + num_pages_per_seq = seqlen_q // page_size + total_pages = batch_size * num_pages_per_seq + k_paged = torch.zeros(total_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + v_paged = torch.zeros(total_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + for b in range(batch_size): + for s in range(seqlen_q): + pi = b * num_pages_per_seq + s // page_size + po = s % page_size + k_paged[pi, po] = k[b * seqlen_q + s] + v_paged[pi, po] = v[b * seqlen_q + s] + page_table = torch.arange(total_pages, dtype=torch.int32, device=device).reshape( + batch_size, num_pages_per_seq + ) + + # Paged via hd256 2CTA TMA paged path — run twice for determinism. + out_paged_0, _ = flash_attn_varlen_func( + q, k_paged, v_paged, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=None, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + page_table=page_table, + ) + out_paged_1, _ = flash_attn_varlen_func( + q, k_paged, v_paged, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=None, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + page_table=page_table, + ) + + if is_fake_mode(): + return + + print(f"Paged vs non-paged max diff: {(out_paged_0 - out_ref).abs().max().item()}") + print(f"Paged determinism diff: {(out_paged_1 - out_paged_0).abs().max().item()}") + assert torch.allclose(out_paged_0, out_ref, atol=1e-3, rtol=1e-3), "Paged output does not match non-paged reference" + assert torch.equal(out_paged_1, out_paged_0), "Paged output is not deterministic" + + +@pytest.mark.parametrize("nheads_kv", [2, 4, 8]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_paged_hd256_sm100_tma_gqa(nheads_kv): + """TMA paged KV for SM100 hd256 2CTA with GQA (nheads_q > nheads_kv). + + Exercises the head_kv_coord derivation for qhead_per_kvhead > 1 — the MHA + test passes by coincidence since modulo and integer division agree when + qhead_per_kvhead == 1. + """ + if not IS_SM100: + pytest.skip("SM100-specific paged hd256 test") + device = "cuda" + dtype = torch.bfloat16 + d = 256 + batch_size = 2 + nheads = 16 + page_size = 128 + seqlen_q = 512 + assert nheads % nheads_kv == 0 and seqlen_q % page_size == 0 + + torch.random.manual_seed(0) + q = torch.randn(batch_size * seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size * seqlen_q, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size * seqlen_q, nheads_kv, d, device=device, dtype=dtype) + cu_seqlens_q = torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) * seqlen_q + cu_seqlens_k = cu_seqlens_q.clone() + + out_ref, _ = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + ) + + num_pages_per_seq = seqlen_q // page_size + total_pages = batch_size * num_pages_per_seq + k_paged = torch.zeros(total_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + v_paged = torch.zeros(total_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + for b in range(batch_size): + for s in range(seqlen_q): + pi = b * num_pages_per_seq + s // page_size + po = s % page_size + k_paged[pi, po] = k[b * seqlen_q + s] + v_paged[pi, po] = v[b * seqlen_q + s] + page_table = torch.arange(total_pages, dtype=torch.int32, device=device).reshape( + batch_size, num_pages_per_seq + ) + + out_paged, _ = flash_attn_varlen_func( + q, k_paged, v_paged, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=None, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + page_table=page_table, + ) + + if is_fake_mode(): + return + + print(f"GQA nheads_kv={nheads_kv} paged vs non-paged max diff: {(out_paged - out_ref).abs().max().item()}") + assert torch.allclose(out_paged, out_ref, atol=1e-3, rtol=1e-3), ( + f"Paged GQA output does not match non-paged reference (nheads_kv={nheads_kv})" + ) + + +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_paged_hd256_sm100_tma_shuffled(): + """TMA paged KV for SM100 hd256 2CTA with a non-identity (shuffled) page_table. + + An identity page_table passes even if the kernel ignores it. This test + shuffles physical pages so a kernel that bypasses page_table would silently + read wrong data, proving the remapping path is exercised. + """ + if not IS_SM100: + pytest.skip("SM100-specific paged hd256 test") + device = "cuda" + dtype = torch.bfloat16 + d = 256 + batch_size = 2 + nheads = 16 + nheads_kv = 16 + page_size = 128 + seqlen_q = 512 + num_pages_per_seq = seqlen_q // page_size + total_pages = batch_size * num_pages_per_seq + + torch.random.manual_seed(42) + q = torch.randn(batch_size * seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size * seqlen_q, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size * seqlen_q, nheads_kv, d, device=device, dtype=dtype) + cu_seqlens_q = torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) * seqlen_q + cu_seqlens_k = cu_seqlens_q.clone() + + out_ref, _ = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + ) + + # Shuffle physical pages: reverse order within each batch item. + # Build as Python list of ints to avoid .item() calls on FakeTensors during compilation. + perm = [ + list(range((b + 1) * num_pages_per_seq - 1, b * num_pages_per_seq - 1, -1)) + for b in range(batch_size) + ] + page_table = torch.tensor(perm, dtype=torch.int32, device=device) + + k_paged = torch.zeros(total_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + v_paged = torch.zeros(total_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + for b in range(batch_size): + for s in range(seqlen_q): + phys = perm[b][s // page_size] # Python int, safe in FakeTensorMode + po = s % page_size + k_paged[phys, po] = k[b * seqlen_q + s] + v_paged[phys, po] = v[b * seqlen_q + s] + + out_paged, _ = flash_attn_varlen_func( + q, k_paged, v_paged, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=None, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + page_table=page_table, + ) + + if is_fake_mode(): + return + + print(f"Shuffled paged vs non-paged max diff: {(out_paged - out_ref).abs().max().item()}") + assert torch.allclose(out_paged, out_ref, atol=1e-3, rtol=1e-3), ( + "Shuffled paged output does not match non-paged reference" + ) + + @pytest.mark.parametrize("head_dim", [4, 148, 288]) def test_flash_attn_invalid_head_dim(head_dim): device = "cuda" @@ -1692,3 +1999,743 @@ def test_flash_attn_invalid_head_dim(head_dim): with pytest.raises(AssertionError, match=re.escape(f"(head_dim, head_dim_v)=({head_dim}, {head_dim}) is not supported on SM")): flash_attn_func(q, k, v) + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mqa"]) +@pytest.mark.parametrize("has_learnable_sink", [False]) +@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("local_enum", [0, 1]) +@pytest.mark.parametrize("local_enum", [0]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("nheads", [16, 128]) +@pytest.mark.parametrize("kv_sparsity", [False, True]) +# @pytest.mark.parametrize("kv_sparsity", [True]) +@pytest.mark.parametrize("gather_kv_length", [1024, 2048]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (3, 3), + (64, 32), + (64, 128), + (128, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + (1, 8192), + (4096, 4096), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_mla_absorbed( + seqlen_q, + seqlen_k, + d, + nheads, + causal, + local_enum, + softcap, + deterministic, + has_learnable_sink, + mha_type, + dtype, + kv_sparsity, + gather_kv_length, +): + has_qv = True + if not IS_SM100: + pytest.skip() + if kv_sparsity and seqlen_k < gather_kv_length: + seqlen_k += gather_kv_length + local = local_enum > 0 + if local and causal: + pytest.skip() + if local: + pytest.xfail("mla absorbed: local not supported yet") + if kv_sparsity and nheads != 128: + pytest.skip() + device = "cuda" + # set seed + seed = 0 + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.empty_cache() + torch.cuda.synchronize() + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 2 + # nheads = 128 + nheads_kv = nheads if mha_type == "mha" else (8 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [512] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = q_ref * softcap / 4 + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + if kv_sparsity: + gather_kv_indices = torch.rand(batch_size, seqlen_q, gather_kv_length, device=device).argsort(dim=-1).to(torch.int32) + else: + gather_kv_indices = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else tuple(random.randrange(0, seqlen_k) for _ in range(2)) + ) + if local_enum == 2: + window_size = (None, -window_size[1]) + elif local_enum == 3: + window_size = (-window_size[0], None) + if local: + print("window size = ", window_size) + # window_size = (-1, -1) if not local else (16, 0) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + gather_kv_indices=gather_kv_indices, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + gather_kv_indices=gather_kv_indices, + ) + + # k_extended = repeat(k_ref, "b s h d -> b s (h k) d", k=nheads // nheads_kv) + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float() + # # if qv is not None: + # # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + if not is_fake_mode(): + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + num_splits_vals = [1] + pack_gqa_vals = [True] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out, lse = flash_attn_func( + q, + k, + v, + qv=qv, + gather_kv_indices=gather_kv_indices, + causal=causal, + # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + softcap=softcap, + learnable_sink=learnable_sink, + pack_gqa=pack_gqa, + num_splits=num_splits, + deterministic=deterministic, + ) + if is_fake_mode(): + # no more flash_attn cutedsl calls for the rest of the loop + # skip data-dependent postprocessing + continue + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + assert not torch.isnan(lse).any(), "LSE contains NaN" + + repeats = 1000 + for iter in range(repeats): + out2, lse2 = flash_attn_func( + q, + k, + v, + qv=qv, + gather_kv_indices=gather_kv_indices, + causal=causal, + # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + softcap=softcap, + learnable_sink=learnable_sink, + pack_gqa=pack_gqa, + num_splits=num_splits, + deterministic=deterministic, + ) + # print(f"out max: {out.abs().max().item()}, {iter=}") + # print(f"out vs out2 max diff: {(out - out2).abs().max().item()}, {iter=}") + # print(f"out vs out2 mean diff: {(out - out2).abs().mean().item()}, {iter=}") + assert torch.equal(out, out2), f"non-deterministic with max diff = {(out - out2).abs().max().item()} on {iter=}" + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mqa"]) +@pytest.mark.parametrize("has_learnable_sink", [False]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local_enum", [0]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("add_unused_qkv", [False, True]) +@pytest.mark.parametrize("add_unused_qkv", [False]) +@pytest.mark.parametrize("kv_sparsity", [False, True]) +# @pytest.mark.parametrize("kv_sparsity", [False]) +@pytest.mark.parametrize("gather_kv_length", [1024, 2048]) +@pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("nheads", [16, 128]) +# @pytest.mark.parametrize("nheads", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + # (1, 1), + # (1, 3), + # (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +@pytest.mark.parametrize("varlen_mode", ["random", "full"]) +# @pytest.mark.parametrize("varlen_mode", ["random"]) +@pytest.mark.parametrize( + "zero_lengths_q, zero_lengths_k", + [ + (False, False), + # (True, False), + ], +) +@pytest.mark.parametrize( + "unpad_q, unpad_kv", + [ + (True, True), + (True, False), + (False, False), + (False, True), + ], +) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_mla_absorbed_varlen( + seqlen_q, + seqlen_k, + d, + nheads, + add_unused_qkv, + causal, + local_enum, + softcap, + deterministic, + has_learnable_sink, + mha_type, + dtype, + varlen_mode, + zero_lengths_q, + zero_lengths_k, + unpad_q, + unpad_kv, + kv_sparsity, + gather_kv_length, +): + has_qv = True + if not IS_SM100: + pytest.skip() + if kv_sparsity and seqlen_k < gather_kv_length: + seqlen_k += gather_kv_length + local = local_enum > 0 + if local and causal: + pytest.skip() + if has_qv and local: + pytest.xfail("has_qv: local not supported yet") + if kv_sparsity and nheads != 128: + pytest.skip() + seqlen_q_og = seqlen_q + seqlen_k_og = seqlen_k + if ( + causal or local + ): # Right now reference only supports causal attention with seqlen_k == seqlen_q + seqlen_q = max(seqlen_q_og, seqlen_k_og) + seqlen_k = max(seqlen_q_og, seqlen_k_og) + device = "cuda" + # set seed + seed = seqlen_q + seqlen_k + d + int(causal) * 2 + int(local) + random.seed(seed) + torch.random.manual_seed(seed) + batch_size = 7 if seqlen_q <= 512 else 3 + nheads_kv = nheads if mha_type == "mha" else (8 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [512] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + if kv_sparsity: + gather_kv_indices = torch.rand(batch_size, seqlen_q, gather_kv_length, device=device).argsort(dim=-1).to(torch.int32) + else: + gather_kv_indices = None + + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else tuple(random.randrange(0, seqlen_k) for _ in range(2)) + ) + if local_enum == 2: + window_size = (None, window_size[1]) + elif local_enum == 3: + window_size = (window_size[0], None) + if local: + print("window size = ", window_size) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, + batch_size, + device, + mode=varlen_mode, + zero_lengths=zero_lengths_q, + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, + batch_size, + device, + mode=varlen_mode, + zero_lengths=zero_lengths_k, + min_seqlen=gather_kv_length if kv_sparsity else None, + ) + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + # query_padding_mask[:] = True + # query_unused_mask = None + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + if causal or local: + key_padding_mask = query_padding_mask + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv( + q, + k, + v, + query_padding_mask, + key_padding_mask, + qv=qv, + kvpacked=False, + query_unused_mask=query_unused_mask, + key_unused_mask=key_unused_mask, + ) + # unpad gather_kv_indices + if kv_sparsity: + _, indices_q, _, _, _ = unpad_input( + q, query_padding_mask, query_unused_mask + ) + gather_kv_indices_unpad = rearrange(gather_kv_indices, "b s ... -> (b s) ...")[indices_q] + else: + gather_kv_indices_unpad = None + if unpad_q: + print("cu_seqlens_q = ", cu_seqlens_q) + else: + print("seqused_q = ", seqused_q) + if unpad_kv: + print("cu_seqlens_k = ", cu_seqlens_k) + else: + print("seqused_k = ", seqused_k) + q_unpad, k_unpad, v_unpad = [ + x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) + ] + + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + gather_kv_indices=gather_kv_indices, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + gather_kv_indices=gather_kv_indices, + ) + + if not is_fake_mode(): + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [True] + num_splits_vals = [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + # SplitKV not supported on SM90 - skip this iteration + if IS_SM90 and num_splits > 1: + continue + out_unpad, lse = flash_attn_varlen_func( + q_unpad if unpad_q else q, + k_unpad if unpad_kv else k, + v_unpad if unpad_kv else v, + qv_unpad if unpad_q else qv, + cu_seqlens_q=cu_seqlens_q if unpad_q else None, + cu_seqlens_k=cu_seqlens_k if unpad_kv else None, + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, + min_seqlen_k=gather_kv_length if kv_sparsity else None, + seqused_q=seqused_q if not unpad_q else None, + seqused_k=seqused_k if not unpad_kv else None, + causal=causal, + window_size=window_size, + learnable_sink=learnable_sink, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + deterministic=deterministic, + gather_kv_indices=gather_kv_indices_unpad if unpad_q else gather_kv_indices, + ) + out = output_pad_fn(out_unpad) if unpad_q else out_unpad + if is_fake_mode(): + # no more flash_attn cutedsl calls for the rest of the loop + # skip data-dependent postprocessing + continue + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + # When unpad_q=False with seqused_q, the kernel doesn't write positions + # beyond seqused_q, so those contain uninitialized values. Mask them out + # before comparing. + out_cmp, out_ref_cmp, out_pt_cmp = out, out_ref, out_pt + if not unpad_q and seqused_q is not None: + seqused_mask = torch.arange(seqlen_q, device=device)[None, :] < seqused_q[:, None] + seqused_mask = rearrange(seqused_mask, "b s -> b s 1 1") + out_cmp = out.clone().masked_fill_(~seqused_mask, 0.0) + out_ref_cmp = out_ref.clone().masked_fill_(~seqused_mask, 0.0) + out_pt_cmp = out_pt.clone().masked_fill_(~seqused_mask, 0.0) + print(f"Output max diff: {(out_cmp - out_ref_cmp).abs().max().item()}") + print(f"Output mean diff: {(out_cmp - out_ref_cmp).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out_cmp - out_ref_cmp).abs().max().item() <= rtol * ( + out_pt_cmp - out_ref_cmp + ).abs().max().item() + fwd_atol + # LSE sanity: only valid positions (packed unpad path; padded path + # can legitimately contain uninit tail beyond seqused_q). + if unpad_q: + assert not torch.isnan(lse).any(), "LSE contains NaN" + + repeats = 1000 + for iter in range(repeats): + out_unpad2, lse = flash_attn_varlen_func( + q_unpad if unpad_q else q, + k_unpad if unpad_kv else k, + v_unpad if unpad_kv else v, + qv_unpad if unpad_q else qv, + cu_seqlens_q=cu_seqlens_q if unpad_q else None, + cu_seqlens_k=cu_seqlens_k if unpad_kv else None, + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, + min_seqlen_k=gather_kv_length if kv_sparsity else None, + seqused_q=seqused_q if not unpad_q else None, + seqused_k=seqused_k if not unpad_kv else None, + causal=causal, + window_size=window_size, + learnable_sink=learnable_sink, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + deterministic=deterministic, + gather_kv_indices=gather_kv_indices_unpad if unpad_q else gather_kv_indices, + ) + out2 = output_pad_fn(out_unpad2) if unpad_q else out_unpad2 + if query_unused_mask is not None: + out2.masked_fill_(q_zero_masking, 0.0) + # When unpad_q=False with seqused_q, the kernel doesn't write positions + # beyond seqused_q, so those contain uninitialized values. Mask them out + # before comparing. + if not unpad_q and seqused_q is not None: + seqused_mask = torch.arange(seqlen_q, device=device)[None, :] < seqused_q[:, None] + seqused_mask = rearrange(seqused_mask, "b s -> b s 1 1") + out2.masked_fill_(~seqused_mask, 0.0) + # print(f"out2 max: {out2.abs().max().item()}, {iter=}") + # print(f"out vs out2 max diff: {(out_cmp - out2).abs().max().item()}, {iter=}") + # print(f"out vs out2 mean diff: {(out_cmp - out2).abs().mean().item()}, {iter=}") + assert torch.equal(out_cmp, out2), f"non-deterministic with max diff = {(out_cmp - out2).abs().max().item()} on {iter=}" + + +# --------------------------------------------------------------------------- +# Regression test: seqlen_k=0 must not crash (CUDA graph padding scenario) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("seqlen_q", [1, 64, 128, 256]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_seqlen_k_zero(seqlen_q, d, causal): + """K/V with physical seqlen dim == 0 must not crash. + + seqlen_k == 0 violates two downstream invariants, producing two + different crashes depending on the mask: + + causal=False -> TMA descriptor over a 0-length K tensor goes OOB + on first tile load -> PTX IllegalInstruction. + + causal=True -> SingleTileLPTScheduler's L2-swizzle heuristic in + tile_scheduler.py evaluates + size_l2 // (seqlen_k * (d + d_v) * elem_size) + -> host SIGFPE before the kernel launches. + + Varlen paths (cu_seqlens_k / seqused_k with K physical seqlen > 0) + are not exercised here: per-batch empty slots are already handled + by the kernel's fake-iteration path and do not hit either invariant. + """ + if IS_SM90: + pytest.skip("SM90 uses a different kernel path") + + device = "cuda" + dtype = torch.bfloat16 + dv = 128 if d == 192 else d + batch_size = 4 + nheads = 16 + nheads_kv = 16 + + torch.manual_seed(0) + + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + # K/V have physical seqlen dim == 0 — this is what crashes on unpatched FA4. + # causal=False hits GPU IllegalInstruction (TMA OOB on 0-length K tensor). + # causal=True hits host SIGFPE in tile_scheduler.py LPT L2-swizzle heuristic + # (size_l2 // size_one_head with size_one_head = seqlen_k*... = 0). + k = torch.empty(batch_size, 0, nheads_kv, d, device=device, dtype=dtype) + v = torch.empty(batch_size, 0, nheads_kv, dv, device=device, dtype=dtype) + + out, lse = flash_attn_func(q, k, v, causal=causal) + + if is_fake_mode(): + return + + # No crash above already validates the fix. Below validates the contract + # the early-return promises: zero output, -inf LSE. + assert out.shape == (batch_size, seqlen_q, nheads, dv), \ + f"Unexpected output shape: {out.shape}" + assert torch.all(out == 0).item(), \ + f"Expected all-zero output when seqlen_k=0, got max={out.abs().max().item():.6f}" + if lse is not None: + assert torch.all(torch.isinf(lse) & (lse < 0)).item(), \ + f"Expected all -inf LSE when seqlen_k=0, got: {lse}" diff --git a/tests/cute/test_flash_attn_fast.py b/tests/cute/test_flash_attn_fast.py index 433859d94d8..32deb4b5168 100644 --- a/tests/cute/test_flash_attn_fast.py +++ b/tests/cute/test_flash_attn_fast.py @@ -139,8 +139,8 @@ def test_flash_attn_varlen_output(seqlen, d, causal, mha_type, dtype): out_varlen, lse = flash_attn_varlen_func( q_varlen, k_varlen, v_varlen, - cu_seqlens, cu_seqlens, - seqlen, seqlen, + cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=seqlen, max_seqlen_k=seqlen, causal=causal, ) diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py index a9b8799f4c1..12f4c15ae3e 100644 --- a/tests/cute/test_flash_attn_race_condition.py +++ b/tests/cute/test_flash_attn_race_condition.py @@ -43,8 +43,8 @@ @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [True]) -# @pytest.mark.parametrize("softcap", [0.0, 15.0]) -@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 15.0]) +# @pytest.mark.parametrize("softcap", [0.0]) # @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) @pytest.mark.parametrize("local_enum", [0, 1]) @pytest.mark.parametrize("causal", [False, True]) @@ -251,8 +251,6 @@ def test_flash_attn_output( and learnable_sink is None # and False ): - if IS_SM90 and mha_type != "mha": - pytest.xfail("SM90 backward: GQA/MQA has tensor layout issue (qhead_per_kvhead > 1)") if IS_SM90 and local: pytest.xfail("SM90 backward: local attention not supported yet") g = torch.randn_like(out) @@ -356,8 +354,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [True]) -# @pytest.mark.parametrize("softcap", [0.0, 15.0]) -@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 15.0]) +# @pytest.mark.parametrize("softcap", [0.0]) # @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) @pytest.mark.parametrize("local_enum", [0, 1]) @pytest.mark.parametrize("causal", [False, True]) diff --git a/tests/cute/test_flash_attn_varlen.py b/tests/cute/test_flash_attn_varlen.py index 24e55315671..ad6dcd6da43 100644 --- a/tests/cute/test_flash_attn_varlen.py +++ b/tests/cute/test_flash_attn_varlen.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("B", [1, 7, 20]) @pytest.mark.parametrize("H", [1, 4, 6]) -@pytest.mark.parametrize("D", [64, 128]) +@pytest.mark.parametrize("D", [64, 128, 256]) @pytest.mark.parametrize("min_seq_len", [1, 32, 128]) @pytest.mark.parametrize("max_seq_len", [8, 64, 2048]) @pytest.mark.parametrize("causal", [True, False]) diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 26e0a5e1353..ceef6500b97 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -22,11 +22,13 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention import torch.nn.functional as F -from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd +from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd, flash_attn_func from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, fast_sampling, normalize_block_sparse_config, + compute_dq_write_order, + compute_dq_write_order_from_block_mask, ) from flash_attn.cute.cache_utils import get_jit_cache from flash_attn.cute import utils @@ -253,6 +255,7 @@ def _run_mask_test( tile_n, use_block_sparsity, needs_backward=False, + use_autograd=False, ): torch.manual_seed(42) @@ -402,35 +405,55 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): else None ) - out_tuple = _flash_attn_fwd( - q=tensors["q"], - k=tensors["k"], - v=tensors["v"], - out=tensors["out"], - lse=tensors["lse"], - cu_seqlens_q=None, - cu_seqlens_k=None, - seqused_q=None, - seqused_k=None, - page_table=None, - softmax_scale=softmax_scale, - causal=causal, - softcap=None, - window_size_left=window_left, - window_size_right=window_right, - learnable_sink=None, - tile_mn=(tile_m, tile_n), - pack_gqa=pack_gqa, - _arch=None, - score_mod=None, - mask_mod=mask_mod_cute, - block_sparse_tensors=block_sparse_mask_fwd, - return_lse=True, - aux_tensors=aux_tensors_arg, - ) + if use_autograd: + q_ag = tensors["q"].detach().requires_grad_(True) + k_ag = tensors["k"].detach().requires_grad_(True) + v_ag = tensors["v"].detach().requires_grad_(True) - out_cute = out_tuple[0] - lse_cute = out_tuple[1] + out_cute, lse_cute = flash_attn_func( + q_ag, + k_ag, + v_ag, + softmax_scale=softmax_scale, + causal=causal, + window_size=(window_left, window_right), + pack_gqa=pack_gqa, + mask_mod=mask_mod_cute, + aux_tensors=aux_tensors_arg, + block_sparse_tensors=block_sparse_mask_fwd, + block_sparse_tensors_bwd=block_sparse_mask_bwd, + return_lse=True, + ) + else: + out_tuple = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"], + lse=tensors["lse"], + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=softmax_scale, + causal=causal, + softcap=None, + window_size_left=window_left, + window_size_right=window_right, + learnable_sink=None, + tile_mn=(tile_m, tile_n), + pack_gqa=pack_gqa, + _arch=None, + score_mod=None, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, + aux_tensors=aux_tensors_arg, + ) + + out_cute = out_tuple[0] + lse_cute = out_tuple[1] tensors_fp32 = { k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v for k, v in tensors.items() @@ -488,11 +511,16 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): device="cuda", BLOCK_SIZE=(tile_m, tile_n), ) - dq_cute, dk_cute, dv_cute = run_cute_mask_bwd( - q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute, - block_sparse_mask_bwd=block_sparse_mask_bwd, tile_m=tile_m, tile_n=tile_n, - aux_tensors=aux_tensors_arg, - ) + if use_autograd: + dq_cute, dk_cute, dv_cute = torch.autograd.grad( + out_cute, (q_ag, k_ag, v_ag), grad_out + ) + else: + dq_cute, dk_cute, dv_cute = run_cute_mask_bwd( + q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute, + block_sparse_mask_bwd=block_sparse_mask_bwd, tile_m=tile_m, tile_n=tile_n, + aux_tensors=aux_tensors_arg, + ) _, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, flex_block_mask, grad_out, dtype=torch.float32 ) @@ -651,8 +679,12 @@ def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): full_block_idx=full_q_idx, block_size=(sparse_tile_m, tile_n), ) - - + dq_write_order = compute_dq_write_order_from_block_mask(bm, spt=False) + block_sparse_mask_bwd = block_sparse_mask_bwd._replace( + dq_write_order=dq_write_order[0], + dq_write_order_full=dq_write_order[1], + spt=False, + ) out_tuple = _flash_attn_fwd( q=q, k=k, v=v, out=out, lse=lse, cu_seqlens_q=None, cu_seqlens_k=None, @@ -760,8 +792,9 @@ def test_static_masks( ], ) @pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112), (64, 128)]) +@pytest.mark.parametrize("use_autograd", [True, False]) def test_parameterized_masks( - seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, window_size, tile_m, tile_n + seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, window_size, tile_m, tile_n, use_autograd, ): """Test parameterized masks that require recompilation per seqlen pair. @@ -790,6 +823,7 @@ def test_parameterized_masks( tile_n=tile_n, use_block_sparsity=use_block_sparsity, needs_backward=True, + use_autograd=use_autograd, ) @@ -1091,7 +1125,7 @@ def wrapped_normalize(*args, **kwargs): def run_cute_mask_bwd( q, k, v, out, lse, grad_out, mask_mod_cute, block_sparse_mask_bwd=None, tile_m=128, tile_n=128, - aux_tensors=None, + aux_tensors=None, deterministic=False, causal=False, window_size_left=None, window_size_right=None, ): """Run flash attention backward with mask_mod. @@ -1104,6 +1138,7 @@ def run_cute_mask_bwd( block_sparse_mask_bwd: Block sparse tensors for backward pass tile_m, tile_n: Tile sizes aux_tensors: Auxiliary tensors for mask_mod (e.g., doc_ids for document masking) + deterministic: Whether to enable deterministic backward Returns (dq, dk, dv) all in BSHD format. """ @@ -1114,12 +1149,15 @@ def run_cute_mask_bwd( out=out, dout=grad_out, lse=lse, - causal=False, + causal=causal, m_block_size=tile_m, n_block_size=tile_n, + window_size_left=window_size_left, + window_size_right=window_size_right, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_bwd, aux_tensors=aux_tensors, + deterministic=deterministic, ) return dq, dk, dv @@ -1711,6 +1749,513 @@ def test_persistent_blocksparse_empty_tiles(): assert not out.isnan().any() +def _build_dense_from_ordered(num_blocks, indices, num_cols): + """Build dense binary matrix from ordered sparse representation (test helper).""" + B, H, num_rows, max_entries = indices.shape + batch_is_broadcast = B == 1 or (indices.stride(0) == 0 and num_blocks.stride(0) == 0) + head_is_broadcast = H == 1 or (indices.stride(1) == 0 and num_blocks.stride(1) == 0) + batch_size = 1 if batch_is_broadcast else B + head_size = 1 if head_is_broadcast else H + indices_view = indices[:batch_size, :head_size] + num_blocks_view = num_blocks[:batch_size, :head_size] + dense = torch.zeros( + batch_size, + head_size, + num_rows, + num_cols + 1, + dtype=torch.int32, + device=indices.device, + ) + valid = ( + torch.arange(max_entries, device=indices.device)[None, None, None, :] + < num_blocks_view[:, :, :, None] + ) + safe_indices = torch.where(valid, indices_view.long(), num_cols) + dense.scatter_(-1, safe_indices, valid.to(torch.int32)) + dense = dense[:, :, :, :num_cols] + if batch_size != B or head_size != H: + return dense.expand(B, H, num_rows, num_cols) + return dense + + +def _verify_deadlock_freedom( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + dq_wo, dq_wo_full, spt=False, +): + """Verify the critical deadlock-freedom invariant for all m_blocks. + + For non-spt: the lowest n_block contributor to each m_block must have lock_value=0. + For spt: the highest n_block contributor must have lock_value=0. + """ + B, H, num_m = kv_mask_cnt.shape + num_n = kv_mask_idx.shape[-1] + + dense = _build_dense_from_ordered(kv_mask_cnt, kv_mask_idx, num_n) + if full_kv_cnt is not None: + dense = dense | _build_dense_from_ordered(full_kv_cnt, full_kv_idx, num_n) + + for b in range(B): + for h in range(H): + for m in range(num_m): + contributors = dense[b, h, m].nonzero(as_tuple=True)[0] + if len(contributors) == 0: + continue + target_n = contributors[-1].item() if spt else contributors[0].item() + + found = False + cnt_partial = q_mask_cnt[b, h, target_n].item() + for i in range(cnt_partial): + if q_mask_idx[b, h, target_n, i].item() == m: + assert dq_wo[b, h, target_n, i].item() == 0, ( + f"n_block={target_n} should get lock_value=0 for m_block={m} (spt={spt})" + ) + found = True + break + if not found and full_q_cnt is not None: + cnt_full = full_q_cnt[b, h, target_n].item() + for i in range(cnt_full): + if full_q_idx[b, h, target_n, i].item() == m: + assert dq_wo_full[b, h, target_n, i].item() == 0, ( + f"n_block={target_n} (full) should get lock_value=0 for m_block={m} (spt={spt})" + ) + found = True + break + assert found, f"target n_block={target_n} not found in backward lists for m_block={m}" + + +def _verify_unique_ranks_per_m_block( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + dq_wo, dq_wo_full, +): + """Verify that for each m_block, the lock values form a contiguous 0..N-1 range.""" + B, H, num_m = kv_mask_cnt.shape + num_n = kv_mask_idx.shape[-1] + + dense = _build_dense_from_ordered(kv_mask_cnt, kv_mask_idx, num_n) + if full_kv_cnt is not None: + dense = dense | _build_dense_from_ordered(full_kv_cnt, full_kv_idx, num_n) + + for b in range(B): + for h in range(H): + for m in range(num_m): + contributors = dense[b, h, m].nonzero(as_tuple=True)[0] + total = len(contributors) + if total == 0: + continue + lock_vals = set() + for n in contributors.tolist(): + cnt_p = q_mask_cnt[b, h, n].item() + for i in range(cnt_p): + if q_mask_idx[b, h, n, i].item() == m: + lock_vals.add(dq_wo[b, h, n, i].item()) + if full_q_cnt is not None: + cnt_f = full_q_cnt[b, h, n].item() + for i in range(cnt_f): + if full_q_idx[b, h, n, i].item() == m: + lock_vals.add(dq_wo_full[b, h, n, i].item()) + assert lock_vals == set(range(total)), ( + f"m_block={m}: expected ranks {{0..{total-1}}}, got {lock_vals}" + ) + + +def _run_write_order_test(mask_mod_flex, seqlen_q, seqlen_k, block_size, B=1, H=4, spt=False): + bm = create_block_mask( + mask_mod_flex, B, H, seqlen_q, seqlen_k, + device="cuda", BLOCK_SIZE=(block_size, block_size), + ) + ( + _, _, + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_, + ) = bm.as_tuple() + + dq_wo, dq_wo_full = compute_dq_write_order( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + spt=spt, + ) + + _verify_deadlock_freedom( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + dq_wo, dq_wo_full, spt=spt, + ) + if not spt: + _verify_unique_ranks_per_m_block( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + dq_wo, dq_wo_full, + ) + + +def _build_block_sparse_masks_for_bwd( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + spt, +): + sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m + bm = create_block_mask( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = bm.as_tuple() + + block_sparse_mask_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(sparse_tile_m, tile_n), + ) + block_sparse_mask_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, + mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, + full_block_idx=full_q_idx, + block_size=(sparse_tile_m, tile_n), + ) + dq_write_order = compute_dq_write_order_from_block_mask(bm, spt=spt) + return block_sparse_mask_fwd, block_sparse_mask_bwd._replace( + dq_write_order=dq_write_order[0], + dq_write_order_full=dq_write_order[1], + spt=spt, + ) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 256), (512, 512), (383, 769)]) +@pytest.mark.parametrize( + "mask_name,window_size", + [ + ("block_diagonal", None), + ("causal", None), + ("sliding_window", 256), + ("document", None), + ], +) +@pytest.mark.parametrize("spt", [False, True]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa"]) +def test_block_sparse_bwd_deterministic(seqlen_q, seqlen_k, mask_name, window_size, spt, kv_mode): + torch.manual_seed(42) + if mask_name == "sliding_window" and seqlen_q > seqlen_k: + pytest.skip("sliding_window requires seqlen_q <= seqlen_k") + if spt and mask_name not in ("sliding_window", "causal"): + pytest.skip("spt path is only exercised for sliding_window and causal in this test") + + batch_size = 1 + nheads = 4 + nheads_kv = 1 if kv_mode == "gqa" else nheads + pack_gqa = nheads != nheads_kv + headdim = 128 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + + mask_mod_cute, mask_mod_flex = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + + aux_tensors_arg = None + if mask_name == "document": + doc_ids = random_doc_id_tensor(nheads, batch_size, max(seqlen_q, seqlen_k), device="cuda").to( + torch.int32 + ) + original_flex_mask = mask_mod_flex + + def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): + return original_flex_mask(b, h, q_idx, kv_idx, doc_ids) + + aux_tensors_arg = [doc_ids] + + tensors = create_tensors(batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim, dtype) + q = tensors["q"] + k = tensors["k"] + v = tensors["v"] + block_mask_nheads = 1 if pack_gqa else nheads + block_sparse_mask_fwd, block_sparse_mask_bwd = _build_block_sparse_masks_for_bwd( + mask_mod_flex=mask_mod_flex, + batch_size=batch_size, + nheads=block_mask_nheads, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + tile_m=tile_m, + tile_n=tile_n, + spt=spt, + ) + causal_arg = spt and mask_name == "causal" + window_size_left_arg = window_size if spt and mask_name == "sliding_window" else None + window_size_right_arg = 0 if spt and mask_name == "sliding_window" else None + mask_mod_arg = mask_mod_cute if not spt else None + + out_cute, lse_cute = _flash_attn_fwd( + q=q, + k=k, + v=v, + out=torch.empty(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype), + lse=torch.empty(batch_size, nheads, seqlen_q, device="cuda", dtype=torch.float32), + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=1.0 / math.sqrt(headdim), + causal=causal_arg, + softcap=None, + window_size_left=window_size_left_arg, + window_size_right=window_size_right_arg, + learnable_sink=None, + tile_mn=(tile_m, tile_n), + pack_gqa=pack_gqa, + _arch=None, + score_mod=None, + mask_mod=mask_mod_arg, + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, + aux_tensors=aux_tensors_arg, + ) + + grad_out = torch.randn_like(out_cute) + dq0, dk0, dv0 = run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + mask_mod_arg, + block_sparse_mask_bwd=block_sparse_mask_bwd, + tile_m=tile_m, + tile_n=tile_n, + aux_tensors=aux_tensors_arg, + deterministic=True, + causal=causal_arg, + window_size_left=window_size_left_arg, + window_size_right=window_size_right_arg, + ) + + num_repeats = 3 if spt else 50 + for _ in range(num_repeats): + dq, dk, dv = run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + mask_mod_arg, + block_sparse_mask_bwd=block_sparse_mask_bwd, + tile_m=tile_m, + tile_n=tile_n, + aux_tensors=aux_tensors_arg, + deterministic=True, + causal=causal_arg, + window_size_left=window_size_left_arg, + window_size_right=window_size_right_arg, + ) + assert torch.equal(dq, dq0) + assert torch.equal(dk, dk0) + assert torch.equal(dv, dv0) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +def _setup_block_sparse_deterministic_validation_case(): + torch.manual_seed(42) + batch_size = 1 + nheads = 4 + seqlen_q = 256 + seqlen_k = 256 + headdim = 128 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + + _, mask_mod_flex = get_mask_pair( + "block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k + ) + + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + block_sparse_mask_fwd, block_sparse_mask_bwd = _build_block_sparse_masks_for_bwd( + mask_mod_flex=mask_mod_flex, + batch_size=batch_size, + nheads=nheads, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + tile_m=tile_m, + tile_n=tile_n, + spt=False, + ) + out_cute, lse_cute = _flash_attn_fwd( + q=q, + k=k, + v=v, + out=torch.empty(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype), + lse=torch.empty(batch_size, nheads, seqlen_q, device="cuda", dtype=torch.float32), + softmax_scale=1.0 / math.sqrt(headdim), + tile_mn=(tile_m, tile_n), + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, + ) + + return q, k, v, out_cute, lse_cute, torch.randn_like(out_cute), block_sparse_mask_bwd, tile_m, tile_n + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +def test_block_sparse_bwd_deterministic_missing_dq_write_order_raises(): + q, k, v, out_cute, lse_cute, grad_out, block_sparse_mask_bwd, tile_m, tile_n = ( + _setup_block_sparse_deterministic_validation_case() + ) + block_sparse_mask_bwd_no_dq_write_order = block_sparse_mask_bwd._replace( + dq_write_order=None, + dq_write_order_full=None, + spt=None, + ) + + with pytest.raises(ValueError, match="requires dq_write_order in block_sparse_tensors"): + run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + None, + block_sparse_mask_bwd=block_sparse_mask_bwd_no_dq_write_order, + tile_m=tile_m, + tile_n=tile_n, + deterministic=True, + ) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +def test_block_sparse_bwd_deterministic_missing_dq_write_order_full_raises(): + q, k, v, out_cute, lse_cute, grad_out, block_sparse_mask_bwd, tile_m, tile_n = ( + _setup_block_sparse_deterministic_validation_case() + ) + block_sparse_mask_bwd_no_dq_write_order_full = block_sparse_mask_bwd._replace( + full_block_cnt=torch.zeros_like(block_sparse_mask_bwd.mask_block_cnt), + full_block_idx=torch.zeros_like(block_sparse_mask_bwd.mask_block_idx), + dq_write_order_full=None, + spt=False, + ) + + with pytest.raises(ValueError, match="requires dq_write_order_full when full blocks are present"): + run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + None, + block_sparse_mask_bwd=block_sparse_mask_bwd_no_dq_write_order_full, + tile_m=tile_m, + tile_n=tile_n, + deterministic=True, + ) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +def test_block_sparse_bwd_deterministic_missing_spt_raises(): + q, k, v, out_cute, lse_cute, grad_out, block_sparse_mask_bwd, tile_m, tile_n = ( + _setup_block_sparse_deterministic_validation_case() + ) + block_sparse_mask_bwd_no_spt = block_sparse_mask_bwd._replace(spt=None) + + with pytest.raises(ValueError, match="requires block_sparse_tensors.spt"): + run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + None, + block_sparse_mask_bwd=block_sparse_mask_bwd_no_spt, + tile_m=tile_m, + tile_n=tile_n, + deterministic=True, + ) + + +WRITE_ORDER_SEQLENS = [ + (256, 256), + (512, 512), + (1024, 1024), + (2048, 2048), + (4096, 4096), + (512, 1024), + (1024, 512), + (384, 768), +] + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", WRITE_ORDER_SEQLENS) +@pytest.mark.parametrize("mask_name", ["block_diagonal", "mini_causal", "prefix_lm", "dilated_sliding_window"]) +@pytest.mark.parametrize("spt", [False, True]) +def test_dq_write_order_static_masks(seqlen_q, seqlen_k, mask_name, spt): + torch.manual_seed(42) + _, mask_mod_flex = get_mask_pair(mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k) + _run_write_order_test(mask_mod_flex, seqlen_q, seqlen_k, block_size=128, spt=spt) + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", WRITE_ORDER_SEQLENS) +@pytest.mark.parametrize( + "mask_name,window_size", + [ + ("causal", None), + ("block_causal", None), + ("sliding_window", 128), + ("sliding_window", 256), + ("sliding_window", 512), + ], +) +@pytest.mark.parametrize("spt", [False, True]) +def test_dq_write_order_parameterized_masks(seqlen_q, seqlen_k, mask_name, window_size, spt): + torch.manual_seed(42) + if mask_name == "sliding_window" and seqlen_q > seqlen_k: + pytest.skip("sliding_window requires seqlen_q <= seqlen_k") + _, mask_mod_flex = get_mask_pair(mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size) + _run_write_order_test(mask_mod_flex, seqlen_q, seqlen_k, block_size=128, spt=spt) + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(512, 512), (1024, 1024), (2048, 2048)]) +@pytest.mark.parametrize("spt", [False, True]) +def test_dq_write_order_document_mask(seqlen_q, seqlen_k, spt): + torch.manual_seed(42) + B, H = 1, 4 + doc_ids = random_doc_id_tensor(H, B, max(seqlen_q, seqlen_k), device="cuda").to(torch.int32) + + def doc_mask(b, h, q_idx, kv_idx): + return doc_ids[b, h, q_idx] == doc_ids[b, h, kv_idx] + + _run_write_order_test(doc_mask, seqlen_q, seqlen_k, block_size=128, B=B, H=H, spt=spt) + def test_compact_block_sparse_indices(): """Test that compact block sparse index tensors (idx.shape[3] < n_blocks) work correctly. diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index 43bf62e7d54..95a05a1d60b 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -5,7 +5,12 @@ from cutlass._mlir.dialects import math as mlir_math import operator from torch.nn.attention.flex_attention import create_block_mask, flex_attention -from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd, _tile_size_bwd_sm90 +from flash_attn.cute.interface import ( + flash_attn_func, + _flash_attn_fwd, + _flash_attn_bwd, + _tile_size_bwd_sm90, +) from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] @@ -753,37 +758,55 @@ def score_squared_eager(score, b, h, q_idx, kv_idx): def run_cute_flash_bwd( - q, k, v, cute_score_mod, cute_score_mod_bwd, aux_tensors=None, pack_gqa=False + q, k, v, cute_score_mod, cute_score_mod_bwd, aux_tensors=None, pack_gqa=False, use_autograd=True, ): """Run flash attention forward + backward with score_mod.""" q_t = q.transpose(1, 2) k_t = k.transpose(1, 2) v_t = v.transpose(1, 2) - out, lse = _flash_attn_fwd( - q_t, - k_t, - v_t, - return_lse=True, - score_mod=cute_score_mod, - aux_tensors=aux_tensors, - pack_gqa=pack_gqa, - ) + if use_autograd: + q_t = q_t.detach().requires_grad_(True) + k_t = k_t.detach().requires_grad_(True) + v_t = v_t.detach().requires_grad_(True) + out, lse = flash_attn_func( + q_t, + k_t, + v_t, + score_mod=cute_score_mod, + score_mod_bwd=cute_score_mod_bwd, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) - grad_out = torch.randn_like(out) + grad_out = torch.randn_like(out) - dq, dk, dv = _flash_attn_bwd( - q_t, - k_t, - v_t, - out, - grad_out, - lse, - score_mod=cute_score_mod, - score_mod_bwd=cute_score_mod_bwd, - aux_tensors=aux_tensors, - pack_gqa=pack_gqa, - ) + dq, dk, dv = torch.autograd.grad(out, (q_t, k_t, v_t), grad_out) + else: + out, lse = _flash_attn_fwd( + q_t, + k_t, + v_t, + return_lse=True, + score_mod=cute_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + + grad_out = torch.randn_like(out) + + dq, dk, dv = _flash_attn_bwd( + q_t, + k_t, + v_t, + out, + grad_out, + lse, + score_mod=cute_score_mod, + score_mod_bwd=cute_score_mod_bwd, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) return ( out.transpose(1, 2), @@ -982,7 +1005,8 @@ def run_flex_block_sparse_score_mod_ref(q_ref, k_ref, v_ref, grad_out_ref, ref_d @pytest.mark.parametrize("dim", [64, 128]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS) -def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_mod_triple): +@pytest.mark.parametrize("use_autograd", [True, False]) +def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_mod_triple, use_autograd): """Test backward pass with score_mod against flex_attention reference.""" if COMPUTE_CAPABILITY == 9 and dim == 64: pytest.skip("head_dim=64 not supported on SM90 for backward") @@ -994,7 +1018,7 @@ def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_ seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=4, dim=dim, dtype=dtype ) - out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd(q, k, v, cute_fwd, cute_bwd) + out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd(q, k, v, cute_fwd, cute_bwd, use_autograd=use_autograd) out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) diff --git a/tests/test_flash_attn_ck.py b/tests/test_flash_attn_ck.py index 24c17b0133c..abd1eb147ad 100644 --- a/tests/test_flash_attn_ck.py +++ b/tests/test_flash_attn_ck.py @@ -1550,6 +1550,52 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): assert not v.grad.isnan().any() +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("nheads_kv", [1, 8]) # GQA (1) and MHA (8) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [64, 128]) +def test_flash_attn_bwd_varlen_seqq_zero(d, causal, nheads_kv, deterministic, dtype): + """Regression test: NaN in dK/dV for the zero-length Q subsequence (run for both GQA and MHA). + """ + if not is_bwd_supported(d, deterministic=deterministic): + pytest.skip(get_bwd_unsupported_reason(d, deterministic=deterministic)) + + device = "cuda" + torch.random.manual_seed(0) + nheads = 8 # n_q_heads; GQA when nheads_kv < nheads + q_cuseqlen = torch.tensor([0, 0, 256, 512], device=device, dtype=torch.int32) + k_cuseqlen = torch.tensor([0, 503, 768, 1536], device=device, dtype=torch.int32) + total_q = int(q_cuseqlen[-1]) # 512 + total_k = int(k_cuseqlen[-1]) # 1536 + Mq = 256 + Mk = 768 + + q = torch.randn([total_q, nheads, d], dtype=dtype, device=device) + k = torch.randn([total_k, nheads_kv, d], dtype=dtype, device=device) + v = torch.randn([total_k, nheads_kv, d], dtype=dtype, device=device) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + + out = flash_attn_varlen_func( + q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal, deterministic=deterministic + ) + g = torch.randn_like(out) + out.backward(g) + + # The bug produced NaN specifically in dK/dV for the zero-length Q subsequence + assert not q.grad.isnan().any() + assert not k.grad.isnan().any() + assert not v.grad.isnan().any() + + # Additionally, for the batch with seqlen_q == 0, the corresponding K/V segment + # should contribute nothing to the loss, so dK/dV for that segment must be zero. + end_kv_zero = int(k_cuseqlen[1]) + assert k.grad[:end_kv_zero].abs().max() == 0 + assert v.grad[:end_kv_zero].abs().max() == 0 + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("causal", [False, True]) diff --git a/third_party/aiter b/third_party/aiter index 428e8e761c7..b4b75165fbd 160000 --- a/third_party/aiter +++ b/third_party/aiter @@ -1 +1 @@ -Subproject commit 428e8e761c7bc22d03513bcb8507375afef1f916 +Subproject commit b4b75165fbd2456dfd0f074c5b2ef91bc87d97e5 diff --git a/tools/ci/README.md b/tools/ci/README.md new file mode 100644 index 00000000000..11deaeb10e0 --- /dev/null +++ b/tools/ci/README.md @@ -0,0 +1,48 @@ +# FA4 CI + +CI runs on a self-hosted GPU runner using an Apptainer (SIF) container pulled from Docker Hub. +Triggered on every push to `main`. + +## Two-pass test strategy + +- **Pass 1** — compile kernels in parallel via `FakeTensorMode` (no GPU memory needed) +- **Pass 2** — run tests using cached compiled kernels on real GPU + +See `run_fa4_ci.py` for the shared logic used by both CI and `test_ci_local.sh`. + +## Required GitHub secrets / variables + +| Name | Kind | Value | +|------|------|-------| +| `DOCKERHUB_USERNAME` | Secret | Docker Hub username | +| `DOCKERHUB_TOKEN` | Secret | Docker Hub access token | +| `CI_WORK_DIR` | Variable | Large-disk path on runner, e.g. `/scratch/user/johnson` | + +`CI_WORK_DIR` is used for SIF caching and Apptainer temp files. Falls back to `/scratch/user/` if unset. + +## Updating the container image + +1. Build and push a new image via `tools/ci/docker/build.sh` + `tag_and_push.sh`. +2. Update `FA4_IMAGE` in `.github/workflows/ci.yml` with the new tag and `sha256` digest. +3. The old SIF is automatically deleted from the runner on the next CI run. + +## Expanding test coverage + +Edit `FA4_TEST_FILTER` in `.github/workflows/ci.yml`. To run the full suite, set it to an empty string and increase `compile-workers` in the `gpu-test` action call. + +Alternatively, edit `run_fa4_ci.py` to change `DEFAULT_TEST_TARGET` or worker defaults — changes there apply to both CI and local runs. + +## FA2 import isolation + +Tests run inside the Apptainer container. The repo's `flash_attn/__init__.py` imports the FA2 C extension (`flash_attn_2_cuda`) which is absent in the container. `run_fa4_ci.py` works around this by: + +1. Installing FA4 from the current repo into the container at runtime (`uv pip install -e flash_attn/cute`). +2. Running pytest from `/tmp` with absolute test paths — this keeps the repo root out of `sys.path[0]` so the installed FA4 package is found instead of the FA2 `__init__.py`. + +`flash_attn/__init__.py` is intentionally not modified; isolation is handled entirely in CI. + +## Adding a new runner / GPU type + +1. Register a self-hosted runner on the machine with the desired label (e.g. `h100`). +2. Add the label to the `gpu` matrix in `.github/workflows/ci.yml`. +3. Set `CI_WORK_DIR` for the new machine if its scratch path differs. diff --git a/tools/ci/build_sif.sh b/tools/ci/build_sif.sh new file mode 100755 index 00000000000..429f8720945 --- /dev/null +++ b/tools/ci/build_sif.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +# Build the FA4 Apptainer SIF image. +# +# Usage: +# ./tools/ci/build_sif.sh [OUTPUT_PATH] +# +# Default output: /scratch/user/$USER/attention_fa4_.sif +# All temp/cache dirs are redirected to /scratch to avoid filling the root volume. +# +# Example: +# ./tools/ci/build_sif.sh +# ./tools/ci/build_sif.sh ~/scratch/my_fa4.sif + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DEF_FILE="$SCRIPT_DIR/fa4.def" + +DATE=$(date +%Y%m%d) +SCRATCH_BASE="${CI_WORK_DIR:-/scratch/user/${USER}}" +DEFAULT_OUT="${SCRATCH_BASE}/attention_fa4_${DATE}.sif" +OUTPUT="${1:-$DEFAULT_OUT}" +TMP_DIR="$SCRATCH_BASE/apptainer_tmp" +CACHE_DIR="$SCRATCH_BASE/apptainer_cache" + +mkdir -p "$TMP_DIR" "$CACHE_DIR" + +echo "=== FA4 SIF Build ===" +echo " def file : $DEF_FILE" +echo " output : $OUTPUT" +echo " tmp dir : $TMP_DIR" +echo " cache dir: $CACHE_DIR" +echo + +sudo \ + APPTAINER_TMPDIR="$TMP_DIR" \ + APPTAINER_CACHEDIR="$CACHE_DIR" \ + apptainer build "$OUTPUT" "$DEF_FILE" + +echo +echo "Build complete: $OUTPUT" +echo "File size: $(du -sh "$OUTPUT" | cut -f1)" +echo +echo "To use in CI, set on the runner:" +echo " export FA4_SIF=$OUTPUT" diff --git a/tools/ci/docker/Dockerfile b/tools/ci/docker/Dockerfile new file mode 100644 index 00000000000..e01b1d5575a --- /dev/null +++ b/tools/ci/docker/Dockerfile @@ -0,0 +1,30 @@ +FROM ubuntu:24.04 + +ENV PATH=/usr/local/cuda/bin:/root/.local/bin:$PATH +ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH +ENV DEBIAN_FRONTEND=noninteractive + +# System dependencies +RUN apt-get update -y && apt-get install -y \ + python3 python3.12 python3.12-dev python3.12-venv \ + curl git ca-certificates gcc && \ + rm -rf /var/lib/apt/lists/* + +# Install uv +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +# Install torch from the published cu130 nightly package page and let pip resolve +# the matching nightly Triton wheel from PyTorch's nightly triton page. +RUN uv pip install --system --break-system-packages --no-cache --pre \ + "torch>=2.12.0.dev0,<2.13.0" \ + --find-links https://download.pytorch.org/whl/nightly/cu130/torch \ + --find-links https://download.pytorch.org/whl/nightly/triton + +# FA4 dependencies — derived from pyproject.toml to keep a single source of truth. +# torch stays pinned from the direct wheel install above. +# The package itself (version 0.0.0 via setuptools-scm fallback) is overwritten at +# step time by the editable install from the mounted repo. +COPY flash_attn/cute/ /tmp/fa4/ +RUN uv pip install --system --break-system-packages --no-cache "/tmp/fa4[cu13,dev]" + +CMD ["/bin/bash"] diff --git a/tools/ci/docker/build.sh b/tools/ci/docker/build.sh new file mode 100755 index 00000000000..9d82bf146d2 --- /dev/null +++ b/tools/ci/docker/build.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +DATE=$(date +%y.%m.%d) +IMAGE_NAME="flash-attn-4:flash-attn-cu13.0-${DATE}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(git -C "$SCRIPT_DIR" rev-parse --show-toplevel)" + +echo "Building $IMAGE_NAME ..." +sudo docker build -t "$IMAGE_NAME" -f "$SCRIPT_DIR/Dockerfile" "$REPO_ROOT" +echo "Done: $IMAGE_NAME" diff --git a/tools/ci/docker/tag_and_push.sh b/tools/ci/docker/tag_and_push.sh new file mode 100755 index 00000000000..b67dc232324 --- /dev/null +++ b/tools/ci/docker/tag_and_push.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -e + +DATE=$(date +%y.%m.%d) +LOCAL_TAG="flash-attn-4:flash-attn-cu13.0-${DATE}" +REMOTE_TAG="togethercomputer/training-performance:flash-attn-cu13.0-${DATE}" + +docker tag "$LOCAL_TAG" "$REMOTE_TAG" +docker push "$REMOTE_TAG" diff --git a/tools/ci/fa4.def b/tools/ci/fa4.def new file mode 100644 index 00000000000..4ffda0a8445 --- /dev/null +++ b/tools/ci/fa4.def @@ -0,0 +1,7 @@ +Bootstrap: docker +From: togethercomputer/training-performance:flash-attn-cu13.0-26.04.01 + +%labels + maintainer tridao + cuda cu130 + built_for B200/SM100 diff --git a/tools/ci/run_fa4_ci.py b/tools/ci/run_fa4_ci.py new file mode 100644 index 00000000000..e539df7f056 --- /dev/null +++ b/tools/ci/run_fa4_ci.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +"""FA4 CI driver — runs inside an Apptainer SIF on a self-hosted GPU runner. + +Requires FA4_SIF (path to the .sif image) to be set, either via env var or --sif. +""" + +from __future__ import annotations + +import argparse +import os +import shlex +import subprocess +from dataclasses import dataclass +from pathlib import Path + +DEFAULT_TEST_FILTER = "" # empty = run all; CI overrides via --test-filter +DEFAULT_TEST_TARGET = "tests/cute/test_flash_attn.py" + + +@dataclass(frozen=True) +class Step: + name: str + command: list[str] + extra_env: dict[str, str] + + +# ── GPU helpers ─────────────────────────────────────────────────────────────── + +def parse_free_gpu_indices(nvidia_smi_output: str, min_free_memory_mb: int) -> list[str]: + indices: list[str] = [] + for raw_line in nvidia_smi_output.splitlines(): + line = raw_line.strip() + if not line: + continue + try: + index, free_memory = [part.strip() for part in line.split(",", maxsplit=1)] + if int(free_memory) >= min_free_memory_mb: + indices.append(index) + except ValueError as exc: + raise ValueError(f"Unexpected nvidia-smi output line: {raw_line!r}") from exc + return indices + + +def select_visible_devices(free_gpu_indices: list[str], use_all_free_gpus: bool) -> str: + if not free_gpu_indices: + raise ValueError("No GPUs satisfy the free-memory threshold") + if use_all_free_gpus: + return ",".join(free_gpu_indices) + return free_gpu_indices[0] + + +def read_free_gpu_indices(min_free_memory_mb: int) -> list[str]: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=index,memory.free", "--format=csv,noheader,nounits"], + check=True, capture_output=True, text=True, + ) + return parse_free_gpu_indices(result.stdout, min_free_memory_mb) + + +# ── Step plan ───────────────────────────────────────────────────────────────── + +def build_step_plan( + test_target: str, + test_filter: str, + compile_workers: int, + run_workers: int, + test_visible_devices: str, + benchmark_visible_devices: str, + skip_benchmark: bool, +) -> list[Step]: + pytest_base = ["python3", "-m", "pytest", test_target, *(["-k", test_filter] if test_filter else [])] + + steps = [ + Step( + name="Pass 1: compile kernels (no GPU memory)", + command=[*pytest_base, "-n", str(compile_workers), "-x"], + extra_env={"FLASH_ATTENTION_FAKE_TENSOR": "1"}, + ), + Step( + name="Pass 2: run tests on GPU", + command=[*pytest_base, "-n", str(run_workers), "-x"], + extra_env={ + "FLASH_ATTENTION_FAKE_TENSOR": "0", + "CUDA_VISIBLE_DEVICES": test_visible_devices, + }, + ), + ] + if not skip_benchmark: + steps.append(Step( + name="Benchmark (FA4 fwd, hdim=128, causal=both, seqlen=1K-32K)", + command=[ + "python3", "benchmarks/benchmark_attn.py", + "--backend", "fa4", "--fwd", "--bwd", + "--headdim", "128", + "--seqlen", "1024,2048,4096,8192,16384,32768", + "--causal", "both", "--warmup", "1", "--rep", "3", + ], + extra_env={"CUDA_VISIBLE_DEVICES": benchmark_visible_devices}, + )) + return steps + + +# ── Step runner ─────────────────────────────────────────────────────────────── + +def run_step(step: Step, repo_root: Path, base_env: dict[str, str], sif: str, work_dir: str) -> None: + print(f"=== {step.name} ===") + + # Install FA4 from the current repo inside this exec invocation. + # Must be done per-step because --writable-tmpfs creates a fresh overlay each time. + install_cmd = f"uv pip install --system --break-system-packages --no-deps -q -e {shlex.quote(str(repo_root / 'flash_attn/cute'))}" + + # Convert relative test/benchmark paths to absolute so we can run from /tmp. + # Running from /tmp ensures Python does not insert repo_root into sys.path[0] + # (which would cause flash_attn/__init__.py to trigger FA2 imports unavailable in the SIF). + command = [ + str(repo_root / arg) if (arg.startswith("tests/") or arg.startswith("benchmarks/")) else arg + for arg in step.command + ] + env_exports = " && ".join(f"export {k}={shlex.quote(v)}" for k, v in step.extra_env.items()) + inner_cmd = shlex.join(command) + shell_parts = [install_cmd] + if env_exports: + shell_parts.append(env_exports) + shell_parts.append(f"cd /tmp && {inner_cmd}") + cmd = ["apptainer", "exec", "--nv", "--writable-tmpfs", "--bind", work_dir, sif, "bash", "-c", " && ".join(shell_parts)] + subprocess.run(cmd, check=True, cwd=repo_root, env=base_env) + + +# ── CLI ─────────────────────────────────────────────────────────────────────── + +def make_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--repo-root", type=Path, default=Path.cwd()) + parser.add_argument("--sif", default=os.environ.get("FA4_SIF", ""), + help="Apptainer .sif image path (or set FA4_SIF env var)") + parser.add_argument("--test-target", default=DEFAULT_TEST_TARGET) + parser.add_argument("--test-filter", default=DEFAULT_TEST_FILTER) + parser.add_argument("--compile-workers", type=int, default=1) + parser.add_argument("--run-workers", type=int, default=1) + parser.add_argument("--min-free-memory-mb", type=int, default=40000) + parser.add_argument("--use-all-free-gpus", action="store_true") + parser.add_argument("--skip-benchmark", action="store_true") + return parser + + +def main() -> None: + args = make_parser().parse_args() + repo_root = args.repo_root.resolve() + + if not args.sif: + raise SystemExit("FA4_SIF is not set — provide --sif or set the FA4_SIF env var.") + print(f"Using SIF: {args.sif}") + + free_gpu_indices = read_free_gpu_indices(args.min_free_memory_mb) + test_visible_devices = select_visible_devices(free_gpu_indices, args.use_all_free_gpus) + benchmark_visible_devices = free_gpu_indices[0] + print(f"Running tests on GPUs: {test_visible_devices}") + + base_env = {**os.environ, "FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED": "1"} + work_dir = os.environ.get("CI_WORK_DIR", f"/scratch/user/{os.environ.get('USER', 'user')}") + + for step in build_step_plan( + test_target=args.test_target, + test_filter=args.test_filter, + compile_workers=args.compile_workers, + run_workers=args.run_workers, + test_visible_devices=test_visible_devices, + benchmark_visible_devices=benchmark_visible_devices, + skip_benchmark=args.skip_benchmark, + ): + run_step(step, repo_root=repo_root, base_env=base_env, sif=args.sif, work_dir=work_dir) + + print("=== All tests passed ===") + + +if __name__ == "__main__": + main()