From fa5b632ddaeaa7daea53405ad1fa8ca042f431b2 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Wed, 1 Apr 2026 15:14:25 -0700 Subject: [PATCH 01/44] fix noisy logger (#2414) stack-info: PR: https://github.com/Dao-AILab/flash-attention/pull/2414, branch: drisspg/stack/32 --- flash_attn/cute/cache_utils.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/flash_attn/cute/cache_utils.py b/flash_attn/cute/cache_utils.py index 3fca0579d98..f1b59700448 100644 --- a/flash_attn/cute/cache_utils.py +++ b/flash_attn/cute/cache_utils.py @@ -1,7 +1,6 @@ # Manage Ahead-of-Time (AOT) compiled kernels import fcntl import hashlib -import logging import os import pickle import sys @@ -18,6 +17,7 @@ import cutlass.cute as cute import tvm_ffi from cutlass.cutlass_dsl import JitCompiledFunction +from flash_attn.cute.fa_logging import fa_log # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. @@ -30,16 +30,6 @@ CompileKeyType: TypeAlias = tuple[Hashable, ...] CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function -logger = logging.getLogger(__name__) -_handler = logging.StreamHandler() -_handler.setFormatter( - logging.Formatter( - "%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" - ) -) -logger.addHandler(_handler) -logger.setLevel(logging.DEBUG) - # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1" @@ -226,13 +216,13 @@ def _try_load_from_storage(self, key: CompileKeyType) -> bool: label=sha256_hex, ): if obj_path.exists(): - logger.debug("Loading compiled function from disk: %s", obj_path) + fa_log(1, f"Loading compiled function from disk: {obj_path}") m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True) fn = getattr(m, self.EXPORT_FUNCTION_PREFIX) JITCache.__setitem__(self, key, fn) return True else: - logger.debug("Cache miss on disk for key hash %s", sha256_hex) + fa_log(1, f"Cache miss on disk for key hash {sha256_hex}") return False def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None: @@ -247,14 +237,14 @@ def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) - obj_path = self.cache_path / f"{sha256_hex}.o" if obj_path.exists(): # Another process already exported. - logger.debug("Skipping export, already on disk: %s", obj_path) + fa_log(1, f"Skipping export, already on disk: {obj_path}") return - logger.debug("Exporting compiled function to disk: %s", obj_path) + fa_log(1, f"Exporting compiled function to disk: {obj_path}") fn.export_to_c( object_file_path=str(obj_path), function_name=self.EXPORT_FUNCTION_PREFIX, ) - logger.debug("Successfully exported compiled function to disk: %s", obj_path) + fa_log(1, f"Successfully exported compiled function to disk: {obj_path}") def _key_to_hash(self, key: CompileKeyType) -> str: return hashlib.sha256(pickle.dumps(key)).hexdigest() @@ -266,7 +256,7 @@ def clear(self) -> None: """ Not only clear the in-memory cache. Also purge persistent compilation cache. """ - logger.debug("Clearing persistent cache at %s", self.cache_path) + fa_log(1, f"Clearing persistent cache at {self.cache_path}") super().clear() for child in self.cache_path.iterdir(): child.unlink() @@ -285,8 +275,8 @@ def get_jit_cache(name: str | None = None) -> JITCache: path = get_cache_path() / _compute_source_fingerprint() if name: path = path / name - logger.debug("Creating persistent JIT cache at %s", path) + fa_log(1, f"Creating persistent JIT cache at {path}") return JITPersistentCache(path) else: - logger.debug("Persistent cache disabled, using in-memory JIT cache") + fa_log(1, "Persistent cache disabled, using in-memory JIT cache") return JITCache() From b46c5875374de6ace968c775c6d319398ac5ab02 Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 2 Apr 2026 06:16:30 +0800 Subject: [PATCH 02/44] [AMD ROCm] Fix NaN in FMHA BWD when seq_q=0 (#2421) * [CK_TILE] Fix NaN for FMHA BWD When seq_q=0 * Add regression test for NaN in dK/dV with zero-length Q subsequence in flash attention BWD * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Add additional assertions for gradients in test_flash_attn_bwd_varlen_seqq_zero --------- Co-authored-by: Ding, Yi Co-authored-by: Yi DING Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- csrc/composable_kernel | 2 +- tests/test_flash_attn_ck.py | 46 +++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index 859acb5ae7f..791afc64655 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit 859acb5ae7fdd7f1016a7bfbd1a85c26bb403c6b +Subproject commit 791afc64655301487cac6e5361c677a0a4b82059 diff --git a/tests/test_flash_attn_ck.py b/tests/test_flash_attn_ck.py index 24c17b0133c..abd1eb147ad 100644 --- a/tests/test_flash_attn_ck.py +++ b/tests/test_flash_attn_ck.py @@ -1550,6 +1550,52 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): assert not v.grad.isnan().any() +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("nheads_kv", [1, 8]) # GQA (1) and MHA (8) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [64, 128]) +def test_flash_attn_bwd_varlen_seqq_zero(d, causal, nheads_kv, deterministic, dtype): + """Regression test: NaN in dK/dV for the zero-length Q subsequence (run for both GQA and MHA). + """ + if not is_bwd_supported(d, deterministic=deterministic): + pytest.skip(get_bwd_unsupported_reason(d, deterministic=deterministic)) + + device = "cuda" + torch.random.manual_seed(0) + nheads = 8 # n_q_heads; GQA when nheads_kv < nheads + q_cuseqlen = torch.tensor([0, 0, 256, 512], device=device, dtype=torch.int32) + k_cuseqlen = torch.tensor([0, 503, 768, 1536], device=device, dtype=torch.int32) + total_q = int(q_cuseqlen[-1]) # 512 + total_k = int(k_cuseqlen[-1]) # 1536 + Mq = 256 + Mk = 768 + + q = torch.randn([total_q, nheads, d], dtype=dtype, device=device) + k = torch.randn([total_k, nheads_kv, d], dtype=dtype, device=device) + v = torch.randn([total_k, nheads_kv, d], dtype=dtype, device=device) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + + out = flash_attn_varlen_func( + q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal, deterministic=deterministic + ) + g = torch.randn_like(out) + out.backward(g) + + # The bug produced NaN specifically in dK/dV for the zero-length Q subsequence + assert not q.grad.isnan().any() + assert not k.grad.isnan().any() + assert not v.grad.isnan().any() + + # Additionally, for the batch with seqlen_q == 0, the corresponding K/V segment + # should contribute nothing to the loss, so dK/dV for that segment must be zero. + end_kv_zero = int(k_cuseqlen[1]) + assert k.grad[:end_kv_zero].abs().max() == 0 + assert v.grad[:end_kv_zero].abs().max() == 0 + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("causal", [False, True]) From 4bc0ab1f0567670610ab884c80abbfba6fa17ab6 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Wed, 1 Apr 2026 18:41:39 -0700 Subject: [PATCH 03/44] Add FA4 CI: GitHub Actions workflow with Apptainer on B200 runner (#2393) * Add FA4 CI: GitHub Actions workflow with Apptainer on B200 self-hosted runner - GitHub Actions workflow (ci.yml) triggering on push to main - Two-pass test strategy: kernel compilation (FakeTensorMode) + GPU execution - Pulls Docker image from Docker Hub (togethercomputer/training-performance) and converts to Apptainer SIF, cached by image tag on the runner - CI_WORK_DIR repo variable for configurable large-disk path on the runner - Shared Python driver (tools/ci/run_fa4_ci.py) used by both CI and local runs - Docker image build scripts and Apptainer SIF definition (tools/ci/) - CI_SETUP.md setup guide covering runner registration, secrets, and migration * Upgrade CI to CUDA 13.0 (cu130) torch nightly for B200/SM100 support * Address CI review: security, shell quoting, and config cleanup - Add permissions: contents: read to ci.yml to restrict GITHUB_TOKEN scope - Pin FA4_IMAGE by digest for reproducibility; bump to flash-attn-cu13.0-26.04.01 - Move FA4_TEST_FILTER to ci.yml and thread through action inputs - Fix shell quoting in run_fa4_ci.py (shlex.quote/join for all paths and env values) - Drop venv fallback mode; script is now apptainer-only with clear error on missing FA4_SIF - Drive Dockerfile deps from pyproject.toml via full cute/ copy (single source of truth) - Widen Docker build context to repo root so pyproject.toml is accessible - Add pytest-xdist to pyproject.toml dev extra (was missing, needed for -n flag) * Fix stale --venv flag in test_ci_local.sh; add compile-workers input to gpu-test action - Remove --venv arg from test_ci_local.sh (dropped in 360e6d7, caused immediate argparse failure) - Add compile-workers input to action.yml (default 64) so Pass 1 compilation runs in parallel * [CI] Add FA4 CI: SIF cache cleanup and trim setup docs - Delete stale SIF files after pulling a new image to prevent unbounded disk growth on the self-hosted runner - Replace AI/CI_SETUP.md with a lean tools/ci/README.md co-located with the CI scripts; drop one-time runner setup steps, keep maintainer-relevant bits (credentials, image update, test expansion, FA2 isolation) --- .github/actions/gpu-test/action.yml | 43 +++++++ .github/workflows/ci.yml | 44 +++++++ flash_attn/cute/cache_utils.py | 1 + flash_attn/cute/pyproject.toml | 1 + test_ci_local.sh | 8 ++ tools/ci/README.md | 48 ++++++++ tools/ci/build_sif.sh | 45 +++++++ tools/ci/docker/Dockerfile | 30 +++++ tools/ci/docker/build.sh | 12 ++ tools/ci/docker/tag_and_push.sh | 9 ++ tools/ci/fa4.def | 7 ++ tools/ci/run_fa4_ci.py | 176 ++++++++++++++++++++++++++++ 12 files changed, 424 insertions(+) create mode 100644 .github/actions/gpu-test/action.yml create mode 100644 .github/workflows/ci.yml create mode 100755 test_ci_local.sh create mode 100644 tools/ci/README.md create mode 100755 tools/ci/build_sif.sh create mode 100644 tools/ci/docker/Dockerfile create mode 100755 tools/ci/docker/build.sh create mode 100755 tools/ci/docker/tag_and_push.sh create mode 100644 tools/ci/fa4.def create mode 100644 tools/ci/run_fa4_ci.py diff --git a/.github/actions/gpu-test/action.yml b/.github/actions/gpu-test/action.yml new file mode 100644 index 00000000000..4759bf0903b --- /dev/null +++ b/.github/actions/gpu-test/action.yml @@ -0,0 +1,43 @@ +name: GPU Test +description: Compile and run FA4 tests (pull SIF from Docker Hub, cache by tag) + +inputs: + test-filter: + description: pytest -k filter expression + required: false + default: "" + compile-workers: + description: parallel workers for Pass 1 kernel compilation + required: false + default: "64" + +runs: + using: composite + steps: + - name: Pull FA4 SIF + shell: bash + run: | + CI_WORK_DIR="${CI_WORK_DIR:-/scratch/user/$USER}" + TAG=$(echo "$FA4_IMAGE" | tr '/: ' '---') + SIF="$CI_WORK_DIR/${TAG}.sif" + echo "FA4_IMAGE=$FA4_IMAGE" + echo "SIF=$SIF" + if [ ! -f "$SIF" ]; then + echo "Pulling $FA4_IMAGE → $SIF" + APPTAINER_TMPDIR="$CI_WORK_DIR/apptainer_tmp" \ + APPTAINER_CACHEDIR="$CI_WORK_DIR/apptainer_cache" \ + apptainer pull "$SIF" "docker://$FA4_IMAGE" + else + echo "Using cached SIF: $SIF" + fi + # Remove stale SIFs from previous image versions to prevent unbounded disk growth. + find "$CI_WORK_DIR" -maxdepth 1 -name "*.sif" ! -name "$(basename "$SIF")" -delete + echo "FA4_SIF=$SIF" >> "$GITHUB_ENV" + + - name: Compile and run tests + shell: bash + run: | + python3 "$GITHUB_WORKSPACE/tools/ci/run_fa4_ci.py" \ + --repo-root "$GITHUB_WORKSPACE" \ + --test-filter "${{ inputs.test-filter }}" \ + --compile-workers "${{ inputs.compile-workers }}" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000000..8ac2467d3c9 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,44 @@ +name: CI + +on: + push: + branches: [main] + +permissions: + contents: read + +env: + FA4_IMAGE: togethercomputer/training-performance:flash-attn-cu13.0-26.04.01@sha256:56e50b056eb4d671410846c3483e843ee7bd0f5b13cb45b6f0d7eb8bd27694a5 + CI_WORK_DIR: ${{ vars.CI_WORK_DIR || format('/scratch/user/{0}', github.actor) }} + FA4_TEST_FILTER: "1-1-128-True-0-0.0-False-False-False-mha-dtype0" + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install ruff + run: pip install ruff + - name: Ruff check + run: ruff check flash_attn/cute/ --extend-exclude "flash_attn/cute/flash_bwd.py,flash_attn/cute/flash_fwd.py,flash_attn/cute/flash_fwd_sm100.py,flash_attn/cute/interface.py" + - name: Ruff format + run: ruff format --check flash_attn/cute/ --exclude "flash_attn/cute/flash_bwd.py,flash_attn/cute/flash_fwd.py,flash_attn/cute/flash_fwd_sm100.py,flash_attn/cute/interface.py" + + test: + strategy: + fail-fast: false + matrix: + gpu: [b200] + runs-on: [self-hosted, '${{ matrix.gpu }}'] + name: test (${{ matrix.gpu }}) + timeout-minutes: 60 + steps: + - uses: actions/checkout@v4 + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - uses: ./.github/actions/gpu-test + with: + test-filter: ${{ env.FA4_TEST_FILTER }} diff --git a/flash_attn/cute/cache_utils.py b/flash_attn/cute/cache_utils.py index f1b59700448..8982aa6bbcf 100644 --- a/flash_attn/cute/cache_utils.py +++ b/flash_attn/cute/cache_utils.py @@ -31,6 +31,7 @@ CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function + # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1" diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml index 2b0b60b42f1..6ecf64d4e1d 100644 --- a/flash_attn/cute/pyproject.toml +++ b/flash_attn/cute/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ cu13 = ["nvidia-cutlass-dsl[cu13]>=4.4.2"] dev = [ "pytest", + "pytest-xdist", "ruff", ] diff --git a/test_ci_local.sh b/test_ci_local.sh new file mode 100755 index 00000000000..24a8404fabe --- /dev/null +++ b/test_ci_local.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd) + +python3 "$SCRIPT_DIR/tools/ci/run_fa4_ci.py" \ + --repo-root "$SCRIPT_DIR" \ + "$@" diff --git a/tools/ci/README.md b/tools/ci/README.md new file mode 100644 index 00000000000..11deaeb10e0 --- /dev/null +++ b/tools/ci/README.md @@ -0,0 +1,48 @@ +# FA4 CI + +CI runs on a self-hosted GPU runner using an Apptainer (SIF) container pulled from Docker Hub. +Triggered on every push to `main`. + +## Two-pass test strategy + +- **Pass 1** — compile kernels in parallel via `FakeTensorMode` (no GPU memory needed) +- **Pass 2** — run tests using cached compiled kernels on real GPU + +See `run_fa4_ci.py` for the shared logic used by both CI and `test_ci_local.sh`. + +## Required GitHub secrets / variables + +| Name | Kind | Value | +|------|------|-------| +| `DOCKERHUB_USERNAME` | Secret | Docker Hub username | +| `DOCKERHUB_TOKEN` | Secret | Docker Hub access token | +| `CI_WORK_DIR` | Variable | Large-disk path on runner, e.g. `/scratch/user/johnson` | + +`CI_WORK_DIR` is used for SIF caching and Apptainer temp files. Falls back to `/scratch/user/` if unset. + +## Updating the container image + +1. Build and push a new image via `tools/ci/docker/build.sh` + `tag_and_push.sh`. +2. Update `FA4_IMAGE` in `.github/workflows/ci.yml` with the new tag and `sha256` digest. +3. The old SIF is automatically deleted from the runner on the next CI run. + +## Expanding test coverage + +Edit `FA4_TEST_FILTER` in `.github/workflows/ci.yml`. To run the full suite, set it to an empty string and increase `compile-workers` in the `gpu-test` action call. + +Alternatively, edit `run_fa4_ci.py` to change `DEFAULT_TEST_TARGET` or worker defaults — changes there apply to both CI and local runs. + +## FA2 import isolation + +Tests run inside the Apptainer container. The repo's `flash_attn/__init__.py` imports the FA2 C extension (`flash_attn_2_cuda`) which is absent in the container. `run_fa4_ci.py` works around this by: + +1. Installing FA4 from the current repo into the container at runtime (`uv pip install -e flash_attn/cute`). +2. Running pytest from `/tmp` with absolute test paths — this keeps the repo root out of `sys.path[0]` so the installed FA4 package is found instead of the FA2 `__init__.py`. + +`flash_attn/__init__.py` is intentionally not modified; isolation is handled entirely in CI. + +## Adding a new runner / GPU type + +1. Register a self-hosted runner on the machine with the desired label (e.g. `h100`). +2. Add the label to the `gpu` matrix in `.github/workflows/ci.yml`. +3. Set `CI_WORK_DIR` for the new machine if its scratch path differs. diff --git a/tools/ci/build_sif.sh b/tools/ci/build_sif.sh new file mode 100755 index 00000000000..429f8720945 --- /dev/null +++ b/tools/ci/build_sif.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +# Build the FA4 Apptainer SIF image. +# +# Usage: +# ./tools/ci/build_sif.sh [OUTPUT_PATH] +# +# Default output: /scratch/user/$USER/attention_fa4_.sif +# All temp/cache dirs are redirected to /scratch to avoid filling the root volume. +# +# Example: +# ./tools/ci/build_sif.sh +# ./tools/ci/build_sif.sh ~/scratch/my_fa4.sif + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DEF_FILE="$SCRIPT_DIR/fa4.def" + +DATE=$(date +%Y%m%d) +SCRATCH_BASE="${CI_WORK_DIR:-/scratch/user/${USER}}" +DEFAULT_OUT="${SCRATCH_BASE}/attention_fa4_${DATE}.sif" +OUTPUT="${1:-$DEFAULT_OUT}" +TMP_DIR="$SCRATCH_BASE/apptainer_tmp" +CACHE_DIR="$SCRATCH_BASE/apptainer_cache" + +mkdir -p "$TMP_DIR" "$CACHE_DIR" + +echo "=== FA4 SIF Build ===" +echo " def file : $DEF_FILE" +echo " output : $OUTPUT" +echo " tmp dir : $TMP_DIR" +echo " cache dir: $CACHE_DIR" +echo + +sudo \ + APPTAINER_TMPDIR="$TMP_DIR" \ + APPTAINER_CACHEDIR="$CACHE_DIR" \ + apptainer build "$OUTPUT" "$DEF_FILE" + +echo +echo "Build complete: $OUTPUT" +echo "File size: $(du -sh "$OUTPUT" | cut -f1)" +echo +echo "To use in CI, set on the runner:" +echo " export FA4_SIF=$OUTPUT" diff --git a/tools/ci/docker/Dockerfile b/tools/ci/docker/Dockerfile new file mode 100644 index 00000000000..e01b1d5575a --- /dev/null +++ b/tools/ci/docker/Dockerfile @@ -0,0 +1,30 @@ +FROM ubuntu:24.04 + +ENV PATH=/usr/local/cuda/bin:/root/.local/bin:$PATH +ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH +ENV DEBIAN_FRONTEND=noninteractive + +# System dependencies +RUN apt-get update -y && apt-get install -y \ + python3 python3.12 python3.12-dev python3.12-venv \ + curl git ca-certificates gcc && \ + rm -rf /var/lib/apt/lists/* + +# Install uv +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +# Install torch from the published cu130 nightly package page and let pip resolve +# the matching nightly Triton wheel from PyTorch's nightly triton page. +RUN uv pip install --system --break-system-packages --no-cache --pre \ + "torch>=2.12.0.dev0,<2.13.0" \ + --find-links https://download.pytorch.org/whl/nightly/cu130/torch \ + --find-links https://download.pytorch.org/whl/nightly/triton + +# FA4 dependencies — derived from pyproject.toml to keep a single source of truth. +# torch stays pinned from the direct wheel install above. +# The package itself (version 0.0.0 via setuptools-scm fallback) is overwritten at +# step time by the editable install from the mounted repo. +COPY flash_attn/cute/ /tmp/fa4/ +RUN uv pip install --system --break-system-packages --no-cache "/tmp/fa4[cu13,dev]" + +CMD ["/bin/bash"] diff --git a/tools/ci/docker/build.sh b/tools/ci/docker/build.sh new file mode 100755 index 00000000000..9d82bf146d2 --- /dev/null +++ b/tools/ci/docker/build.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +DATE=$(date +%y.%m.%d) +IMAGE_NAME="flash-attn-4:flash-attn-cu13.0-${DATE}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(git -C "$SCRIPT_DIR" rev-parse --show-toplevel)" + +echo "Building $IMAGE_NAME ..." +sudo docker build -t "$IMAGE_NAME" -f "$SCRIPT_DIR/Dockerfile" "$REPO_ROOT" +echo "Done: $IMAGE_NAME" diff --git a/tools/ci/docker/tag_and_push.sh b/tools/ci/docker/tag_and_push.sh new file mode 100755 index 00000000000..b67dc232324 --- /dev/null +++ b/tools/ci/docker/tag_and_push.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -e + +DATE=$(date +%y.%m.%d) +LOCAL_TAG="flash-attn-4:flash-attn-cu13.0-${DATE}" +REMOTE_TAG="togethercomputer/training-performance:flash-attn-cu13.0-${DATE}" + +docker tag "$LOCAL_TAG" "$REMOTE_TAG" +docker push "$REMOTE_TAG" diff --git a/tools/ci/fa4.def b/tools/ci/fa4.def new file mode 100644 index 00000000000..4ffda0a8445 --- /dev/null +++ b/tools/ci/fa4.def @@ -0,0 +1,7 @@ +Bootstrap: docker +From: togethercomputer/training-performance:flash-attn-cu13.0-26.04.01 + +%labels + maintainer tridao + cuda cu130 + built_for B200/SM100 diff --git a/tools/ci/run_fa4_ci.py b/tools/ci/run_fa4_ci.py new file mode 100644 index 00000000000..29b2aeaaaa1 --- /dev/null +++ b/tools/ci/run_fa4_ci.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +"""FA4 CI driver — runs inside an Apptainer SIF on a self-hosted GPU runner. + +Requires FA4_SIF (path to the .sif image) to be set, either via env var or --sif. +""" + +from __future__ import annotations + +import argparse +import os +import shlex +import subprocess +from dataclasses import dataclass +from pathlib import Path + +DEFAULT_TEST_FILTER = "" # empty = run all; CI overrides via --test-filter +DEFAULT_TEST_TARGET = "tests/cute/test_flash_attn.py" + + +@dataclass(frozen=True) +class Step: + name: str + command: list[str] + extra_env: dict[str, str] + + +# ── GPU helpers ─────────────────────────────────────────────────────────────── + +def parse_free_gpu_indices(nvidia_smi_output: str, min_free_memory_mb: int) -> list[str]: + indices: list[str] = [] + for raw_line in nvidia_smi_output.splitlines(): + line = raw_line.strip() + if not line: + continue + try: + index, free_memory = [part.strip() for part in line.split(",", maxsplit=1)] + if int(free_memory) >= min_free_memory_mb: + indices.append(index) + except ValueError as exc: + raise ValueError(f"Unexpected nvidia-smi output line: {raw_line!r}") from exc + return indices + + +def select_visible_devices(free_gpu_indices: list[str], use_all_free_gpus: bool) -> str: + if not free_gpu_indices: + raise ValueError("No GPUs satisfy the free-memory threshold") + if use_all_free_gpus: + return ",".join(free_gpu_indices) + return free_gpu_indices[0] + + +def read_free_gpu_indices(min_free_memory_mb: int) -> list[str]: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=index,memory.free", "--format=csv,noheader,nounits"], + check=True, capture_output=True, text=True, + ) + return parse_free_gpu_indices(result.stdout, min_free_memory_mb) + + +# ── Step plan ───────────────────────────────────────────────────────────────── + +def build_step_plan( + test_target: str, + test_filter: str, + compile_workers: int, + run_workers: int, + test_visible_devices: str, + benchmark_visible_devices: str, + skip_benchmark: bool, +) -> list[Step]: + pytest_base = ["python3", "-m", "pytest", test_target, *(["-k", test_filter] if test_filter else [])] + + steps = [ + Step( + name="Pass 1: compile kernels (no GPU memory)", + command=[*pytest_base, "-n", str(compile_workers), "-x"], + extra_env={"FLASH_ATTENTION_FAKE_TENSOR": "1"}, + ), + Step( + name="Pass 2: run tests on GPU", + command=[*pytest_base, "-n", str(run_workers), "-x"], + extra_env={ + "FLASH_ATTENTION_FAKE_TENSOR": "0", + "CUDA_VISIBLE_DEVICES": test_visible_devices, + }, + ), + ] + if not skip_benchmark: + steps.append(Step( + name="Benchmark (FA4 fwd, hdim=128, seqlen=8192)", + command=[ + "python3", "benchmarks/benchmark_attn.py", + "--backend", "fa4", "--fwd", + "--headdim", "128", "--seqlen", "8192", + "--causal", "both", "--warmup", "1", "--rep", "3", + ], + extra_env={"CUDA_VISIBLE_DEVICES": benchmark_visible_devices}, + )) + return steps + + +# ── Step runner ─────────────────────────────────────────────────────────────── + +def run_step(step: Step, repo_root: Path, base_env: dict[str, str], sif: str, work_dir: str) -> None: + print(f"=== {step.name} ===") + + # Install FA4 from the current repo inside this exec invocation. + # Must be done per-step because --writable-tmpfs creates a fresh overlay each time. + install_cmd = f"uv pip install --system --break-system-packages --no-deps -q -e {shlex.quote(str(repo_root / 'flash_attn/cute'))}" + + # Convert relative test/benchmark paths to absolute so we can run from /tmp. + # Running from /tmp ensures Python does not insert repo_root into sys.path[0] + # (which would cause flash_attn/__init__.py to trigger FA2 imports unavailable in the SIF). + command = [ + str(repo_root / arg) if (arg.startswith("tests/") or arg.startswith("benchmarks/")) else arg + for arg in step.command + ] + env_exports = " && ".join(f"export {k}={shlex.quote(v)}" for k, v in step.extra_env.items()) + inner_cmd = shlex.join(command) + shell_parts = [install_cmd] + if env_exports: + shell_parts.append(env_exports) + shell_parts.append(f"cd /tmp && {inner_cmd}") + cmd = ["apptainer", "exec", "--nv", "--writable-tmpfs", "--bind", work_dir, sif, "bash", "-c", " && ".join(shell_parts)] + subprocess.run(cmd, check=True, cwd=repo_root, env=base_env) + + +# ── CLI ─────────────────────────────────────────────────────────────────────── + +def make_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--repo-root", type=Path, default=Path.cwd()) + parser.add_argument("--sif", default=os.environ.get("FA4_SIF", ""), + help="Apptainer .sif image path (or set FA4_SIF env var)") + parser.add_argument("--test-target", default=DEFAULT_TEST_TARGET) + parser.add_argument("--test-filter", default=DEFAULT_TEST_FILTER) + parser.add_argument("--compile-workers", type=int, default=1) + parser.add_argument("--run-workers", type=int, default=1) + parser.add_argument("--min-free-memory-mb", type=int, default=40000) + parser.add_argument("--use-all-free-gpus", action="store_true") + parser.add_argument("--skip-benchmark", action="store_true") + return parser + + +def main() -> None: + args = make_parser().parse_args() + repo_root = args.repo_root.resolve() + + if not args.sif: + raise SystemExit("FA4_SIF is not set — provide --sif or set the FA4_SIF env var.") + print(f"Using SIF: {args.sif}") + + free_gpu_indices = read_free_gpu_indices(args.min_free_memory_mb) + test_visible_devices = select_visible_devices(free_gpu_indices, args.use_all_free_gpus) + benchmark_visible_devices = free_gpu_indices[0] + print(f"Running tests on GPUs: {test_visible_devices}") + + base_env = {**os.environ, "FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED": "1"} + work_dir = os.environ.get("CI_WORK_DIR", f"/scratch/user/{os.environ.get('USER', 'user')}") + + for step in build_step_plan( + test_target=args.test_target, + test_filter=args.test_filter, + compile_workers=args.compile_workers, + run_workers=args.run_workers, + test_visible_devices=test_visible_devices, + benchmark_visible_devices=benchmark_visible_devices, + skip_benchmark=args.skip_benchmark, + ): + run_step(step, repo_root=repo_root, base_env=base_env, sif=args.sif, work_dir=work_dir) + + print("=== All tests passed ===") + + +if __name__ == "__main__": + main() From 83f9e450cd10e20701fb109db9c7703d376f282b Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Wed, 1 Apr 2026 21:38:44 -0700 Subject: [PATCH 04/44] Fix some bugs of CI (#2423) * CI: fix ruff format, Apptainer pull, add cu129/cu130 auto-selection by CUDA version, trigger on main and johnson/ci-fix branches * CI: trigger on main and ci-fix branches --- .github/actions/gpu-test/action.yml | 30 ++++++++++++++++++++++++++--- .github/workflows/ci.yml | 5 +++-- flash_attn/cute/cache_utils.py | 1 - 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/.github/actions/gpu-test/action.yml b/.github/actions/gpu-test/action.yml index 4759bf0903b..54adc89b3b2 100644 --- a/.github/actions/gpu-test/action.yml +++ b/.github/actions/gpu-test/action.yml @@ -10,23 +10,47 @@ inputs: description: parallel workers for Pass 1 kernel compilation required: false default: "64" + fa4_image_cu129: + description: Docker image for CUDA 12.9 (used when driver does not support CUDA 13.0) + required: true + fa4_image_cu130: + description: Docker image for CUDA 13.0 (used when driver supports CUDA 13.0) + required: true runs: using: composite steps: + - name: Select FA4 image based on CUDA version + shell: bash + run: | + # Read max supported CUDA version from nvidia-smi header, e.g. "CUDA Version: 12.9" + CUDA_VER=$(nvidia-smi | grep -oP "CUDA Version: \K[0-9]+\.[0-9]+") + CUDA_MAJOR=$(echo "$CUDA_VER" | cut -d. -f1) + echo "Detected max CUDA version: $CUDA_VER" + if [ "$CUDA_MAJOR" -ge 13 ]; then + echo "Using cu130 image" + echo "FA4_IMAGE=${{ inputs.fa4_image_cu130 }}" >> "$GITHUB_ENV" + else + echo "Using cu129 image" + echo "FA4_IMAGE=${{ inputs.fa4_image_cu129 }}" >> "$GITHUB_ENV" + fi + - name: Pull FA4 SIF shell: bash run: | CI_WORK_DIR="${CI_WORK_DIR:-/scratch/user/$USER}" TAG=$(echo "$FA4_IMAGE" | tr '/: ' '---') SIF="$CI_WORK_DIR/${TAG}.sif" - echo "FA4_IMAGE=$FA4_IMAGE" + # Apptainer doesn't support tag@digest refs — strip the tag, keep digest only. + PULL_REF=$(echo "$FA4_IMAGE" | sed 's/:[^@]*@/@/') + echo "PULL_REF=$PULL_REF" echo "SIF=$SIF" + mkdir -p "$CI_WORK_DIR/apptainer_tmp" "$CI_WORK_DIR/apptainer_cache" if [ ! -f "$SIF" ]; then - echo "Pulling $FA4_IMAGE → $SIF" + echo "Pulling $PULL_REF → $SIF" APPTAINER_TMPDIR="$CI_WORK_DIR/apptainer_tmp" \ APPTAINER_CACHEDIR="$CI_WORK_DIR/apptainer_cache" \ - apptainer pull "$SIF" "docker://$FA4_IMAGE" + apptainer pull "$SIF" "docker://$PULL_REF" else echo "Using cached SIF: $SIF" fi diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8ac2467d3c9..41cb14e5688 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,13 +2,12 @@ name: CI on: push: - branches: [main] + branches: [main, ci-fix] permissions: contents: read env: - FA4_IMAGE: togethercomputer/training-performance:flash-attn-cu13.0-26.04.01@sha256:56e50b056eb4d671410846c3483e843ee7bd0f5b13cb45b6f0d7eb8bd27694a5 CI_WORK_DIR: ${{ vars.CI_WORK_DIR || format('/scratch/user/{0}', github.actor) }} FA4_TEST_FILTER: "1-1-128-True-0-0.0-False-False-False-mha-dtype0" @@ -42,3 +41,5 @@ jobs: - uses: ./.github/actions/gpu-test with: test-filter: ${{ env.FA4_TEST_FILTER }} + fa4_image_cu129: "togethercomputer/training-performance:flash-attn-cu12.9-26.03.25@sha256:304a5c3d2b3a75b151cd2a964cd26d444e0d8b5686d63943df13378c9705f943" + fa4_image_cu130: "togethercomputer/training-performance:flash-attn-cu13.0-26.04.01@sha256:56e50b056eb4d671410846c3483e843ee7bd0f5b13cb45b6f0d7eb8bd27694a5" diff --git a/flash_attn/cute/cache_utils.py b/flash_attn/cute/cache_utils.py index 8982aa6bbcf..f1b59700448 100644 --- a/flash_attn/cute/cache_utils.py +++ b/flash_attn/cute/cache_utils.py @@ -31,7 +31,6 @@ CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function - # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1" From ab5cb6e68b9203c4ce1811765a6e53dcb8bae069 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 2 Apr 2026 17:08:45 -0400 Subject: [PATCH 05/44] [ROCM] Fix windows issues (#2385) * Initial FA-2 aiter Triton Windows build support * minimize diff * bump commit * bump commit * minimize diff * bump commit * bump aiter submodule * bump aiter submodule to merged #2433 * fix: guard distributed.py fallbacks with hasattr for Windows --------- Co-authored-by: 0xDELUXA --- flash_attn/utils/distributed.py | 6 ++++-- third_party/aiter | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/flash_attn/utils/distributed.py b/flash_attn/utils/distributed.py index 74c55279645..6fe93790154 100644 --- a/flash_attn/utils/distributed.py +++ b/flash_attn/utils/distributed.py @@ -9,9 +9,11 @@ # version of PyTorch. The following 4 lines are for backward compatibility with # older PyTorch. if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + if hasattr(torch.distributed, "_all_gather_base"): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base if "reduce_scatter_tensor" not in dir(torch.distributed): - torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base + if hasattr(torch.distributed, "_reduce_scatter_base"): + torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base # Raw operation, does not support autograd, but does support async diff --git a/third_party/aiter b/third_party/aiter index 428e8e761c7..b4b75165fbd 160000 --- a/third_party/aiter +++ b/third_party/aiter @@ -1 +1 @@ -Subproject commit 428e8e761c7bc22d03513bcb8507375afef1f916 +Subproject commit b4b75165fbd2456dfd0f074c5b2ef91bc87d97e5 From 1233b73b6c95340c65c9edfe929611838354fc6e Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Thu, 2 Apr 2026 19:06:51 -0700 Subject: [PATCH 06/44] fix: add cu13 extra to dev install instructions for CUDA 13 / B200 systems (#2430) --- flash_attn/cute/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/README.md b/flash_attn/cute/README.md index 653f7b1cee2..c7f1b32ebd0 100644 --- a/flash_attn/cute/README.md +++ b/flash_attn/cute/README.md @@ -27,6 +27,7 @@ out = flash_attn_func(q, k, v, causal=True) ```sh git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention -pip install -e "flash_attn/cute[dev]" +pip install -e "flash_attn/cute[dev]" # CUDA 12.x +pip install -e "flash_attn/cute[dev,cu13]" # CUDA 13.x (e.g. B200) pytest tests/cute/ ``` From 65bfd9a6f27803636e86d75b048da2c13a06c096 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Fri, 3 Apr 2026 11:39:09 -0700 Subject: [PATCH 07/44] Fix: disable 2-CTA backward mode when block_sparse_tensors is used (#2433) The SM100 2-CTA backward kernel does not properly handle block_sparse_tensors. When block sparsity is combined with 2-CTA mode, the kernel hits an assertion: 'AssertionError: 2-CTA mode does not support block sparsity' This fix adds block_sparse_tensors to the disable_2cta condition in the backward path, forcing the 1-CTA kernel when block sparsity is active. The 1-CTA backward kernel already supports block_sparse_tensors correctly. Without this fix, any backward pass using block_sparse_tensors on SM100 (B200/GB200) with head_dim >= 128 will crash with the above assertion. --- flash_attn/cute/interface.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index ef624677f01..872960a298b 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1086,6 +1086,7 @@ def _flash_attn_bwd( or score_mod is not None or score_mod_bwd is not None or mask_mod is not None + or block_sparse_tensors is not None ) cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1 use_2cta_instrs = cluster_size==2 From 15270e66dc88d70d08f6ba6003eb17e7307149b0 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Fri, 3 Apr 2026 23:15:34 -0700 Subject: [PATCH 08/44] CI: extend FA4 test matrix with causal/non-causal correctness and fwd+bwd benchmark seqlen 1K-32K (#2428) --- test_ci_local.sh => .github/scripts/test_ci_local.sh | 2 +- .github/workflows/ci.yml | 6 +++--- tools/ci/run_fa4_ci.py | 7 ++++--- 3 files changed, 8 insertions(+), 7 deletions(-) rename test_ci_local.sh => .github/scripts/test_ci_local.sh (63%) diff --git a/test_ci_local.sh b/.github/scripts/test_ci_local.sh similarity index 63% rename from test_ci_local.sh rename to .github/scripts/test_ci_local.sh index 24a8404fabe..d767492d4e1 100755 --- a/test_ci_local.sh +++ b/.github/scripts/test_ci_local.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -euo pipefail -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd) +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../../" && pwd) python3 "$SCRIPT_DIR/tools/ci/run_fa4_ci.py" \ --repo-root "$SCRIPT_DIR" \ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 41cb14e5688..a992b677a11 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ permissions: env: CI_WORK_DIR: ${{ vars.CI_WORK_DIR || format('/scratch/user/{0}', github.actor) }} - FA4_TEST_FILTER: "1-1-128-True-0-0.0-False-False-False-mha-dtype0" + FA4_TEST_FILTER: "1024-1024-128-True-0-0.0-False-False-False-mha-dtype0 or 1024-1024-128-False-0-0.0-False-False-False-mha-dtype0" jobs: lint: @@ -23,13 +23,13 @@ jobs: - name: Ruff format run: ruff format --check flash_attn/cute/ --exclude "flash_attn/cute/flash_bwd.py,flash_attn/cute/flash_fwd.py,flash_attn/cute/flash_fwd_sm100.py,flash_attn/cute/interface.py" - test: + fa4-correctness-and-benchmark: strategy: fail-fast: false matrix: gpu: [b200] runs-on: [self-hosted, '${{ matrix.gpu }}'] - name: test (${{ matrix.gpu }}) + name: fa4-correctness-and-benchmark (${{ matrix.gpu }}) timeout-minutes: 60 steps: - uses: actions/checkout@v4 diff --git a/tools/ci/run_fa4_ci.py b/tools/ci/run_fa4_ci.py index 29b2aeaaaa1..e539df7f056 100644 --- a/tools/ci/run_fa4_ci.py +++ b/tools/ci/run_fa4_ci.py @@ -87,11 +87,12 @@ def build_step_plan( ] if not skip_benchmark: steps.append(Step( - name="Benchmark (FA4 fwd, hdim=128, seqlen=8192)", + name="Benchmark (FA4 fwd, hdim=128, causal=both, seqlen=1K-32K)", command=[ "python3", "benchmarks/benchmark_attn.py", - "--backend", "fa4", "--fwd", - "--headdim", "128", "--seqlen", "8192", + "--backend", "fa4", "--fwd", "--bwd", + "--headdim", "128", + "--seqlen", "1024,2048,4096,8192,16384,32768", "--causal", "both", "--warmup", "1", "--rep", "3", ], extra_env={"CUDA_VISIBLE_DEVICES": benchmark_visible_devices}, From 14f3627d44687513adff00819ec894e54bf92cd7 Mon Sep 17 00:00:00 2001 From: CaesarG <44970034+CaesarG@users.noreply.github.com> Date: Sat, 11 Apr 2026 21:14:52 +0800 Subject: [PATCH 09/44] feat(cute): implement softcap backward pass, correct math formula, and resolve JIT cache bug (#2402) --- flash_attn/cute/flash_bwd.py | 31 +++++++++++++--- flash_attn/cute/flash_bwd_sm100.py | 1 - flash_attn/cute/flash_bwd_sm90.py | 1 - flash_attn/cute/flash_fwd.py | 38 ++++++++++++++++++-- flash_attn/cute/interface.py | 36 +++++++++++++------ flash_attn/cute/utils.py | 22 +++++++++--- tests/cute/test_flash_attn.py | 33 ++++++++++++++--- tests/cute/test_flash_attn_race_condition.py | 8 ++--- 8 files changed, 138 insertions(+), 32 deletions(-) diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 824abdda139..eeb7615b1d3 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -46,6 +46,8 @@ def __init__( AtomLayoutNdKV: int = 8, AtomLayoutMdQ: int = 1, V_in_regs: bool = False, + score_mod: cutlass.Constexpr | None = None, + score_mod_bwd: cutlass.Constexpr | None = None, ): """Initializes the configuration for a flash attention v2 kernel. @@ -90,6 +92,8 @@ def __init__( self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_mma_warps and SdP_swapAB and not dKV_swapAB self.V_in_regs = V_in_regs self.share_QV_smem = V_in_regs + self.score_mod = score_mod + self.score_mod_bwd = score_mod_bwd @staticmethod def can_implement( @@ -377,7 +381,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, @@ -430,7 +433,7 @@ def __call__( tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) - softmax_scale_log2 = softmax_scale * math.log2(math.e) + softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod) self.kernel( mQ, mK, @@ -773,6 +776,7 @@ def kernel( smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params, load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum, m_block_max=m_block_max, + softmax_scale=softmax_scale, softmax_scale_log2=softmax_scale_log2, ) @@ -861,6 +865,7 @@ def compute_one_m_block( load_Q_LSE: Callable, load_dO_dPsum: Callable, m_block_max: cutlass.Int32, + softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, mask_fn: Optional[Callable] = None, ): @@ -890,13 +895,24 @@ def load_dO_next(): smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, swap_AB=self.SdP_swapAB, ) + acc_S_pre = cute.make_fragment_like(acc_S) + acc_S_pre.store(acc_S.load()) tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0]) cute.autovec_copy( smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE ) + acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) + acc_S_pre_mn = layout_utils.reshape_acc_to_mn(acc_S_pre) + if cutlass.const_expr(self.score_mod is not None): + for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True): + acc_S_mn[r, None].store( + self.score_mod( + acc_S_mn[r, None].load() * softmax_scale, + 0, 0, 0, 0, None, [], + ) + ) if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) - acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) bidx = 0 # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE) @@ -926,7 +942,14 @@ def load_dO_next(): # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum) for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True): - acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r])) + grad_val = acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r]) + if cutlass.const_expr(self.score_mod_bwd is not None): + grad_val = self.score_mod_bwd( + grad_val, + acc_S_pre_mn[r, None].load() * softmax_scale, + 0, 0, 0, 0, None, [], + ) + acc_dP_mn[r, None].store(grad_val) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) rP = cute.make_fragment_like(acc_S, self.dtype) rP.store(acc_S.load().to(self.dtype)) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index e06cd811fc6..4b4083eda9e 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -456,7 +456,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index f724b5a11e3..c9a690d1e90 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -350,7 +350,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 4d47fab109f..d1a43cfd247 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -27,7 +27,7 @@ from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.pack_gqa import PackGQA @@ -1145,8 +1145,8 @@ def load_V_next(): m_block, acc_S, n_block, - seqlen, softmax_scale=softmax.softmax_scale, + seqlen=seqlen, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) @@ -1185,6 +1185,40 @@ def load_K_next(): ) # if const_expr(self.num_stages > 1): # load_K_next() + @cute.jit + def apply_score_mod( + self, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + acc_S, + n_block, + softmax_scale, + seqlen, + aux_tensors: Optional[list] = None, + fastdiv_mods=None, + ): + # Prepare index tensor + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS) + tScS = thr_mma_qk.partition_C(cS) + + apply_score_mod_inner( + acc_S, + tScS, + self.score_mod, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info=seqlen, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 872960a298b..d092be1d8f6 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -543,13 +543,16 @@ def _flash_attn_fwd( and (tile_m % qhead_per_kvhead == 0 or not pack_gqa) ) - # hash score and mask mods for compile cache - score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False - mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False - if softcap is not None: assert score_mod is None, "softcap and score_mod cannot be used together" score_mod = utils.create_softcap_scoremod(softcap) + elif score_mod is not None: + if arch // 10 == 8: + raise NotImplementedError("Custom user-provided score_mod is not supported on SM8x architectures.") + + # hash score and mask mods for compile cache + score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False + mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False is_varlen = ( cu_seqlens_q is not None @@ -1171,12 +1174,20 @@ def _flash_attn_bwd( pack_gqa = qhead_per_kvhead > 1 # pack_gqa backward not yet supported in bwd pack_gqa = False - if score_mod is not None: + + if softcap != 0.0: + assert score_mod is None and score_mod_bwd is None, ( + "softcap and score_mod/score_mod_bwd cannot be used together" + ) + score_mod = utils.create_softcap_scoremod(softcap) + score_mod_bwd = utils.create_softcap_scoremod_bwd(softcap) + elif score_mod is not None: assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided" - assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)" assert cu_seqlens_q is None and cu_seqlens_k is None, ( "varlen + score_mod not supported in bwd yet" ) + if arch // 10 == 8: + raise NotImplementedError("Custom user-provided score_mod is not supported on SM8x architectures.") device = q.device out_torch_dtype = q.dtype @@ -1322,7 +1333,6 @@ def _flash_attn_bwd( causal, window_size_left is not None, window_size_right is not None, - softcap != 0.0, m_block_size, n_block_size, num_threads, @@ -1352,6 +1362,9 @@ def _flash_attn_bwd( get_broadcast_dims(k), get_broadcast_dims(v), get_broadcast_dims(dout), + # Prevent TVM stride poisoning when only one block is present. + (seqlen_q_rounded // m_block_size == 1), + (seqlen_k_rounded // n_block_size == 1), ) else: compile_key = ( @@ -1363,7 +1376,6 @@ def _flash_attn_bwd( causal, window_size_left is not None, window_size_right is not None, - softcap != 0.0, m_block_size, n_block_size, num_threads, @@ -1385,6 +1397,9 @@ def _flash_attn_bwd( get_broadcast_dims(k), get_broadcast_dims(v), get_broadcast_dims(dout), + # Prevent TVM stride poisoning when only one block is present. + (seqlen_q_rounded // m_block_size == 1), + (seqlen_k_rounded // n_block_size == 1), ) if compile_key not in _flash_attn_bwd.compile_cache: q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ @@ -1427,6 +1442,8 @@ def _flash_attn_bwd( AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs=V_in_regs, + score_mod=score_mod, + score_mod_bwd=score_mod_bwd, ) elif arch // 10 == 9: fa_bwd_obj = FlashAttentionBackwardSm90( @@ -1498,7 +1515,6 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, - None, # softcap - not yet supported in backward window_size_left, window_size_right, dQ_semaphore_tensor, @@ -1525,7 +1541,6 @@ def _flash_attn_bwd( cu_seqlens_k, seqused_q, seqused_k, - None, # softcap - not yet supported in backward window_size_left, window_size_right, dQ_semaphore, @@ -1724,7 +1739,6 @@ def forward( @staticmethod def backward(ctx, dout, dlse): q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors - assert ctx.softcap == 0.0 if not ctx.return_lse: dlse = None if dout is None: diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 31186618569..76579c81cc7 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -126,16 +126,28 @@ def hash_callable( def create_softcap_scoremod(softcap_val): - inv_softcap = 1.0 / softcap_val - @cute.jit - def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors): - scores = acc_S_SSA * inv_softcap - return scores * cute.math.tanh(scores, fastmath=True) + def scoremod_premask_fn( + acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, seqlen_info, aux_tensors + ): + scores = acc_S_SSA / softcap_val + return softcap_val * cute.math.tanh(scores, fastmath=True) return scoremod_premask_fn +def create_softcap_scoremod_bwd(softcap_val): + @cute.jit + def scoremod_bwd_fn( + grad_out_SSA, score_SSA, batch_idx, head_idx, q_idx, kv_idx, seqlen_info, aux_tensors + ): + scores = score_SSA / softcap_val + tanh_scores = cute.math.tanh(scores, fastmath=True) + return grad_out_SSA * (1.0 - tanh_scores * tanh_scores) + + return scoremod_bwd_fn + + LOG2_E = math.log2(math.e) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 69e6308fb60..5f2c7732956 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -5,6 +5,8 @@ import os import random import re +import gc +from functools import wraps import pytest import torch @@ -28,8 +30,28 @@ from flash_attn.cute.interface import ( flash_attn_func, flash_attn_varlen_func, + _flash_attn_fwd, + _flash_attn_bwd, ) +def retry_on_oom(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except torch.OutOfMemoryError as e: + if "out of memory" in str(e).lower(): + if hasattr(_flash_attn_fwd, "compile_cache"): + _flash_attn_fwd.compile_cache.clear() + if hasattr(_flash_attn_bwd, "compile_cache"): + _flash_attn_bwd.compile_cache.clear() + gc.collect() + torch.cuda.empty_cache() + return func(*args, **kwargs) + else: + raise + return wrapper + # torch FakeTensorMode would enable fast cutedsl kernel compilation without allocating the actual GPU memory or running the kernel # When operating fake tensors, we cannot perform data-dependent operations (e.g., `tensor.max()`). USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 @@ -50,8 +72,8 @@ @pytest.mark.parametrize("has_qv", [False]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("softcap", [0.0, 15.0]) -@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 15.0]) +# @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) # @pytest.mark.parametrize("local_enum", [0]) @pytest.mark.parametrize("causal", [False, True]) @@ -96,6 +118,7 @@ ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +@retry_on_oom @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_output( seqlen_q, @@ -388,8 +411,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("has_qv", [False]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("softcap", [0.0, 15.0]) -@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 15.0]) +# @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) # @pytest.mark.parametrize("local_enum", [0]) @pytest.mark.parametrize("causal", [False, True]) @@ -449,6 +472,7 @@ def test_flash_attn_output( (False, True), ], ) +@retry_on_oom @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_varlen_output( seqlen_q, @@ -927,6 +951,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@retry_on_oom @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_kvcache( seqlen_q, diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py index a9b8799f4c1..18295e01843 100644 --- a/tests/cute/test_flash_attn_race_condition.py +++ b/tests/cute/test_flash_attn_race_condition.py @@ -43,8 +43,8 @@ @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [True]) -# @pytest.mark.parametrize("softcap", [0.0, 15.0]) -@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 15.0]) +# @pytest.mark.parametrize("softcap", [0.0]) # @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) @pytest.mark.parametrize("local_enum", [0, 1]) @pytest.mark.parametrize("causal", [False, True]) @@ -356,8 +356,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [True]) -# @pytest.mark.parametrize("softcap", [0.0, 15.0]) -@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 15.0]) +# @pytest.mark.parametrize("softcap", [0.0]) # @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) @pytest.mark.parametrize("local_enum", [0, 1]) @pytest.mark.parametrize("causal", [False, True]) From 79f317c7573dfa136dbe2d994b93ea1bdb30ef49 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 14 Apr 2026 06:52:19 +0700 Subject: [PATCH 10/44] [Doc] Remove old comment about supported features --- flash_attn/cute/interface.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index d092be1d8f6..5a4085ed4e1 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1,24 +1,6 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. -# Supported features: -# - BF16 & FP16 dtype -# - noncausal & causal attention -# - MHA, GQA, MQA -# - hdim 64, 96, 128. -# - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape) -# - varlen -# - sliding window -# - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow) - -# Features not supported yet: -# - split (i.e. FlashDecoding) -# - tuned block sizes -# - paged KV -# - append KV to existing KV cache -# - FP8 -# - bwd pass optimized for Hopper/Blackwell - import os import math from dataclasses import dataclass From 09c93eaa710b90cf7ee2b3270fc775b3b742181e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 14 Apr 2026 06:54:29 +0700 Subject: [PATCH 11/44] [DSL] Remove cute_compile_patched --- flash_attn/cute/__init__.py | 8 -------- flash_attn/cute/cute_dsl_utils.py | 25 +------------------------ 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/flash_attn/cute/__init__.py b/flash_attn/cute/__init__.py index 1b84363b63d..be32e149b85 100644 --- a/flash_attn/cute/__init__.py +++ b/flash_attn/cute/__init__.py @@ -7,19 +7,11 @@ except PackageNotFoundError: __version__ = "0.0.0" -import cutlass.cute as cute - from .interface import ( flash_attn_func, flash_attn_varlen_func, ) -from flash_attn.cute.cute_dsl_utils import cute_compile_patched - -# Patch cute.compile to optionally dump SASS -cute.compile = cute_compile_patched - - __all__ = [ "flash_attn_func", "flash_attn_varlen_func", diff --git a/flash_attn/cute/cute_dsl_utils.py b/flash_attn/cute/cute_dsl_utils.py index 79ebd9df6cf..a6bae4179bb 100644 --- a/flash_attn/cute/cute_dsl_utils.py +++ b/flash_attn/cute/cute_dsl_utils.py @@ -1,9 +1,7 @@ # Copyright (c) 2025, Tri Dao. -import os -import pathlib from typing import Tuple -from functools import partial, lru_cache +from functools import lru_cache import torch @@ -41,27 +39,6 @@ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: return torch.cuda.get_device_capability(device) -def load_cubin_module_data_patched(cubin_data, filepath): - pathlib.Path(filepath).write_bytes(cubin_data) - return load_cubin_module_data_og(cubin_data) - - -def cute_compile_patched(*args, **kwargs): - """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set.""" - cubin_path = os.getenv("CUTE_CUBIN_PATH", None) - if cubin_path is not None: - cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( - load_cubin_module_data_patched, filepath=cubin_path - ) - output = cute_compile_og(*args, **kwargs) - if cubin_path is not None: - cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og - if extract is not None: - sass = extract(cubin_path, None) - pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) - return output - - def assume_strides_aligned(t): """Assume all strides except the last are divisible by 128 bits. From f219c89c886c6ccbf9d3dbd9fe41b11ac64e9df8 Mon Sep 17 00:00:00 2001 From: jayhshah Date: Tue, 14 Apr 2026 20:42:40 -0700 Subject: [PATCH 12/44] [Cute,Sm100,Fwd] add MLA 64/512 with topk sparsity for MQA 128 heads (#2441) * add mla 2cta with topk sparsity support * add tma store O * add clc option; performs worse than single tile * enable clc for topk gather * add producer tails * add mla dsa to interface * ruff format * use tma store for varlen * decouple sm stats from scale for smem * add varlen tests * credit monellz for kernel dump attributes utility * add docstring for optional args, change default value of topk_indices_maybe_oob to None * give default vals for new args in interface * more rigorous tests; fix race condition on smem for rowmax * add bandwidth calc and qv to benchmark script * refactor interface per suggestions * return more Nones for gradients --- .pre-commit-config.yaml | 1 + benchmarks/benchmark_attn.py | 175 +- flash_attn/cute/bench_utils.py | 42 +- flash_attn/cute/cute_dsl_utils.py | 35 + flash_attn/cute/flash_fwd_mla_sm100.py | 3215 ++++++++++++++++++++++++ flash_attn/cute/interface.py | 289 ++- flash_attn/cute/mask.py | 18 +- flash_attn/cute/named_barrier.py | 9 + flash_attn/cute/softmax.py | 33 +- flash_attn/cute/testing.py | 26 +- flash_attn/cute/tile_scheduler.py | 11 +- flash_attn/cute/topk_gather_kv.py | 274 ++ tests/cute/test_flash_attn.py | 693 ++++- 13 files changed, 4702 insertions(+), 119 deletions(-) create mode 100644 flash_attn/cute/flash_fwd_mla_sm100.py create mode 100644 flash_attn/cute/topk_gather_kv.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6118dfa2283..8a4ab4bc7e4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,6 +10,7 @@ repos: flash_bwd| flash_fwd| flash_fwd_sm100| + flash_fwd_mla_sm100| interface| )\.py$ - id: ruff-format diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index 239dff46664..1f727461256 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -10,8 +10,12 @@ from einops import rearrange from flash_attn.cute.bench_utils import ( - flops, attention_ref, - cudnn_fwd_setup, cudnn_bwd_setup, + flops, + bandwidth_fwd_bytes, + bandwidth_bwd_bytes, + attention_ref, + cudnn_fwd_setup, + cudnn_bwd_setup, ) try: @@ -54,7 +58,8 @@ def _make_bwd_fn(fwd_fn, g, inputs): g_match = g[:out.shape[0]] if g.shape[0] != out.shape[0] else g # handle varlen def bwd_fn(): for x in inputs: - x.grad = None + if x is not None: + x.grad = None out.backward(g_match, retain_graph=True) return bwd_fn @@ -134,29 +139,34 @@ def setup_fa3(ctx): def setup_fa4(ctx): if flash_attn_func_python is None: return None, None - q, k, v, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["g"], ctx["causal"] + q, k, v, qv, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["qv"], ctx["g"], ctx["causal"] window_size, softcap = ctx["window_size"], ctx["softcap"] pack_gqa, deterministic = ctx["pack_gqa"], ctx["deterministic"] sinks = ctx["sinks"] + gather_kv_indices = ctx.get("gather_kv_indices") k_use = ctx.get("k_paged", k) if ctx["page_size"] is not None else k v_use = ctx.get("v_paged", v) if ctx["page_size"] is not None else v if ctx["varlen"]: qu = ctx["q_unpad"] ku = ctx.get("k_paged", ctx["k_unpad"]) if ctx["page_size"] is not None else ctx["k_unpad"] vu = ctx.get("v_paged", ctx["v_unpad"]) if ctx["page_size"] is not None else ctx["v_unpad"] + qvu = ctx["qv_unpad"] csq, csk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"] pt = ctx["page_table"] - fwd_fn = lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, page_table=pt, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa) + gather_kv_indices_unpad = ctx.get("gather_kv_indices_unpad") + fwd_fn = lambda: flash_attn_varlen_func_python(qu, ku, vu, qvu, csq, csk, page_table=pt, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa, gather_kv_indices=gather_kv_indices_unpad) else: - fwd_fn = lambda: flash_attn_func_python(q, k_use, v_use, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa) + fwd_fn = lambda: flash_attn_func_python(q, k_use, v_use, qv=qv, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa, gather_kv_indices=gather_kv_indices) bwd_fn = None if ctx["has_backward"] and ctx["dtype"] != torch.float8_e4m3fn: if ctx["varlen"]: qu, ku, vu, gu = ctx["q_unpad"], ctx["k_unpad"], ctx["v_unpad"], ctx["g_unpad"] + qvu = ctx["qv_unpad"] csq, csk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"] - bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, causal=causal, softcap=softcap, deterministic=deterministic), gu, [qu, ku, vu]) + gather_kv_indices_unpad = ctx.get("gather_kv_indices_unpad") + bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_python(qu, ku, vu, qvu, csq, csk, causal=causal, softcap=softcap, deterministic=deterministic, gather_kv_indices=gather_kv_indices_unpad), gu, [qu, ku, vu, qvu]) else: - bwd_fn = _make_bwd_fn(lambda: flash_attn_func_python(q, k, v, causal=causal, softcap=softcap, deterministic=deterministic), g, [q, k, v]) + bwd_fn = _make_bwd_fn(lambda: flash_attn_func_python(q, k, v, qv=qv, causal=causal, softcap=softcap, deterministic=deterministic, gather_kv_indices=gather_kv_indices), g, [q, k, v, qv]) return fwd_fn, bwd_fn @@ -215,6 +225,33 @@ def get_peak_flops(device_index: int = 0, dtype: torch.dtype = torch.bfloat16) - return peak +def get_peak_bandwidth(device_index: int = 0) -> float | None: + """Return peak HBM bandwidth in bytes/sec for the given device. Returns None if unknown.""" + _PEAK_BW = { + # Ampere + "A100": 2.0e12, + "A6000": 0.768e12, + # Ada Lovelace + "L40S": 0.864e12, + # Hopper + "H100 SXM": 3.35e12, + "H100 NVL": 3.35e12, + "H100 PCIe": 2.0e12, + "H200": 4.8e12, + "H20": 4.0e12, + # Blackwell + "GB200": 8.0e12, + "GB300": 8.0e12, + "B300": 8.0e12, + "B200": 8.0e12, + } + device_name = torch.cuda.get_device_name(device_index) + for key in sorted(_PEAK_BW, key=len, reverse=True): + if key.lower() in device_name.lower(): + return _PEAK_BW[key] + return None + + def parse_int_k(s): """Parse an integer with optional k/K suffix, e.g. '8k' -> 8192.""" s = s.strip().lower() @@ -263,11 +300,16 @@ def parse_args(): parser.add_argument('--causal', type=str.lower, choices=['true', 'false', 'both'], default='both', help='Causal mode (default: both)') parser.add_argument('--seqlen', type=csv_ints, default=[8192], - help='Sequence length(s), comma-separated. Supports k suffix, e.g. 1k,2k,8k') + help='KV sequence length(s), comma-separated. Supports k suffix, e.g. 1k,2k,8k') + parser.add_argument('--seqlen-q', type=csv_ints, default=None, + help='Q sequence length(s), comma-separated. Supports k suffix, e.g. 1,128,1k. ' + 'Defaults to matching --seqlen (i.e. seqlen_q == seqlen_kv). ' + 'If a single value is given it is broadcast across all --seqlen values; ' + 'otherwise the list must match the length of --seqlen or be length 1.') parser.add_argument('--total-seqlen', type=parse_int_k, default='32k', help='Total sequence length for batch sizing (default: 32k)') parser.add_argument('--batch-size', type=int, default=None, - help='Batch size (default: total_seqlen // seqlen)') + help='Batch size (default: total_seqlen // seqlen_kv)') parser.add_argument('--deterministic', action='store_true', default=False) parser.add_argument('--nheads', type=int, default=None, help='Number of Q heads (default: 32 for hdim<=64, 16 for hdim<=192, 8 for hdim>192)') @@ -277,6 +319,10 @@ def parse_args(): help='GQA ratio (nheads // nheads_kv). Ignored if --nheads-kv is set.') parser.add_argument('--backend', type=csv_strs, default=['all'], help='Which backends to benchmark, comma-separated (choices: all,standard,fa2,fa3,fa4,cudnn)') + parser.add_argument('--gather-kv', type=int, default=None, + help='kv sparsity length for MLA (hdim=64, hdim_v=512 only). ' + 'When set, passes random kv indices (without repeats) to FA4 and uses gather-kv as ' + 'the effective KV length for flops/bandwidth accounting.') parser.add_argument('--warmup', type=int, default=5, help='Warmup iterations (default: 5)') parser.add_argument('--rep', type=int, default=10, @@ -284,6 +330,27 @@ def parse_args(): return parser.parse_args() +def resolve_seqlen_q_list(seqlen_list, seqlen_q_arg): + """Return a list of seqlen_q values parallel to seqlen_list. + + Rules: + - None → seqlen_q == seqlen_kv for every entry + - single value [x] → broadcast x across all entries + - matching length → use as-is + - mismatch → raise + """ + if seqlen_q_arg is None: + return list(seqlen_list) + if len(seqlen_q_arg) == 1: + return [seqlen_q_arg[0]] * len(seqlen_list) + if len(seqlen_q_arg) != len(seqlen_list): + raise ValueError( + f"--seqlen-q has {len(seqlen_q_arg)} values but --seqlen has " + f"{len(seqlen_list)}; they must match (or pass a single value to broadcast)." + ) + return list(seqlen_q_arg) + + def main(): args = parse_args() @@ -302,7 +369,9 @@ def main(): causal_vals = [False, True] seqlen_list = args.seqlen + seqlen_q_list = resolve_seqlen_q_list(seqlen_list, args.seqlen_q) varlen = args.varlen + gather_kv_length = args.gather_kv # Filter backends to those requested and available enabled = set(args.backend) @@ -317,6 +386,7 @@ def main(): dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype device = 'cuda' peak_flops = get_peak_flops(0, dtype=dtype) + peak_bw = get_peak_bandwidth(0) page_size = None softcap = 0.0 deterministic = args.deterministic @@ -341,20 +411,20 @@ def main(): window_size_fa = (-1, -1) pack_gqa = None - for seqlen in seqlen_list: + for seqlen, seqlen_q in zip(seqlen_list, seqlen_q_list): batch_size = args.batch_size if args.batch_size is not None else max(1, args.total_seqlen // seqlen) - seqlen_q = seqlen q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=has_backward) - q, k, v = [x.detach().to(dtype).requires_grad_(has_backward) for x in [q, k, v]] + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=has_backward) if has_qv else None + q, k, v, qv = [x.detach().to(dtype).requires_grad_(has_backward) if x is not None else None for x in [q, k, v, qv]] g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) # Varlen tensors - q_unpad = k_unpad = v_unpad = g_unpad = cu_seqlens_q = cu_seqlens_k = None + q_unpad = k_unpad = v_unpad = qv_unpad = g_unpad = cu_seqlens_q = cu_seqlens_k = None if varlen: - q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) for x in [q, k, v]] + q_unpad, k_unpad, v_unpad, qv_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) if x is not None else None for x in [q, k, v, qv]] g_unpad = rearrange(g.detach(), "b s h d -> (b s) h d") cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen if page_size is None else None @@ -367,33 +437,50 @@ def main(): page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), "(b s) -> b s", s=seqlen // page_size) + # kv sparsity indices — only meaningful for MLA (hdim=64, hdim_v=512) + gather_kv_indices = gather_kv_indices_unpad = None + gather_kv_eff = None # effective KV length for this config + if gather_kv_length is not None and has_qv: + assert gather_kv_length <= seqlen, f"--gather_kv {gather_kv_length} > seqlen_kv {seqlen}" + gather_kv_indices = ( + torch.rand(batch_size, seqlen_q, gather_kv_length, device=device) + .argsort(dim=-1) + .to(torch.int32) + ) + if varlen: + gather_kv_indices_unpad = rearrange(gather_kv_indices, "b s t -> (b s) t") + gather_kv_eff = gather_kv_length + for causal in causal_vals: - cfg = (headdim, headdim_v, causal, seqlen, batch_size, nheads) + cfg = (headdim, headdim_v, causal, seqlen_q, seqlen, batch_size, nheads, nheads_kv, gather_kv_eff) # Build context dict shared by all backends ctx = dict( - q=q, k=k, v=v, g=g, causal=causal, + q=q, k=k, v=v, qv=qv, g=g, causal=causal, headdim=headdim, headdim_v=headdim_v, dtype=dtype, has_backward=has_backward, - varlen=varlen, q_unpad=q_unpad, k_unpad=k_unpad, v_unpad=v_unpad, g_unpad=g_unpad, + varlen=varlen, q_unpad=q_unpad, k_unpad=k_unpad, v_unpad=v_unpad, qv_unpad=qv_unpad, g_unpad=g_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, seqlen_q=seqlen_q, seqlen=seqlen, page_size=page_size, k_paged=k_paged, v_paged=v_paged, page_table=page_table, dropout_p=dropout_p, window_size=window_size, window_size_fa=window_size_fa, softcap=softcap, deterministic=deterministic, num_splits=num_splits, pack_gqa=pack_gqa, sinks=sinks, + gather_kv_indices=gather_kv_indices, gather_kv_indices_unpad=gather_kv_indices_unpad, ) for display_name, cli_name, setup_fn in active_backends: fwd_fn, bwd_fn = setup_fn(ctx) if fwd_fn is not None and has_forward: time.sleep(1.0) - print(f"Benchmarking {display_name} fwd, hdim={headdim}, seqlen={seqlen}, causal={causal}, {nheads=}, {nheads_kv=}") + gather_kv_str = f", gather_kv={gather_kv_eff}" if gather_kv_eff is not None else "" + print(f"Benchmarking {display_name} fwd, hdim={headdim}, seqlen_q={seqlen_q}, seqlen_kv={seqlen}{gather_kv_str}, causal={causal}, {nheads=}, {nheads_kv=}") ms = do_bench(fwd_fn, warmup=warmup, rep=rep) * 1e-3 time_f[cfg, display_name] = ms if bwd_fn is not None and has_backward: time.sleep(1.0) - print(f"Benchmarking {display_name} bwd, hdim={headdim}, seqlen={seqlen}, causal={causal}, {nheads=}, {nheads_kv=}, {deterministic=}") + gather_kv_str = f", gather_kv={gather_kv_eff}" if gather_kv_eff is not None else "" + print(f"Benchmarking {display_name} bwd, hdim={headdim}, seqlen_q={seqlen_q}, seqlen_kv={seqlen}{gather_kv_str}, causal={causal}, {nheads=}, {nheads_kv=}, {deterministic=}") ms = do_bench(bwd_fn, warmup=warmup, rep=rep) * 1e-3 time_b[cfg, display_name] = ms @@ -404,7 +491,13 @@ def main(): if not shown_backends: return - col_w = 20 if peak_flops is not None else 16 + col_w = 28 if peak_flops is not None else 16 + + # Determine whether any config has seqlen_q != seqlen_kv so we can show + # the extra column only when it's actually useful. + all_cfgs = sorted(set(k[0] for k in list(time_f) + list(time_b))) + show_seqlen_q_col = any(cfg[3] != cfg[4] for cfg in all_cfgs) # seqlen_q vs seqlen_kv + show_gather_kv_col = any(cfg[8] is not None for cfg in all_cfgs) for direction, times, flops_mult in [("FWD", time_f, 1.0), ("BWD", time_b, 2.5)]: if not times: @@ -413,8 +506,14 @@ def main(): if not configs: continue - col_label = "ms / TFLOPS / MFU%" if peak_flops is not None else "ms / TFLOPS" - header = f"{'hdim':>9} {'causal':>6} {'batch':>5} {'seqlen':>6}" + col_label = "ms / TFLOPS / MFU% / TBs / BW%" if peak_flops is not None else "ms / TFLOPS" + header = f"{'hdim':>9} {'causal':>6} {'batch':>5}" + if show_seqlen_q_col: + header += f" {'seqlen_q':>8} {'seqlen_kv':>9}" + else: + header += f" {'seqlen':>6}" + if show_gather_kv_col: + header += f" {'gather_kv':>6}" for b in shown_backends: header += f" {b:>{col_w}}" print(f"\n{'=' * len(header)}") @@ -424,18 +523,36 @@ def main(): print("-" * len(header)) for cfg in configs: - headdim, headdim_v, causal, seqlen, batch_size, nheads = cfg - nFLOPS = flops(batch_size, nheads, seqlen, seqlen, headdim, headdim_v, causal=causal) + headdim, headdim_v, causal, seqlen_q, seqlen, batch_size, nheads, nheads_kv, gather_kv_eff = cfg + has_qv = (headdim == 64 and headdim_v == 512) + seqlen_k_eff = gather_kv_eff if gather_kv_eff is not None else seqlen + nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen_k_eff, headdim, headdim_v, causal=causal, has_qv=has_qv) + dtype_bytes = 1 if dtype == torch.float8_e4m3fn else 2 + if direction == "FWD": + nbytes = bandwidth_fwd_bytes(batch_size, nheads, nheads_kv, seqlen_q, seqlen_k_eff, + headdim, headdim_v, dtype_bytes=dtype_bytes, has_qv=has_qv) + else: + nbytes = bandwidth_bwd_bytes(batch_size, nheads, nheads_kv, seqlen_q, seqlen_k_eff, + headdim, headdim_v, dtype_bytes=dtype_bytes) hdim_str = str(headdim) if headdim == headdim_v else f"{headdim}-{headdim_v}" - row = f"{hdim_str:>9} {str(causal):>6} {batch_size:>5} {seqlen:>6}" + row = f"{hdim_str:>9} {str(causal):>6} {batch_size:>5}" + if show_seqlen_q_col: + row += f" {seqlen_q:>8} {seqlen:>9}" + else: + row += f" {seqlen:>6}" + if show_gather_kv_col: + row += f" {str(gather_kv_eff) if gather_kv_eff is not None else '—':>6}" for b in shown_backends: t = times.get((cfg, b)) if t is not None: tflops = flops_mult * nFLOPS / t * 1e-12 + tbs = nbytes / t * 1e-12 # TB/s ms = t * 1e3 if peak_flops is not None: - mfu = flops_mult * nFLOPS / t / peak_flops * 100 - cell = f"{ms:.2f}/{tflops:.0f}/{mfu:.1f}%" + mfu = flops_mult * nFLOPS / t / peak_flops * 100 + bw_pct = (nbytes / t / peak_bw * 100) if peak_bw is not None else None + bw_str = f"/{bw_pct:.1f}%" if bw_pct is not None else "" + cell = f"{ms:.2f}/{tflops:.0f}/{mfu:.1f}%/{tbs:.2f}{bw_str}" else: cell = f"{ms:.2f}/{tflops:.0f}" row += f" {cell:>{col_w}}" @@ -445,4 +562,4 @@ def main(): if __name__ == '__main__': - main() + main() \ No newline at end of file diff --git a/flash_attn/cute/bench_utils.py b/flash_attn/cute/bench_utils.py index 45cbcf1af36..f6ad96d7c4f 100644 --- a/flash_attn/cute/bench_utils.py +++ b/flash_attn/cute/bench_utils.py @@ -13,7 +13,15 @@ def flops( - batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None) + batch, + nheads, + seqlen_q, + seqlen_k, + headdim, + headdim_v, + causal=False, + window_size=(None, None), + has_qv=False, ): if causal: avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 @@ -35,7 +43,37 @@ def flops( else torch.full_like(row_idx, seqlen_k - 1) ) avg_seqlen = (col_right - col_left + 1).float().mean().item() - return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) + eff_headdim = headdim + headdim_v if has_qv else headdim + return batch * nheads * 2 * seqlen_q * avg_seqlen * (eff_headdim + headdim_v) + + +# ── Bandwidth calculation ──────────────────────────────────────────────────── + + +def bandwidth_fwd_bytes( + batch, nheads, nheads_kv, seqlen_q, seqlen_k, headdim, headdim_v, dtype_bytes=2, has_qv=False +): + """HBM traffic for one attention pass: read Q,K,V + write O.""" + q = batch * nheads * seqlen_q * headdim + qv = batch * nheads * seqlen_q * headdim_v if has_qv else 0 + k = batch * nheads_kv * seqlen_k * headdim + v = batch * nheads_kv * seqlen_k * headdim_v + o = batch * nheads * seqlen_q * headdim_v + return (q + qv + k + v + o) * dtype_bytes + + +def bandwidth_bwd_bytes( + batch, nheads, nheads_kv, seqlen_q, seqlen_k, headdim, headdim_v, dtype_bytes=2 +): + """HBM traffic for one attention pass: read Q,K,V,dO + write dQ,dK,dV.""" + q = batch * nheads * seqlen_q * headdim + k = batch * nheads_kv * seqlen_k * headdim + v = batch * nheads_kv * seqlen_k * headdim_v + do = batch * nheads * seqlen_q * headdim_v + dq = q + dk = k + dv = v + return (q + k + v + do + dq + dk + dv) * dtype_bytes # ── Reference attention ───────────────────────────────────────────────────── diff --git a/flash_attn/cute/cute_dsl_utils.py b/flash_attn/cute/cute_dsl_utils.py index a6bae4179bb..636f7de4de5 100644 --- a/flash_attn/cute/cute_dsl_utils.py +++ b/flash_attn/cute/cute_dsl_utils.py @@ -104,3 +104,38 @@ def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: patterns are not interchangeable. """ return tuple(s == 0 for s in tensor.stride()) + + +# credit: monellz (https://github.com/NVIDIA/cutlass/issues/2658#issuecomment-3630564264) +def dump_kernel_attributes(compiled_kernel): + from cuda.bindings import driver + from cutlass.utils import HardwareInfo + import torch + + device_id = torch.cuda.current_device() + hardware_info = HardwareInfo(device_id=device_id) + cubin_data = compiled_kernel.artifacts.CUBIN + assert cubin_data is not None, "cubin_data is None, need '--keep-cubin' option when compiling" + cuda_library = hardware_info._checkCudaErrors( + driver.cuLibraryLoadData(cubin_data, None, None, 0, None, None, 0) + ) + kernels = hardware_info._checkCudaErrors(driver.cuLibraryEnumerateKernels(1, cuda_library)) + kernel = hardware_info._checkCudaErrors(driver.cuKernelGetFunction(kernels[0])) + # more metrics: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g5e92a1b0d8d1b82cb00dcfb2de15961b + local_size_bytes = hardware_info._checkCudaErrors( + driver.cuFuncGetAttribute( + driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, + kernel, + ) + ) + num_regs = hardware_info._checkCudaErrors( + driver.cuFuncGetAttribute( + driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NUM_REGS, + kernel, + ) + ) + + print("--- Kernel Info ---") + print(f"local_size_bytes: {local_size_bytes}") + print(f"num_regs: {num_regs}") + print("--- End Kernel Info ---") diff --git a/flash_attn/cute/flash_fwd_mla_sm100.py b/flash_attn/cute/flash_fwd_mla_sm100.py new file mode 100644 index 00000000000..8ab546326ec --- /dev/null +++ b/flash_attn/cute/flash_fwd_mla_sm100.py @@ -0,0 +1,3215 @@ +import math +import time +from functools import partial +from typing import Callable, Optional + +import torch +import torch.utils.benchmark as benchmark + +import cuda.bindings.driver as cuda + +import cutlass +from cutlass.base_dsl.arch import Arch +from cutlass.cutlass_dsl import BaseDSL +import cutlass.cute as cute +from cutlass import Float32, BFloat16, Int64, Int32, Uint32, Boolean, const_expr +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.runtime import from_dlpack +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.utils import ClcDynamicPersistentTileScheduler + +from quack import copy_utils, layout_utils + +from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_attn.cute import utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.mask import AttentionMask +import flash_attn.cute.blackwell_helpers as fa_sm100_utils +from flash_attn.cute.softmax import SoftmaxSm100 +from flash_attn.cute.tile_scheduler import ( + ClcState, + SchedulingMode, + TileSchedulerArguments, + TileSchedulerProtocol, + SingleTileScheduler, + StaticPersistentTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + ParamsBase, +) +from flash_attn.cute.fa_logging import fa_log, fa_printf +from flash_attn.cute.utils import smid + +from flash_attn.cute.topk_gather_kv import CpasyncGatherKVManager + +from flash_attn.cute.testing import attention_ref + +from flash_attn.cute.named_barrier import NamedBarrierFwdSm100_MLA2CTA + +from flash_attn.cute.cute_dsl_utils import dump_kernel_attributes + +class FlashAttentionMLAForwardSm100: + def __init__( + self, + is_causal: bool = False, + use_cpasync_load_KV: bool = False, + topk_length: int = 2048, + is_topk_gather: bool = True, + pack_gqa: bool = False, + qhead_per_kvhead: int = 1, + nheads_kv: int = 1, + hdim: int = 64, + hdimv: int = 512, + is_varlen_q: bool = False, + disable_bitmask: bool = False, + use_clc_scheduler: bool = True, + ): + self.is_causal = is_causal + self.is_local = False + self.pack_gqa = pack_gqa + self.qhead_per_kvhead = qhead_per_kvhead + self.nheads_kv = nheads_kv + self.is_varlen_q = is_varlen_q + self.use_tma_O = True + self.use_cpasync_load_KV = use_cpasync_load_KV + self.use_tma_KV = not use_cpasync_load_KV + self.topk_length = topk_length + self.is_topk_gather = is_topk_gather + if is_topk_gather: + assert pack_gqa + assert qhead_per_kvhead == 128, "require MQA 128 for DSA path" + assert use_cpasync_load_KV + # user-provided option if topk indices guaranteed in bounds + self.disable_bitmask = disable_bitmask + + # ==== tile scheduler ==== + self.is_persistent = False + self.use_clc_scheduler = use_clc_scheduler and not is_varlen_q + self.sched_stages = 1 + self.scheduling_mode = SchedulingMode.CLC if self.use_clc_scheduler else SchedulingMode.STATIC + + if const_expr(is_varlen_q): + self.TileScheduler = SingleTileVarlenScheduler + elif self.use_clc_scheduler: + self.TileScheduler = SingleTileLPTScheduler + else: + self.TileScheduler = SingleTileScheduler + + fa_log(1, f"TileScheduler={self.TileScheduler.__name__}, scheduling_mode={self.scheduling_mode.name}") + + # ==== thread info ==== + self.num_softmax_threads = 128 + self.num_epilogue_threads = 128 + self.num_load_threads = 32 + self.num_mma_threads = 32 + self.num_empty_threads = 32 if use_cpasync_load_KV else 64 + self.num_relay_threads = 32 if use_cpasync_load_KV else 0 + self.num_cpasync_load_threads = 128 if use_cpasync_load_KV else 0 + self.num_threads = ( + self.num_softmax_threads + + self.num_epilogue_threads + + self.num_load_threads + + self.num_mma_threads + + self.num_empty_threads + + self.num_relay_threads + + self.num_cpasync_load_threads + ) + self.num_warps = self.num_threads // 32 + assert self.num_warps == 12 or self.num_warps == 16 + self.softmax_warp_indices = (0, 1, 2, 3,) + self.epilogue_warp_indices = (4, 5, 6, 7,) + self.load_warp_id = 8 + self.mma_warp_id = 9 + self.clc_scheduler_warp_id = 10 + self.relay_warp_id = 11 + self.empty_warp_ids = tuple( + w for w, active in [ + (self.relay_warp_id, not use_cpasync_load_KV), + (self.clc_scheduler_warp_id, not self.use_clc_scheduler), + ] if active + ) + self.cpasync_load_warp_indices = (12, 13, 14, 15,) + + # ==== register usage ==== + if self.num_warps == 16: + self.num_regs_load = 80 + self.num_regs_mma = 80 + self.num_regs_softmax = 208 + self.num_regs_epilogue = 128 + self.num_regs_cpasync = 96 if self.use_cpasync_load_KV else 0 + self.num_regs_other = 48 + else: + self.num_regs_load = 168 - 40 + self.num_regs_mma = 168 - 40 + self.num_regs_softmax = 168 + 80 + self.num_regs_epilogue = 168 - 40 + self.num_regs_cpasync = 0 + self.num_regs_other = 48 + + self.num_regs_per_thread = 168 if self.num_warps == 12 else 128 + self.num_regs_total = 504 if self.num_warps == 12 else 512 + + assert self.num_regs_mma + self.num_regs_softmax + self.num_regs_epilogue + self.num_regs_cpasync <= self.num_regs_total + + # ==== 2cta info ==== + self.use_2cta_instrs = True + self.cta_group = tcgen05.CtaGroup.TWO + self.cta_group_size = 2 + self.cluster_shape_mn = (2, 1,) + self.cluster_shape_mnk = (2, 1, 1,) + + # ==== problem shape info ==== + self.hdim = hdim + self.hdimv = hdimv + self.cta_tile_m = 64 + self.cluster_tile_m = self.cta_group_size * self.cta_tile_m + self.tile_n = 128 + assert ( + pack_gqa is False + or self.cluster_tile_m % qhead_per_kvhead == 0 + or qhead_per_kvhead % self.cluster_tile_m == 0 + ) + self.num_hdimv_splits = 2 # split hdimv in half for our Qv @ V^T and P @ V mmas. + assert hdimv % 32 == 0 + assert self.topk_length % (self.tile_n * 2) == 0 or not self.is_topk_gather + self.epi_tile = (self.cta_tile_m, self.hdimv//self.num_hdimv_splits) + + # ==== MMA info ==== + self.mma_tiler_QK = (self.cluster_tile_m, self.tile_n, self.hdim,) + self.mma_tiler_QviVi = (self.cluster_tile_m, self.tile_n, self.hdimv//self.num_hdimv_splits,) + self.mma_tiler_PVti = (self.cluster_tile_m, self.hdimv // self.num_hdimv_splits, self.tile_n,) + self.major_mode_Q = tcgen05.OperandMajorMode.K + self.major_mode_Qvi = tcgen05.OperandMajorMode.K + self.major_mode_K = tcgen05.OperandMajorMode.K + self.major_mode_Vi = tcgen05.OperandMajorMode.K + self.major_mode_Vti = tcgen05.OperandMajorMode.MN + self.major_mode_P = tcgen05.OperandMajorMode.K + self.operand_source_Q = tcgen05.OperandSource.SMEM + self.operand_source_Qvi = tcgen05.OperandSource.SMEM + self.operand_source_P = tcgen05.OperandSource.SMEM + + # ==== pipeline info ==== + self.num_stages_Q = 1 + self.num_stages_K = 1 + self.num_stages_Qvi = 1 + self.num_stages_Vi = 2 + self.num_stages_S = 2 + self.num_stages_P = 1 + self.num_stages_Oi = 1 + self.num_stages_sm_stats = 2 + self.num_stages_bitmask = 4 + assert self.num_stages_S == 2, "mainloops expect 2 stages for S" + + # ==== dtype info ==== + self.dtype_acc = Float32 + + # ==== TMEM info ==== + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + self.tmem_cols_S = self.tile_n // self.cta_group_size + self.tmem_cols_Oi = (self.hdimv // self.num_hdimv_splits) // self.cta_group_size + self.tmem_offset_S = [self.tmem_cols_S*stage for stage in range(self.num_stages_S)] # allocate 64 TMEM columns for each stage of S + self.tmem_offset_O0 = self.tmem_cols_S * self.num_stages_S + self.tmem_offset_O1 = self.tmem_offset_O0 + self.tmem_cols_Oi + self.tmem_offsets_O = [self.tmem_offset_O0, self.tmem_offset_O1] + self.total_tmem = self.tmem_offset_O1 + self.tmem_cols_Oi + assert self.total_tmem <= self.tmem_alloc_cols, f"Total TMEM columns allocated {self.total_tmem} exceeds capacity {self.tmem_alloc_cols}" + + + def _get_shared_storage_cls(self): + self.buffer_align_bytes = 1024 + + def smem_struct_align(dtype, staged_layout): + return cute.struct.Align[ + cute.struct.MemRange[dtype, cute.cosize(staged_layout)], + self.buffer_align_bytes, + ] + + def mbar_struct(num_stages): + return cute.struct.MemRange[Int64, 2 * num_stages] + + (sQ_struct, sK_struct, sQv0_struct, sQv1_struct, sV0_struct, sV1_struct, sP_struct) = ( + smem_struct_align(dtype, layout) for dtype, layout in [ + (self.dtype_Q, self.sQ_layout_staged), + (self.dtype_K, self.sK_layout_staged), + (self.dtype_Qv, self.sQvi_layout_staged), + (self.dtype_Qv, self.sQvi_layout_staged), + (self.dtype_V, self.sVi_layout_staged), + (self.dtype_V, self.sVi_layout_staged), + (self.dtype_P, self.sP_layout_staged), + ] + ) + sStats_struct = cute.struct.MemRange[Float32, cute.cosize(self.sStats_layout)] + sScale_struct = cute.struct.MemRange[Float32, cute.cosize(self.sScale_layout)] + sBitmask_struct = cute.struct.MemRange[Uint32, cute.cosize(self.sBitmask_layout)] + + (mbar_ptr_Q_struct, mbar_ptr_K_struct, mbar_ptr_Qv0_struct, mbar_ptr_Qv1_struct, + mbar_ptr_V0_struct, mbar_ptr_V1_struct, mbar_ptr_S_struct, mbar_ptr_P_struct, + mbar_ptr_O0_struct, mbar_ptr_O1_struct, mbar_sm_stats_struct, + mbar_bitmask_struct) = ( + mbar_struct(n) for n in [ + self.num_stages_Q, self.num_stages_K, + self.num_stages_Qvi, self.num_stages_Qvi, + self.num_stages_Vi, self.num_stages_Vi, + self.num_stages_S, self.num_stages_P, + self.num_stages_Oi, self.num_stages_Oi, + self.num_stages_sm_stats, + self.num_stages_bitmask, + ] + ) + mbar_ptr_tmem_dealloc_struct = Int64 + tmem_holding_buf_struct = Int32 + + self.sched_stages = 1 + clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 + clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + + @cute.struct + class SharedStorage: + mbar_ptr_Q: mbar_ptr_Q_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_Qv0: mbar_ptr_Qv0_struct + mbar_ptr_Qv1: mbar_ptr_Qv1_struct + mbar_ptr_V0: mbar_ptr_V0_struct + mbar_ptr_V1: mbar_ptr_V1_struct + mbar_ptr_S: mbar_ptr_S_struct + mbar_ptr_P: mbar_ptr_P_struct + mbar_ptr_O0: mbar_ptr_O0_struct + mbar_ptr_O1: mbar_ptr_O1_struct + mbar_ptr_K_cpasync: mbar_ptr_K_struct + mbar_ptr_V0_cpasync: mbar_ptr_V0_struct + mbar_ptr_V1_cpasync: mbar_ptr_V1_struct + mbar_ptr_sm_stats: mbar_sm_stats_struct + mbar_ptr_bitmask: mbar_bitmask_struct + mbar_ptr_tmem_dealloc: mbar_ptr_tmem_dealloc_struct + tmem_holding_buf: tmem_holding_buf_struct + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] + clc_response: cute.struct.MemRange[Int32, clc_response_size] + sO_empty_mbar_ptr: cutlass.Int64 + + sRowMax: sStats_struct + sRowSum: sStats_struct + sScale: sScale_struct + sBitmask: sBitmask_struct + sQv0: sQv0_struct + sQv1: sQv1_struct + sQ: sQ_struct + sK: sK_struct + sV0: sV0_struct + sV1: sV1_struct + sP: sP_struct + + # print("smem bytes = ", SharedStorage.size_in_bytes()) + + return SharedStorage + + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mQv: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + mLSE: Optional[cute.Tensor], # (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + softmax_scale: Float32, + mCuSeqlensQ: Optional[cute.Tensor] = None, # (b + 1) + mCuSeqlensK: Optional[cute.Tensor] = None, # (b + 1) + mSeqUsedQ: Optional[cute.Tensor] = None, # (b) + mSeqUsedK: Optional[cute.Tensor] = None, # (b) + mIndexTopk: Optional[cute.Tensor] = None, # (b, s_q, topk) or (total_q, topk) if there is cu_seqlens_q + mPageTable: Optional[cute.Tensor] = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, + ): + # ==== asserts for unimplemented features ==== + assert mPageTable is None, "page table tbd for MLA" + + # ==== dtype info ==== + self.dtype_Q = mQ.element_type + self.dtype_K = mK.element_type + self.dtype_Qv = mQv.element_type + self.dtype_V = mV.element_type + self.dtype_P = mV.element_type + self.dtype_O = mO.element_type + + # ==== Prepare Tensors ==== + new_stride = lambda mX: ( + *(cute.assume(s, divby=128 // mX.element_type.width) for s in mX.stride[:-1]), + mX.stride[-1], + ) + mQ, mQv, mK, mV, mO = [ + cute.make_tensor(mX.iterator, cute.make_layout(mX.shape, stride=new_stride(mX))) + for mX in (mQ, mQv, mK, mV, mO) + ] + + # (b, s, h, d) -> (s, d, h, b) or + # (total, h, d) -> (total, d, h) or + # (num_pages, page_size, h_k, d) -> (page_size, d, h_k, num_pages) + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mQ, mQv, mO = [ + cute.make_tensor(mX.iterator, cute.select(mX.layout, mode=QO_layout_transpose)) + for mX in (mQ, mQv, mO) + ] + mK, mV = [ + cute.make_tensor(mX.iterator, cute.select(mX.layout, mode=KV_layout_transpose)) + for mX in (mK, mV) + ] + # (s_k, dv, h_k, b) -> (dv, s_k, h_k, b) or + # (total_k, dv, h_k) -> (dv, total_k, h_k) + V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] + mVt = cute.make_tensor( + mV.iterator, cute.select(mV.layout, mode=V_layout_transpose) + ) + # (b, h, s_q) -> (s_q, h, b) or (h, total_q) -> (total_q, h) + # (b, s_q, topk) -> (topk, s_q, b) or (total_q, topk) -> (topk, total_q) + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE, mIndexTopk = ( + cute.make_tensor(t.iterator, cute.select(t.layout, mode=LSE_layout_transpose)) + if t is not None else None for t in (mLSE, mIndexTopk) + ) + topk_length_dynamic = mIndexTopk.shape[0] if mIndexTopk is not None else None + + self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO) + + mO_og = mO + if const_expr(self.pack_gqa): + mQ, mQv, mO = [ + pack_gqa_layout(mX, self.qhead_per_kvhead, self.nheads_kv, head_idx=2) + for mX in (mQ, mQv, mO) + ] + if const_expr(mLSE is not None): + mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, self.nheads_kv, head_idx=1) + + def split_hdimv(m, dim: int): + """Re-tile mode `dim` of tensor `m` from hdimv into (hdimv//S, S), + and return (slice0, slice1) where slice_i selects chunk i.""" + S = self.num_hdimv_splits + chunk = self.hdimv // S + split_shape = (*m.shape[:dim], (chunk, S), *m.shape[dim+1:]) + split_stride = (*m.layout.stride[:dim], (1, chunk), *m.layout.stride[dim+1:]) + split = cute.make_tensor(m.iterator, cute.make_layout(split_shape, stride=split_stride)) + ndim = len(split.shape) + slices = [ + split[(*([None] * dim), (None, i), *([None] * (ndim - dim - 1)))] + for i in range(S) + ] + return slices + + # (seqlen_q, hdimv//2, nheads, batch) or (total_q, hdimv//2, nheads) + mQv0, mQv1 = split_hdimv(mQv, dim=1) + mV0, mV1 = split_hdimv(mV, dim=1) + # (hdimv//2, seqlen_k, nheads_k, batch) or (hdimv//2, total_k, nheads_k) + mVt0, mVt1 = split_hdimv(mVt, dim=0) + + # ==== Prepare MMAs ==== + # (local_var, dtype_a, major_a, major_b, mma_tiler, operand_source_a) + _mma_specs = [ + ("tiled_mma_QK", self.dtype_Q, self.major_mode_Q, self.major_mode_K, self.mma_tiler_QK, self.operand_source_Q), + ("tiled_mma_QviVi", self.dtype_Qv, self.major_mode_Qvi, self.major_mode_Vi, self.mma_tiler_QviVi, self.operand_source_Qvi), + ("tiled_mma_PVti", self.dtype_P, self.major_mode_P, self.major_mode_Vti, self.mma_tiler_PVti, self.operand_source_P), + ] + tiled_mma_QK, tiled_mma_QviVi, tiled_mma_PVti = ( + sm100_utils.make_trivial_tiled_mma( + dtype_a, major_a, major_b, self.dtype_acc, self.cta_group, mma_tiler[:2], operand_source_a, + ) + for _, dtype_a, major_a, major_b, mma_tiler, operand_source_a in _mma_specs + ) + + # ==== Prepare SMEM layouts and TMAs ==== + # (attr, make_fn, tiled_mma, mma_tiler, dtype, num_stages) + _smem_layout_specs = [ + ("sQ_layout", sm100_utils.make_smem_layout_a, tiled_mma_QK, self.mma_tiler_QK, self.dtype_Q, self.num_stages_Q), + ("sK_layout", sm100_utils.make_smem_layout_b, tiled_mma_QK, self.mma_tiler_QK, self.dtype_K, self.num_stages_K), + ("sQvi_layout", sm100_utils.make_smem_layout_a, tiled_mma_QviVi, self.mma_tiler_QviVi, self.dtype_Qv, self.num_stages_Qvi), + ("sVi_layout", sm100_utils.make_smem_layout_b, tiled_mma_QviVi, self.mma_tiler_QviVi, self.dtype_V, self.num_stages_Vi), + ("sVti_layout", sm100_utils.make_smem_layout_b, tiled_mma_PVti, self.mma_tiler_PVti, self.dtype_V, self.num_stages_Vi), + ("sP_layout", sm100_utils.make_smem_layout_a, tiled_mma_PVti, self.mma_tiler_PVti, self.dtype_P, self.num_stages_P), + ] + for attr, make_fn, tiled_mma, mma_tiler, dtype, num_stages in _smem_layout_specs: + ab_kwarg = "a_dtype" if make_fn is sm100_utils.make_smem_layout_a else "b_dtype" + staged = make_fn( + tiled_mma=tiled_mma, + mma_tiler_mnk=mma_tiler, + num_stages=num_stages, + **{ab_kwarg: dtype}, + ) + setattr(self, f"{attr}_staged", staged) + setattr(self, attr, cute.select(staged, mode=[0, 1, 2])) + + self.sStats_layout = cute.make_layout((self.cta_tile_m, self.cta_group_size)) + self.sScale_layout = cute.make_layout((self.cta_tile_m, self.num_stages_sm_stats)) + self.sBitmask_layout = cute.make_layout((self.tile_n//32, self.num_stages_bitmask)) + + for attr, dtype, layout in [ + ("tma_copy_bytes_Q", self.dtype_Q, self.sQ_layout), + ("tma_copy_bytes_K", self.dtype_K, self.sK_layout), + ("tma_copy_bytes_Qvi", self.dtype_Qv, self.sQvi_layout), + ("tma_copy_bytes_Vi", self.dtype_V, self.sVi_layout), + ]: + setattr(self, attr, cute.size_in_bytes(dtype, layout) * self.cta_group_size) + + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(self.cta_group) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_QK.thr_id.shape,) + ) + cta_shape = cta_layout_vmnk.shape + + def make_tma(make_fn, mX, smem_layout, mma_tiler, tiled_mma): + return make_fn(tma_load_op, mX, smem_layout, mma_tiler, tiled_mma, cta_shape) + + A, B = cute.nvgpu.make_tiled_tma_atom_A, cute.nvgpu.make_tiled_tma_atom_B + + # (atom_name, tensor_name, make_fn, m, smem_layout, mma_tiler, tiled_mma, kv_only) + _tma_specs = [ + ("tma_atom_Q", "tma_tensor_Q", A, mQ, self.sQ_layout, self.mma_tiler_QK, tiled_mma_QK, False), + ("tma_atom_Qv0", "tma_tensor_Qv0", A, mQv0, self.sQvi_layout, self.mma_tiler_QviVi, tiled_mma_QviVi, False), + ("tma_atom_Qv1", "tma_tensor_Qv1", A, mQv1, self.sQvi_layout, self.mma_tiler_QviVi, tiled_mma_QviVi, False), + ("tma_atom_K", "tma_tensor_K", B, mK, self.sK_layout, self.mma_tiler_QK, tiled_mma_QK, True), + ("tma_atom_V0", "tma_tensor_V0", B, mV0, self.sVi_layout, self.mma_tiler_QviVi, tiled_mma_QviVi, True), + ("tma_atom_V1", "tma_tensor_V1", B, mV1, self.sVi_layout, self.mma_tiler_QviVi, tiled_mma_QviVi, True), + ("tma_atom_Vt0", "tma_tensor_Vt0", B, mVt0, self.sVti_layout, self.mma_tiler_PVti, tiled_mma_PVti, True), + ("tma_atom_Vt1", "tma_tensor_Vt1", B, mVt1, self.sVti_layout, self.mma_tiler_PVti, tiled_mma_PVti, True), + ] + _tmas = {} + for atom_name, tensor_name, make_fn, m, smem_layout, mma_tiler, tiled_mma, kv_only in _tma_specs: + _tmas[atom_name], _tmas[tensor_name] = ( + make_tma(make_fn, m, smem_layout, mma_tiler, tiled_mma) + if const_expr(not kv_only or self.use_tma_KV) + else (None, None) + ) + + (tma_atom_Q, tma_tensor_Q, + tma_atom_Qv0, tma_tensor_Qv0, + tma_atom_Qv1, tma_tensor_Qv1, + tma_atom_K, tma_tensor_K, + tma_atom_V0, tma_tensor_V0, + tma_atom_V1, tma_tensor_V1, + tma_atom_Vt0, tma_tensor_Vt0, + tma_atom_Vt1, tma_tensor_Vt1) = _tmas.values() + + # ==== Set up Oi smem -> gmem tma store ==== + + self.overlap_sO_sV = True + if const_expr(self.overlap_sO_sV): + num_stages_sO = self.num_hdimv_splits * self.num_stages_Vi + else: + num_stages_sO = self.num_hdimv_splits + sO_layout = sm100_utils.make_smem_layout_epi( + self.dtype_O, self.o_layout, self.epi_tile, num_stages_sO + ) + self.ragged_tma_O = ( + self.use_tma_O + and self.is_varlen_q + and self.pack_gqa + and self.cta_tile_m % self.qhead_per_kvhead == 0 + ) + make_tiled_tma_atom_fn = ( + partial(make_packgqa_tiled_tma_atom, qhead_per_kvhead=self.qhead_per_kvhead, head_idx=2) + if const_expr(self.ragged_tma_O) + else cpasync.make_tiled_tma_atom + ) + if const_expr(self.use_tma_O): + mO_tma = mO_og if const_expr(self.ragged_tma_O) else mO + if const_expr(self.ragged_tma_O): + mO_tma = copy_utils.create_ragged_tensor_for_tma( + mO_tma, ragged_dim=0, ptr_shift=True + ) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + tma_atom_O, tma_tensor_O = make_tiled_tma_atom_fn( + tma_store_op, mO_tma, cute.select(sO_layout, mode=[0, 1]), self.epi_tile + ) + else: + tma_atom_O = None + tma_tensor_O = None + + # ==== Set up Oi rmem -> gmem copy ==== + universal_copy_bits = 128 + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype_O, + num_bits_per_copy=universal_copy_bits, + ) + thread_layout_O_r2g = cute.make_layout((64, 2), stride=(1, 64)) + value_layout_O_r2g = cute.make_layout( + (1, self.hdimv//self.num_hdimv_splits//self.cta_group_size) + ) + tiled_copy_O_r2g = cute.make_tiled_copy_tv( + atom=atom_universal_copy, + thr_layout=thread_layout_O_r2g, + val_layout=value_layout_O_r2g, + ) + + # ==== Allocate shared memory ==== + SharedStorage = self._get_shared_storage_cls() + + # ==== Tile scheduler ==== + + TileScheduler = self.TileScheduler + + tile_sched_args = TileSchedulerArguments( + num_block=cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tile_m), + num_head=cute.size(mQ.shape[2]), + num_batch=cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), + num_splits=1, # todo: split_kv + seqlen_k=cute.size(mK.shape[0]), # todo: page table + headdim=self.hdim, + headdim_v=self.hdimv, + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=(self.cta_tile_m, self.tile_n,), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.dtype_K.width // 8, + is_persistent=self.is_persistent, + # lpt=self.is_causal or self.is_local, + lpt=False, + is_split_kv=False, + cluster_shape_mn=self.cluster_shape_mn, + use_cluster_idx=False, + ) + tile_sched_params = TileScheduler.to_underlying_arguments( + tile_sched_args, scheduling_mode=self.scheduling_mode + ) + self.tile_scheduler_cls = TileScheduler + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + fa_printf(1, "grid = {}", grid_dim) + + # ==== Named Barrier ==== + self.cpasync_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.Cpasync), + num_threads=self.num_cpasync_load_threads, + ) + self.softmax_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.Softmax), + num_threads=self.num_softmax_threads, + ) + self.epi_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.Epilogue), + num_threads=self.num_epilogue_threads, + ) + # softmax -> correction + self.sm_stats_barrier_full = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.SoftmaxStatsFull), + num_threads=self.num_softmax_threads + self.num_epilogue_threads, + ) + self.sm_stats_barrier_empty = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.SoftmaxStatsEmpty), + num_threads=self.num_softmax_threads + self.num_epilogue_threads, + ) + + LOG2_E = math.log2(math.e) + softmax_scale_log2 = softmax_scale * LOG2_E + + # ==== Launch kernel ==== + self.kernel( + tma_tensor_Q, + tma_tensor_Qv0, + tma_tensor_Qv1, + tma_tensor_K if self.use_tma_KV else mK, + tma_tensor_V0 if self.use_tma_KV else mV0, + tma_tensor_V1 if self.use_tma_KV else mV1, + tma_tensor_Vt0 if self.use_tma_KV else mVt0, + tma_tensor_Vt1 if self.use_tma_KV else mVt1, + tma_tensor_O if self.use_tma_O else mO, + mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + mIndexTopk, + tma_atom_Q, + tma_atom_Qv0, + tma_atom_Qv1, + tma_atom_K, + tma_atom_V0, + tma_atom_V1, + tma_atom_Vt0, + tma_atom_Vt1, + tma_atom_O, + tiled_copy_O_r2g, + self.sQ_layout_staged, + self.sK_layout_staged, + self.sQvi_layout_staged, + self.sVi_layout_staged, + self.sVti_layout_staged, + self.sP_layout_staged, + self.sStats_layout, + self.sScale_layout, + self.sBitmask_layout, + sO_layout, + tiled_mma_QK, + tiled_mma_QviVi, + tiled_mma_PVti, + softmax_scale, + softmax_scale_log2, + topk_length_dynamic, + tile_sched_params, + SharedStorage, + ).launch( + grid=grid_dim, + block=(self.num_threads, 1, 1,), + cluster=self.cluster_shape_mnk, + smem = SharedStorage.size_in_bytes(), + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mQv0: cute.Tensor, + mQv1: cute.Tensor, + mK: cute.Tensor, + mV0: cute.Tensor, + mV1: cute.Tensor, + mVt0: cute.Tensor, + mVt1: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + mIndexTopk: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_Qv0: cute.CopyAtom, + tma_atom_Qv1: cute.CopyAtom, + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V0: Optional[cute.CopyAtom], + tma_atom_V1: Optional[cute.CopyAtom], + tma_atom_Vt0: Optional[cute.CopyAtom], + tma_atom_Vt1: Optional[cute.CopyAtom], + tma_atom_O: Optional[cute.CopyAtom], + tiled_copy_O_r2g: cute.TiledCopy, + sQ_layout_staged: cute.ComposedLayout, + sK_layout_staged: cute.ComposedLayout, + sQvi_layout_staged: cute.ComposedLayout, + sVi_layout_staged: cute.ComposedLayout, + sVti_layout_staged: cute.ComposedLayout, + sP_layout_staged: cute.ComposedLayout, + sStats_layout: cute.Layout, + sScale_layout: cute.Layout, + sBitmask_layout: cute.Layout, + sO_layout: cute.ComposedLayout, + tiled_mma_QK: cute.TiledMma, + tiled_mma_QviVi: cute.TiledMma, + tiled_mma_PVti: cute.TiledMma, + softmax_scale: Float32, + softmax_scale_log2: Float32, + topk_length_dynamic: Optional[Int32], + tile_sched_params: ParamsBase, + SharedStorage: cutlass.Constexpr[Callable], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_QK.thr_id.shape,) + ) + + cta_m_block, head_idx, batch_idx = cute.arch.block_idx() + cluster_m_block = cta_m_block // self.cta_group_size + mma_tile_coord_v = cta_m_block % cute.size(tiled_mma_QK.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + + # ==== Allocate SMEM ==== + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # ==== TMEM stuff ==== + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.TmemPtr), + num_threads=self.num_mma_threads + self.num_softmax_threads + self.num_epilogue_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + is_two_cta=self.use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.mbar_ptr_tmem_dealloc, + ) + + # ==== Prefetch TMA descriptors ==== + if warp_idx == self.load_warp_id: + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_Qv0) + cpasync.prefetch_descriptor(tma_atom_Qv1) + if const_expr(self.use_tma_KV): + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V0) + cpasync.prefetch_descriptor(tma_atom_V1) + cpasync.prefetch_descriptor(tma_atom_Vt0) + cpasync.prefetch_descriptor(tma_atom_Vt1) + if const_expr(self.use_tma_O): + cpasync.prefetch_descriptor(tma_atom_O) + + # ==== Construct pipelines ==== + tma_warp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + mma_warp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + sm_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_softmax_threads) + epi_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_epilogue_threads) + sm_threads_cluster = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_softmax_threads * self.cta_group_size) + epi_threads_cluster = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_epilogue_threads * self.cta_group_size) + + TmaUmma = pipeline.PipelineTmaUmma + AsyncUmma = pipeline.PipelineAsyncUmma + UmmaAsync = pipeline.PipelineUmmaAsync + Async = pipeline.PipelineAsync + + def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None): + return cls.create( + barrier_storage=mbar_ptr.data_ptr(), + num_stages=num_stages, + producer_group=producer, + consumer_group=consumer, + defer_sync=True, + **({"cta_layout_vmnk": cta_layout_vmnk} if cls is not Async else {}), + **({"tx_count": tx_count} if tx_count is not None else {}), + ) + + # Unconditional pipelines + pipeline_Q = make_pipeline(TmaUmma, storage.mbar_ptr_Q, self.num_stages_Q, tma_warp, mma_warp, self.tma_copy_bytes_Q) + pipeline_Qv0 = make_pipeline(TmaUmma, storage.mbar_ptr_Qv0, self.num_stages_Qvi, tma_warp, mma_warp, self.tma_copy_bytes_Qvi) + pipeline_Qv1 = make_pipeline(TmaUmma, storage.mbar_ptr_Qv1, self.num_stages_Qvi, tma_warp, mma_warp, self.tma_copy_bytes_Qvi) + pipeline_S = make_pipeline(UmmaAsync, storage.mbar_ptr_S, self.num_stages_S, mma_warp, sm_threads_cluster) + pipeline_P = make_pipeline(AsyncUmma, storage.mbar_ptr_P, self.num_stages_P, sm_threads_cluster, mma_warp) + pipeline_O0 = make_pipeline(UmmaAsync, storage.mbar_ptr_O0, self.num_stages_Oi, mma_warp, epi_threads_cluster) + pipeline_O1 = make_pipeline(UmmaAsync, storage.mbar_ptr_O1, self.num_stages_Oi, mma_warp, epi_threads_cluster) + pipeline_sm_stats = make_pipeline(Async, storage.mbar_ptr_sm_stats, self.num_stages_sm_stats, sm_threads, epi_threads) + + # K/V pipelines: type and producer depend on use_tma_KV + if const_expr(self.use_tma_KV): + pipeline_K = make_pipeline(TmaUmma, storage.mbar_ptr_K, self.num_stages_K, tma_warp, mma_warp, self.tma_copy_bytes_K) + pipeline_V0 = make_pipeline(TmaUmma, storage.mbar_ptr_V0, self.num_stages_Vi, tma_warp, mma_warp, self.tma_copy_bytes_Vi) + pipeline_V1 = make_pipeline(TmaUmma, storage.mbar_ptr_V1, self.num_stages_Vi, tma_warp, mma_warp, self.tma_copy_bytes_Vi) + pipeline_K_cpasync = pipeline_V0_cpasync = pipeline_V1_cpasync = pipeline_bitmask = None + else: + cpasync_load_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_cpasync_load_threads) + relay_warps_cluster = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.cta_group_size) + relay_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_relay_threads) + + pipeline_K = make_pipeline(AsyncUmma, storage.mbar_ptr_K, self.num_stages_K, relay_warps_cluster, mma_warp) + pipeline_V0 = make_pipeline(AsyncUmma, storage.mbar_ptr_V0, self.num_stages_Vi, relay_warps_cluster, mma_warp) + pipeline_V1 = make_pipeline(AsyncUmma, storage.mbar_ptr_V1, self.num_stages_Vi, relay_warps_cluster, mma_warp) + pipeline_K_cpasync = make_pipeline(Async, storage.mbar_ptr_K_cpasync, self.num_stages_K, cpasync_load_threads, relay_threads) + pipeline_V0_cpasync = make_pipeline(Async, storage.mbar_ptr_V0_cpasync, self.num_stages_Vi, cpasync_load_threads, relay_threads) + pipeline_V1_cpasync = make_pipeline(Async, storage.mbar_ptr_V1_cpasync, self.num_stages_Vi, cpasync_load_threads, relay_threads) + pipeline_bitmask = ( + make_pipeline(Async, storage.mbar_ptr_bitmask, self.num_stages_bitmask, cpasync_load_threads, sm_threads) + if const_expr(self.is_topk_gather and not self.disable_bitmask) else None + ) + + sO_empty_mbar_ptr = None + if const_expr(self.use_tma_O and self.overlap_sO_sV): + sO_empty_mbar_ptr = storage.sO_empty_mbar_ptr + if warp_idx == 0: + cute.arch.mbarrier_init(sO_empty_mbar_ptr, 1) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + + # ==== Get SMEM tensors ==== + sQ, sK, sQv0, sQv1, sV0, sV1, sVt0, sVt1, sP = ( + store.get_tensor(layout.outer, swizzle=layout.inner) + for store, layout in [ + (storage.sQ, sQ_layout_staged), + (storage.sK, sK_layout_staged), + (storage.sQv0, sQvi_layout_staged), + (storage.sQv1, sQvi_layout_staged), + (storage.sV0, sVi_layout_staged), + (storage.sV1, sVi_layout_staged), + (storage.sV0, sVti_layout_staged), # sVt0 reuses sV0 storage + (storage.sV1, sVti_layout_staged), # sVt1 reuses sV1 storage + (storage.sP, sP_layout_staged), + ] + ) + sRowMax = storage.sRowMax.get_tensor(sStats_layout) + sRowSum = storage.sRowSum.get_tensor(sStats_layout) + sScale = storage.sScale.get_tensor(sScale_layout) + sBitmask = None + if const_expr(self.is_topk_gather): + sBitmask = storage.sBitmask.get_tensor(sBitmask_layout) + + if const_expr(self.overlap_sO_sV): + sO_iterator = sV0.iterator + assert cute.cosize(sO_layout) <= cute.cosize(sVi_layout_staged) * self.num_hdimv_splits + else: + sO_iterator = sQv0.iterator + assert cute.cosize(sO_layout) <= cute.cosize(sQvi_layout_staged) * self.num_hdimv_splits + sO = cute.make_tensor(cute.recast_ptr(sO_iterator, sO_layout.inner, self.dtype_O), sO_layout.outer) + + # ==== Get thread MMAs and accumulator fragments ==== + thr_mma_QK = tiled_mma_QK.get_slice(mma_tile_coord_v) + thr_mma_QviVi = tiled_mma_QviVi.get_slice(mma_tile_coord_v) + thr_mma_PVti = tiled_mma_PVti.get_slice(mma_tile_coord_v) + + acc_shape_QK = thr_mma_QK.partition_shape_C(self.mma_tiler_QK[:2]) + tStS = thr_mma_QK.make_fragment_C(cute.append(acc_shape_QK, self.num_stages_S)) + + acc_shape_PVi = thr_mma_PVti.partition_shape_C(self.mma_tiler_PVti[:2]) + tO0tO0 = thr_mma_PVti.make_fragment_C(acc_shape_PVi) + tO1tO1 = thr_mma_PVti.make_fragment_C(acc_shape_PVi) + tO0tO0 = cute.make_tensor(tO0tO0.iterator + self.tmem_offset_O0, tO0tO0.layout) + tO1tO1 = cute.make_tensor(tO1tO1.iterator + self.tmem_offset_O1, tO1tO1.layout) + + block_info = BlockInfo( + self.cta_tile_m * self.cta_group_size, + self.tile_n, + is_causal=self.is_causal, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + seqlen_k_static=mK.shape[0], + tile_m=self.cta_tile_m, + tile_n=self.tile_n, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + ) + AttentionMaskCls = partial( + AttentionMask, + self.cta_tile_m * self.cta_group_size, + self.tile_n, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + + if const_expr(self.use_clc_scheduler): + clc_response_ptr = storage.clc_response.data_ptr() + clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() + + clc_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread + ) + num_clc_consumer_warps_per_cta = self.num_threads // cute.arch.WARP_SIZE + num_clc_consumer_warps = num_clc_consumer_warps_per_cta * self.cta_group_size + clc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps + ) + clc = ClcState.create( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_scheduler_cls.clc_problem_shape(tile_sched_params), + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + ), + pipeline=pipeline.PipelineClcFetchAsync.create( + barrier_storage=clc_mbar_ptr, + num_stages=self.sched_stages, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ), + consumer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.sched_stages + ), + producer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.sched_stages + ), + ) + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, clc=clc) + else: + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + assert isinstance(tile_scheduler, TileSchedulerProtocol), f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}" + + pipeline.pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + if const_expr(self.use_clc_scheduler): + if warp_idx == self.clc_scheduler_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_other) + if is_leader_cta: + self.clc_scheduler_warp(tile_scheduler) + else: + self.empty_warp(tile_scheduler) + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i] and warp_idx != self.clc_scheduler_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_other) + self.empty_warp(tile_scheduler) + else: + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i]: + cute.arch.setmaxregister_decrease(self.num_regs_other) + + if const_expr(self.use_cpasync_load_KV): + if warp_idx == self.relay_warp_id: + if const_expr(self.num_regs_load < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_load) + self.relay( + pipeline_K, + pipeline_V0, + pipeline_V1, + pipeline_K_cpasync, + pipeline_V0_cpasync, + pipeline_V1_cpasync, + sO_empty_mbar_ptr, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + + if warp_idx in self.cpasync_load_warp_indices: + if const_expr(self.num_regs_cpasync < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_cpasync) + self.load_cpasync( + mIndexTopk, + mK, + mV0, + mV1, + mVt0, + mVt1, + sK, + sV0, + sV1, + sVt0, + sVt1, + sBitmask, + pipeline_K, + pipeline_V0, + pipeline_V1, + pipeline_K_cpasync, + pipeline_V0_cpasync, + pipeline_V1_cpasync, + pipeline_bitmask, + sO_empty_mbar_ptr, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + + if warp_idx == self.load_warp_id: + if const_expr(self.num_regs_load < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_load) + self.load( + mQ, + mK, + mQv0, + mQv1, + mV0, + mV1, + mVt0, + mVt1, + sQ, + sK, + sQv0, + sQv1, + sV0, + sV1, + sVt0, + sVt1, + tma_atom_Q, + tma_atom_K, + tma_atom_Qv0, + tma_atom_Qv1, + tma_atom_V0, + tma_atom_V1, + tma_atom_Vt0, + tma_atom_Vt1, + pipeline_Q, + pipeline_K, + pipeline_Qv0, + pipeline_Qv1, + pipeline_V0, + pipeline_V1, + sO_empty_mbar_ptr, + thr_mma_QK, + thr_mma_QviVi, + thr_mma_PVti, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + + if warp_idx == self.mma_warp_id: + if const_expr(self.num_regs_mma < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_mma) + # ==== Allocate TMEM ==== + tmem.allocate(self.tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.dtype_acc) + self.mma( + sQ, + sK, + sQv0, + sQv1, + sV0, + sV1, + sVt0, + sVt1, + sP, + tiled_mma_QK, + tiled_mma_QviVi, + tiled_mma_PVti, + pipeline_Q, + pipeline_K, + pipeline_Qv0, + pipeline_Qv1, + pipeline_V0, + pipeline_V1, + pipeline_S, + pipeline_P, + pipeline_O0, + pipeline_O1, + sO_empty_mbar_ptr, + is_leader_cta, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + + if warp_idx in self.softmax_warp_indices: + cute.arch.setmaxregister_increase(self.num_regs_softmax) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.dtype_acc) + self.softmax_loop( + softmax_scale, + softmax_scale_log2, + mLSE, + sRowMax, + sRowSum, + sScale, + sBitmask, + sP, + tStS, + thr_mma_QK, + pipeline_S, + pipeline_P, + pipeline_sm_stats, + pipeline_bitmask, + sO_empty_mbar_ptr, + AttentionMaskCls, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + tmem_alloc_barrier.arrive() + + if warp_idx in self.epilogue_warp_indices: + if const_expr(self.num_regs_epilogue < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_epilogue) + elif const_expr(self.num_regs_epilogue > self.num_regs_per_thread): + cute.arch.setmaxregister_increase(self.num_regs_epilogue) + + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.dtype_acc) + self.correction_loop( + softmax_scale_log2, + mO, + mLSE, + tma_atom_O, + sRowMax, + sRowSum, + sScale, + sO, + tO0tO0, + tO1tO1, + pipeline_O0, + pipeline_O1, + pipeline_sm_stats, + sO_empty_mbar_ptr, + tiled_copy_O_r2g, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + tmem_alloc_barrier.arrive() + + @cute.jit + def clc_scheduler_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + tile_scheduler.prefetch_next_work() + work_tile = tile_scheduler.advance_to_next_work() + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE and cute.arch.block_idx() == (0, 0, 0): + fa_printf( + 3, + "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", + smid(), + cute.arch.block_idx()[0], + work_tile.tile_idx[0], + work_tile.tile_idx[1], + work_tile.tile_idx[2], + work_tile.tile_idx[3], + work_tile.is_valid_tile, + ) + tile_scheduler.producer_tail() + + @cute.jit + def empty_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_tile = tile_scheduler.advance_to_next_work() + + @cute.jit + def relay( + self, + pipeline_K: pipeline.PipelineAsyncUmma, + pipeline_V0: pipeline.PipelineAsyncUmma, + pipeline_V1: pipeline.PipelineAsyncUmma, + pipeline_K_cpasync: pipeline.PipelineAsync, + pipeline_V0_cpasync: pipeline.PipelineAsync, + pipeline_V1_cpasync: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== Make pipeline states ==== + # pipeline_{K,V0,V1} producer + # pipeline_{K,V0,V1}_cpasync consumer + producer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_K + ) + producer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + producer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + consumer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_K + ) + consumer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Vi + ) + consumer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Vi + ) + relay_K_fn = partial(self.relay_inner, pipeline_K_cpasync, pipeline_K) + relay_V0_fn = partial(self.relay_inner, pipeline_V0_cpasync, pipeline_V0) + relay_V1_fn = partial(self.relay_inner, pipeline_V1_cpasync, pipeline_V1) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + + # ==== Prologue ==== + # relay K, V0, V1 + consumer_state_K, producer_state_K = relay_K_fn(consumer_state_K, producer_state_K) + consumer_state_V0, producer_state_V0 = relay_V0_fn(consumer_state_V0, producer_state_V0) + consumer_state_V1, producer_state_V1 = relay_V1_fn(consumer_state_V1, producer_state_V1) + + # ==== Mainloop ==== + for _ in cutlass.range(num_n_blocks-1, unroll=2): + # relay K, V0, V1, Vt0, Vt1 + consumer_state_K, producer_state_K = relay_K_fn(consumer_state_K, producer_state_K) + for _ in cutlass.range_constexpr(2): + consumer_state_V0, producer_state_V0 = relay_V0_fn(consumer_state_V0, producer_state_V0) + consumer_state_V1, producer_state_V1 = relay_V1_fn(consumer_state_V1, producer_state_V1) + + # ==== Epilogue === + # relay Vt0, Vt1 + consumer_state_V0, producer_state_V0 = relay_V0_fn(consumer_state_V0, producer_state_V0) + consumer_state_V1, producer_state_V1 = relay_V1_fn(consumer_state_V1, producer_state_V1) + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + + pipeline_K.producer_tail(producer_state_K) + pipeline_V0.producer_tail(producer_state_V0) + pipeline_V1.producer_tail(producer_state_V1) + + + @cute.jit + def relay_inner( + self, + pipeline_cpasync: pipeline.PipelineAsync, + pipeline_mma: pipeline.PipelineAsyncUmma, + consumer_state: pipeline.PipelineState, + producer_state: pipeline.PipelineState, + ): + pipeline_cpasync.consumer_wait(consumer_state) + with cute.arch.elect_one(): + pipeline_mma.producer_commit(producer_state) + consumer_state.advance() + producer_state.advance() + return consumer_state, producer_state + + + @cute.jit + def load_cpasync( + self, + mIndexTopk: cute.Tensor, + mK: cute.Tensor, + mV0: cute.Tensor, + mV1: cute.Tensor, + mVt0: cute.Tensor, + mVt1: cute.Tensor, + sK: cute.Tensor, + sV0: cute.Tensor, + sV1: cute.Tensor, + sVt0: cute.Tensor, + sVt1: cute.Tensor, + sBitmask: Optional[cute.Tensor], + pipeline_K: pipeline.PipelineAsyncUmma, + pipeline_V0: pipeline.PipelineAsyncUmma, + pipeline_V1: pipeline.PipelineAsyncUmma, + pipeline_K_cpasync: pipeline.PipelineAsync, + pipeline_V0_cpasync: pipeline.PipelineAsync, + pipeline_V1_cpasync: pipeline.PipelineAsync, + pipeline_bitmask: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== cpasync load warpgroup ==== + # Description: loads tiles of K, V, V0, V1 from gmem to smem using cpasync + # produces: K, V, V0, V1, bitmask + # consumes: - + + # TODO: use cpasync for non-topk paged attn + assert sBitmask is not None, "cpasync load meant to be used with topk gather" + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + tidx = cute.arch.thread_idx()[0] % self.num_cpasync_load_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % (self.num_cpasync_load_threads//32) + + # ==== Make pipeline states ==== + # producer: acquire PipelineAsyncUmma <- mma + # producer: commit PipelineAsync -> relay + producer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_K + ) + producer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + producer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + if const_expr(not self.disable_bitmask): + producer_state_bitmask = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_bitmask, + ) + if const_expr(self.use_tma_O): + producer_phase_O = Int32(1) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + num_n_block_groups = cute.ceil_div(num_n_blocks, self.num_stages_S) + + # cluster_m_block == m_idx under MQA 128 assumption + m_idx = cluster_m_block + if const_expr(not seqlen.has_cu_seqlens_q): + mIndexTopk_cur = mIndexTopk[None, m_idx, batch_idx] + else: + offset_q = seqlen.offset_q + mIndexTopk_cur = mIndexTopk[None, m_idx + offset_q] + + if const_expr(self.is_causal): + seqlen_k_limit = m_idx + 1 + seqlen.seqlen_k - seqlen.seqlen_q + else: + seqlen_k_limit = seqlen.seqlen_k + cpasync_gather_kv_manager = CpasyncGatherKVManager.create( + mIndexTopk_cur, + sBitmask, + cta_rank_in_cluster, + tidx, + warp_idx, + self.topk_length, + seqlen_k_limit, + self.tile_n, + self.hdim, + self.hdimv, + self.num_hdimv_splits, + self.num_cpasync_load_threads, + mK.element_type, + self.cta_group_size, + pipeline_bitmask, + self.num_stages_bitmask, + self.cpasync_barrier, + self.disable_bitmask, + ) + + # (seqlen_k, hdim) or (seqlen_k, hdimv//2) + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mV0_cur = seqlen.offset_batch_K(mV0, batch_idx, dim=3)[None, None, head_idx_kv] + mV1_cur = seqlen.offset_batch_K(mV1, batch_idx, dim=3)[None, None, head_idx_kv] + # (hdimv//2, seqlen_k) + if const_expr(not seqlen.has_cu_seqlens_k): + mVt0_cur = mVt0[None, None, head_idx_kv, batch_idx] + mVt1_cur = mVt1[None, None, head_idx_kv, batch_idx] + else: + mVt0_cur = cute.domain_offset((0, seqlen.offset_k), mVt0[None, None, head_idx_kv]) + mVt1_cur = cute.domain_offset((0, seqlen.offset_k), mVt1[None, None, head_idx_kv]) + # (hdimv//4, seqlen_k) + hdimv_split_per_cta = self.hdimv // self.num_hdimv_splits // self.cta_group_size + mVt0_cur = cute.tiled_divide(mVt0_cur, (hdimv_split_per_cta, ))[None, cta_rank_in_cluster, None] + mVt1_cur = cute.tiled_divide(mVt1_cur, (hdimv_split_per_cta, ))[None, cta_rank_in_cluster, None] + + load_K = partial(self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_K, pipeline_K_cpasync, sK, False, "K", mK_cur, + ) + load_V0 = partial(self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_V0, pipeline_V0_cpasync, sV0, False, "V", mV0_cur, + ) + load_V1 = partial(self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_V1, pipeline_V1_cpasync, sV1, False, "V", mV1_cur, + ) + load_Vt0 = partial(self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_V0, pipeline_V0_cpasync, sVt0, True, "V", mVt0_cur, + ) + load_Vt1 = partial(self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_V1, pipeline_V1_cpasync, sVt1, True, "V", mVt1_cur, + ) + + # gather KV path processes n_blocks in increasing order + n_block = 0 + + # ==== Prologue ==== + # K, V0, V1 + cpasync_gather_kv_manager.load_index_topk(n_block, transpose=False) + producer_state_K = load_K(producer_state_K) + producer_state_V0 = load_V0(producer_state_V0) + producer_state_V1 = load_V1(producer_state_V1) + if const_expr(not self.disable_bitmask): + producer_state_bitmask = cpasync_gather_kv_manager.compute_bitmask( + producer_state_bitmask + ) + + if const_expr(self.use_tma_O and self.overlap_sO_sV): + cute.arch.mbarrier_wait(sO_empty_mbar_ptr, phase=producer_phase_O) + producer_phase_O ^= 1 + + # ==== Mainloop ==== + for n_block_group in cutlass.range(num_n_block_groups-1, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + n_block = n_block_group * self.num_stages_S + stage + # K, V0, V1 + cpasync_gather_kv_manager.load_index_topk(n_block + 1, transpose=False) + producer_state_K = load_K(producer_state_K) + producer_state_V0 = load_V0(producer_state_V0) + producer_state_V1 = load_V1(producer_state_V1) + if const_expr(not self.disable_bitmask): + producer_state_bitmask = cpasync_gather_kv_manager.compute_bitmask( + producer_state_bitmask + ) + # Vt0, Vt1 + cpasync_gather_kv_manager.load_index_topk(n_block, transpose=True) + producer_state_V0 = load_Vt0(producer_state_V0) + producer_state_V1 = load_Vt1(producer_state_V1) + + # ==== Epilogue ==== + for stage in cutlass.range_constexpr(self.num_stages_S): + n_block = (num_n_block_groups-1) * self.num_stages_S + stage + if const_expr(stage == 0): + # K, V0, V1 + cpasync_gather_kv_manager.load_index_topk(n_block + 1, transpose=False) + producer_state_K = load_K(producer_state_K) + producer_state_V0 = load_V0(producer_state_V0) + producer_state_V1 = load_V1(producer_state_V1) + if const_expr(not self.disable_bitmask): + producer_state_bitmask = cpasync_gather_kv_manager.compute_bitmask( + producer_state_bitmask + ) + + # Vt0, Vt1 + cpasync_gather_kv_manager.load_index_topk(n_block, transpose=True) + producer_state_V0 = load_Vt0(producer_state_V0) + producer_state_V1 = load_Vt1(producer_state_V1) + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + + pipeline_K_cpasync.producer_tail(producer_state_K) + pipeline_V0_cpasync.producer_tail(producer_state_V0) + pipeline_V1_cpasync.producer_tail(producer_state_V1) + if const_expr(not self.disable_bitmask): + pipeline_bitmask.producer_tail(producer_state_bitmask) + + + @cute.jit + def cpasync_gather_load_KV( + self, + cpasync_gather_kv_manager: CpasyncGatherKVManager, + pipeline_mma: pipeline.PipelineAsyncUmma, + pipeline_cpasync: pipeline.PipelineAsync, + sX: cute.Tensor, + transpose: bool, + K_or_V: str, + mX: cute.Tensor, + producer_state: pipeline.PipelineState, + ): + stage, phase = producer_state.index, producer_state.phase + pipeline_mma.producer_acquire(producer_state) + cpasync_gather_kv_manager.load_X( + mX, sX[None, None, None, stage], transpose, K_or_V + ) + cute.arch.cp_async_commit_group() + pipeline_cpasync.sync_object_full.arrive_cp_async_mbarrier(stage) + producer_state.advance() + return producer_state + + + @cute.jit + def load( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mQv0: cute.Tensor, + mQv1: cute.Tensor, + mV0: cute.Tensor, + mV1: cute.Tensor, + mVt0: cute.Tensor, + mVt1: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sQv0: cute.Tensor, + sQv1: cute.Tensor, + sV0: cute.Tensor, + sV1: cute.Tensor, + sVt0: cute.Tensor, + sVt1: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_Qv0: cute.CopyAtom, + tma_atom_Qv1: cute.CopyAtom, + tma_atom_V0: cute.CopyAtom, + tma_atom_V1: cute.CopyAtom, + tma_atom_Vt0: cute.CopyAtom, + tma_atom_Vt1: cute.CopyAtom, + pipeline_Q: pipeline.PipelineAsync, + pipeline_K: pipeline.PipelineAsync, + pipeline_Qv0: pipeline.PipelineAsync, + pipeline_Qv1: pipeline.PipelineAsync, + pipeline_V0: pipeline.PipelineAsync, + pipeline_V1: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + thr_mma_QK: cute.ThrMma, + thr_mma_QviVi: cute.ThrMma, + thr_mma_PVti: cute.ThrMma, + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== Load warp ==== + # Description: loads tiles of Q, Qv, K, V, V0, V1 from gmem to smem using TMA + # produces: Q, Qv, K, V, V0, V1 + # consumes: - + + mQvs = [mQv0, mQv1] + mVs = [mV0, mV1] + mVts = [mVt0, mVt1] + + sQvs = [sQv0, sQv1] + sVs = [sV0, sV1] + sVts = [sVt0, sVt1] + + tma_atom_Qvs = [tma_atom_Qv0, tma_atom_Qv1] + tma_atom_Vs = [tma_atom_V0, tma_atom_V1] + tma_atom_Vts = [tma_atom_Vt0, tma_atom_Vt1] + + pipeline_Qvs = [pipeline_Qv0, pipeline_Qv1] + pipeline_Vs = [pipeline_V0, pipeline_V1] + + # ==== Make pipeline states ==== + producer_state_Q = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Q + ) + producer_state_Qv0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Qvi + ) + producer_state_Qv1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Qvi + ) + if const_expr(self.use_tma_KV): + producer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_K + ) + producer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + producer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + if const_expr(self.use_tma_O): + producer_phase_O = Int32(1) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + even_n_blocks = num_n_blocks % 2 == 0 and num_n_blocks > 0 + num_n_block_groups = cute.ceil_div(num_n_blocks, self.num_stages_S) + + # ==== Partition GMEM tensors ==== + # (seqlen_q, hdim or hdimv//2) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + mQvs_cur = [ + seqlen.offset_batch_Q(mQvs[split], batch_idx, dim=3)[None, None, head_idx] + for split in range(self.num_hdimv_splits) + ] + # (mma_tile_m, hdim or hdimv//2) + gQ = cute.local_tile( + mQ_cur, + (self.mma_tiler_QK[0], self.mma_tiler_QK[2]), + (cluster_m_block, 0), + ) + gQvs = [ + cute.local_tile( + mQvs_cur[split], + (self.mma_tiler_QviVi[0], self.mma_tiler_QviVi[2]), + (cluster_m_block, 0), + ) for split in range(self.num_hdimv_splits) + ] + tSgQ = thr_mma_QK.partition_A(gQ) + tSgQvs = [ + thr_mma_QviVi.partition_A(gQvs[split]) + for split in range(self.num_hdimv_splits) + ] + tQsQ, tQgQ = cpasync.tma_partition( + atom=tma_atom_Q, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sQ, 0, 3), + gmem_tensor=cute.group_modes(tSgQ, 0, 3), + ) + tQvsQvs, tQvgQvs = zip(*[ + cpasync.tma_partition( + atom=tma_atom, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sQv, 0, 3), + gmem_tensor=cute.group_modes(tSgQv, 0, 3), + ) + for tma_atom, sQv, tSgQv in zip(tma_atom_Qvs, sQvs, tSgQvs) + ]) + + if const_expr(self.use_tma_KV): + # (seqlen_k, hdim) or (seqlen_k, hdimv//2) + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mVs_cur = [ + seqlen.offset_batch_K(mVs[split], batch_idx, dim=3)[None, None, head_idx_kv] + for split in range(self.num_hdimv_splits) + ] + # (hdimv//2, seqlen_k) + if const_expr(not seqlen.has_cu_seqlens_k): + mVts_cur = [ + mVts[split][None, None, head_idx_kv, batch_idx] + for split in range(self.num_hdimv_splits) + ] + else: + mVts_cur = [ + cute.domain_offset((0, seqlen.offset_k), mVts[split][None, None, head_idx_kv]) + for split in range(self.num_hdimv_splits) + ] + # (tile_n, hdim or hdimv//2, num_n_blocks) + gK = cute.local_tile( + mK_cur, + (self.mma_tiler_QK[1], self.mma_tiler_QK[2]), + (None, 0), + ) + gVs = [ + cute.local_tile( + mVs_cur[split], + (self.mma_tiler_QviVi[1], self.mma_tiler_QviVi[2]), + (None, 0), + ) for split in range(self.num_hdimv_splits) + ] + # (hdim or hdimv//2, tile_n, num_n_blocks) + gVts = [ + cute.local_tile( + mVts_cur[split], + (self.mma_tiler_PVti[1], self.mma_tiler_PVti[2]), + (0, None), + ) for split in range(self.num_hdimv_splits) + ] + tSgK = thr_mma_QK.partition_B(gK) + tSgVs = [ + thr_mma_QviVi.partition_B(gVs[split]) + for split in range(self.num_hdimv_splits) + ] + tOgVts = [ + thr_mma_PVti.partition_B(gVts[split]) + for split in range(self.num_hdimv_splits) + ] + tKsK, tKgK = cpasync.tma_partition( + atom=tma_atom_K, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sK, 0, 3), + gmem_tensor=cute.group_modes(tSgK, 0, 3), + ) + tVsVs, tVgVs = zip(*[ + cpasync.tma_partition( + atom=tma_atom, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sV, 0, 3), + gmem_tensor=cute.group_modes(tSgV, 0, 3), + ) + for tma_atom, sV, tSgV in zip(tma_atom_Vs, sVs, tSgVs) + ]) + tVtsVts, tVtgVts = zip(*[ + cpasync.tma_partition( + atom=tma_atom, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sVt, 0, 3), + gmem_tensor=cute.group_modes(tOgV, 0, 3), + ) + for tma_atom, sVt, tOgV in zip(tma_atom_Vts, sVts, tOgVts) + ]) + + load_Q = partial(self.load_inner, tma_atom_Q, tQgQ, tQsQ, pipeline_Q) + load_Qv = partial(self.load_inner, tma_atom_Qvs, tQvgQvs, tQvsQvs, pipeline_Qvs) + if const_expr(self.use_tma_KV): + load_K = partial(self.load_inner, tma_atom_K, tKgK, tKsK, pipeline_K) + load_V = partial(self.load_inner, tma_atom_Vs, tVgVs, tVsVs, pipeline_Vs) + load_Vt = partial(self.load_inner, tma_atom_Vts, tVtgVts, tVtsVts, pipeline_Vs) + + # ==== Load stationary operands ==== + + # copy Q, Qvi gmem -> smem + producer_state_Q = load_Q(producer_state_Q) + producer_state_Qv0 = load_Qv(producer_state_Qv0, split=0) + producer_state_Qv1 = load_Qv(producer_state_Qv1, split=1) + + if const_expr(self.use_tma_KV): + # ==== Prologue ==== + n_block_first = n_block_max - 1 + # copy K gmem -> smem + producer_state_K = load_K(producer_state_K, n_block=n_block_first) + # copy Vi gmem -> smem + producer_state_V0 = load_V(producer_state_V0, n_block=n_block_first, split=0) + producer_state_V1 = load_V(producer_state_V1, n_block=n_block_first, split=1) + + if const_expr(self.use_tma_O and self.overlap_sO_sV): + cute.arch.mbarrier_wait(sO_empty_mbar_ptr, phase=producer_phase_O) + producer_phase_O ^= 1 + + # ==== Main loop ==== + for n_block_group in cutlass.range(num_n_block_groups-1, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + n_block = n_block_max - 1 - n_block_group * self.num_stages_S - stage + # copy K gmem -> smem + producer_state_K = load_K(producer_state_K, n_block=n_block-1) + # copy Vi gmem -> smem + producer_state_V0 = load_V(producer_state_V0, n_block=n_block-1, split=0) + producer_state_V1 = load_V(producer_state_V1, n_block=n_block-1, split=1) + # copy Vti gmem -> smem + producer_state_V0 = load_Vt(producer_state_V0, n_block=n_block, split=0) + producer_state_V1 = load_Vt(producer_state_V1, n_block=n_block, split=1) + + # ==== Epilogue ==== + num_final_n_blocks = self.num_stages_S if even_n_blocks else self.num_stages_S - 1 + for stage in cutlass.range(num_final_n_blocks, unroll_full=True): + n_block = num_final_n_blocks - 1 - stage + if n_block > 0: + # copy K gmem -> smem + producer_state_K = load_K(producer_state_K, n_block=n_block-1) + # copy Vi gmem -> smem + producer_state_V0 = load_V(producer_state_V0, n_block=n_block-1, split=0) + producer_state_V1 = load_V(producer_state_V1, n_block=n_block-1, split=1) + # copy Vti gmem -> smem + producer_state_V0 = load_Vt(producer_state_V0, n_block=n_block, split=0) + producer_state_V1 = load_Vt(producer_state_V1, n_block=n_block, split=1) + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + + pipeline_Q.producer_tail(producer_state_Q) + pipeline_Qv0.producer_tail(producer_state_Qv0) + pipeline_Qv1.producer_tail(producer_state_Qv1) + if const_expr(self.use_tma_KV): + pipeline_K.producer_tail(producer_state_K) + pipeline_V0.producer_tail(producer_state_V0) + pipeline_V1.producer_tail(producer_state_V1) + + @cute.jit + def load_inner( + self, + tma_atom: cute.CopyAtom, + tXgX: cute.Tensor, + tXsX: cute.Tensor, + load_pipeline: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + n_block: Optional[Int32] = None, + split: Optional[Int32] = None, + ): + stage = producer_state.index + if const_expr(split is not None): + tma_atom = tma_atom[split] + tXgX = tXgX[split] + tXsX = tXsX[split] + load_pipeline = load_pipeline[split] + if const_expr(n_block is not None): + tXgX = tXgX[(None, n_block)] + tXsX = tXsX[(None, stage)] + + load_pipeline.producer_acquire(producer_state) + tma_bar_ptr = load_pipeline.producer_get_barrier(producer_state) + cute.copy(tma_atom, tXgX, tXsX, tma_bar_ptr=tma_bar_ptr) + producer_state.advance() + return producer_state + + @cute.jit + def mma( + self, + sQ: cute.Tensor, + sK: cute.Tensor, + sQv0: cute.Tensor, + sQv1: cute.Tensor, + sV0: cute.Tensor, + sV1: cute.Tensor, + sVt0: cute.Tensor, + sVt1: cute.Tensor, + sP: cute.Tensor, + tiled_mma_QK: cute.TiledMma, + tiled_mma_QviVi: cute.TiledMma, + tiled_mma_PVti: cute.TiledMma, + pipeline_Q: pipeline.PipelineAsync, + pipeline_K: pipeline.PipelineAsync, + pipeline_Qv0: pipeline.PipelineAsync, + pipeline_Qv1: pipeline.PipelineAsync, + pipeline_V0: pipeline.PipelineAsync, + pipeline_V1: pipeline.PipelineAsync, + pipeline_S: pipeline.PipelineAsync, + pipeline_P: pipeline.PipelineAsync, + pipeline_O0: pipeline.PipelineAsync, + pipeline_O1: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + is_leader_cta: Boolean, + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== mma warp ==== + # Description: Computes Q @ K^T, Qv @ V^T, and P @ V + # Produces: S, O + # Consumes: Q, K, Qv, V, P + + pipelines_V = [pipeline_V0, pipeline_V1] + pipelines_Qv = [pipeline_Qv0, pipeline_Qv1] + pipelines_O = [pipeline_O0, pipeline_O1] + + sQvs = [sQv0, sQv1] + sVs = [sV0, sV1] + sVts = [sVt0, sVt1] + + # Set accumulate = True for Qv @ V^T since we are accumulating on the Q @ K^T result + tiled_mma_QviVi.set(tcgen05.Field.ACCUMULATE, True) + + # Operands for S = Q @ K^T + tSrQ = tiled_mma_QK.make_fragment_A(sQ) + tSrK = tiled_mma_QK.make_fragment_B(sK) + + # Operands for S += Qv @ V^T + tSrQvs = [ + tiled_mma_QviVi.make_fragment_A(sQvs[split]) + for split in range(self.num_hdimv_splits) + ] + tSrVs = [ + tiled_mma_QviVi.make_fragment_B(sVs[split]) + for split in range(self.num_hdimv_splits) + ] + + # Operands for Oi = P @ Vi + tOrP = tiled_mma_PVti.make_fragment_A(sP) + tOrVts = [ + tiled_mma_PVti.make_fragment_B(sVts[split]) + for split in range(self.num_hdimv_splits) + ] + + # GEMM functions + gemm_QK = [ + partial( + fa_sm100_utils.gemm_ptx_partial, + tiled_mma_QK.op, + self.tmem_offset_S[stage], + tCrA=tSrQ[None, None, None, 0], + sA=sQ[None, None, None, 0], + zero_init=True, + cta_group=self.cta_group_size, + ) for stage in range(self.num_stages_S) + ] + gemms_QvV = [ + [ + partial( + fa_sm100_utils.gemm_ptx_partial, + tiled_mma_QviVi.op, + self.tmem_offset_S[stage], + tCrA=tSrQvs[split][None, None, None, 0], + sA=sQvs[split][None, None, None, 0], + zero_init=False, + cta_group=self.cta_group_size, + ) for stage in range(self.num_stages_S) + ] for split in range(self.num_hdimv_splits) + ] + gemms_PVt = [ + partial( + fa_sm100_utils.gemm_ptx_partial, + tiled_mma_PVti.op, + self.tmem_offsets_O[split], + tOrP[None, None, None, 0], + sA=sP[None, None, None, 0], + cta_group=self.cta_group_size, + ) for split in range(self.num_hdimv_splits) + ] + + consumer_state_Q = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Q + ) + consumer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_K + ) + consumer_state_Qv0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Qvi + ) + consumer_state_Qv1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Qvi + ) + consumer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Vi + ) + consumer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Vi + ) + producer_state_S = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_S + ) + consumer_state_P = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_P + ) + producer_state_O0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Oi + ) + producer_state_O1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Oi + ) + + mma_QK = partial(self.mma_inner, gemm_QK, pipeline_K, tSrK, sK) + mma_QvV = partial(self.mma_inner, gemms_QvV, pipelines_V, tSrVs, sVs) + mma_PVt = partial(self.mma_inner, gemms_PVt, pipelines_V, tOrVts, sVts) + + work_tile = tile_scheduler.initial_work_tile_info() + O_should_accumulate = False + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + # n_block_max = self.topk_length // self.tile_n + n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + even_n_blocks = num_n_blocks % 2 == 0 and num_n_blocks > 0 + num_n_block_groups = cute.ceil_div(num_n_blocks, self.num_stages_S) + + if is_leader_cta: + pipeline_Q.consumer_wait(consumer_state_Q) + pipeline_Qv0.consumer_wait(consumer_state_Qv0) + pipeline_Qv1.consumer_wait(consumer_state_Qv1) + + consumer_states_V = [consumer_state_V0, consumer_state_V1] + producer_states_O = [producer_state_O0, producer_state_O1] + + # ==== Prologue ==== + pipeline_S.producer_acquire(producer_state_S) + # S = Q @ K^T + consumer_state_K = mma_QK(consumer_state_K, stage=0) + # S += Qvi @ Vi^T + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_states_V[split] = mma_QvV( + consumer_states_V[split], stage=0, split=split + ) + pipeline_S.producer_commit(producer_state_S) + producer_state_S.advance() + + # ==== Mainloop ==== + for _ in cutlass.range(num_n_block_groups-1, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + next_stage = const_expr((stage + 1) % self.num_stages_S) + pipeline_S.producer_acquire(producer_state_S) + # S = Q @ K^T + consumer_state_K = mma_QK(consumer_state_K, stage=next_stage) + # S += Qvi @ Vi^T + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_states_V[split] = mma_QvV( + consumer_states_V[split], stage=next_stage, split=split + ) + pipeline_S.producer_commit(producer_state_S) + producer_state_S.advance() + # Oi += P @ Vi + pipeline_P.consumer_wait(consumer_state_P) + for split in cutlass.range_constexpr(self.num_hdimv_splits): + producer_state_Oi = producer_states_O[split] + pipelines_O[split].producer_acquire(producer_state_Oi) + consumer_states_V[split] = mma_PVt( + consumer_states_V[split], + split=split, + zero_init=not O_should_accumulate + ) + pipelines_O[split].producer_commit(producer_state_Oi) + producer_state_Oi.advance() + producer_states_O[split] = producer_state_Oi + pipeline_P.consumer_release(consumer_state_P) + consumer_state_P.advance() + O_should_accumulate = True + + # ==== Epilogue ==== + num_final_n_blocks = self.num_stages_S if even_n_blocks else self.num_stages_S - 1 + for stage in cutlass.range_constexpr(self.num_stages_S): + n_block = num_final_n_blocks - 1 - stage + if const_expr(stage == 0): + if n_block > 0: + pipeline_S.producer_acquire(producer_state_S) + # S = Q @ K^T + consumer_state_K = mma_QK(consumer_state_K, stage=stage+1) + # S += Qvi @ Vi^T + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_states_V[split] = mma_QvV( + consumer_states_V[split], stage=stage+1, split=split + ) + pipeline_S.producer_commit(producer_state_S) + producer_state_S.advance() + if n_block >= 0: + # Oi += P @ Vi + pipeline_P.consumer_wait(consumer_state_P) + for split in cutlass.range_constexpr(self.num_hdimv_splits): + producer_state_Oi = producer_states_O[split] + pipelines_O[split].producer_acquire(producer_state_Oi) + consumer_states_V[split] = mma_PVt( + consumer_states_V[split], + split=split, + zero_init=not O_should_accumulate + ) + pipelines_O[split].producer_commit(producer_state_Oi) + producer_state_Oi.advance() + producer_states_O[split] = producer_state_Oi + pipeline_P.consumer_release(consumer_state_P) + consumer_state_P.advance() + O_should_accumulate = True + + consumer_state_V0, consumer_state_V1 = consumer_states_V + producer_state_O0, producer_state_O1 = producer_states_O + + pipeline_Q.consumer_release(consumer_state_Q) + + # if we overlap sOi with sQvi for tma store, need to acquire signal + if const_expr(self.use_tma_O and not self.overlap_sO_sV): + pipeline_O0.producer_tail(producer_state_O0.clone()) + pipeline_O1.producer_tail(producer_state_O1.clone()) + + pipeline_Qv0.consumer_release(consumer_state_Qv0) + pipeline_Qv1.consumer_release(consumer_state_Qv1) + consumer_state_Q.advance() + consumer_state_Qv0.advance() + consumer_state_Qv1.advance() + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + O_should_accumulate = False + + pipeline_S.producer_tail(producer_state_S) + pipeline_O0.producer_tail(producer_state_O0) + pipeline_O1.producer_tail(producer_state_O1) + + @cute.jit + def mma_inner( + self, + gemm, + load_pipeline, + tCrB, + sB, + consumer_state: pipeline.PipelineState, + stage: Optional[Int32] = None, + split: Optional[Int32] = None, + zero_init: Optional[bool] = None, + ): + if const_expr(split is not None): + gemm = gemm[split] + load_pipeline = load_pipeline[split] + tCrB = tCrB[split] + sB = sB[split] + if const_expr(stage is not None): + gemm = gemm[stage] + + smem_stage = consumer_state.index + tCrB_cur = tCrB[None, None, None, smem_stage] + sB_cur = sB[None, None, None, smem_stage] + + load_pipeline.consumer_wait(consumer_state) + if const_expr(zero_init is not None): + gemm(tCrB=tCrB_cur, sB=sB_cur, zero_init=zero_init) + else: + gemm(tCrB=tCrB_cur, sB=sB_cur) + load_pipeline.consumer_release(consumer_state) + consumer_state.advance() + return consumer_state + + + @cute.jit + def softmax_loop( + self, + softmax_scale: Float32, + softmax_scale_log2: Float32, + mLSE: Optional[cute.Tensor], + sRowMax: cute.Tensor, + sRowSum: cute.Tensor, + sScale: cute.Tensor, + sBitmask: Optional[cute.Tensor], + sP: cute.Tensor, + tStS: cute.Tensor, + thr_mma_QK: cute.ThrMma, + pipeline_S: pipeline.PipelineAsync, + pipeline_P: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + pipeline_bitmask: Optional[pipeline.PipelineAsync], + sO_empty_mbar_ptr: Optional[cute.Pointer], + AttentionMaskCls: Callable, + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== softmax warpgroup ==== + # Description: computes softmax on S and writes the result to P + # Produces: P, softmax stats + # Consumes: S, bitmask (for topk sparsity) + + tidx = cute.arch.thread_idx()[0] % self.num_softmax_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % (self.num_softmax_threads//32) + + tSAcc = tStS[(None, None), 0, 0, 0] + tSAcc_staged = [tStS[(None, None), 0, 0, stage] for stage in range(self.num_stages_S)] + + cS = cute.make_identity_tensor(self.mma_tiler_QK[:2]) # (128, 128) + tScS = thr_mma_QK.partition_C(cS)[(None, None), 0, 0] # (64, 128) + + # S tmem -> rmem copy objects + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.dtype_acc, + ) + tmem_load_tiled = tcgen05.make_tmem_copy(tmem_load_atom, tSAcc) + tmem_load_thr = tmem_load_tiled.get_slice(tidx) + # S tmem -> rmem copy operands + tStS_t2r = tmem_load_thr.partition_S(tSAcc) # (((32, 32), 1), 1, 2) + tStS_t2r_staged = [tmem_load_thr.partition_S(tSAcc_staged[stage]) for stage in range(self.num_stages_S)] + tScS_t2r = tmem_load_thr.partition_D(tScS) + tSrS_t2r = cute.make_rmem_tensor(tScS_t2r.shape, self.dtype_acc) + + # P rmem -> smem copy objects + universal_copy_bits = 128 + smem_store_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype_P, + num_bits_per_copy=universal_copy_bits, + ) + smem_store_tiled = cute.make_tiled_copy_D(smem_store_atom, tmem_load_tiled) + smem_store_thr = smem_store_tiled.get_slice(tidx) + # P rmem -> smem copy operands + sP_slice = sP[None, None, None, 0] + sP_mn = cute.make_tensor( + sP_slice.iterator, + cute.make_layout( + ( + (sP_slice.shape[0][0], sP_slice.shape[1]), + (sP_slice.shape[0][1], sP_slice.shape[2]), + ), + stride=( + (sP_slice.stride[0][0], sP_slice.stride[1]), + (sP_slice.stride[0][1], sP_slice.stride[2]), + ), + ), + ) + sP_smem_view = smem_store_thr.partition_D(sP_mn) + + consumer_state_S = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_S + ) + producer_state_P = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_P + ) + producer_state_sm_stats = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_sm_stats + ) + consumer_state_bitmask = None + if const_expr(self.is_topk_gather and not self.disable_bitmask): + consumer_state_bitmask = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_bitmask + ) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + even_n_blocks = num_n_blocks % 2 == 0 and num_n_blocks > 0 + num_n_block_groups = cute.ceil_div(num_n_blocks, self.num_stages_S) + + mask = AttentionMaskCls(seqlen) + mask_fn = partial( + mask.apply_mask_sm100, + m_block=cluster_m_block, + thr_mma=thr_mma_QK, + thr_tmem_load=tmem_load_thr, + mask_causal=self.is_causal, + mask_local=self.is_local, + batch_idx=batch_idx, + head_idx=head_idx, + r2p=False, # TODO: fix r2p for 2cta + ) + disable_mask = self.disable_bitmask and self.is_topk_gather + + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=8.0 if const_expr(self.dtype_Q.width == 16) else 0.0, + softmax_scale=softmax_scale, + ) + softmax.reset() + + softmax_step_fn = partial( + self.softmax_step, + softmax, + sRowMax, + sScale, + sBitmask, + tStS_t2r_staged, + tSrS_t2r, + sP_smem_view, + tmem_load_thr, + smem_store_thr, + pipeline_S, + pipeline_P, + pipeline_sm_stats, + pipeline_bitmask, + tidx, + warp_idx, + ) + + ### first iteration ### + n_block = n_block_max - 1 + (consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask + ) = softmax_step_fn( + consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask, + 0, n_block, + mask_fn=partial(mask_fn, mask_seqlen=True) + if not const_expr(disable_mask) else None, + is_first=True, + ) + n_block -= 1 + + ### Separate iterations with causal masking + # note: For square mma tile, can mask at most 1 n_block_group + if const_expr((self.is_causal or self.is_local) and not self.is_topk_gather): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, cluster_m_block, n_block_min + ) + num_masked_n_blocks = n_block_max - 1 - n_block_min_causal_local_mask + num_masked_n_block_groups = min( + num_n_block_groups - 1, + cute.ceil_div(num_masked_n_blocks, self.num_stages_S) + ) + num_n_block_groups -= num_masked_n_block_groups + for _ in cutlass.range(num_masked_n_block_groups, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + (consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask + ) = softmax_step_fn( + consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask, + 1-stage, n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + n_block -= 1 + + ### Mainloop ### + for n_block_group in cutlass.range(num_n_block_groups-1, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + (consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask + ) = softmax_step_fn( + consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask, + 1-stage, n_block, + mask_fn=partial(mask_fn, mask_seqlen=False) + if const_expr(self.is_topk_gather and not self.disable_bitmask) else None, + ) + n_block -= 1 + + ### last iteration if even ### + # always mask to simplify logic + if even_n_blocks: + (consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask + ) = softmax_step_fn( + consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask, + 1, n_block, + mask_fn=partial(mask_fn, mask_seqlen=False) + if not const_expr(disable_mask) else None, + ) + n_block -= 1 + + # write row max and sum to smem + sRowSum[tidx % self.cta_tile_m, warp_idx//self.cta_group_size] = softmax.row_sum[0] + if const_expr(mLSE is not None): + if tidx < self.cta_tile_m: + sRowMax[tidx, 0] = softmax.row_max[0] + self.sm_stats_barrier_full.arrive() + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + self.sm_stats_barrier_empty.arrive_and_wait() + + pipeline_P.producer_tail(producer_state_P) + pipeline_sm_stats.producer_tail(producer_state_sm_stats) + + + @cute.jit + def softmax_step( + self, + softmax: SoftmaxSm100, + sRowMax: cute.Tensor, + sScale: cute.Tensor, + sBitmask: Optional[cute.Tensor], + tStS_t2r_staged: cute.Tensor, + tSrS_t2r: cute.Tensor, + sP_smem_view: cute.Tensor, + tmem_load_thr: cute.CopyAtom, + smem_store_thr: cute.CopyAtom, + pipeline_S: pipeline.PipelineAsync, + pipeline_P: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + pipeline_bitmask: Optional[pipeline.PipelineAsync], + tidx: Int32, + warp_idx: Int32, + consumer_state_S: pipeline.PipelineState, + producer_state_P: pipeline.PipelineState, + producer_state_sm_stats: pipeline.PipelineState, + consumer_state_bitmask: Optional[pipeline.PipelineState], + stage: cutlass.Constexpr[Int32], + n_block: Int32, + mask_fn: Optional[Callable] = None, + is_first: Boolean = False, + ): + tSrP = cute.make_rmem_tensor(tSrS_t2r.shape, self.dtype_P) + rP_smem_view = smem_store_thr.retile(tSrP) + + pipeline_S.consumer_wait(consumer_state_S) + cute.copy(tmem_load_thr, tStS_t2r_staged[stage], tSrS_t2r) + cute.arch.fence_view_async_tmem_load() + pipeline_S.consumer_release(consumer_state_S) + + rBitmask = None + if const_expr(self.is_topk_gather and not self.disable_bitmask): + assert pipeline_bitmask is not None + assert consumer_state_bitmask is not None + pipeline_bitmask.consumer_wait(consumer_state_bitmask) + rBitmask = cute.make_rmem_tensor((self.tile_n//64, ), dtype=Uint32) + bitmask_col_offset = self.tile_n//64 if warp_idx >= 2 else 0 + for i in cutlass.range_constexpr(cute.size(rBitmask)): + rBitmask[i] = sBitmask[bitmask_col_offset + i, consumer_state_bitmask.index] + + if const_expr(mask_fn is not None): + mask_fn(tSrS_t2r, n_block=n_block, rBitmask=rBitmask) + + # compute threadwise row_max + row_max = softmax.compute_row_max_local(tSrS_t2r.load(), is_first) + self.softmax_barrier.arrive_and_wait() + + # 2-thread reduce row_max through smem + assert self.cta_tile_m * self.cta_group_size == 128 + sRowMax[tidx % self.cta_tile_m, warp_idx//self.cta_group_size] = row_max + self.softmax_barrier.arrive_and_wait() + # must release after barrier sync + if const_expr(self.is_topk_gather and not self.disable_bitmask): + pipeline_bitmask.consumer_release(consumer_state_bitmask) + row_max0 = sRowMax[tidx % self.cta_tile_m, 0] + row_max1 = sRowMax[tidx % self.cta_tile_m, 1] + row_max = max(row_max0, row_max1) + + row_max, acc_scale = softmax.update_row_max_from_local(row_max, is_first) + + # note: acc_scales agree for paired threads + pipeline_sm_stats.producer_acquire(producer_state_sm_stats) + if warp_idx < self.cta_group_size: + sScale[tidx % self.cta_tile_m, producer_state_sm_stats.index] = acc_scale + pipeline_sm_stats.producer_commit(producer_state_sm_stats) + + # x -> scale_log2*x-rowmax + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + + # x -> exp2(x) + softmax.apply_exp2_convert(tSrS_t2r, tSrP) + + pipeline_P.producer_acquire(producer_state_P) + cute.copy(smem_store_thr, rP_smem_view, sP_smem_view) + cute.arch.fence_view_async_shared() + pipeline_P.producer_commit(producer_state_P) + + consumer_state_S.advance() + producer_state_P.advance() + producer_state_sm_stats.advance() + if const_expr(self.is_topk_gather and not self.disable_bitmask): + consumer_state_bitmask.advance() + + softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) + + return consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask + + + @cute.jit + def correction_loop( + self, + softmax_scale_log2: Float32, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + tma_atom_O: Optional[cute.CopyAtom], + sRowMax: cute.Tensor, + sRowSum: cute.Tensor, + sScale: cute.Tensor, + sO: cute.Tensor, + tO0tO0: cute.Tensor, + tO1tO1: cute.Tensor, + pipeline_O0: pipeline.PipelineAsync, + pipeline_O1: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + tiled_copy_O_r2g: cute.TiledCopy, + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + ### ==== correction/epilogue warpgroup ==== + # Correction: copy scale smem -> rmem, copy O tmem -> rmem, rescale O, store O rmem -> tmem + # Epilogue: copy O tmem -> rmem, do final scaling of O, store O rmem -> gmem, + # optionally store LSE + # Produces: - + # Consumes: O, softmax stats + + tidx = cute.arch.thread_idx()[0] % self.num_epilogue_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % (self.num_epilogue_threads//32) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + leader_warp = warp_idx==0 + + tO0tO0 = tO0tO0[(None, None), 0, 0] # (64, (128, 2)) + tO1tO1 = tO1tO1[(None, None), 0, 0] # (64, (128, 2)) + tOtOs = [tO0tO0, tO1tO1] + + # tuneable parameter + corr_tile_size = math.gcd(32, self.tmem_cols_Oi) + + tmem_load_atom_O = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.dtype_acc, + ) + tmem_store_atom_O = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.dtype_acc, + ) + thr_tmem_load_O = tcgen05.make_tmem_copy(tmem_load_atom_O, tO0tO0).get_slice(tidx) + thr_tmem_store_O = tcgen05.make_tmem_copy(tmem_store_atom_O, tO0tO0).get_slice(tidx) + + # ((32,1),1,4) + tOtOs_t2r = [ + thr_tmem_load_O.partition_S(tOtOs[split]) + for split in range(self.num_hdimv_splits) + ] + tOtOs_r2t = [ + thr_tmem_store_O.partition_D(tOtOs[split]) + for split in range(self.num_hdimv_splits) + ] + + cOi = cute.make_identity_tensor((self.cta_tile_m, self.hdimv // self.num_hdimv_splits)) + thr_tiled_copy_O_r2g = tiled_copy_O_r2g.get_slice(tidx) + tOicOi = thr_tiled_copy_O_r2g.partition_S(cOi) + + tOicOi_t2r = thr_tmem_load_O.partition_D(tOicOi[(None, None), 0, 0]) + + pipelines_O = [pipeline_O0, pipeline_O1] + + consumer_state_O0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Oi + ) + consumer_state_O1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Oi + ) + consumer_state_sm_stats = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_sm_stats + ) + + do_correction_rescale = partial( + self.correction_rescale, + thr_tmem_load_O, + thr_tmem_store_O, + tOicOi_t2r, + ) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + + consumer_states_O = [consumer_state_O0, consumer_state_O1] + + # acquire first signal and release immediately + pipeline_sm_stats.consumer_wait(consumer_state_sm_stats) + pipeline_sm_stats.consumer_release(consumer_state_sm_stats) + consumer_state_sm_stats.advance() + + for _ in cutlass.range(num_n_blocks - 1, unroll=1): + pipeline_sm_stats.consumer_wait(consumer_state_sm_stats) + scale = sScale[tidx % self.cta_tile_m, consumer_state_sm_stats.index] + should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 + pipeline_sm_stats.consumer_release(consumer_state_sm_stats) + consumer_state_sm_stats.advance() + + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_state_Oi = consumer_states_O[split] + pipelines_O[split].consumer_wait(consumer_state_Oi) + if should_rescale: + do_correction_rescale( + tOtOs_t2r[split], + tOtOs_r2t[split], + scale, + ) + pipelines_O[split].consumer_release(consumer_state_Oi) + consumer_state_Oi.advance() + consumer_states_O[split] = consumer_state_Oi + + # (seqlen_q, hdimv) + mO_cur = seqlen.offset_batch_Q( + mO, batch_idx, dim=3, ragged=self.ragged_tma_O + )[None, None, head_idx] + # (cta_tile_m, hdimv//2, 2) + gO = cute.local_tile( + mO_cur, + (self.cta_tile_m, self.hdimv // self.num_hdimv_splits), + (cta_m_block, None), + ) + tOgO = thr_tiled_copy_O_r2g.partition_D(gO) + # ((32, 1), 1, 4) + tOrOs_t2r = [ + cute.make_rmem_tensor(tOicOi_t2r.shape, self.dtype_acc) + for split in range(self.num_hdimv_splits) + ] + tOrOs_r2g_f32 = [ + thr_tiled_copy_O_r2g.retile(tOrOs_t2r[split]) + for split in range(self.num_hdimv_splits) + ] + tOrOs_r2g = [ + cute.make_rmem_tensor_like(tOrOs_r2g_f32[split], self.dtype_O) + for split in range(self.num_hdimv_splits) + ] + if const_expr(self.use_tma_O): + tOsO = thr_tiled_copy_O_r2g.partition_D(sO) + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), + sO, gO, + ) + + self.sm_stats_barrier_full.arrive_and_wait() + + row_sum0 = sRowSum[tidx % self.cta_tile_m, 0] + row_sum1 = sRowSum[tidx % self.cta_tile_m, 1] + row_sum = row_sum0 + row_sum1 + acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) + + self.sm_stats_barrier_empty.arrive() + + seqlen_q = ( + seqlen.seqlen_q + if const_expr(not self.pack_gqa) + else seqlen.seqlen_q * self.qhead_per_kvhead + ) + + # compute and store lse to gmem + if const_expr(mLSE is not None): + if const_expr(not seqlen.has_cu_seqlens_q): + mLSE_cur = mLSE[None, head_idx, batch_idx] + else: + lse_offset = ( + seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + ) + mLSE_cur = cute.domain_offset((lse_offset,), mLSE[None, head_idx]) + gLSE = cute.local_tile(mLSE_cur, (self.cta_tile_m,), (cta_m_block,)) + if tidx < self.cta_tile_m: + row_max = sRowMax[tidx, 0] + LN2 = math.log(2.0) + lse = ( + (row_max * softmax_scale_log2 + cute.math.log2(row_sum, fastmath=True)) * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + if tidx < seqlen_q - cta_m_block * self.cta_tile_m: + gLSE[tidx] = lse + + row_idx = cta_m_block * self.cta_tile_m + tOicOi[0][0] + + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_state_Oi = consumer_states_O[split] + pipelines_O[split].consumer_wait(consumer_state_Oi) + # copy Oi tmem -> rmem + cute.copy( + thr_tmem_load_O, + tOtOs_t2r[split], + tOrOs_t2r[split], + ) + + # scale and downcast Oi + tOrOs_r2g[split].store((tOrOs_r2g_f32[split].load() * scale).to(self.dtype_O)) + + if const_expr(not self.use_tma_O): + # copy Oi rmem -> gmem + if row_idx < seqlen_q: + cute.copy( + thr_tiled_copy_O_r2g, + tOrOs_r2g[split], + tOgO[None, None, None, split], + ) + else: + # copy Oi rmem -> smem + if const_expr(self.overlap_sO_sV): + # last slot for Vti is always 1, 3 + sO_idx = 1 + 2 * split + else: + sO_idx = split + cute.copy( + thr_tiled_copy_O_r2g, + tOrOs_r2g[split], + tOsO[None, None, None, sO_idx], + ) + cute.arch.fence_view_async_shared() + self.epi_barrier.arrive_and_wait() + # tma store Oi smem -> gmem + if leader_warp: + store_O(src_idx=sO_idx, dst_idx=split) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(1 - split, read=True) + if const_expr(split == 1 and self.overlap_sO_sV): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(sO_empty_mbar_ptr) + + consumer_state_O0, consumer_state_O1 = consumer_states_O + + cute.arch.fence_view_async_tmem_load() + pipeline_O0.consumer_release(consumer_state_O0) + pipeline_O1.consumer_release(consumer_state_O1) + consumer_state_O0.advance() + consumer_state_O1.advance() + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + + @cute.jit + def correction_rescale( + self, + thr_tmem_load: cute.CopyAtom, + thr_tmem_store: cute.CopyAtom, + tOcO_t2r: cute.Tensor, + tOtO_t2r: cute.Tensor, + tOtO_r2t: cute.Tensor, + scale: Float32, + ): + tOrO_t2r_frg = cute.make_rmem_tensor_like(tOcO_t2r[None, None, 0], self.dtype_acc) + + for i in cutlass.range_constexpr(cute.size(tOtO_t2r, mode=[2])): + tOtO_t2r_cur = tOtO_t2r[None, None, i] + tOtO_r2t_cur = tOtO_r2t[None, None, i] + + cute.copy(thr_tmem_load, tOtO_t2r_cur, tOrO_t2r_frg) + for j in cutlass.range(0, cute.size(tOrO_t2r_frg), 2, unroll_full=True): + tOrO_t2r_frg[j], tOrO_t2r_frg[j + 1] = cute.arch.mul_packed_f32x2( + (tOrO_t2r_frg[j], tOrO_t2r_frg[j + 1]), (scale, scale) + ) + cute.copy(thr_tmem_store, tOrO_t2r_frg, tOtO_r2t_cur) + cute.arch.fence_view_async_tmem_store() + + +def test_mla_kernel( + seqlen_q=2048, + seqlen_k=2048, + topk_length=2048, + nheads=1, + batch=1, + iter=0, + compile_cache=dict(), + validate=True, + seed=0, + gather_kv=True, + pack_gqa=False, + is_causal=False, + varlen_q=False, + varlen_k=False, + disable_bitmask=False, +): + torch.manual_seed(seed) + hdim = 64 + hdimv = 512 + softmax_scale = 1.0 / math.sqrt(hdim + hdimv) + + nheads_kv = 1 + qhead_per_kvhead = nheads + + compile_key = ( + is_causal, + gather_kv, + topk_length if gather_kv else None, + pack_gqa, + qhead_per_kvhead, + nheads_kv, + varlen_q, + varlen_k, + disable_bitmask, + ) + if compile_key not in compile_cache: + total_q_dummy = batch * seqlen_q + total_k_dummy = batch * seqlen_k + + if varlen_q: + Q = torch.randn(total_q_dummy, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(total_q_dummy, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(total_q_dummy, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(nheads, total_q_dummy, dtype=torch.float32, device="cuda") + index_topk = ( + torch.rand(total_q_dummy, topk_length, device="cuda") + .argsort(dim=-1).to(torch.int32) + ) + cu_seqlens_q_dummy = torch.arange( + 0, (batch + 1) * seqlen_q, seqlen_q, dtype=torch.int32, device="cuda" + ) + else: + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(batch, nheads, seqlen_q, dtype=torch.float32, device="cuda") + index_topk = ( + torch.rand(batch, seqlen_q, topk_length, device="cuda") + .argsort(dim=-1).to(torch.int32) + ) + + if varlen_k: + K = torch.randn(total_k_dummy, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(total_k_dummy, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + cu_seqlens_k_dummy = torch.arange( + 0, (batch + 1) * seqlen_k, seqlen_k, dtype=torch.int32, device="cuda" + ) + else: + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + mLSE = from_dlpack(lse, assumed_align=4 ).mark_layout_dynamic(leading_dim=lse.ndim - 1) + if gather_kv: + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic(leading_dim=index_topk.ndim - 1) + else: + mIndexTopk = None + + compile_kwargs = dict(mIndexTopk=mIndexTopk) + if varlen_q: + compile_kwargs["mCuSeqlensQ"] = from_dlpack(cu_seqlens_q_dummy, assumed_align=4) + if varlen_k: + compile_kwargs["mCuSeqlensK"] = from_dlpack(cu_seqlens_k_dummy, assumed_align=4) + + kernel = cute.compile( + FlashAttentionMLAForwardSm100( + is_causal=is_causal, + use_cpasync_load_KV=gather_kv, + topk_length=topk_length if gather_kv else 2048, + is_topk_gather=gather_kv, + pack_gqa=pack_gqa, + qhead_per_kvhead=qhead_per_kvhead, + nheads_kv=nheads_kv, + is_varlen_q=varlen_q, + disable_bitmask=disable_bitmask, + ), + mQ, mQv, mK, mV, mO, mLSE, softmax_scale, + **compile_kwargs, + options="--keep-ptx --keep-cubin --generate-line-info" + ) + dump_kernel_attributes(kernel) + compile_cache[compile_key] = kernel + + # ================================================================ + # ---- Generate variable seqlens for this run ---- + if varlen_q: + torch.manual_seed(seed + 1000) + # When causal without varlen_k, every per-batch seqlen_q must not exceed seqlen_k. + max_seqlen_q = (seqlen_k if (is_causal and not varlen_k) else seqlen_q) + seqlens_q = torch.randint(1, max_seqlen_q + 1, (batch,), dtype=torch.int32) + cu_seqlens_q = torch.zeros(batch + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q[1:] = seqlens_q.cumsum(0).to(torch.int32).cuda() + total_q = cu_seqlens_q[-1].item() + else: + seqlens_q = torch.full((batch,), seqlen_q, dtype=torch.int32) + total_q = None # unused + + if varlen_k: + torch.manual_seed(seed + 2000) + # Each batch item must have at least topk_length keys so topk gather is valid. + min_seqlen_k = topk_length if gather_kv else 1 + seqlens_k = torch.randint(min_seqlen_k, seqlen_k + 1, (batch,), dtype=torch.int32) + # When causal, every batch item needs seqlens_k[b] >= seqlens_q[b]. + if is_causal: + seqlens_k = torch.maximum(seqlens_k, seqlens_q) + cu_seqlens_k = torch.zeros(batch + 1, dtype=torch.int32, device="cuda") + cu_seqlens_k[1:] = seqlens_k.cumsum(0).to(torch.int32).cuda() + total_k = cu_seqlens_k[-1].item() + else: + seqlens_k = torch.full((batch,), seqlen_k, dtype=torch.int32) + total_k = None # unused + + torch.manual_seed(seed) # restore main seed before drawing actual tensors + + # ---- Allocate Q / Qv / O / lse ---- + if varlen_q: + Q = torch.randn(total_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(total_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(total_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(nheads, total_q, dtype=torch.float32, device="cuda") + else: + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(batch, nheads, seqlen_q, dtype=torch.float32, device="cuda") + + # ---- Allocate K / V ---- + if varlen_k: + K = torch.randn(total_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(total_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + else: + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + + # ---- Generate index_topk with per-batch valid ranges when varlen_k ---- + # index_topk shape: (total_q, topk_length) if varlen_q else (batch, seqlen_q, topk_length) + if gather_kv: + topk_parts = [] + for b in range(batch): + sl_q_b = seqlens_q[b].item() + sl_k_b = seqlens_k[b].item() + # Draw topk_length unique indices from [0, sl_k_b) for each query in this batch item. + topk_b = ( + torch.rand(sl_q_b, sl_k_b, device="cuda") + .argsort(dim=-1)[..., :topk_length] + .to(torch.int32) + ) # (sl_q_b, topk_length), all < sl_k_b + topk_parts.append(topk_b) + + if varlen_q: + index_topk = torch.cat(topk_parts, dim=0) # (total_q, topk_length) + else: + index_topk = torch.stack(topk_parts, dim=0) # (batch, seqlen_q, topk_length) + else: + index_topk = None + + # ---- Reference computation (per-batch loop covers all four varlen combos) ---- + O_ref_list, O_pt_list, lse_ref_list, lse_pt_list = [], [], [], [] + for b in range(batch): + qs = cu_seqlens_q[b].item() if varlen_q else b * seqlen_q + qe = cu_seqlens_q[b+1].item() if varlen_q else (b + 1) * seqlen_q + ks = cu_seqlens_k[b].item() if varlen_k else b * seqlen_k + ke = cu_seqlens_k[b+1].item() if varlen_k else (b + 1) * seqlen_k + + Q_b = Q [qs:qe].unsqueeze(0) if varlen_q else Q [b:b+1] # (1, sl_q, nheads, hdim) + Qv_b = Qv[qs:qe].unsqueeze(0) if varlen_q else Qv[b:b+1] # (1, sl_q, nheads, hdimv) + K_b = K [ks:ke].unsqueeze(0) if varlen_k else K [b:b+1] # (1, sl_k, nheads_kv, hdim) + V_b = V [ks:ke].unsqueeze(0) if varlen_k else V [b:b+1] # (1, sl_k, nheads_kv, hdimv) + if gather_kv: + topk_b = index_topk[qs:qe].unsqueeze(0) if varlen_q else index_topk[b:b+1] + else: + topk_b = None + + O_b, _, lse_b = attention_ref(Q_b, K_b, V_b, qv=Qv_b, causal=is_causal, + return_lse=True, gather_kv_indices=topk_b) + O_pt_b, _, lse_pt_b = attention_ref(Q_b, K_b, V_b, qv=Qv_b, causal=is_causal, + upcast=False, reorder_ops=True, + return_lse=True, gather_kv_indices=topk_b) + O_ref_list.append(O_b.squeeze(0)) + O_pt_list.append(O_pt_b.squeeze(0)) + lse_ref_list.append(lse_b.squeeze(0)) + lse_pt_list.append(lse_pt_b.squeeze(0)) + + cat_dim_o = 0 if (varlen_q) else 0 # always 0: leading token/batch dim + cat_dim_lse = -1 if (varlen_q) else -1 # always last: token dim + + if varlen_q: + O_ref = torch.cat(O_ref_list, dim=0) # (total_q, nheads, hdimv) + O_pt = torch.cat(O_pt_list, dim=0) + lse_ref = torch.cat(lse_ref_list, dim=-1) # (nheads, total_q) + lse_pt = torch.cat(lse_pt_list, dim=-1) + else: + O_ref = torch.stack(O_ref_list, dim=0) # (batch, seqlen_q, nheads, hdimv) + O_pt = torch.stack(O_pt_list, dim=0) + lse_ref = torch.stack(lse_ref_list, dim=0) # (batch, nheads, seqlen_q) + lse_pt = torch.stack(lse_pt_list, dim=0) + + rtol = 2 + atol = 2 * (O_ref + 0.3 - 0.3 - O_ref).abs().max().item() + + # ---- CuTe tensor wrappers ---- + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + mLSE = from_dlpack(lse, assumed_align=4 ).mark_layout_dynamic(leading_dim=lse.ndim - 1) + if index_topk is not None: + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic(leading_dim=index_topk.ndim - 1) + else: + mIndexTopk = None + + run_kwargs = dict(mIndexTopk=mIndexTopk) + if varlen_q: + run_kwargs["mCuSeqlensQ"] = from_dlpack(cu_seqlens_q, assumed_align=4) + if varlen_k: + run_kwargs["mCuSeqlensK"] = from_dlpack(cu_seqlens_k, assumed_align=4) + + # ---- Run kernel ---- + compile_cache[compile_key]( + mQ, mQv, mK, mV, mO, mLSE, softmax_scale, + **run_kwargs, + ) + + print(f"Pytorch max O diff: {(O_pt - O_ref).abs().max().item()}") + print(f"Pytorch mean O diff: {(O_pt - O_ref).abs().mean().item()}") + print(f"Max abs diff O, O_ref: {(O - O_ref).abs().max().item()}") + print(f"Mean abs diff O, O_ref: {(O - O_ref).abs().mean().item()}") + + # print(f"Pytorch LSE max diff: {(lse_pt - lse_ref).abs().max().item()}") + # print(f"Pytorch LSE mean diff: {(lse_pt - lse_ref).abs().mean().item()}") + # print(f"Max abs diff LSE: {(lse - lse_ref).abs().max().item()}") + # print(f"Mean abs diff LSE: {(lse - lse_ref).abs().mean().item()}") + + if validate: + assert (O - O_ref).abs().max().item() <= rtol * (O_pt - O_ref).abs().max().item() + atol + varlen_tag = "" + if varlen_q: varlen_tag += f", total_q:{total_q}" + if varlen_k: varlen_tag += f", total_k:{total_k}" + print( + f"batch:{batch:3d}, nheads:{nheads:3d}, seqlen_q:{seqlen_q:5d}, seqlen_k:{seqlen_k:5d}" + f"{varlen_tag}, iter:{iter:2d} PASSED" + ) + else: + print(mO) + print( + f"batch:{batch:3d}, nheads:{nheads:3d}, seqlen_q:{seqlen_q:5d}, seqlen_k:{seqlen_k:5d}" + f", iter:{iter:2d} RUN (NOT TESTING CORRECTNESS)" + ) + + return None + +def timeit(fn, *args, **kwargs): + # Synchronize before timing + torch.cuda.synchronize() + + # Warmup + for _ in range(10): + fn(*args, **kwargs) + + # Benchmark using PyTorch's Timer + t = benchmark.Timer( + stmt="fn(*args, **kwargs)", globals={"fn": fn, "args": args, "kwargs": kwargs} + ) + + # Time it multiple runs + measurement = t.timeit(20) # 20 repeats + avg_time = measurement.mean # Average time in seconds + + time.sleep(1) + + return avg_time + + +def benchmark_mla_kernel( + batch=1, seqlen_q=2048, seqlen_k=2048, topk_length=2048, nheads=128, hdim=64, hdimv=512, compile_cache=dict(), + gather_kv=True, is_causal=False, disable_bitmask=False, +): + assert hdim == 64, "hdim must be 64" + assert hdimv == 512, "hdimv must be 512" + + qhead_per_kvhead=nheads + nheads_kv=1 + pack_gqa=True + softmax_scale = 1.0 / math.sqrt(hdim + hdimv) + + compile_key = ( + is_causal, + gather_kv, + topk_length if gather_kv else None, + pack_gqa, + qhead_per_kvhead, + nheads_kv, + disable_bitmask, + ) + if compile_key not in compile_cache: + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + index_topk = torch.rand(batch, seqlen_q, topk_length, device="cuda").argsort(dim=-1).to(torch.int32) + + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + if gather_kv: + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic(leading_dim=index_topk.ndim - 1) + else: + mIndexTopk = None + + mLSE = None + + kernel = cute.compile( + FlashAttentionMLAForwardSm100( + is_causal=is_causal, + use_cpasync_load_KV = gather_kv, + topk_length = topk_length if gather_kv else 2048, + is_topk_gather = gather_kv, + pack_gqa=pack_gqa, + qhead_per_kvhead=qhead_per_kvhead, + nheads_kv=nheads_kv, + disable_bitmask=disable_bitmask, + ), + mQ, mQv, mK, mV, mO, mLSE, softmax_scale, + mIndexTopk=mIndexTopk, + ) + compile_cache[compile_key] = kernel + + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + + index_topk = torch.rand(batch, seqlen_q, topk_length, device="cuda").argsort(dim=-1).to(torch.int32) + + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + if gather_kv: + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic(leading_dim=index_topk.ndim - 1) + else: + mIndexTopk = None + mLSE = None + + exec_time_in_s = timeit( + compile_cache[compile_key], + mQ, mQv, mK, mV, mO, mLSE, softmax_scale, + mIndexTopk=mIndexTopk, + ) + + seqlen_k_eff = topk_length if gather_kv else seqlen_k + + FLOPs = 2 * batch * nheads * seqlen_q * seqlen_k_eff * (hdim + 2 * hdimv) + if is_causal and not gather_kv: + FLOPs /= 2 + + TFLOPS = FLOPs / exec_time_in_s / 1e12 + + q_bytes = 2 * batch * nheads * seqlen_q * hdim + qv_bytes = 2 * batch * nheads * seqlen_q * hdimv + k_bytes = 2 * batch * nheads_kv * seqlen_k_eff * hdim + v_bytes = 2 * batch * nheads_kv * seqlen_k_eff * hdimv + o_bytes = 2 * batch * nheads * seqlen_q * hdimv + total_bytes = q_bytes + qv_bytes + k_bytes + v_bytes + o_bytes + TBs = total_bytes / exec_time_in_s / 1e12 + + print( + f"batch: {batch}, seqlen_q: {seqlen_q}, seqlen_k: {seqlen_k}, nheads: {nheads}, -> {exec_time_in_s * 1e3:.2f} ms, {TFLOPS:.2f} TFLOPS, {TBs:.2f} TBs" + ) + + +if __name__ == "__main__": + run_test = True + run_benchmark = True + gather_kv = False + is_causal = True + pack_gqa = True + topk_length = 2048 + varlen_q = False + varlen_k = False + disable_bitmask = True + validate = True + + if run_test: + if not gather_kv: + seqlen_q_test_values = range(1, 4002, 400) + seqlen_k_test_values = range(1, 4002, 400) + else: + seqlen_q_test_values = range(1, 1001, 200) + seqlen_k_test_values = range(topk_length, 9001, 2000) + seqlen_q_test_values = [1] + seqlen_k_test_values = [4096] + nheads_test_values = [128] + batch_test_values = [4] + test_configs = [ + (batch, nheads, seqlen_q, seqlen_k,) + for batch in batch_test_values + for nheads in nheads_test_values + for seqlen_q in seqlen_q_test_values + for seqlen_k in seqlen_k_test_values + ] + iters_per_config = 1 + compile_cache = dict() + print("=" * 40) + print("Testing MLA Kernel") + print("=" * 40) + for config in test_configs: + batch, nheads, seqlen_q, seqlen_k = config + # if is_causal and seqlen_k < seqlen_q: + # continue + for iter in range(iters_per_config): + test_mla_kernel( + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + topk_length=topk_length, + nheads=nheads, + batch=batch, + iter=iter, + compile_cache=compile_cache, + validate=validate, + seed=0, + gather_kv=gather_kv, + pack_gqa=pack_gqa, + is_causal=is_causal, + varlen_q=varlen_q, + varlen_k=varlen_k, + disable_bitmask=disable_bitmask, + ) + if run_benchmark: + if gather_kv: + seqlen_q_benchmark_values = [1] + seqlen_k_benchmark_values = [8192*2] + nheads_benchmark_values = [128] + batch_benchmark_values = [512] + else: + seqlen_q_benchmark_values = [1] + seqlen_k_benchmark_values = [8192*2] + nheads_benchmark_values = [128] + batch_benchmark_values = [512] + seqlen_q_benchmark_values = [4096] + seqlen_k_benchmark_values = [4096] + nheads_benchmark_values = [16] + batch_benchmark_values = [8] + benchmark_configs = [ (batch, nheads, seqlen_q, seqlen_k,) + for batch in batch_benchmark_values + for nheads in nheads_benchmark_values + for seqlen_q in seqlen_q_benchmark_values + for seqlen_k in seqlen_k_benchmark_values + ] + compile_cache=dict() + print("="*40) + print("Benchmarking MLA Kernel") + print("="*40) + for config in benchmark_configs: + batch, nheads, seqlen_q, seqlen_k = config + benchmark_mla_kernel( + batch = batch, + seqlen_q = seqlen_q, + seqlen_k = seqlen_k, + topk_length = topk_length, + nheads = nheads, + gather_kv=gather_kv, + is_causal=is_causal, + disable_bitmask=disable_bitmask, + compile_cache=compile_cache + ) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 5a4085ed4e1..b01376a4214 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -43,6 +43,7 @@ from flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine +from flash_attn.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, @@ -84,6 +85,7 @@ def _get_device_arch(): def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, alignment: int) -> None: """Validate head dimension constraints based on compute capability.""" is_deepseek_shape = head_dim == 192 and head_dim_v == 128 + is_deepseek_mla_absorbed_shape = head_dim == 64 and head_dim_v == 512 is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128 is_sm90_range = 8 <= head_dim <= 256 and 8 <= head_dim_v <= 256 @@ -93,7 +95,7 @@ def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, f"head_dim and head_dim_v must be between 8 and 256 and divisible by {alignment}." ) elif compute_capability in [10, 11]: - assert (is_standard_range or is_deepseek_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( + assert (is_standard_range or is_deepseek_shape or is_deepseek_mla_absorbed_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. " f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek." ) @@ -280,12 +282,14 @@ def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + qv: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, + min_seqlen_k: Optional[int] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, @@ -307,6 +311,7 @@ def _flash_attn_fwd( out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, aux_tensors: Optional[list[torch.Tensor]] = None, + gather_kv_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. @@ -403,7 +408,7 @@ def _flash_attn_fwd( if arch // 10 not in [8, 12]: _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment) if softmax_scale is None: - softmax_scale = 1.0 / math.sqrt(head_dim) + softmax_scale = 1.0 / math.sqrt(head_dim) if qv is None else 1.0 / math.sqrt(head_dim + head_dim_v) if softcap == 0.0: softcap = None qhead_per_kvhead = num_head // num_head_kv @@ -474,11 +479,16 @@ def _flash_attn_fwd( # TODO: fix GQA + SplitKV + non-varlen if pack_gqa and num_splits != 1 and cu_seqlens_q is None: pack_gqa = False + + if pack_gqa and qv is not None and 128 % qhead_per_kvhead != 0: + pack_gqa = False if max_seqlen_q is None: max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q if max_seqlen_k is None: max_seqlen_k = seqlen_k + if cu_seqlens_k is None and seqused_k is None: + min_seqlen_k = seqlen_k seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead if arch // 10 == 10: q_stage = 2 if seqlen_q_packgqa > tile_m else 1 @@ -497,7 +507,7 @@ def _flash_attn_fwd( # SplitKV uses float32 partial output, which doubles the O buffer size # in shared memory, causing OOM for diff-headdim (192, 128) if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1: - if num_n_blocks >= 64: + if num_n_blocks >= 64 and head_dim_v != 512: tile_n = 64 num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) @@ -587,6 +597,40 @@ def _flash_attn_fwd( else: aux_tensor_metadata = None + if qv is not None: + assert arch // 10 in [10, 11], "only support Blackwell arch with qv" + assert qv.shape[:-1] == q.shape[:-1] + assert qv.shape[-1] == head_dim_v + assert head_dim == 64 and head_dim_v == 512, "only support MLA weight absorbed shape with qv" + assert not local, "local not yet supported with qv" + assert page_table is None, "page table not yet supported with qv" + + assert not is_split_kv, "split kv not supported with qv" + assert learnable_sink is None + assert softcap is None + assert score_mod is None + assert mask_mod is None + + qv = maybe_contiguous(qv) + + gather_kv_length = 2048 + sparse_kv = gather_kv_indices is not None + disable_sparse_kv_bitmask = False + if sparse_kv: + assert gather_kv_indices.shape[:-1] == q.shape[:-2] + gather_kv_length = gather_kv_indices.shape[-1] + assert gather_kv_length % 256 == 0 + if min_seqlen_k is None or causal: + disable_sparse_kv_bitmask = False + else: + # seqlen_k_boundary = min_seqlen_k - max_seqlen_q + 1 if causal else min_seqlen_k + seqlen_k_boundary = min_seqlen_k + disable_sparse_kv_bitmask = seqlen_k_boundary >= gather_kv_length + else: + gather_kv_length = None + sparse_kv = None + disable_sparse_kv_bitmask = None + compile_key = ( dtype, head_dim, @@ -620,6 +664,10 @@ def _flash_attn_fwd( mma_pv_is_rs, intra_wg_overlap, requested_use_clc_scheduler, + qv is not None, + gather_kv_length, + sparse_kv, + disable_sparse_kv_bitmask, fa_logging.get_fa_log_level(), ) if compile_key not in _flash_attn_fwd.compile_cache: @@ -659,6 +707,9 @@ def _flash_attn_fwd( if aux_tensors is not None: cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] + qv_tensor = to_cute_tensor(qv) if qv is not None else None + gather_kv_indices_tensor = to_cute_tensor(gather_kv_indices) if gather_kv_indices is not None else None + if arch // 10 == 8: assert page_table is None, "paged KV not supported on SM 8.0" assert not is_split_kv, "SplitKV not supported on SM 8.0" @@ -704,31 +755,44 @@ def _flash_attn_fwd( paged_kv_non_tma=page_size not in [None, tile_n], ) elif arch // 10 in [10, 11]: - fa_fwd = FlashAttentionForwardSm100( - head_dim, - head_dim_v, - qhead_per_kvhead=qhead_per_kvhead, - is_causal=causal, - is_local=local, - is_split_kv=is_split_kv, - pack_gqa=pack_gqa, - m_block_size=tile_m, - n_block_size=tile_n, - q_stage=q_stage, - is_persistent=not causal - and not local - and cu_seqlens_q is None - and seqused_q is None - and not is_split_kv, - score_mod=score_mod, - mask_mod=mask_mod, - has_aux_tensors=aux_tensors is not None, - paged_kv_non_tma=page_size not in [None, tile_n], - is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, - q_subtile_factor=q_subtile_factor, - use_2cta_instrs=use_2cta_instrs, - use_clc_scheduler=requested_use_clc_scheduler, - ) + if qv is not None: + fa_fwd = FlashAttentionMLAForwardSm100( + is_causal=causal, + use_cpasync_load_KV=sparse_kv, + topk_length=gather_kv_length, + is_topk_gather=sparse_kv, + pack_gqa=pack_gqa, + qhead_per_kvhead=qhead_per_kvhead, + nheads_kv=num_head_kv, + is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, + disable_bitmask=disable_sparse_kv_bitmask, + ) + else: + fa_fwd = FlashAttentionForwardSm100( + head_dim, + head_dim_v, + qhead_per_kvhead=qhead_per_kvhead, + is_causal=causal, + is_local=local, + is_split_kv=is_split_kv, + pack_gqa=pack_gqa, + m_block_size=tile_m, + n_block_size=tile_n, + q_stage=q_stage, + is_persistent=not causal + and not local + and cu_seqlens_q is None + and seqused_q is None + and not is_split_kv, + score_mod=score_mod, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + paged_kv_non_tma=page_size not in [None, tile_n], + is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, + q_subtile_factor=q_subtile_factor, + use_2cta_instrs=use_2cta_instrs, + use_clc_scheduler=requested_use_clc_scheduler, + ) elif arch // 10 == 12: # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity assert not use_block_sparsity, "Block sparsity not supported on SM 12.0" @@ -756,51 +820,92 @@ def _flash_attn_fwd( f"Unsupported compute capability: {arch}. Supported: 8.x, 9.x, 10.x, 11.x, 12.x" ) # TODO: check @can_implement - _flash_attn_fwd.compile_cache[compile_key] = cute.compile( - fa_fwd, - q_tensor, - k_tensor, - v_tensor, - o_tensor, - lse_tensor, - softmax_scale, - cu_seqlens_q_tensor, - cu_seqlens_k_tensor, - seqused_q_tensor, - seqused_k_tensor, - page_table_tensor, - window_size_left, - window_size_right, - learnable_sink_tensor, - sparse_tensors, - cute_aux_tensors, - current_stream, - options="--enable-tvm-ffi", - ) + if qv is not None: + _flash_attn_fwd.compile_cache[compile_key] = cute.compile( + fa_fwd, + q_tensor, + qv_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + softmax_scale, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + gather_kv_indices_tensor, + page_table_tensor, + window_size_left, + window_size_right, + current_stream, + options="--enable-tvm-ffi", + ) + else: + _flash_attn_fwd.compile_cache[compile_key] = cute.compile( + fa_fwd, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + softmax_scale, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + page_table_tensor, + window_size_left, + window_size_right, + learnable_sink_tensor, + sparse_tensors, + cute_aux_tensors, + current_stream, + options="--enable-tvm-ffi", + ) # In "fake mode", we will take torch fake tensors as input and the expected behaviors are: # - Use those fake metadata to populate compilation cache # - Return "fake" output tensors, which could be needed in follow-up fake operations # Thus, we skip the actual kernel invocation here. if not is_fake_mode(): - _flash_attn_fwd.compile_cache[compile_key]( - q.detach(), - k.detach(), - v.detach(), - out.detach() if not is_split_kv else out_partial, - lse_partial if is_split_kv else lse, - softmax_scale, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - page_table, - window_size_left, - window_size_right, - learnable_sink, - normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, - aux_tensors, - ) + if qv is not None: + _flash_attn_fwd.compile_cache[compile_key]( + q.detach(), + qv.detach(), + k.detach(), + v.detach(), + out.detach(), + lse, + softmax_scale, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + gather_kv_indices, + page_table, + window_size_left, + window_size_right, + ) + else: + _flash_attn_fwd.compile_cache[compile_key]( + q.detach(), + k.detach(), + v.detach(), + out.detach() if not is_split_kv else out_partial, + lse_partial if is_split_kv else lse, + softmax_scale, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + page_table, + window_size_left, + window_size_right, + learnable_sink, + normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, + aux_tensors, + ) if is_split_kv: _flash_attn_fwd_combine( out_partial, @@ -1580,6 +1685,8 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + qv: Optional[torch.Tensor] = None, + gather_kv_indices: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), @@ -1610,6 +1717,7 @@ def forward( q, k, v, + qv=qv, softmax_scale=softmax_scale, causal=causal, window_size_left=window_size[0], @@ -1621,6 +1729,7 @@ def forward( mask_mod=mask_mod, block_sparse_tensors=block_sparse_tensors, return_lse=return_lse, + gather_kv_indices=gather_kv_indices, ) ctx.save_for_backward(q, k, v, out, lse) ctx.softmax_scale = softmax_scale @@ -1654,7 +1763,7 @@ def backward(ctx, dout, dlse): deterministic=ctx.deterministic, dlse=dlse, ) - return dq, dk, dv, *((None,) * 20) # Extra Nones is fine + return dq, dk, dv, *((None,) * 30) # Extra Nones is fine class FlashAttnVarlenFunc(torch.autograd.Function): @@ -1664,12 +1773,15 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], + qv: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, + min_seqlen_k: Optional[int] = None, + gather_kv_indices: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, @@ -1687,12 +1799,14 @@ def forward( q, k, v, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, + qv=qv, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=seqused_q, + seqused_k=seqused_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, + min_seqlen_k=min_seqlen_k, page_table=page_table, softmax_scale=softmax_scale, causal=causal, @@ -1705,6 +1819,7 @@ def forward( score_mod=score_mod, aux_tensors=aux_tensors, return_lse=return_lse, + gather_kv_indices=gather_kv_indices, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.softmax_scale = softmax_scale @@ -1747,13 +1862,15 @@ def backward(ctx, dout, dlse): dlse=dlse, ) - return dq, dk, dv, *((None,) * 20) + return dq, dk, dv, *((None,) * 30) def flash_attn_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + qv: Optional[torch.Tensor] = None, + gather_kv_indices: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), @@ -1774,6 +1891,8 @@ def flash_attn_func( q, k, v, + qv, + gather_kv_indices, softmax_scale, causal, window_size, @@ -1796,12 +1915,15 @@ def flash_attn_varlen_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + qv: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, + min_seqlen_k: Optional[int] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + gather_kv_indices: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, @@ -1815,16 +1937,33 @@ def flash_attn_varlen_func( aux_tensors: Optional[list] = None, return_lse: bool = False, ): + """ + Explanation of some optional arguments: + + qv: we write the MLA weight absorbed formula as + O = softmax(scale * (Q @ K.T + Qv @ V.T)) @ V + where Q = q_pe, Qv = q_nope, K = pe_cache, V = kv_cache. + + gather_kv_indices: a tensor of shape (batch, seqlen_q, gather_kv_length) or + (total_q, gather_kv_length) if there is cu_seqlens_q. + Currently, only used for topk sparsity with MLA absorption kernel. + + min_seqlen_k: for varlen, specifies the minimum kv sequence length for any batch. + Used with gather_kv_indices to determine if we need oob masking. + """ return FlashAttnVarlenFunc.apply( q, k, v, + qv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, + min_seqlen_k, + gather_kv_indices, page_table, softmax_scale, causal, diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 6b5ca16c6f5..99e7008ab82 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -5,7 +5,7 @@ import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, Uint32, const_expr +from cutlass import Float32, Int32, Uint32, const_expr, Boolean from quack import layout_utils import flash_attn.cute.utils as utils @@ -384,6 +384,8 @@ def apply_mask_sm100( fastdiv_mods=(None, None), head_divmod=None, check_q_boundary: bool = False, + r2p: bool = True, + rBitmask: Optional[cute.Tensor] = None, ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_shape = (self.tile_m, self.tile_n) @@ -397,8 +399,18 @@ def apply_mask_sm100( if n_block < 0: n_block = 0 seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - r2p = True - if const_expr(not mask_causal and not mask_local and mask_mod is None): + + if const_expr(rBitmask is not None): + ncol_packed = const_expr(cute.size(rBitmask.shape[0])) + for i in cutlass.range_constexpr(ncol_packed): + col_start = 32 * i # mask is bit-packed into uint32 + curr_mask_val = rBitmask[i] + for j in cutlass.range_constexpr(32): + curr_col = col_start + j + mask = (curr_mask_val >> j) & 1 + acc_S[curr_col] = acc_S[curr_col] if Boolean(mask) else -Float32.inf + + elif const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): if const_expr(not r2p): for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): diff --git a/flash_attn/cute/named_barrier.py b/flash_attn/cute/named_barrier.py index dd0d1988960..c4536dabd0f 100644 --- a/flash_attn/cute/named_barrier.py +++ b/flash_attn/cute/named_barrier.py @@ -45,3 +45,12 @@ class NamedBarrierBwdSm100(enum.IntEnum): Compute = enum.auto() dQaccReduce = enum.auto() TmemPtr = enum.auto() + + +class NamedBarrierFwdSm100_MLA2CTA(enum.IntEnum): + Epilogue = enum.auto() + TmemPtr = enum.auto() + Cpasync = enum.auto() + Softmax = enum.auto() + SoftmaxStatsFull = enum.auto() + SoftmaxStatsEmpty = enum.auto() diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index eed55a0b721..0565827b601 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -7,7 +7,7 @@ import cutlass import cutlass.cute as cute -from cutlass import Float32 +from cutlass import Float32, Boolean from quack import layout_utils import flash_attn.cute.utils as utils @@ -190,6 +190,37 @@ def create( rescale_threshold=rescale_threshold, ) + @cute.jit + def compute_row_max_local(self, acc_S_row: cute.TensorSSA, is_first: Boolean) -> Float32: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + return row_max_new + + @cute.jit + def update_row_max_from_local( + self, + row_max_new: Float32, + is_first: Boolean, + ) -> Tuple[Float32, Float32]: + if cutlass.const_expr(is_first): + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + else: + row_max_old = self.row_max[0] + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + acc_scale = cute.math.exp2(acc_scale_) + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale = 1.0 + self.row_max[0] = row_max_new + return row_max_safe, acc_scale + @cute.jit def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: if cutlass.const_expr(is_first): diff --git a/flash_attn/cute/testing.py b/flash_attn/cute/testing.py index 6e3c40eb451..6e4bfed1335 100644 --- a/flash_attn/cute/testing.py +++ b/flash_attn/cute/testing.py @@ -91,20 +91,23 @@ def pad_input(hidden_states, indices, batch, seqlen): return rearrange(output, "(b s) ... -> b s ...", b=batch) -def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): +def generate_random_padding_mask( + max_seqlen, batch_size, device, mode="random", zero_lengths=False, min_seqlen=None +): assert mode in ["full", "random", "third"] + min_seqlen = min_seqlen if min_seqlen is not None else 0 if zero_lengths else 1 if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": lengths = torch.randint( - max(0 if zero_lengths else 1, max_seqlen - 20), + max(min_seqlen, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device, ) else: lengths = torch.randint( - max(0 if zero_lengths else 1, max_seqlen // 3), + max(min_seqlen, max_seqlen // 3), max_seqlen + 1, (batch_size, 1), device=device, @@ -343,6 +346,8 @@ def attention_ref( upcast=True, reorder_ops=False, intermediate_dtype=None, + return_lse=False, + gather_kv_indices=None, ): if causal: window_size = (window_size[0], 0) @@ -399,10 +404,21 @@ def attention_ref( local_mask = ( torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask ) + if gather_kv_indices is not None: + batch = q.shape[0] + topk_len = gather_kv_indices.shape[2] + if topk_len < seqlen_k: + topk_index_mask = torch.full( + (batch, seqlen_q, seqlen_k), False, device="cuda" + ).scatter_(-1, gather_kv_indices, True) + scores.masked_fill_(rearrange(~topk_index_mask, "b t s -> b 1 t s"), float("-inf")) if local_mask is not None: scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias + # After all masks are applied, before softmax: + # scores shape: [b, h, t, s] + lse = torch.logsumexp(scores, dim=-1) # [b, h, t] if learnable_sink is None: attention = torch.softmax(scores, dim=-1).to(v.dtype) else: @@ -414,6 +430,8 @@ def attention_ref( normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp( learnable_sink - logits_or_sinks_max ) + # LSE with sink: log(Z) = log(normalizer) + max + lse = (torch.log(normalizer.squeeze(-1)) + logits_or_sinks_max.squeeze(-1)).to(dtype_og) attention = (unnormalized_scores / normalizer).to(v.dtype) if query_padding_mask is not None: attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) @@ -431,6 +449,8 @@ def attention_ref( output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + if return_lse: + return output.to(dtype_og), attention.to(dtype_og), lse.to(dtype_og) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 3ee4bc8bab1..ae57858acd5 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -402,6 +402,7 @@ class Params(ParamsBase): cluster_shape_m: cutlass.Constexpr[int] = 1 scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC lpt: cutlass.Constexpr[bool] = True + use_cluster_idx: cutlass.Constexpr[bool] = True @staticmethod @cute.jit @@ -445,6 +446,7 @@ def create( cluster_shape_m=args.cluster_shape_mn[0], scheduling_mode=scheduling_mode, lpt=args.lpt, + use_cluster_idx=args.use_cluster_idx, ) def __init__( @@ -532,12 +534,19 @@ def clc_work_to_coords(self, work) -> WorkTileInfo: block_idx = block_idx // self.params.cluster_shape_m if const_expr(self.params.lpt): # Longest-processing-time-first: reverse block order - block_idx = self.params.num_block - 1 - block_idx + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + num_block = self.params.num_block // self.params.cluster_shape_m + else: + num_block = self.params.num_block + block_idx = num_block - 1 - block_idx split_idx = Int32(0) if const_expr(self.params.is_split_kv): batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod) else: batch_idx = work.tile_idx[2] + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block_idx = block_idx * self.params.cluster_shape_m + bidx_in_cluster[0] return WorkTileInfo( (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)), work.is_valid_tile, diff --git a/flash_attn/cute/topk_gather_kv.py b/flash_attn/cute/topk_gather_kv.py new file mode 100644 index 00000000000..67169fb5900 --- /dev/null +++ b/flash_attn/cute/topk_gather_kv.py @@ -0,0 +1,274 @@ +from typing import Type, Optional +from dataclasses import dataclass +import operator + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync +from cutlass import Int32, Uint32, const_expr, Boolean + +from flash_attn.cute import utils +from flash_attn.cute.utils import warp_reduce +from quack.cute_dsl_utils import ParamsBase + +import math + + +@dataclass +class CpasyncGatherKVManager(ParamsBase): + mIndexTopk: cute.Tensor + sBitmask: cute.Tensor + + cta_rank_in_cluster: Int32 + thread_idx: Int32 + warp_idx: Int32 + + topk_length: Int32 + seqlen_k_limit: Int32 + tile_n: Int32 + num_threads: cutlass.Constexpr[Int32] + hdim: cutlass.Constexpr[Int32] + hdim_v: cutlass.Constexpr[Int32] + num_hdimv_splits: cutlass.Constexpr[Int32] + cta_group_size: cutlass.Constexpr[Int32] + + gmem_threads_per_row: cutlass.Constexpr[Int32] + topk_indices_per_thread: Int32 + async_copy_elems: Int32 + + gmem_tiled_copy_KV: cute.TiledCopy + gmem_thr_copy_KV: cute.TiledCopy + + rTopk: cute.Tensor + rTopkHalf: cute.Tensor + # for bitmask + rTopk_NonInterleaved: cute.Tensor + + pipeline_bitmask: Optional[pipeline.PipelineAsync] + cpasync_barrier: pipeline.NamedBarrier + + disable_bitmask: cutlass.Constexpr[Boolean] + + @staticmethod + def create( + mIndexTopk: cute.Tensor, + sBitmask: cute.Tensor, + cta_rank_in_cluster: Int32, + thread_idx: Int32, + warp_idx: Int32, + topk_length: Int32, + seqlen_k_limit: Int32, + tile_n: cutlass.Constexpr[Int32], + hdim: cutlass.Constexpr[Int32], + hdim_v: cutlass.Constexpr[Int32], + num_hdimv_splits: cutlass.Constexpr[Int32], + num_threads: cutlass.Constexpr[Int32], + dtype: Type[cutlass.Numeric], + cta_group_size: cutlass.Constexpr[Int32], + pipeline_bitmask: Optional[pipeline.PipelineAsync], + num_stages_bitmask: cutlass.Constexpr[Int32], + cpasync_barrier: pipeline.NamedBarrier, + disable_bitmask: cutlass.Constexpr[Boolean], + ): + assert tile_n % num_threads == 0 + assert num_threads == 128 + assert hdim % 64 == 0 + assert (hdim_v // num_hdimv_splits // cta_group_size) % 64 == 0 + assert num_threads % cute.arch.WARP_SIZE == 0 + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // dtype.width + dtype_bytes = dtype.width // 8 + # assumes hdim is never part of transposed operand + gmem_k_block_size = math.gcd( + hdim, + hdim_v // num_hdimv_splits // cta_group_size, + 128 // dtype_bytes, + ) + assert gmem_k_block_size % async_copy_elems == 0 + gmem_threads_per_row = gmem_k_block_size // async_copy_elems + assert cute.arch.WARP_SIZE % gmem_threads_per_row == 0 + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + dtype, + num_bits_per_copy=universal_copy_bits, + ) + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, async_copy_elems)) + gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout) + gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx) + topk_indices_per_thread = tile_n // num_threads + + rTopk = cute.make_rmem_tensor((topk_indices_per_thread,), Int32) + rTopkHalf = cute.make_rmem_tensor((topk_indices_per_thread,), Int32) + rTopk_NonInterleaved = cute.make_rmem_tensor((topk_indices_per_thread,), Int32) + + return CpasyncGatherKVManager( + mIndexTopk, + sBitmask, + cta_rank_in_cluster, + thread_idx, + warp_idx, + topk_length, + seqlen_k_limit, + tile_n, + num_threads, + hdim, + hdim_v, + num_hdimv_splits, + cta_group_size, + gmem_threads_per_row, + topk_indices_per_thread, + async_copy_elems, + gmem_tiled_copy_KV, + gmem_thr_copy_KV, + rTopk, + rTopkHalf, + rTopk_NonInterleaved, + pipeline_bitmask, + cpasync_barrier, + disable_bitmask, + ) + + @cute.jit + def load_index_topk( + self, + n_block: Int32, + transpose: bool, + ): + entries_per_thread = self.topk_indices_per_thread + rTopk = self.rTopk if const_expr(transpose) else self.rTopkHalf + + for i in cutlass.range_constexpr(entries_per_thread): + row = ( + i * self.num_threads + + (self.thread_idx % self.gmem_threads_per_row) + * (self.num_threads // self.gmem_threads_per_row) + + (self.thread_idx // self.gmem_threads_per_row) + ) + # need this if not offset in load_X + # if const_expr(not transpose): + # row += self.cta_rank_in_cluster * (self.tile_n//self.cta_group_size) + # row = row % self.tile_n + row_idx = n_block * self.tile_n + row + rTopk[i] = self.mIndexTopk[row_idx] + + if const_expr(not transpose and not self.disable_bitmask): + row_non_interleaved = i * self.num_threads + self.thread_idx + row_idx_non_interleaved = n_block * self.tile_n + row_non_interleaved + self.rTopk_NonInterleaved[0] = self.mIndexTopk[row_idx_non_interleaved] + + @cute.jit + def compute_bitmask( + self, + producer_state_bitmask, + ): + lane_idx = cute.arch.lane_idx() + assert cute.size(self.rTopk_NonInterleaved) == 1 + bitmask = Uint32(0) + + # Step 1. Construct per-thread bitmask + topk_idx = self.rTopk_NonInterleaved[0] + is_valid = topk_idx >= 0 and topk_idx < self.seqlen_k_limit + if is_valid: + bitmask = Uint32(1 << lane_idx) + + # Step 2. Warp shuffle bitwise OR = add since indices are exclusive. + bitmask = warp_reduce(bitmask, operator.add) + + self.pipeline_bitmask.producer_acquire(producer_state_bitmask) + # store to smem and sync threads + if lane_idx == 0: + self.sBitmask[self.warp_idx, producer_state_bitmask.index] = bitmask + self.cpasync_barrier.arrive_and_wait() + + self.pipeline_bitmask.producer_commit(producer_state_bitmask) + producer_state_bitmask.advance() + return producer_state_bitmask + + @cute.jit + def compute_X_ptr( + self, + mX: cute.Tensor, + transpose: bool, + ): + entries_per_thread = self.topk_indices_per_thread + tPrXPtr = cute.make_rmem_tensor((entries_per_thread,), cutlass.Int64) + tPrRowValid = cute.make_rmem_tensor((entries_per_thread,), cutlass.Int32) + rTopk = self.rTopk if const_expr(transpose) else self.rTopkHalf + + for i in cutlass.range_constexpr(entries_per_thread): + topk_idx = rTopk[i] + if const_expr(not self.disable_bitmask): + row_valid = topk_idx >= 0 and topk_idx < self.seqlen_k_limit + tPrRowValid[i] = row_valid + if const_expr(not transpose): + tPrXPtr[i] = utils.elem_pointer(mX, (topk_idx, 0)).toint() + else: + tPrXPtr[i] = utils.elem_pointer(mX, (0, topk_idx)).toint() + + return tPrXPtr, tPrRowValid + + @cute.jit + def load_X( + self, + mX: cute.Tensor, + sX: cute.Tensor, + transpose: bool, + K_or_V: str, + ): + assert K_or_V in ("K", "V") + cta_tile_n = self.tile_n if const_expr(transpose) else self.tile_n // self.cta_group_size + head_dim = self.hdim if const_expr(K_or_V == "K") else self.hdim_v // self.num_hdimv_splits + if const_expr(transpose): + head_dim = head_dim // self.cta_group_size + order = (1, 0) if const_expr(transpose) else (0, 1) + + sX_nd_layout = cute.make_ordered_layout((cta_tile_n, head_dim), order=order) + sX_nd = cute.composition(sX, sX_nd_layout) + + cX = cute.make_identity_tensor((cta_tile_n, head_dim)) + tXsX = self.gmem_thr_copy_KV.partition_D(sX_nd) + tXcX = self.gmem_thr_copy_KV.partition_S(cX) + + tPrXPtr, tPrRowValid = self.compute_X_ptr(mX, transpose) + + if const_expr(not transpose): + offset = self.cta_rank_in_cluster * (self.gmem_threads_per_row // self.cta_group_size) + else: + offset = 0 + + for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])): + if const_expr(not self.disable_bitmask): + row_valid = utils.shuffle_sync( + tPrRowValid[m // self.gmem_threads_per_row], + (m + offset) % self.gmem_threads_per_row, + width=self.gmem_threads_per_row, + ) + should_load = cute.make_fragment_like(tXsX[(0, None), m, 0], Boolean) + should_load.fill(Boolean(row_valid)) + x_ptr_i64 = utils.shuffle_sync( + tPrXPtr[m // self.gmem_threads_per_row], + (m + offset) % self.gmem_threads_per_row, + width=self.gmem_threads_per_row, + ) + x_gmem_ptr = cute.make_ptr( + mX.element_type, x_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + mX_cur = cute.make_tensor(x_gmem_ptr, cute.make_layout((head_dim,))) + mX_cur_copy = cute.tiled_divide(mX_cur, (self.async_copy_elems,)) + + for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])): + ki = tXcX[0, 0, k][1] // self.async_copy_elems + mX_cur_copy_ki = mX_cur_copy[None, ki] + tXsX_k = tXsX[None, m, k] + mX_cur_copy_ki = cute.make_tensor(mX_cur_copy_ki.iterator, tXsX_k.layout) + cute.copy( + self.gmem_tiled_copy_KV, + mX_cur_copy_ki, + tXsX_k, + pred=should_load if const_expr(not self.disable_bitmask) else None, + ) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 5f2c7732956..b551c01d7d6 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -136,6 +136,12 @@ def test_flash_attn_output( local = local_enum > 0 if local and causal: pytest.skip() + if has_qv and d != 64: + pytest.skip() + if has_qv and local: + pytest.xfail("has_qv: local not supported yet") + if has_qv and has_learnable_sink: + pytest.xfail("has_qv: learnable sink not supported yet") device = "cuda" # set seed seed = 0 @@ -145,14 +151,20 @@ def test_flash_attn_output( torch.cuda.synchronize() batch_size = 9 if seqlen_k <= 2048 else 2 # batch_size = 2 - nheads = 6 + nheads = 6 if not has_qv else 128 # nheads = 1 - nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + if not has_qv: + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + else: + nheads_kv = nheads if mha_type == "mha" else (8 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] + if has_qv: + assert d == 64 + dv_vals = [512] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): @@ -268,10 +280,10 @@ def test_flash_attn_output( print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # num_splits_vals = [1, 3] - pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False] + pack_gqa_vals = [True] if has_qv else [False, True, None] if not TEST_BWD_ONLY else [False] # SplitKV is not supported for hdim >= 192 # pack_gqa_vals = [False] - num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY and not has_qv else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): # SplitKV not supported on SM90 - skip this iteration if IS_SM90 and num_splits > 1: @@ -282,8 +294,8 @@ def test_flash_attn_output( q, k, v, + qv=qv, causal=causal, - # qv=qv, # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, # attention_chunk=attention_chunk, @@ -1727,3 +1739,674 @@ def test_flash_attn_invalid_head_dim(head_dim): with pytest.raises(AssertionError, match=re.escape(f"(head_dim, head_dim_v)=({head_dim}, {head_dim}) is not supported on SM")): flash_attn_func(q, k, v) + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mqa"]) +@pytest.mark.parametrize("has_learnable_sink", [False]) +@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("local_enum", [0, 1]) +@pytest.mark.parametrize("local_enum", [0]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("nheads", [16, 128]) +@pytest.mark.parametrize("kv_sparsity", [False, True]) +# @pytest.mark.parametrize("kv_sparsity", [True]) +@pytest.mark.parametrize("gather_kv_length", [2048]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (3, 3), + (64, 32), + (64, 128), + (128, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + (1, 8192), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_mla_absorbed( + seqlen_q, + seqlen_k, + d, + nheads, + causal, + local_enum, + softcap, + deterministic, + has_learnable_sink, + mha_type, + dtype, + kv_sparsity, + gather_kv_length, +): + has_qv = True + if not IS_SM100: + pytest.skip() + if kv_sparsity and seqlen_k < gather_kv_length: + seqlen_k += gather_kv_length + local = local_enum > 0 + if local and causal: + pytest.skip() + if local: + pytest.xfail("mla absorbed: local not supported yet") + if kv_sparsity and nheads != 128: + pytest.skip() + device = "cuda" + # set seed + seed = 0 + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.empty_cache() + torch.cuda.synchronize() + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 2 + nheads = 128 + nheads_kv = nheads if mha_type == "mha" else (8 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [512] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = q_ref * softcap / 4 + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + if kv_sparsity: + gather_kv_indices = torch.rand(batch_size, seqlen_q, gather_kv_length, device=device).argsort(dim=-1).to(torch.int32) + else: + gather_kv_indices = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else tuple(random.randrange(0, seqlen_k) for _ in range(2)) + ) + if local_enum == 2: + window_size = (None, -window_size[1]) + elif local_enum == 3: + window_size = (-window_size[0], None) + if local: + print("window size = ", window_size) + # window_size = (-1, -1) if not local else (16, 0) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + gather_kv_indices=gather_kv_indices, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + gather_kv_indices=gather_kv_indices, + ) + + # k_extended = repeat(k_ref, "b s h d -> b s (h k) d", k=nheads // nheads_kv) + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float() + # # if qv is not None: + # # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + if not is_fake_mode(): + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + num_splits_vals = [1] + pack_gqa_vals = [True] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out, lse = flash_attn_func( + q, + k, + v, + qv=qv, + gather_kv_indices=gather_kv_indices, + causal=causal, + # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + softcap=softcap, + learnable_sink=learnable_sink, + pack_gqa=pack_gqa, + num_splits=num_splits, + deterministic=deterministic, + ) + if is_fake_mode(): + # no more flash_attn cutedsl calls for the rest of the loop + # skip data-dependent postprocessing + continue + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + repeats = 1000 + for iter in range(repeats): + out2, lse2 = flash_attn_func( + q, + k, + v, + qv=qv, + gather_kv_indices=gather_kv_indices, + causal=causal, + # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + softcap=softcap, + learnable_sink=learnable_sink, + pack_gqa=pack_gqa, + num_splits=num_splits, + deterministic=deterministic, + ) + # print(f"out max: {out.abs().max().item()}, {iter=}") + # print(f"out vs out2 max diff: {(out - out2).abs().max().item()}, {iter=}") + # print(f"out vs out2 mean diff: {(out - out2).abs().mean().item()}, {iter=}") + assert torch.equal(out, out2), f"non-deterministic with max diff = {(out - out2).abs().max().item()} on {iter=}" + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mqa"]) +@pytest.mark.parametrize("has_learnable_sink", [False]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local_enum", [0]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("add_unused_qkv", [False]) +@pytest.mark.parametrize("kv_sparsity", [False, True]) +# @pytest.mark.parametrize("kv_sparsity", [False]) +@pytest.mark.parametrize("gather_kv_length", [2048]) +@pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("nheads", [16, 128]) +# @pytest.mark.parametrize("nheads", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + # (1, 1), + # (1, 3), + # (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +@pytest.mark.parametrize("varlen_mode", ["random", "full"]) +# @pytest.mark.parametrize("varlen_mode", ["random"]) +@pytest.mark.parametrize( + "zero_lengths_q, zero_lengths_k", + [ + (False, False), + # (True, False), + ], +) +@pytest.mark.parametrize( + "unpad_q, unpad_kv", + [ + (True, True), + (True, False), + (False, False), + (False, True), + ], +) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_mla_absorbed_varlen( + seqlen_q, + seqlen_k, + d, + nheads, + add_unused_qkv, + causal, + local_enum, + softcap, + deterministic, + has_learnable_sink, + mha_type, + dtype, + varlen_mode, + zero_lengths_q, + zero_lengths_k, + unpad_q, + unpad_kv, + kv_sparsity, + gather_kv_length, +): + has_qv = True + if not IS_SM100: + pytest.skip() + if kv_sparsity and seqlen_k < gather_kv_length: + seqlen_k += gather_kv_length + local = local_enum > 0 + if local and causal: + pytest.skip() + if has_qv and local: + pytest.xfail("has_qv: local not supported yet") + if kv_sparsity and nheads != 128: + pytest.skip() + seqlen_q_og = seqlen_q + seqlen_k_og = seqlen_k + if ( + causal or local + ): # Right now reference only supports causal attention with seqlen_k == seqlen_q + seqlen_q = max(seqlen_q_og, seqlen_k_og) + seqlen_k = max(seqlen_q_og, seqlen_k_og) + device = "cuda" + # set seed + seed = seqlen_q + seqlen_k + d + int(causal) * 2 + int(local) + random.seed(seed) + torch.random.manual_seed(seed) + batch_size = 7 if seqlen_q <= 512 else 3 + nheads_kv = nheads if mha_type == "mha" else (8 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [512] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + if kv_sparsity: + gather_kv_indices = torch.rand(batch_size, seqlen_q, gather_kv_length, device=device).argsort(dim=-1).to(torch.int32) + else: + gather_kv_indices = None + + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else tuple(random.randrange(0, seqlen_k) for _ in range(2)) + ) + if local_enum == 2: + window_size = (None, window_size[1]) + elif local_enum == 3: + window_size = (window_size[0], None) + if local: + print("window size = ", window_size) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, + batch_size, + device, + mode=varlen_mode, + zero_lengths=zero_lengths_q, + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, + batch_size, + device, + mode=varlen_mode, + zero_lengths=zero_lengths_k, + min_seqlen=gather_kv_length if kv_sparsity else None, + ) + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + # query_padding_mask[:] = True + # query_unused_mask = None + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + if causal or local: + key_padding_mask = query_padding_mask + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv( + q, + k, + v, + query_padding_mask, + key_padding_mask, + qv=qv, + kvpacked=False, + query_unused_mask=query_unused_mask, + key_unused_mask=key_unused_mask, + ) + # unpad gather_kv_indices + if kv_sparsity: + _, indices_q, _, _, _ = unpad_input( + q, query_padding_mask, query_unused_mask + ) + gather_kv_indices_unpad = rearrange(gather_kv_indices, "b s ... -> (b s) ...")[indices_q] + else: + gather_kv_indices_unpad = None + if unpad_q: + print("cu_seqlens_q = ", cu_seqlens_q) + else: + print("seqused_q = ", seqused_q) + if unpad_kv: + print("cu_seqlens_k = ", cu_seqlens_k) + else: + print("seqused_k = ", seqused_k) + q_unpad, k_unpad, v_unpad = [ + x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) + ] + + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + gather_kv_indices=gather_kv_indices, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + gather_kv_indices=gather_kv_indices, + ) + + if not is_fake_mode(): + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [True] + num_splits_vals = [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + # SplitKV not supported on SM90 - skip this iteration + if IS_SM90 and num_splits > 1: + continue + out_unpad, lse = flash_attn_varlen_func( + q_unpad if unpad_q else q, + k_unpad if unpad_kv else k, + v_unpad if unpad_kv else v, + qv_unpad if unpad_q else qv, + cu_seqlens_q=cu_seqlens_q if unpad_q else None, + cu_seqlens_k=cu_seqlens_k if unpad_kv else None, + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, + min_seqlen_k=gather_kv_length if kv_sparsity else None, + seqused_q=seqused_q if not unpad_q else None, + seqused_k=seqused_k if not unpad_kv else None, + causal=causal, + window_size=window_size, + learnable_sink=learnable_sink, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + deterministic=deterministic, + gather_kv_indices=gather_kv_indices_unpad if unpad_q else gather_kv_indices, + ) + out = output_pad_fn(out_unpad) if unpad_q else out_unpad + if is_fake_mode(): + # no more flash_attn cutedsl calls for the rest of the loop + # skip data-dependent postprocessing + continue + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + # When unpad_q=False with seqused_q, the kernel doesn't write positions + # beyond seqused_q, so those contain uninitialized values. Mask them out + # before comparing. + out_cmp, out_ref_cmp, out_pt_cmp = out, out_ref, out_pt + if not unpad_q and seqused_q is not None: + seqused_mask = torch.arange(seqlen_q, device=device)[None, :] < seqused_q[:, None] + seqused_mask = rearrange(seqused_mask, "b s -> b s 1 1") + out_cmp = out.clone().masked_fill_(~seqused_mask, 0.0) + out_ref_cmp = out_ref.clone().masked_fill_(~seqused_mask, 0.0) + out_pt_cmp = out_pt.clone().masked_fill_(~seqused_mask, 0.0) + print(f"Output max diff: {(out_cmp - out_ref_cmp).abs().max().item()}") + print(f"Output mean diff: {(out_cmp - out_ref_cmp).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out_cmp - out_ref_cmp).abs().max().item() <= rtol * ( + out_pt_cmp - out_ref_cmp + ).abs().max().item() + fwd_atol + + repeats = 1000 + for iter in range(repeats): + out_unpad2, lse = flash_attn_varlen_func( + q_unpad if unpad_q else q, + k_unpad if unpad_kv else k, + v_unpad if unpad_kv else v, + qv_unpad if unpad_q else qv, + cu_seqlens_q=cu_seqlens_q if unpad_q else None, + cu_seqlens_k=cu_seqlens_k if unpad_kv else None, + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, + min_seqlen_k=gather_kv_length if kv_sparsity else None, + seqused_q=seqused_q if not unpad_q else None, + seqused_k=seqused_k if not unpad_kv else None, + causal=causal, + window_size=window_size, + learnable_sink=learnable_sink, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + deterministic=deterministic, + gather_kv_indices=gather_kv_indices_unpad if unpad_q else gather_kv_indices, + ) + out2 = output_pad_fn(out_unpad2) if unpad_q else out_unpad2 + if query_unused_mask is not None: + out2.masked_fill_(q_zero_masking, 0.0) + # When unpad_q=False with seqused_q, the kernel doesn't write positions + # beyond seqused_q, so those contain uninitialized values. Mask them out + # before comparing. + if not unpad_q and seqused_q is not None: + seqused_mask = torch.arange(seqlen_q, device=device)[None, :] < seqused_q[:, None] + seqused_mask = rearrange(seqused_mask, "b s -> b s 1 1") + out2.masked_fill_(~seqused_mask, 0.0) + # print(f"out2 max: {out2.abs().max().item()}, {iter=}") + # print(f"out vs out2 max diff: {(out_cmp - out2).abs().max().item()}, {iter=}") + # print(f"out vs out2 mean diff: {(out_cmp - out2).abs().mean().item()}, {iter=}") + assert torch.equal(out_cmp, out2), f"non-deterministic with max diff = {(out_cmp - out2).abs().max().item()} on {iter=}" \ No newline at end of file From 628452c73a4fab560189a7caa8702642c6a38235 Mon Sep 17 00:00:00 2001 From: jayhshah Date: Tue, 14 Apr 2026 21:23:24 -0700 Subject: [PATCH 13/44] Handle linter for flash mla file (#2459) * fix outstanding ruff check and exclude flash_fwd_mla_sm100.py from ci * add fmt comments for ruff --- .pre-commit-config.yaml | 1 - flash_attn/cute/flash_fwd_mla_sm100.py | 941 +++++++++++++++---------- 2 files changed, 583 insertions(+), 359 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8a4ab4bc7e4..6118dfa2283 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,6 @@ repos: flash_bwd| flash_fwd| flash_fwd_sm100| - flash_fwd_mla_sm100| interface| )\.py$ - id: ruff-format diff --git a/flash_attn/cute/flash_fwd_mla_sm100.py b/flash_attn/cute/flash_fwd_mla_sm100.py index 8ab546326ec..07cd99f71e9 100644 --- a/flash_attn/cute/flash_fwd_mla_sm100.py +++ b/flash_attn/cute/flash_fwd_mla_sm100.py @@ -9,23 +9,19 @@ import cuda.bindings.driver as cuda import cutlass -from cutlass.base_dsl.arch import Arch -from cutlass.cutlass_dsl import BaseDSL import cutlass.cute as cute -from cutlass import Float32, BFloat16, Int64, Int32, Uint32, Boolean, const_expr +from cutlass import Float32, Int64, Int32, Uint32, Boolean, const_expr import cutlass.pipeline as pipeline from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass.cute.runtime import from_dlpack import cutlass.utils.blackwell_helpers as sm100_utils from cutlass.utils import ClcDynamicPersistentTileScheduler -from quack import copy_utils, layout_utils +from quack import copy_utils -from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_attn.cute import utils +from flash_attn.cute.pack_gqa import pack_gqa_layout, make_packgqa_tiled_tma_atom from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.mask import AttentionMask import flash_attn.cute.blackwell_helpers as fa_sm100_utils from flash_attn.cute.softmax import SoftmaxSm100 @@ -35,7 +31,6 @@ TileSchedulerArguments, TileSchedulerProtocol, SingleTileScheduler, - StaticPersistentTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ParamsBase, @@ -51,6 +46,7 @@ from flash_attn.cute.cute_dsl_utils import dump_kernel_attributes + class FlashAttentionMLAForwardSm100: def __init__( self, @@ -84,12 +80,14 @@ def __init__( assert use_cpasync_load_KV # user-provided option if topk indices guaranteed in bounds self.disable_bitmask = disable_bitmask - + # ==== tile scheduler ==== self.is_persistent = False self.use_clc_scheduler = use_clc_scheduler and not is_varlen_q self.sched_stages = 1 - self.scheduling_mode = SchedulingMode.CLC if self.use_clc_scheduler else SchedulingMode.STATIC + self.scheduling_mode = ( + SchedulingMode.CLC if self.use_clc_scheduler else SchedulingMode.STATIC + ) if const_expr(is_varlen_q): self.TileScheduler = SingleTileVarlenScheduler @@ -98,7 +96,10 @@ def __init__( else: self.TileScheduler = SingleTileScheduler - fa_log(1, f"TileScheduler={self.TileScheduler.__name__}, scheduling_mode={self.scheduling_mode.name}") + fa_log( + 1, + f"TileScheduler={self.TileScheduler.__name__}, scheduling_mode={self.scheduling_mode.name}", + ) # ==== thread info ==== self.num_softmax_threads = 128 @@ -119,19 +120,21 @@ def __init__( ) self.num_warps = self.num_threads // 32 assert self.num_warps == 12 or self.num_warps == 16 - self.softmax_warp_indices = (0, 1, 2, 3,) - self.epilogue_warp_indices = (4, 5, 6, 7,) + self.softmax_warp_indices = (0, 1, 2, 3) + self.epilogue_warp_indices = (4, 5, 6, 7) self.load_warp_id = 8 self.mma_warp_id = 9 self.clc_scheduler_warp_id = 10 self.relay_warp_id = 11 self.empty_warp_ids = tuple( - w for w, active in [ - (self.relay_warp_id, not use_cpasync_load_KV), + w + for w, active in [ + (self.relay_warp_id, not use_cpasync_load_KV), (self.clc_scheduler_warp_id, not self.use_clc_scheduler), - ] if active + ] + if active ) - self.cpasync_load_warp_indices = (12, 13, 14, 15,) + self.cpasync_load_warp_indices = (12, 13, 14, 15) # ==== register usage ==== if self.num_warps == 16: @@ -151,15 +154,21 @@ def __init__( self.num_regs_per_thread = 168 if self.num_warps == 12 else 128 self.num_regs_total = 504 if self.num_warps == 12 else 512 - - assert self.num_regs_mma + self.num_regs_softmax + self.num_regs_epilogue + self.num_regs_cpasync <= self.num_regs_total + + assert ( + self.num_regs_mma + + self.num_regs_softmax + + self.num_regs_epilogue + + self.num_regs_cpasync + <= self.num_regs_total + ) # ==== 2cta info ==== self.use_2cta_instrs = True self.cta_group = tcgen05.CtaGroup.TWO self.cta_group_size = 2 - self.cluster_shape_mn = (2, 1,) - self.cluster_shape_mnk = (2, 1, 1,) + self.cluster_shape_mn = (2, 1) + self.cluster_shape_mnk = (2, 1, 1) # ==== problem shape info ==== self.hdim = hdim @@ -175,12 +184,24 @@ def __init__( self.num_hdimv_splits = 2 # split hdimv in half for our Qv @ V^T and P @ V mmas. assert hdimv % 32 == 0 assert self.topk_length % (self.tile_n * 2) == 0 or not self.is_topk_gather - self.epi_tile = (self.cta_tile_m, self.hdimv//self.num_hdimv_splits) + self.epi_tile = (self.cta_tile_m, self.hdimv // self.num_hdimv_splits) # ==== MMA info ==== - self.mma_tiler_QK = (self.cluster_tile_m, self.tile_n, self.hdim,) - self.mma_tiler_QviVi = (self.cluster_tile_m, self.tile_n, self.hdimv//self.num_hdimv_splits,) - self.mma_tiler_PVti = (self.cluster_tile_m, self.hdimv // self.num_hdimv_splits, self.tile_n,) + self.mma_tiler_QK = ( + self.cluster_tile_m, + self.tile_n, + self.hdim, + ) + self.mma_tiler_QviVi = ( + self.cluster_tile_m, + self.tile_n, + self.hdimv // self.num_hdimv_splits, + ) + self.mma_tiler_PVti = ( + self.cluster_tile_m, + self.hdimv // self.num_hdimv_splits, + self.tile_n, + ) self.major_mode_Q = tcgen05.OperandMajorMode.K self.major_mode_Qvi = tcgen05.OperandMajorMode.K self.major_mode_K = tcgen05.OperandMajorMode.K @@ -211,17 +232,20 @@ def __init__( self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS self.tmem_cols_S = self.tile_n // self.cta_group_size self.tmem_cols_Oi = (self.hdimv // self.num_hdimv_splits) // self.cta_group_size - self.tmem_offset_S = [self.tmem_cols_S*stage for stage in range(self.num_stages_S)] # allocate 64 TMEM columns for each stage of S - self.tmem_offset_O0 = self.tmem_cols_S * self.num_stages_S + self.tmem_offset_S = [ + self.tmem_cols_S * stage for stage in range(self.num_stages_S) + ] # allocate 64 TMEM columns for each stage of S + self.tmem_offset_O0 = self.tmem_cols_S * self.num_stages_S self.tmem_offset_O1 = self.tmem_offset_O0 + self.tmem_cols_Oi self.tmem_offsets_O = [self.tmem_offset_O0, self.tmem_offset_O1] self.total_tmem = self.tmem_offset_O1 + self.tmem_cols_Oi - assert self.total_tmem <= self.tmem_alloc_cols, f"Total TMEM columns allocated {self.total_tmem} exceeds capacity {self.tmem_alloc_cols}" - + assert self.total_tmem <= self.tmem_alloc_cols, ( + f"Total TMEM columns allocated {self.total_tmem} exceeds capacity {self.tmem_alloc_cols}" + ) def _get_shared_storage_cls(self): self.buffer_align_bytes = 1024 - + def smem_struct_align(dtype, staged_layout): return cute.struct.Align[ cute.struct.MemRange[dtype, cute.cosize(staged_layout)], @@ -232,30 +256,47 @@ def mbar_struct(num_stages): return cute.struct.MemRange[Int64, 2 * num_stages] (sQ_struct, sK_struct, sQv0_struct, sQv1_struct, sV0_struct, sV1_struct, sP_struct) = ( - smem_struct_align(dtype, layout) for dtype, layout in [ - (self.dtype_Q, self.sQ_layout_staged), - (self.dtype_K, self.sK_layout_staged), + smem_struct_align(dtype, layout) + for dtype, layout in [ + (self.dtype_Q, self.sQ_layout_staged), + (self.dtype_K, self.sK_layout_staged), (self.dtype_Qv, self.sQvi_layout_staged), (self.dtype_Qv, self.sQvi_layout_staged), - (self.dtype_V, self.sVi_layout_staged), - (self.dtype_V, self.sVi_layout_staged), - (self.dtype_P, self.sP_layout_staged), + (self.dtype_V, self.sVi_layout_staged), + (self.dtype_V, self.sVi_layout_staged), + (self.dtype_P, self.sP_layout_staged), ] ) sStats_struct = cute.struct.MemRange[Float32, cute.cosize(self.sStats_layout)] sScale_struct = cute.struct.MemRange[Float32, cute.cosize(self.sScale_layout)] sBitmask_struct = cute.struct.MemRange[Uint32, cute.cosize(self.sBitmask_layout)] - (mbar_ptr_Q_struct, mbar_ptr_K_struct, mbar_ptr_Qv0_struct, mbar_ptr_Qv1_struct, - mbar_ptr_V0_struct, mbar_ptr_V1_struct, mbar_ptr_S_struct, mbar_ptr_P_struct, - mbar_ptr_O0_struct, mbar_ptr_O1_struct, mbar_sm_stats_struct, - mbar_bitmask_struct) = ( - mbar_struct(n) for n in [ - self.num_stages_Q, self.num_stages_K, - self.num_stages_Qvi, self.num_stages_Qvi, - self.num_stages_Vi, self.num_stages_Vi, - self.num_stages_S, self.num_stages_P, - self.num_stages_Oi, self.num_stages_Oi, + ( + mbar_ptr_Q_struct, + mbar_ptr_K_struct, + mbar_ptr_Qv0_struct, + mbar_ptr_Qv1_struct, + mbar_ptr_V0_struct, + mbar_ptr_V1_struct, + mbar_ptr_S_struct, + mbar_ptr_P_struct, + mbar_ptr_O0_struct, + mbar_ptr_O1_struct, + mbar_sm_stats_struct, + mbar_bitmask_struct, + ) = ( + mbar_struct(n) + for n in [ + self.num_stages_Q, + self.num_stages_K, + self.num_stages_Qvi, + self.num_stages_Qvi, + self.num_stages_Vi, + self.num_stages_Vi, + self.num_stages_S, + self.num_stages_P, + self.num_stages_Oi, + self.num_stages_Oi, self.num_stages_sm_stats, self.num_stages_bitmask, ] @@ -305,29 +346,30 @@ class SharedStorage: # print("smem bytes = ", SharedStorage.size_in_bytes()) return SharedStorage - - + @cute.jit def __call__( self, - mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q mQv: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, d) if there is cu_seqlens_q - mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table - mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table - mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q mLSE: Optional[cute.Tensor], # (b, h, s_q) or (h, total_q) if there is cu_seqlens_q softmax_scale: Float32, - mCuSeqlensQ: Optional[cute.Tensor] = None, # (b + 1) - mCuSeqlensK: Optional[cute.Tensor] = None, # (b + 1) - mSeqUsedQ: Optional[cute.Tensor] = None, # (b) - mSeqUsedK: Optional[cute.Tensor] = None, # (b) - mIndexTopk: Optional[cute.Tensor] = None, # (b, s_q, topk) or (total_q, topk) if there is cu_seqlens_q + mCuSeqlensQ: Optional[cute.Tensor] = None, # (b + 1) + mCuSeqlensK: Optional[cute.Tensor] = None, # (b + 1) + mSeqUsedQ: Optional[cute.Tensor] = None, # (b) + mSeqUsedK: Optional[cute.Tensor] = None, # (b) + mIndexTopk: Optional[ + cute.Tensor + ] = None, # (b, s_q, topk) or (total_q, topk) if there is cu_seqlens_q mPageTable: Optional[cute.Tensor] = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, - ): + ): # ==== asserts for unimplemented features ==== assert mPageTable is None, "page table tbd for MLA" @@ -365,15 +407,15 @@ def __call__( # (s_k, dv, h_k, b) -> (dv, s_k, h_k, b) or # (total_k, dv, h_k) -> (dv, total_k, h_k) V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] - mVt = cute.make_tensor( - mV.iterator, cute.select(mV.layout, mode=V_layout_transpose) - ) + mVt = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose)) # (b, h, s_q) -> (s_q, h, b) or (h, total_q) -> (total_q, h) # (b, s_q, topk) -> (topk, s_q, b) or (total_q, topk) -> (topk, total_q) LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] mLSE, mIndexTopk = ( cute.make_tensor(t.iterator, cute.select(t.layout, mode=LSE_layout_transpose)) - if t is not None else None for t in (mLSE, mIndexTopk) + if t is not None + else None + for t in (mLSE, mIndexTopk) ) topk_length_dynamic = mIndexTopk.shape[0] if mIndexTopk is not None else None @@ -393,24 +435,24 @@ def split_hdimv(m, dim: int): and return (slice0, slice1) where slice_i selects chunk i.""" S = self.num_hdimv_splits chunk = self.hdimv // S - split_shape = (*m.shape[:dim], (chunk, S), *m.shape[dim+1:]) - split_stride = (*m.layout.stride[:dim], (1, chunk), *m.layout.stride[dim+1:]) + split_shape = (*m.shape[:dim], (chunk, S), *m.shape[dim + 1 :]) + split_stride = (*m.layout.stride[:dim], (1, chunk), *m.layout.stride[dim + 1 :]) split = cute.make_tensor(m.iterator, cute.make_layout(split_shape, stride=split_stride)) ndim = len(split.shape) slices = [ - split[(*([None] * dim), (None, i), *([None] * (ndim - dim - 1)))] - for i in range(S) + split[(*([None] * dim), (None, i), *([None] * (ndim - dim - 1)))] for i in range(S) ] return slices # (seqlen_q, hdimv//2, nheads, batch) or (total_q, hdimv//2, nheads) mQv0, mQv1 = split_hdimv(mQv, dim=1) - mV0, mV1 = split_hdimv(mV, dim=1) + mV0, mV1 = split_hdimv(mV, dim=1) # (hdimv//2, seqlen_k, nheads_k, batch) or (hdimv//2, total_k, nheads_k) mVt0, mVt1 = split_hdimv(mVt, dim=0) # ==== Prepare MMAs ==== # (local_var, dtype_a, major_a, major_b, mma_tiler, operand_source_a) + # fmt: off _mma_specs = [ ("tiled_mma_QK", self.dtype_Q, self.major_mode_Q, self.major_mode_K, self.mma_tiler_QK, self.operand_source_Q), ("tiled_mma_QviVi", self.dtype_Qv, self.major_mode_Qvi, self.major_mode_Vi, self.mma_tiler_QviVi, self.operand_source_Qvi), @@ -422,9 +464,11 @@ def split_hdimv(m, dim: int): ) for _, dtype_a, major_a, major_b, mma_tiler, operand_source_a in _mma_specs ) + # fmt: on # ==== Prepare SMEM layouts and TMAs ==== # (attr, make_fn, tiled_mma, mma_tiler, dtype, num_stages) + # fmt: off _smem_layout_specs = [ ("sQ_layout", sm100_utils.make_smem_layout_a, tiled_mma_QK, self.mma_tiler_QK, self.dtype_Q, self.num_stages_Q), ("sK_layout", sm100_utils.make_smem_layout_b, tiled_mma_QK, self.mma_tiler_QK, self.dtype_K, self.num_stages_K), @@ -443,11 +487,13 @@ def split_hdimv(m, dim: int): ) setattr(self, f"{attr}_staged", staged) setattr(self, attr, cute.select(staged, mode=[0, 1, 2])) + # fmt: on self.sStats_layout = cute.make_layout((self.cta_tile_m, self.cta_group_size)) self.sScale_layout = cute.make_layout((self.cta_tile_m, self.num_stages_sm_stats)) - self.sBitmask_layout = cute.make_layout((self.tile_n//32, self.num_stages_bitmask)) + self.sBitmask_layout = cute.make_layout((self.tile_n // 32, self.num_stages_bitmask)) + # fmt: off for attr, dtype, layout in [ ("tma_copy_bytes_Q", self.dtype_Q, self.sQ_layout), ("tma_copy_bytes_K", self.dtype_K, self.sK_layout), @@ -455,6 +501,7 @@ def split_hdimv(m, dim: int): ("tma_copy_bytes_Vi", self.dtype_V, self.sVi_layout), ]: setattr(self, attr, cute.size_in_bytes(dtype, layout) * self.cta_group_size) + # fmt: on tma_load_op = cpasync.CopyBulkTensorTileG2SOp(self.cta_group) cta_layout_vmnk = cute.tiled_divide( @@ -468,6 +515,7 @@ def make_tma(make_fn, mX, smem_layout, mma_tiler, tiled_mma): A, B = cute.nvgpu.make_tiled_tma_atom_A, cute.nvgpu.make_tiled_tma_atom_B # (atom_name, tensor_name, make_fn, m, smem_layout, mma_tiler, tiled_mma, kv_only) + # fmt: off _tma_specs = [ ("tma_atom_Q", "tma_tensor_Q", A, mQ, self.sQ_layout, self.mma_tiler_QK, tiled_mma_QK, False), ("tma_atom_Qv0", "tma_tensor_Qv0", A, mQv0, self.sQvi_layout, self.mma_tiler_QviVi, tiled_mma_QviVi, False), @@ -494,6 +542,7 @@ def make_tma(make_fn, mX, smem_layout, mma_tiler, tiled_mma): tma_atom_V1, tma_tensor_V1, tma_atom_Vt0, tma_tensor_Vt0, tma_atom_Vt1, tma_tensor_Vt1) = _tmas.values() + # fmt: on # ==== Set up Oi smem -> gmem tma store ==== @@ -539,7 +588,7 @@ def make_tma(make_fn, mX, smem_layout, mma_tiler, tiled_mma): ) thread_layout_O_r2g = cute.make_layout((64, 2), stride=(1, 64)) value_layout_O_r2g = cute.make_layout( - (1, self.hdimv//self.num_hdimv_splits//self.cta_group_size) + (1, self.hdimv // self.num_hdimv_splits // self.cta_group_size) ) tiled_copy_O_r2g = cute.make_tiled_copy_tv( atom=atom_universal_copy, @@ -553,21 +602,24 @@ def make_tma(make_fn, mX, smem_layout, mma_tiler, tiled_mma): # ==== Tile scheduler ==== TileScheduler = self.TileScheduler - + tile_sched_args = TileSchedulerArguments( num_block=cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tile_m), num_head=cute.size(mQ.shape[2]), num_batch=cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), - num_splits=1, # todo: split_kv - seqlen_k=cute.size(mK.shape[0]), # todo: page table + num_splits=1, # todo: split_kv + seqlen_k=cute.size(mK.shape[0]), # todo: page table headdim=self.hdim, headdim_v=self.hdimv, total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), - tile_shape_mn=(self.cta_tile_m, self.tile_n,), + tile_shape_mn=( + self.cta_tile_m, + self.tile_n, + ), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, @@ -585,7 +637,7 @@ def make_tma(make_fn, mX, smem_layout, mma_tiler, tiled_mma): self.tile_scheduler_cls = TileScheduler grid_dim = TileScheduler.get_grid_shape(tile_sched_params) fa_printf(1, "grid = {}", grid_dim) - + # ==== Named Barrier ==== self.cpasync_barrier = cutlass.pipeline.NamedBarrier( barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.Cpasync), @@ -659,9 +711,13 @@ def make_tma(make_fn, mX, smem_layout, mma_tiler, tiled_mma): SharedStorage, ).launch( grid=grid_dim, - block=(self.num_threads, 1, 1,), + block=( + self.num_threads, + 1, + 1, + ), cluster=self.cluster_shape_mnk, - smem = SharedStorage.size_in_bytes(), + smem=SharedStorage.size_in_bytes(), ) @cute.kernel @@ -715,7 +771,7 @@ def kernel( cta_layout_vmnk = cute.tiled_divide( cute.make_layout(self.cluster_shape_mnk), (tiled_mma_QK.thr_id.shape,) ) - + cta_m_block, head_idx, batch_idx = cute.arch.block_idx() cluster_m_block = cta_m_block // self.cta_group_size mma_tile_coord_v = cta_m_block % cute.size(tiled_mma_QK.thr_id.shape) @@ -753,17 +809,21 @@ def kernel( cpasync.prefetch_descriptor(tma_atom_O) # ==== Construct pipelines ==== - tma_warp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) - mma_warp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) - sm_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_softmax_threads) - epi_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_epilogue_threads) - sm_threads_cluster = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_softmax_threads * self.cta_group_size) - epi_threads_cluster = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_epilogue_threads * self.cta_group_size) - - TmaUmma = pipeline.PipelineTmaUmma + tma_warp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + mma_warp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + sm_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_softmax_threads) + epi_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_epilogue_threads) + sm_threads_cluster = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.num_softmax_threads * self.cta_group_size + ) + epi_threads_cluster = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.num_epilogue_threads * self.cta_group_size + ) + + TmaUmma = pipeline.PipelineTmaUmma AsyncUmma = pipeline.PipelineAsyncUmma UmmaAsync = pipeline.PipelineUmmaAsync - Async = pipeline.PipelineAsync + Async = pipeline.PipelineAsync def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None): return cls.create( @@ -777,6 +837,7 @@ def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None): ) # Unconditional pipelines + # fmt: off pipeline_Q = make_pipeline(TmaUmma, storage.mbar_ptr_Q, self.num_stages_Q, tma_warp, mma_warp, self.tma_copy_bytes_Q) pipeline_Qv0 = make_pipeline(TmaUmma, storage.mbar_ptr_Qv0, self.num_stages_Qvi, tma_warp, mma_warp, self.tma_copy_bytes_Qvi) pipeline_Qv1 = make_pipeline(TmaUmma, storage.mbar_ptr_Qv1, self.num_stages_Qvi, tma_warp, mma_warp, self.tma_copy_bytes_Qvi) @@ -807,7 +868,8 @@ def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None): make_pipeline(Async, storage.mbar_ptr_bitmask, self.num_stages_bitmask, cpasync_load_threads, sm_threads) if const_expr(self.is_topk_gather and not self.disable_bitmask) else None ) - + # fmt: on + sO_empty_mbar_ptr = None if const_expr(self.use_tma_O and self.overlap_sO_sV): sO_empty_mbar_ptr = storage.sO_empty_mbar_ptr @@ -817,6 +879,7 @@ def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None): pipeline.pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) # ==== Get SMEM tensors ==== + # fmt: off sQ, sK, sQv0, sQv1, sV0, sV1, sVt0, sVt1, sP = ( store.get_tensor(layout.outer, swizzle=layout.inner) for store, layout in [ @@ -831,9 +894,10 @@ def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None): (storage.sP, sP_layout_staged), ] ) + # fmt: on sRowMax = storage.sRowMax.get_tensor(sStats_layout) sRowSum = storage.sRowSum.get_tensor(sStats_layout) - sScale = storage.sScale.get_tensor(sScale_layout) + sScale = storage.sScale.get_tensor(sScale_layout) sBitmask = None if const_expr(self.is_topk_gather): sBitmask = storage.sBitmask.get_tensor(sBitmask_layout) @@ -844,7 +908,9 @@ def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None): else: sO_iterator = sQv0.iterator assert cute.cosize(sO_layout) <= cute.cosize(sQvi_layout_staged) * self.num_hdimv_splits - sO = cute.make_tensor(cute.recast_ptr(sO_iterator, sO_layout.inner, self.dtype_O), sO_layout.outer) + sO = cute.make_tensor( + cute.recast_ptr(sO_iterator, sO_layout.inner, self.dtype_O), sO_layout.outer + ) # ==== Get thread MMAs and accumulator fragments ==== thr_mma_QK = tiled_mma_QK.get_slice(mma_tile_coord_v) @@ -888,9 +954,7 @@ def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None): clc_response_ptr = storage.clc_response.data_ptr() clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() - clc_pipeline_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread - ) + clc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_clc_consumer_warps_per_cta = self.num_threads // cute.arch.WARP_SIZE num_clc_consumer_warps = num_clc_consumer_warps_per_cta * self.cta_group_size clc_pipeline_consumer_group = pipeline.CooperativeGroup( @@ -921,7 +985,9 @@ def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None): tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, clc=clc) else: tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) - assert isinstance(tile_scheduler, TileSchedulerProtocol), f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}" + assert isinstance(tile_scheduler, TileSchedulerProtocol), ( + f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}" + ) pipeline.pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) @@ -1143,7 +1209,7 @@ def clc_scheduler_warp( tile_scheduler.prefetch_next_work() work_tile = tile_scheduler.advance_to_next_work() cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx - if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE and cute.arch.block_idx() == (0, 0, 0): + if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE: fa_printf( 3, "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", @@ -1218,7 +1284,8 @@ def relay( # n_block_max = topk_length_dynamic // self.tile_n else: n_block_min, n_block_max = block_info.get_n_block_min_max( - seqlen, cluster_m_block, + seqlen, + cluster_m_block, ) num_n_blocks = n_block_max - n_block_min @@ -1227,15 +1294,19 @@ def relay( consumer_state_K, producer_state_K = relay_K_fn(consumer_state_K, producer_state_K) consumer_state_V0, producer_state_V0 = relay_V0_fn(consumer_state_V0, producer_state_V0) consumer_state_V1, producer_state_V1 = relay_V1_fn(consumer_state_V1, producer_state_V1) - + # ==== Mainloop ==== - for _ in cutlass.range(num_n_blocks-1, unroll=2): + for _ in cutlass.range(num_n_blocks - 1, unroll=2): # relay K, V0, V1, Vt0, Vt1 consumer_state_K, producer_state_K = relay_K_fn(consumer_state_K, producer_state_K) for _ in cutlass.range_constexpr(2): - consumer_state_V0, producer_state_V0 = relay_V0_fn(consumer_state_V0, producer_state_V0) - consumer_state_V1, producer_state_V1 = relay_V1_fn(consumer_state_V1, producer_state_V1) - + consumer_state_V0, producer_state_V0 = relay_V0_fn( + consumer_state_V0, producer_state_V0 + ) + consumer_state_V1, producer_state_V1 = relay_V1_fn( + consumer_state_V1, producer_state_V1 + ) + # ==== Epilogue === # relay Vt0, Vt1 consumer_state_V0, producer_state_V0 = relay_V0_fn(consumer_state_V0, producer_state_V0) @@ -1248,7 +1319,6 @@ def relay( pipeline_V0.producer_tail(producer_state_V0) pipeline_V1.producer_tail(producer_state_V1) - @cute.jit def relay_inner( self, @@ -1264,7 +1334,6 @@ def relay_inner( producer_state.advance() return consumer_state, producer_state - @cute.jit def load_cpasync( self, @@ -1297,13 +1366,15 @@ def load_cpasync( # Description: loads tiles of K, V, V0, V1 from gmem to smem using cpasync # produces: K, V, V0, V1, bitmask # consumes: - - + # TODO: use cpasync for non-topk paged attn assert sBitmask is not None, "cpasync load meant to be used with topk gather" cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) - tidx = cute.arch.thread_idx()[0] % self.num_cpasync_load_threads - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % (self.num_cpasync_load_threads//32) - + tidx = cute.arch.thread_idx()[0] % self.num_cpasync_load_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % ( + self.num_cpasync_load_threads // 32 + ) + # ==== Make pipeline states ==== # producer: acquire PipelineAsyncUmma <- mma # producer: commit PipelineAsync -> relay @@ -1318,7 +1389,8 @@ def load_cpasync( ) if const_expr(not self.disable_bitmask): producer_state_bitmask = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, stages=self.num_stages_bitmask, + pipeline.PipelineUserType.Producer, + stages=self.num_stages_bitmask, ) if const_expr(self.use_tma_O): producer_phase_O = Int32(1) @@ -1338,7 +1410,8 @@ def load_cpasync( # n_block_max = topk_length_dynamic // self.tile_n else: n_block_min, n_block_max = block_info.get_n_block_min_max( - seqlen, cluster_m_block, + seqlen, + cluster_m_block, ) num_n_blocks = n_block_max - n_block_min num_n_block_groups = cute.ceil_div(num_n_blocks, self.num_stages_S) @@ -1350,7 +1423,7 @@ def load_cpasync( else: offset_q = seqlen.offset_q mIndexTopk_cur = mIndexTopk[None, m_idx + offset_q] - + if const_expr(self.is_causal): seqlen_k_limit = m_idx + 1 + seqlen.seqlen_k - seqlen.seqlen_q else: @@ -1389,28 +1462,62 @@ def load_cpasync( mVt1_cur = cute.domain_offset((0, seqlen.offset_k), mVt1[None, None, head_idx_kv]) # (hdimv//4, seqlen_k) hdimv_split_per_cta = self.hdimv // self.num_hdimv_splits // self.cta_group_size - mVt0_cur = cute.tiled_divide(mVt0_cur, (hdimv_split_per_cta, ))[None, cta_rank_in_cluster, None] - mVt1_cur = cute.tiled_divide(mVt1_cur, (hdimv_split_per_cta, ))[None, cta_rank_in_cluster, None] + mVt0_cur = cute.tiled_divide(mVt0_cur, (hdimv_split_per_cta,))[ + None, cta_rank_in_cluster, None + ] + mVt1_cur = cute.tiled_divide(mVt1_cur, (hdimv_split_per_cta,))[ + None, cta_rank_in_cluster, None + ] - load_K = partial(self.cpasync_gather_load_KV, + load_K = partial( + self.cpasync_gather_load_KV, cpasync_gather_kv_manager, - pipeline_K, pipeline_K_cpasync, sK, False, "K", mK_cur, + pipeline_K, + pipeline_K_cpasync, + sK, + False, + "K", + mK_cur, ) - load_V0 = partial(self.cpasync_gather_load_KV, + load_V0 = partial( + self.cpasync_gather_load_KV, cpasync_gather_kv_manager, - pipeline_V0, pipeline_V0_cpasync, sV0, False, "V", mV0_cur, + pipeline_V0, + pipeline_V0_cpasync, + sV0, + False, + "V", + mV0_cur, ) - load_V1 = partial(self.cpasync_gather_load_KV, + load_V1 = partial( + self.cpasync_gather_load_KV, cpasync_gather_kv_manager, - pipeline_V1, pipeline_V1_cpasync, sV1, False, "V", mV1_cur, + pipeline_V1, + pipeline_V1_cpasync, + sV1, + False, + "V", + mV1_cur, ) - load_Vt0 = partial(self.cpasync_gather_load_KV, + load_Vt0 = partial( + self.cpasync_gather_load_KV, cpasync_gather_kv_manager, - pipeline_V0, pipeline_V0_cpasync, sVt0, True, "V", mVt0_cur, + pipeline_V0, + pipeline_V0_cpasync, + sVt0, + True, + "V", + mVt0_cur, ) - load_Vt1 = partial(self.cpasync_gather_load_KV, + load_Vt1 = partial( + self.cpasync_gather_load_KV, cpasync_gather_kv_manager, - pipeline_V1, pipeline_V1_cpasync, sVt1, True, "V", mVt1_cur, + pipeline_V1, + pipeline_V1_cpasync, + sVt1, + True, + "V", + mVt1_cur, ) # gather KV path processes n_blocks in increasing order @@ -1432,7 +1539,7 @@ def load_cpasync( producer_phase_O ^= 1 # ==== Mainloop ==== - for n_block_group in cutlass.range(num_n_block_groups-1, unroll=1): + for n_block_group in cutlass.range(num_n_block_groups - 1, unroll=1): for stage in cutlass.range_constexpr(self.num_stages_S): n_block = n_block_group * self.num_stages_S + stage # K, V0, V1 @@ -1451,7 +1558,7 @@ def load_cpasync( # ==== Epilogue ==== for stage in cutlass.range_constexpr(self.num_stages_S): - n_block = (num_n_block_groups-1) * self.num_stages_S + stage + n_block = (num_n_block_groups - 1) * self.num_stages_S + stage if const_expr(stage == 0): # K, V0, V1 cpasync_gather_kv_manager.load_index_topk(n_block + 1, transpose=False) @@ -1462,7 +1569,7 @@ def load_cpasync( producer_state_bitmask = cpasync_gather_kv_manager.compute_bitmask( producer_state_bitmask ) - + # Vt0, Vt1 cpasync_gather_kv_manager.load_index_topk(n_block, transpose=True) producer_state_V0 = load_Vt0(producer_state_V0) @@ -1470,14 +1577,13 @@ def load_cpasync( # Advance to next tile work_tile = tile_scheduler.advance_to_next_work() - + pipeline_K_cpasync.producer_tail(producer_state_K) pipeline_V0_cpasync.producer_tail(producer_state_V0) pipeline_V1_cpasync.producer_tail(producer_state_V1) if const_expr(not self.disable_bitmask): pipeline_bitmask.producer_tail(producer_state_bitmask) - @cute.jit def cpasync_gather_load_KV( self, @@ -1492,14 +1598,11 @@ def cpasync_gather_load_KV( ): stage, phase = producer_state.index, producer_state.phase pipeline_mma.producer_acquire(producer_state) - cpasync_gather_kv_manager.load_X( - mX, sX[None, None, None, stage], transpose, K_or_V - ) + cpasync_gather_kv_manager.load_X(mX, sX[None, None, None, stage], transpose, K_or_V) cute.arch.cp_async_commit_group() pipeline_cpasync.sync_object_full.arrive_cp_async_mbarrier(stage) producer_state.advance() return producer_state - @cute.jit def load( @@ -1562,7 +1665,7 @@ def load( pipeline_Qvs = [pipeline_Qv0, pipeline_Qv1] pipeline_Vs = [pipeline_V0, pipeline_V1] - + # ==== Make pipeline states ==== producer_state_Q = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, stages=self.num_stages_Q @@ -1601,7 +1704,8 @@ def load( # n_block_max = topk_length_dynamic // self.tile_n else: n_block_min, n_block_max = block_info.get_n_block_min_max( - seqlen, cluster_m_block, + seqlen, + cluster_m_block, ) num_n_blocks = n_block_max - n_block_min even_n_blocks = num_n_blocks % 2 == 0 and num_n_blocks > 0 @@ -1616,21 +1720,21 @@ def load( ] # (mma_tile_m, hdim or hdimv//2) gQ = cute.local_tile( - mQ_cur, + mQ_cur, (self.mma_tiler_QK[0], self.mma_tiler_QK[2]), (cluster_m_block, 0), - ) + ) gQvs = [ cute.local_tile( mQvs_cur[split], (self.mma_tiler_QviVi[0], self.mma_tiler_QviVi[2]), (cluster_m_block, 0), - ) for split in range(self.num_hdimv_splits) + ) + for split in range(self.num_hdimv_splits) ] tSgQ = thr_mma_QK.partition_A(gQ) tSgQvs = [ - thr_mma_QviVi.partition_A(gQvs[split]) - for split in range(self.num_hdimv_splits) + thr_mma_QviVi.partition_A(gQvs[split]) for split in range(self.num_hdimv_splits) ] tQsQ, tQgQ = cpasync.tma_partition( atom=tma_atom_Q, @@ -1639,16 +1743,18 @@ def load( smem_tensor=cute.group_modes(sQ, 0, 3), gmem_tensor=cute.group_modes(tSgQ, 0, 3), ) - tQvsQvs, tQvgQvs = zip(*[ - cpasync.tma_partition( - atom=tma_atom, - cta_coord=0, - cta_layout=cute.make_layout(1), - smem_tensor=cute.group_modes(sQv, 0, 3), - gmem_tensor=cute.group_modes(tSgQv, 0, 3), - ) - for tma_atom, sQv, tSgQv in zip(tma_atom_Qvs, sQvs, tSgQvs) - ]) + tQvsQvs, tQvgQvs = zip( + *[ + cpasync.tma_partition( + atom=tma_atom, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sQv, 0, 3), + gmem_tensor=cute.group_modes(tSgQv, 0, 3), + ) + for tma_atom, sQv, tSgQv in zip(tma_atom_Qvs, sQvs, tSgQvs) + ] + ) if const_expr(self.use_tma_KV): # (seqlen_k, hdim) or (seqlen_k, hdimv//2) @@ -1665,7 +1771,9 @@ def load( ] else: mVts_cur = [ - cute.domain_offset((0, seqlen.offset_k), mVts[split][None, None, head_idx_kv]) + cute.domain_offset( + (0, seqlen.offset_k), mVts[split][None, None, head_idx_kv] + ) for split in range(self.num_hdimv_splits) ] # (tile_n, hdim or hdimv//2, num_n_blocks) @@ -1673,13 +1781,14 @@ def load( mK_cur, (self.mma_tiler_QK[1], self.mma_tiler_QK[2]), (None, 0), - ) + ) gVs = [ cute.local_tile( mVs_cur[split], (self.mma_tiler_QviVi[1], self.mma_tiler_QviVi[2]), (None, 0), - ) for split in range(self.num_hdimv_splits) + ) + for split in range(self.num_hdimv_splits) ] # (hdim or hdimv//2, tile_n, num_n_blocks) gVts = [ @@ -1687,16 +1796,15 @@ def load( mVts_cur[split], (self.mma_tiler_PVti[1], self.mma_tiler_PVti[2]), (0, None), - ) for split in range(self.num_hdimv_splits) + ) + for split in range(self.num_hdimv_splits) ] tSgK = thr_mma_QK.partition_B(gK) tSgVs = [ - thr_mma_QviVi.partition_B(gVs[split]) - for split in range(self.num_hdimv_splits) + thr_mma_QviVi.partition_B(gVs[split]) for split in range(self.num_hdimv_splits) ] tOgVts = [ - thr_mma_PVti.partition_B(gVts[split]) - for split in range(self.num_hdimv_splits) + thr_mma_PVti.partition_B(gVts[split]) for split in range(self.num_hdimv_splits) ] tKsK, tKgK = cpasync.tma_partition( atom=tma_atom_K, @@ -1705,26 +1813,30 @@ def load( smem_tensor=cute.group_modes(sK, 0, 3), gmem_tensor=cute.group_modes(tSgK, 0, 3), ) - tVsVs, tVgVs = zip(*[ - cpasync.tma_partition( - atom=tma_atom, - cta_coord=0, - cta_layout=cute.make_layout(1), - smem_tensor=cute.group_modes(sV, 0, 3), - gmem_tensor=cute.group_modes(tSgV, 0, 3), - ) - for tma_atom, sV, tSgV in zip(tma_atom_Vs, sVs, tSgVs) - ]) - tVtsVts, tVtgVts = zip(*[ - cpasync.tma_partition( - atom=tma_atom, - cta_coord=0, - cta_layout=cute.make_layout(1), - smem_tensor=cute.group_modes(sVt, 0, 3), - gmem_tensor=cute.group_modes(tOgV, 0, 3), - ) - for tma_atom, sVt, tOgV in zip(tma_atom_Vts, sVts, tOgVts) - ]) + tVsVs, tVgVs = zip( + *[ + cpasync.tma_partition( + atom=tma_atom, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sV, 0, 3), + gmem_tensor=cute.group_modes(tSgV, 0, 3), + ) + for tma_atom, sV, tSgV in zip(tma_atom_Vs, sVs, tSgVs) + ] + ) + tVtsVts, tVtgVts = zip( + *[ + cpasync.tma_partition( + atom=tma_atom, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sVt, 0, 3), + gmem_tensor=cute.group_modes(tOgV, 0, 3), + ) + for tma_atom, sVt, tOgV in zip(tma_atom_Vts, sVts, tOgVts) + ] + ) load_Q = partial(self.load_inner, tma_atom_Q, tQgQ, tQsQ, pipeline_Q) load_Qv = partial(self.load_inner, tma_atom_Qvs, tQvgQvs, tQvsQvs, pipeline_Qvs) @@ -1754,28 +1866,28 @@ def load( producer_phase_O ^= 1 # ==== Main loop ==== - for n_block_group in cutlass.range(num_n_block_groups-1, unroll=1): + for n_block_group in cutlass.range(num_n_block_groups - 1, unroll=1): for stage in cutlass.range_constexpr(self.num_stages_S): n_block = n_block_max - 1 - n_block_group * self.num_stages_S - stage # copy K gmem -> smem - producer_state_K = load_K(producer_state_K, n_block=n_block-1) + producer_state_K = load_K(producer_state_K, n_block=n_block - 1) # copy Vi gmem -> smem - producer_state_V0 = load_V(producer_state_V0, n_block=n_block-1, split=0) - producer_state_V1 = load_V(producer_state_V1, n_block=n_block-1, split=1) + producer_state_V0 = load_V(producer_state_V0, n_block=n_block - 1, split=0) + producer_state_V1 = load_V(producer_state_V1, n_block=n_block - 1, split=1) # copy Vti gmem -> smem producer_state_V0 = load_Vt(producer_state_V0, n_block=n_block, split=0) producer_state_V1 = load_Vt(producer_state_V1, n_block=n_block, split=1) - + # ==== Epilogue ==== num_final_n_blocks = self.num_stages_S if even_n_blocks else self.num_stages_S - 1 for stage in cutlass.range(num_final_n_blocks, unroll_full=True): n_block = num_final_n_blocks - 1 - stage if n_block > 0: # copy K gmem -> smem - producer_state_K = load_K(producer_state_K, n_block=n_block-1) + producer_state_K = load_K(producer_state_K, n_block=n_block - 1) # copy Vi gmem -> smem - producer_state_V0 = load_V(producer_state_V0, n_block=n_block-1, split=0) - producer_state_V1 = load_V(producer_state_V1, n_block=n_block-1, split=1) + producer_state_V0 = load_V(producer_state_V0, n_block=n_block - 1, split=0) + producer_state_V1 = load_V(producer_state_V1, n_block=n_block - 1, split=1) # copy Vti gmem -> smem producer_state_V0 = load_Vt(producer_state_V0, n_block=n_block, split=0) producer_state_V1 = load_Vt(producer_state_V1, n_block=n_block, split=1) @@ -1872,19 +1984,16 @@ def mma( # Operands for S += Qv @ V^T tSrQvs = [ - tiled_mma_QviVi.make_fragment_A(sQvs[split]) - for split in range(self.num_hdimv_splits) + tiled_mma_QviVi.make_fragment_A(sQvs[split]) for split in range(self.num_hdimv_splits) ] tSrVs = [ - tiled_mma_QviVi.make_fragment_B(sVs[split]) - for split in range(self.num_hdimv_splits) + tiled_mma_QviVi.make_fragment_B(sVs[split]) for split in range(self.num_hdimv_splits) ] # Operands for Oi = P @ Vi tOrP = tiled_mma_PVti.make_fragment_A(sP) tOrVts = [ - tiled_mma_PVti.make_fragment_B(sVts[split]) - for split in range(self.num_hdimv_splits) + tiled_mma_PVti.make_fragment_B(sVts[split]) for split in range(self.num_hdimv_splits) ] # GEMM functions @@ -1897,7 +2006,8 @@ def mma( sA=sQ[None, None, None, 0], zero_init=True, cta_group=self.cta_group_size, - ) for stage in range(self.num_stages_S) + ) + for stage in range(self.num_stages_S) ] gemms_QvV = [ [ @@ -1909,8 +2019,10 @@ def mma( sA=sQvs[split][None, None, None, 0], zero_init=False, cta_group=self.cta_group_size, - ) for stage in range(self.num_stages_S) - ] for split in range(self.num_hdimv_splits) + ) + for stage in range(self.num_stages_S) + ] + for split in range(self.num_hdimv_splits) ] gemms_PVt = [ partial( @@ -1920,7 +2032,8 @@ def mma( tOrP[None, None, None, 0], sA=sP[None, None, None, 0], cta_group=self.cta_group_size, - ) for split in range(self.num_hdimv_splits) + ) + for split in range(self.num_hdimv_splits) ] consumer_state_Q = pipeline.make_pipeline_state( @@ -1971,7 +2084,8 @@ def mma( n_block_max = topk_length_dynamic // self.tile_n else: n_block_min, n_block_max = block_info.get_n_block_min_max( - seqlen, cluster_m_block, + seqlen, + cluster_m_block, ) num_n_blocks = n_block_max - n_block_min even_n_blocks = num_n_blocks % 2 == 0 and num_n_blocks > 0 @@ -1998,7 +2112,7 @@ def mma( producer_state_S.advance() # ==== Mainloop ==== - for _ in cutlass.range(num_n_block_groups-1, unroll=1): + for _ in cutlass.range(num_n_block_groups - 1, unroll=1): for stage in cutlass.range_constexpr(self.num_stages_S): next_stage = const_expr((stage + 1) % self.num_stages_S) pipeline_S.producer_acquire(producer_state_S) @@ -2019,7 +2133,7 @@ def mma( consumer_states_V[split] = mma_PVt( consumer_states_V[split], split=split, - zero_init=not O_should_accumulate + zero_init=not O_should_accumulate, ) pipelines_O[split].producer_commit(producer_state_Oi) producer_state_Oi.advance() @@ -2027,7 +2141,7 @@ def mma( pipeline_P.consumer_release(consumer_state_P) consumer_state_P.advance() O_should_accumulate = True - + # ==== Epilogue ==== num_final_n_blocks = self.num_stages_S if even_n_blocks else self.num_stages_S - 1 for stage in cutlass.range_constexpr(self.num_stages_S): @@ -2036,11 +2150,11 @@ def mma( if n_block > 0: pipeline_S.producer_acquire(producer_state_S) # S = Q @ K^T - consumer_state_K = mma_QK(consumer_state_K, stage=stage+1) + consumer_state_K = mma_QK(consumer_state_K, stage=stage + 1) # S += Qvi @ Vi^T for split in cutlass.range_constexpr(self.num_hdimv_splits): consumer_states_V[split] = mma_QvV( - consumer_states_V[split], stage=stage+1, split=split + consumer_states_V[split], stage=stage + 1, split=split ) pipeline_S.producer_commit(producer_state_S) producer_state_S.advance() @@ -2053,7 +2167,7 @@ def mma( consumer_states_V[split] = mma_PVt( consumer_states_V[split], split=split, - zero_init=not O_should_accumulate + zero_init=not O_should_accumulate, ) pipelines_O[split].producer_commit(producer_state_Oi) producer_state_Oi.advance() @@ -2071,7 +2185,7 @@ def mma( if const_expr(self.use_tma_O and not self.overlap_sO_sV): pipeline_O0.producer_tail(producer_state_O0.clone()) pipeline_O1.producer_tail(producer_state_O1.clone()) - + pipeline_Qv0.consumer_release(consumer_state_Qv0) pipeline_Qv1.consumer_release(consumer_state_Qv1) consumer_state_Q.advance() @@ -2118,7 +2232,6 @@ def mma_inner( load_pipeline.consumer_release(consumer_state) consumer_state.advance() return consumer_state - @cute.jit def softmax_loop( @@ -2150,11 +2263,13 @@ def softmax_loop( # Consumes: S, bitmask (for topk sparsity) tidx = cute.arch.thread_idx()[0] % self.num_softmax_threads - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % (self.num_softmax_threads//32) + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % ( + self.num_softmax_threads // 32 + ) tSAcc = tStS[(None, None), 0, 0, 0] tSAcc_staged = [tStS[(None, None), 0, 0, stage] for stage in range(self.num_stages_S)] - + cS = cute.make_identity_tensor(self.mma_tiler_QK[:2]) # (128, 128) tScS = thr_mma_QK.partition_C(cS)[(None, None), 0, 0] # (64, 128) @@ -2167,7 +2282,9 @@ def softmax_loop( tmem_load_thr = tmem_load_tiled.get_slice(tidx) # S tmem -> rmem copy operands tStS_t2r = tmem_load_thr.partition_S(tSAcc) # (((32, 32), 1), 1, 2) - tStS_t2r_staged = [tmem_load_thr.partition_S(tSAcc_staged[stage]) for stage in range(self.num_stages_S)] + tStS_t2r_staged = [ + tmem_load_thr.partition_S(tSAcc_staged[stage]) for stage in range(self.num_stages_S) + ] tScS_t2r = tmem_load_thr.partition_D(tScS) tSrS_t2r = cute.make_rmem_tensor(tScS_t2r.shape, self.dtype_acc) @@ -2223,7 +2340,8 @@ def softmax_loop( # n_block_max = topk_length_dynamic // self.tile_n else: n_block_min, n_block_max = block_info.get_n_block_min_max( - seqlen, cluster_m_block, + seqlen, + cluster_m_block, ) num_n_blocks = n_block_max - n_block_min even_n_blocks = num_n_blocks % 2 == 0 and num_n_blocks > 0 @@ -2239,7 +2357,7 @@ def softmax_loop( mask_local=self.is_local, batch_idx=batch_idx, head_idx=head_idx, - r2p=False, # TODO: fix r2p for 2cta + r2p=False, # TODO: fix r2p for 2cta ) disable_mask = self.disable_bitmask and self.is_topk_gather @@ -2271,12 +2389,21 @@ def softmax_loop( ### first iteration ### n_block = n_block_max - 1 - (consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask - ) = softmax_step_fn( - consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask, - 0, n_block, + ( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + ) = softmax_step_fn( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + 0, + n_block, mask_fn=partial(mask_fn, mask_seqlen=True) - if not const_expr(disable_mask) else None, + if not const_expr(disable_mask) + else None, is_first=True, ) n_block -= 1 @@ -2289,46 +2416,71 @@ def softmax_loop( ) num_masked_n_blocks = n_block_max - 1 - n_block_min_causal_local_mask num_masked_n_block_groups = min( - num_n_block_groups - 1, - cute.ceil_div(num_masked_n_blocks, self.num_stages_S) + num_n_block_groups - 1, cute.ceil_div(num_masked_n_blocks, self.num_stages_S) ) num_n_block_groups -= num_masked_n_block_groups for _ in cutlass.range(num_masked_n_block_groups, unroll=1): for stage in cutlass.range_constexpr(self.num_stages_S): - (consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask - ) = softmax_step_fn( - consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask, - 1-stage, n_block, + ( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + ) = softmax_step_fn( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + 1 - stage, + n_block, mask_fn=partial(mask_fn, mask_seqlen=False), ) n_block -= 1 ### Mainloop ### - for n_block_group in cutlass.range(num_n_block_groups-1, unroll=1): + for n_block_group in cutlass.range(num_n_block_groups - 1, unroll=1): for stage in cutlass.range_constexpr(self.num_stages_S): - (consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask - ) = softmax_step_fn( - consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask, - 1-stage, n_block, + ( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + ) = softmax_step_fn( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + 1 - stage, + n_block, mask_fn=partial(mask_fn, mask_seqlen=False) - if const_expr(self.is_topk_gather and not self.disable_bitmask) else None, + if const_expr(self.is_topk_gather and not self.disable_bitmask) + else None, ) n_block -= 1 ### last iteration if even ### # always mask to simplify logic if even_n_blocks: - (consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask - ) = softmax_step_fn( - consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask, - 1, n_block, + ( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + ) = softmax_step_fn( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + 1, + n_block, mask_fn=partial(mask_fn, mask_seqlen=False) - if not const_expr(disable_mask) else None, + if not const_expr(disable_mask) + else None, ) n_block -= 1 - + # write row max and sum to smem - sRowSum[tidx % self.cta_tile_m, warp_idx//self.cta_group_size] = softmax.row_sum[0] + sRowSum[tidx % self.cta_tile_m, warp_idx // self.cta_group_size] = softmax.row_sum[0] if const_expr(mLSE is not None): if tidx < self.cta_tile_m: sRowMax[tidx, 0] = softmax.row_max[0] @@ -2341,7 +2493,6 @@ def softmax_loop( pipeline_P.producer_tail(producer_state_P) pipeline_sm_stats.producer_tail(producer_state_sm_stats) - @cute.jit def softmax_step( self, @@ -2382,8 +2533,8 @@ def softmax_step( assert pipeline_bitmask is not None assert consumer_state_bitmask is not None pipeline_bitmask.consumer_wait(consumer_state_bitmask) - rBitmask = cute.make_rmem_tensor((self.tile_n//64, ), dtype=Uint32) - bitmask_col_offset = self.tile_n//64 if warp_idx >= 2 else 0 + rBitmask = cute.make_rmem_tensor((self.tile_n // 64,), dtype=Uint32) + bitmask_col_offset = self.tile_n // 64 if warp_idx >= 2 else 0 for i in cutlass.range_constexpr(cute.size(rBitmask)): rBitmask[i] = sBitmask[bitmask_col_offset + i, consumer_state_bitmask.index] @@ -2396,7 +2547,7 @@ def softmax_step( # 2-thread reduce row_max through smem assert self.cta_tile_m * self.cta_group_size == 128 - sRowMax[tidx % self.cta_tile_m, warp_idx//self.cta_group_size] = row_max + sRowMax[tidx % self.cta_tile_m, warp_idx // self.cta_group_size] = row_max self.softmax_barrier.arrive_and_wait() # must release after barrier sync if const_expr(self.is_topk_gather and not self.disable_bitmask): @@ -2412,7 +2563,7 @@ def softmax_step( if warp_idx < self.cta_group_size: sScale[tidx % self.cta_tile_m, producer_state_sm_stats.index] = acc_scale pipeline_sm_stats.producer_commit(producer_state_sm_stats) - + # x -> scale_log2*x-rowmax softmax.scale_subtract_rowmax(tSrS_t2r, row_max) @@ -2430,11 +2581,10 @@ def softmax_step( if const_expr(self.is_topk_gather and not self.disable_bitmask): consumer_state_bitmask.advance() - softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) + softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) return consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask - @cute.jit def correction_loop( self, @@ -2464,11 +2614,13 @@ def correction_loop( # optionally store LSE # Produces: - # Consumes: O, softmax stats - + tidx = cute.arch.thread_idx()[0] % self.num_epilogue_threads - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % (self.num_epilogue_threads//32) + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % ( + self.num_epilogue_threads // 32 + ) cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) - leader_warp = warp_idx==0 + leader_warp = warp_idx == 0 tO0tO0 = tO0tO0[(None, None), 0, 0] # (64, (128, 2)) tO1tO1 = tO1tO1[(None, None), 0, 0] # (64, (128, 2)) @@ -2476,7 +2628,7 @@ def correction_loop( # tuneable parameter corr_tile_size = math.gcd(32, self.tmem_cols_Oi) - + tmem_load_atom_O = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.dtype_acc, @@ -2490,12 +2642,10 @@ def correction_loop( # ((32,1),1,4) tOtOs_t2r = [ - thr_tmem_load_O.partition_S(tOtOs[split]) - for split in range(self.num_hdimv_splits) + thr_tmem_load_O.partition_S(tOtOs[split]) for split in range(self.num_hdimv_splits) ] tOtOs_r2t = [ - thr_tmem_store_O.partition_D(tOtOs[split]) - for split in range(self.num_hdimv_splits) + thr_tmem_store_O.partition_D(tOtOs[split]) for split in range(self.num_hdimv_splits) ] cOi = cute.make_identity_tensor((self.cta_tile_m, self.hdimv // self.num_hdimv_splits)) @@ -2535,7 +2685,8 @@ def correction_loop( # n_block_max = topk_length_dynamic // self.tile_n else: n_block_min, n_block_max = block_info.get_n_block_min_max( - seqlen, cluster_m_block, + seqlen, + cluster_m_block, ) num_n_blocks = n_block_max - n_block_min @@ -2552,7 +2703,7 @@ def correction_loop( should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 pipeline_sm_stats.consumer_release(consumer_state_sm_stats) consumer_state_sm_stats.advance() - + for split in cutlass.range_constexpr(self.num_hdimv_splits): consumer_state_Oi = consumer_states_O[split] pipelines_O[split].consumer_wait(consumer_state_Oi) @@ -2567,9 +2718,9 @@ def correction_loop( consumer_states_O[split] = consumer_state_Oi # (seqlen_q, hdimv) - mO_cur = seqlen.offset_batch_Q( - mO, batch_idx, dim=3, ragged=self.ragged_tma_O - )[None, None, head_idx] + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=self.ragged_tma_O)[ + None, None, head_idx + ] # (cta_tile_m, hdimv//2, 2) gO = cute.local_tile( mO_cur, @@ -2593,8 +2744,11 @@ def correction_loop( if const_expr(self.use_tma_O): tOsO = thr_tiled_copy_O_r2g.partition_D(sO) store_O, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_O, 0, cute.make_layout(1), - sO, gO, + tma_atom_O, + 0, + cute.make_layout(1), + sO, + gO, ) self.sm_stats_barrier_full.arrive_and_wait() @@ -2613,7 +2767,7 @@ def correction_loop( else seqlen.seqlen_q * self.qhead_per_kvhead ) - # compute and store lse to gmem + # compute and store lse to gmem if const_expr(mLSE is not None): if const_expr(not seqlen.has_cu_seqlens_q): mLSE_cur = mLSE[None, head_idx, batch_idx] @@ -2627,7 +2781,8 @@ def correction_loop( row_max = sRowMax[tidx, 0] LN2 = math.log(2.0) lse = ( - (row_max * softmax_scale_log2 + cute.math.log2(row_sum, fastmath=True)) * LN2 + (row_max * softmax_scale_log2 + cute.math.log2(row_sum, fastmath=True)) + * LN2 if not acc_O_mn_row_is_zero_or_nan else -Float32.inf ) @@ -2648,7 +2803,7 @@ def correction_loop( # scale and downcast Oi tOrOs_r2g[split].store((tOrOs_r2g_f32[split].load() * scale).to(self.dtype_O)) - + if const_expr(not self.use_tma_O): # copy Oi rmem -> gmem if row_idx < seqlen_q: @@ -2757,45 +2912,49 @@ def test_mla_kernel( total_k_dummy = batch * seqlen_k if varlen_q: - Q = torch.randn(total_q_dummy, nheads, hdim, dtype=torch.bfloat16, device="cuda") - Qv = torch.randn(total_q_dummy, nheads, hdimv, dtype=torch.bfloat16, device="cuda") - O = torch.empty(total_q_dummy, nheads, hdimv, dtype=torch.bfloat16, device="cuda") - lse = torch.empty(nheads, total_q_dummy, dtype=torch.float32, device="cuda") + Q = torch.randn(total_q_dummy, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(total_q_dummy, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(total_q_dummy, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(nheads, total_q_dummy, dtype=torch.float32, device="cuda") index_topk = ( torch.rand(total_q_dummy, topk_length, device="cuda") - .argsort(dim=-1).to(torch.int32) + .argsort(dim=-1) + .to(torch.int32) ) cu_seqlens_q_dummy = torch.arange( 0, (batch + 1) * seqlen_q, seqlen_q, dtype=torch.int32, device="cuda" ) else: - Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") - Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") - O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") - lse = torch.empty(batch, nheads, seqlen_q, dtype=torch.float32, device="cuda") + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(batch, nheads, seqlen_q, dtype=torch.float32, device="cuda") index_topk = ( torch.rand(batch, seqlen_q, topk_length, device="cuda") - .argsort(dim=-1).to(torch.int32) + .argsort(dim=-1) + .to(torch.int32) ) if varlen_k: - K = torch.randn(total_k_dummy, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + K = torch.randn(total_k_dummy, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") V = torch.randn(total_k_dummy, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") cu_seqlens_k_dummy = torch.arange( 0, (batch + 1) * seqlen_k, seqlen_k, dtype=torch.int32, device="cuda" ) else: - K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") - mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) - mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) - mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) - mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) - mLSE = from_dlpack(lse, assumed_align=4 ).mark_layout_dynamic(leading_dim=lse.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + mLSE = from_dlpack(lse, assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) if gather_kv: - mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic(leading_dim=index_topk.ndim - 1) + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic( + leading_dim=index_topk.ndim - 1 + ) else: mIndexTopk = None @@ -2817,9 +2976,15 @@ def test_mla_kernel( is_varlen_q=varlen_q, disable_bitmask=disable_bitmask, ), - mQ, mQv, mK, mV, mO, mLSE, softmax_scale, + mQ, + mQv, + mK, + mV, + mO, + mLSE, + softmax_scale, **compile_kwargs, - options="--keep-ptx --keep-cubin --generate-line-info" + options="--keep-ptx --keep-cubin --generate-line-info", ) dump_kernel_attributes(kernel) compile_cache[compile_key] = kernel @@ -2829,14 +2994,14 @@ def test_mla_kernel( if varlen_q: torch.manual_seed(seed + 1000) # When causal without varlen_k, every per-batch seqlen_q must not exceed seqlen_k. - max_seqlen_q = (seqlen_k if (is_causal and not varlen_k) else seqlen_q) + max_seqlen_q = seqlen_k if (is_causal and not varlen_k) else seqlen_q seqlens_q = torch.randint(1, max_seqlen_q + 1, (batch,), dtype=torch.int32) cu_seqlens_q = torch.zeros(batch + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:] = seqlens_q.cumsum(0).to(torch.int32).cuda() total_q = cu_seqlens_q[-1].item() else: seqlens_q = torch.full((batch,), seqlen_q, dtype=torch.int32) - total_q = None # unused + total_q = None # unused if varlen_k: torch.manual_seed(seed + 2000) @@ -2851,28 +3016,28 @@ def test_mla_kernel( total_k = cu_seqlens_k[-1].item() else: seqlens_k = torch.full((batch,), seqlen_k, dtype=torch.int32) - total_k = None # unused + total_k = None # unused - torch.manual_seed(seed) # restore main seed before drawing actual tensors + torch.manual_seed(seed) # restore main seed before drawing actual tensors # ---- Allocate Q / Qv / O / lse ---- if varlen_q: - Q = torch.randn(total_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") - Qv = torch.randn(total_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") - O = torch.empty(total_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") - lse = torch.empty(nheads, total_q, dtype=torch.float32, device="cuda") + Q = torch.randn(total_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(total_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(total_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(nheads, total_q, dtype=torch.float32, device="cuda") else: - Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") - Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") - O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") - lse = torch.empty(batch, nheads, seqlen_q, dtype=torch.float32, device="cuda") + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(batch, nheads, seqlen_q, dtype=torch.float32, device="cuda") # ---- Allocate K / V ---- if varlen_k: - K = torch.randn(total_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + K = torch.randn(total_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") V = torch.randn(total_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") else: - K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") # ---- Generate index_topk with per-batch valid ranges when varlen_k ---- @@ -2891,65 +3056,76 @@ def test_mla_kernel( topk_parts.append(topk_b) if varlen_q: - index_topk = torch.cat(topk_parts, dim=0) # (total_q, topk_length) + index_topk = torch.cat(topk_parts, dim=0) # (total_q, topk_length) else: - index_topk = torch.stack(topk_parts, dim=0) # (batch, seqlen_q, topk_length) + index_topk = torch.stack(topk_parts, dim=0) # (batch, seqlen_q, topk_length) else: index_topk = None # ---- Reference computation (per-batch loop covers all four varlen combos) ---- O_ref_list, O_pt_list, lse_ref_list, lse_pt_list = [], [], [], [] for b in range(batch): - qs = cu_seqlens_q[b].item() if varlen_q else b * seqlen_q - qe = cu_seqlens_q[b+1].item() if varlen_q else (b + 1) * seqlen_q - ks = cu_seqlens_k[b].item() if varlen_k else b * seqlen_k - ke = cu_seqlens_k[b+1].item() if varlen_k else (b + 1) * seqlen_k - - Q_b = Q [qs:qe].unsqueeze(0) if varlen_q else Q [b:b+1] # (1, sl_q, nheads, hdim) - Qv_b = Qv[qs:qe].unsqueeze(0) if varlen_q else Qv[b:b+1] # (1, sl_q, nheads, hdimv) - K_b = K [ks:ke].unsqueeze(0) if varlen_k else K [b:b+1] # (1, sl_k, nheads_kv, hdim) - V_b = V [ks:ke].unsqueeze(0) if varlen_k else V [b:b+1] # (1, sl_k, nheads_kv, hdimv) + qs = cu_seqlens_q[b].item() if varlen_q else b * seqlen_q + qe = cu_seqlens_q[b + 1].item() if varlen_q else (b + 1) * seqlen_q + ks = cu_seqlens_k[b].item() if varlen_k else b * seqlen_k + ke = cu_seqlens_k[b + 1].item() if varlen_k else (b + 1) * seqlen_k + + Q_b = Q[qs:qe].unsqueeze(0) if varlen_q else Q[b : b + 1] # (1, sl_q, nheads, hdim) + Qv_b = Qv[qs:qe].unsqueeze(0) if varlen_q else Qv[b : b + 1] # (1, sl_q, nheads, hdimv) + K_b = K[ks:ke].unsqueeze(0) if varlen_k else K[b : b + 1] # (1, sl_k, nheads_kv, hdim) + V_b = V[ks:ke].unsqueeze(0) if varlen_k else V[b : b + 1] # (1, sl_k, nheads_kv, hdimv) if gather_kv: - topk_b = index_topk[qs:qe].unsqueeze(0) if varlen_q else index_topk[b:b+1] + topk_b = index_topk[qs:qe].unsqueeze(0) if varlen_q else index_topk[b : b + 1] else: topk_b = None - O_b, _, lse_b = attention_ref(Q_b, K_b, V_b, qv=Qv_b, causal=is_causal, - return_lse=True, gather_kv_indices=topk_b) - O_pt_b, _, lse_pt_b = attention_ref(Q_b, K_b, V_b, qv=Qv_b, causal=is_causal, - upcast=False, reorder_ops=True, - return_lse=True, gather_kv_indices=topk_b) + O_b, _, lse_b = attention_ref( + Q_b, K_b, V_b, qv=Qv_b, causal=is_causal, return_lse=True, gather_kv_indices=topk_b + ) + O_pt_b, _, lse_pt_b = attention_ref( + Q_b, + K_b, + V_b, + qv=Qv_b, + causal=is_causal, + upcast=False, + reorder_ops=True, + return_lse=True, + gather_kv_indices=topk_b, + ) O_ref_list.append(O_b.squeeze(0)) O_pt_list.append(O_pt_b.squeeze(0)) lse_ref_list.append(lse_b.squeeze(0)) lse_pt_list.append(lse_pt_b.squeeze(0)) - cat_dim_o = 0 if (varlen_q) else 0 # always 0: leading token/batch dim + cat_dim_o = 0 if (varlen_q) else 0 # always 0: leading token/batch dim cat_dim_lse = -1 if (varlen_q) else -1 # always last: token dim if varlen_q: - O_ref = torch.cat(O_ref_list, dim=0) # (total_q, nheads, hdimv) - O_pt = torch.cat(O_pt_list, dim=0) - lse_ref = torch.cat(lse_ref_list, dim=-1) # (nheads, total_q) - lse_pt = torch.cat(lse_pt_list, dim=-1) + O_ref = torch.cat(O_ref_list, dim=0) # (total_q, nheads, hdimv) + O_pt = torch.cat(O_pt_list, dim=0) + lse_ref = torch.cat(lse_ref_list, dim=-1) # (nheads, total_q) + lse_pt = torch.cat(lse_pt_list, dim=-1) else: - O_ref = torch.stack(O_ref_list, dim=0) # (batch, seqlen_q, nheads, hdimv) - O_pt = torch.stack(O_pt_list, dim=0) + O_ref = torch.stack(O_ref_list, dim=0) # (batch, seqlen_q, nheads, hdimv) + O_pt = torch.stack(O_pt_list, dim=0) lse_ref = torch.stack(lse_ref_list, dim=0) # (batch, nheads, seqlen_q) - lse_pt = torch.stack(lse_pt_list, dim=0) + lse_pt = torch.stack(lse_pt_list, dim=0) rtol = 2 atol = 2 * (O_ref + 0.3 - 0.3 - O_ref).abs().max().item() # ---- CuTe tensor wrappers ---- - mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) - mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) - mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) - mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) - mLSE = from_dlpack(lse, assumed_align=4 ).mark_layout_dynamic(leading_dim=lse.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + mLSE = from_dlpack(lse, assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) if index_topk is not None: - mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic(leading_dim=index_topk.ndim - 1) + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic( + leading_dim=index_topk.ndim - 1 + ) else: mIndexTopk = None @@ -2961,7 +3137,13 @@ def test_mla_kernel( # ---- Run kernel ---- compile_cache[compile_key]( - mQ, mQv, mK, mV, mO, mLSE, softmax_scale, + mQ, + mQv, + mK, + mV, + mO, + mLSE, + softmax_scale, **run_kwargs, ) @@ -2978,8 +3160,10 @@ def test_mla_kernel( if validate: assert (O - O_ref).abs().max().item() <= rtol * (O_pt - O_ref).abs().max().item() + atol varlen_tag = "" - if varlen_q: varlen_tag += f", total_q:{total_q}" - if varlen_k: varlen_tag += f", total_k:{total_k}" + if varlen_q: + varlen_tag += f", total_q:{total_q}" + if varlen_k: + varlen_tag += f", total_k:{total_k}" print( f"batch:{batch:3d}, nheads:{nheads:3d}, seqlen_q:{seqlen_q:5d}, seqlen_k:{seqlen_k:5d}" f"{varlen_tag}, iter:{iter:2d} PASSED" @@ -2993,6 +3177,7 @@ def test_mla_kernel( return None + def timeit(fn, *args, **kwargs): # Synchronize before timing torch.cuda.synchronize() @@ -3016,15 +3201,24 @@ def timeit(fn, *args, **kwargs): def benchmark_mla_kernel( - batch=1, seqlen_q=2048, seqlen_k=2048, topk_length=2048, nheads=128, hdim=64, hdimv=512, compile_cache=dict(), - gather_kv=True, is_causal=False, disable_bitmask=False, + batch=1, + seqlen_q=2048, + seqlen_k=2048, + topk_length=2048, + nheads=128, + hdim=64, + hdimv=512, + compile_cache=dict(), + gather_kv=True, + is_causal=False, + disable_bitmask=False, ): assert hdim == 64, "hdim must be 64" assert hdimv == 512, "hdimv must be 512" - qhead_per_kvhead=nheads - nheads_kv=1 - pack_gqa=True + qhead_per_kvhead = nheads + nheads_kv = 1 + pack_gqa = True softmax_scale = 1.0 / math.sqrt(hdim + hdimv) compile_key = ( @@ -3042,7 +3236,9 @@ def benchmark_mla_kernel( K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") - index_topk = torch.rand(batch, seqlen_q, topk_length, device="cuda").argsort(dim=-1).to(torch.int32) + index_topk = ( + torch.rand(batch, seqlen_q, topk_length, device="cuda").argsort(dim=-1).to(torch.int32) + ) mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) @@ -3050,7 +3246,9 @@ def benchmark_mla_kernel( mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) if gather_kv: - mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic(leading_dim=index_topk.ndim - 1) + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic( + leading_dim=index_topk.ndim - 1 + ) else: mIndexTopk = None @@ -3059,26 +3257,34 @@ def benchmark_mla_kernel( kernel = cute.compile( FlashAttentionMLAForwardSm100( is_causal=is_causal, - use_cpasync_load_KV = gather_kv, - topk_length = topk_length if gather_kv else 2048, - is_topk_gather = gather_kv, + use_cpasync_load_KV=gather_kv, + topk_length=topk_length if gather_kv else 2048, + is_topk_gather=gather_kv, pack_gqa=pack_gqa, qhead_per_kvhead=qhead_per_kvhead, nheads_kv=nheads_kv, disable_bitmask=disable_bitmask, ), - mQ, mQv, mK, mV, mO, mLSE, softmax_scale, + mQ, + mQv, + mK, + mV, + mO, + mLSE, + softmax_scale, mIndexTopk=mIndexTopk, ) compile_cache[compile_key] = kernel - + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") - index_topk = torch.rand(batch, seqlen_q, topk_length, device="cuda").argsort(dim=-1).to(torch.int32) + index_topk = ( + torch.rand(batch, seqlen_q, topk_length, device="cuda").argsort(dim=-1).to(torch.int32) + ) mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) @@ -3086,17 +3292,25 @@ def benchmark_mla_kernel( mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) if gather_kv: - mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic(leading_dim=index_topk.ndim - 1) + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic( + leading_dim=index_topk.ndim - 1 + ) else: mIndexTopk = None mLSE = None exec_time_in_s = timeit( compile_cache[compile_key], - mQ, mQv, mK, mV, mO, mLSE, softmax_scale, + mQ, + mQv, + mK, + mV, + mO, + mLSE, + softmax_scale, mIndexTopk=mIndexTopk, ) - + seqlen_k_eff = topk_length if gather_kv else seqlen_k FLOPs = 2 * batch * nheads * seqlen_q * seqlen_k_eff * (hdim + 2 * hdimv) @@ -3142,7 +3356,12 @@ def benchmark_mla_kernel( nheads_test_values = [128] batch_test_values = [4] test_configs = [ - (batch, nheads, seqlen_q, seqlen_k,) + ( + batch, + nheads, + seqlen_q, + seqlen_k, + ) for batch in batch_test_values for nheads in nheads_test_values for seqlen_q in seqlen_q_test_values @@ -3178,38 +3397,44 @@ def benchmark_mla_kernel( if run_benchmark: if gather_kv: seqlen_q_benchmark_values = [1] - seqlen_k_benchmark_values = [8192*2] + seqlen_k_benchmark_values = [8192 * 2] nheads_benchmark_values = [128] batch_benchmark_values = [512] else: seqlen_q_benchmark_values = [1] - seqlen_k_benchmark_values = [8192*2] + seqlen_k_benchmark_values = [8192 * 2] nheads_benchmark_values = [128] batch_benchmark_values = [512] seqlen_q_benchmark_values = [4096] seqlen_k_benchmark_values = [4096] nheads_benchmark_values = [16] batch_benchmark_values = [8] - benchmark_configs = [ (batch, nheads, seqlen_q, seqlen_k,) - for batch in batch_benchmark_values - for nheads in nheads_benchmark_values - for seqlen_q in seqlen_q_benchmark_values - for seqlen_k in seqlen_k_benchmark_values - ] - compile_cache=dict() - print("="*40) + benchmark_configs = [ + ( + batch, + nheads, + seqlen_q, + seqlen_k, + ) + for batch in batch_benchmark_values + for nheads in nheads_benchmark_values + for seqlen_q in seqlen_q_benchmark_values + for seqlen_k in seqlen_k_benchmark_values + ] + compile_cache = dict() + print("=" * 40) print("Benchmarking MLA Kernel") - print("="*40) + print("=" * 40) for config in benchmark_configs: batch, nheads, seqlen_q, seqlen_k = config benchmark_mla_kernel( - batch = batch, - seqlen_q = seqlen_q, - seqlen_k = seqlen_k, - topk_length = topk_length, - nheads = nheads, + batch=batch, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + topk_length=topk_length, + nheads=nheads, gather_kv=gather_kv, is_causal=is_causal, disable_bitmask=disable_bitmask, - compile_cache=compile_cache + compile_cache=compile_cache, ) From 83b8b8f0279f2632cd1829b247f2edc2f5cfd5b1 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Wed, 15 Apr 2026 08:58:15 -0700 Subject: [PATCH 14/44] Disable 2CTA fwd non-causal on CUDA 12 to work around codegen regression (#2461) * Disable 2CTA forward non-causal on CUDA 12.9 to work around codegen regression CUDA 12.9 has a codegen issue that causes ~18% slowdown for 2CTA forward non-causal (hdim=128: 1280 vs 1542 TFLOPS). This is fixed in CUDA 13.x. Auto-disable 2CTA when CUDA 12.9 is detected. Users on CUDA 13.x are unaffected. The manual `FA_DISABLE_2CTA=1` override continues to work regardless of CUDA version. * Disable 2CTA forward non-causal on all CUDA 12.x (not just 12.9) --- flash_attn/cute/utils.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 76579c81cc7..a39230520e4 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -60,12 +60,33 @@ _fa_disable_2cta_enabled: bool = os.environ.get("FA_DISABLE_2CTA", "0") == "1" +def _is_cuda_12() -> bool: + """Check if the CUDA toolkit version is 12.x. + + 2CTA forward non-causal has a codegen regression on CUDA 12 that causes + ~18% slowdown compared to 1CTA. This is fixed in CUDA 13.x. + """ + try: + import torch + + cuda_version = torch.version.cuda + if cuda_version is not None: + major = cuda_version.split(".")[0] + return int(major) == 12 + except Exception: + pass + return False + + +_fa_disable_2cta_cuda12: bool = _is_cuda_12() + + def _get_use_clc_scheduler_default() -> bool: return _fa_clc_enabled def _get_disable_2cta_default() -> bool: - return _fa_disable_2cta_enabled + return _fa_disable_2cta_enabled or _fa_disable_2cta_cuda12 def _compute_base_hash(func: Callable) -> str: From d7f60e6639250626589c528a943aaaba6fe5955a Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Wed, 15 Apr 2026 17:13:52 -0700 Subject: [PATCH 15/44] Add CLC scheduler heuristic (#2455) Made-with: Cursor --- .gitignore | 2 + AGENTS.md | 1 + CLAUDE.md | 4 + benchmarks/clc_bench.py | 750 +++++++++++++++++++++++++++++ benchmarks/configs/clc.yaml | 35 ++ flash_attn/cute/flash_fwd_sm100.py | 1 + flash_attn/cute/interface.py | 9 +- tests/cute/test_clc_fuzz.py | 38 +- 8 files changed, 830 insertions(+), 10 deletions(-) create mode 120000 AGENTS.md create mode 100644 benchmarks/clc_bench.py create mode 100644 benchmarks/configs/clc.yaml diff --git a/.gitignore b/.gitignore index a4f1703b494..387a5f4535e 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,8 @@ var/ # Dev venv +agent_space/ +benchmarks/results/ # compile-time generated file flash_attn_config.py diff --git a/AGENTS.md b/AGENTS.md new file mode 120000 index 00000000000..681311eb9cf --- /dev/null +++ b/AGENTS.md @@ -0,0 +1 @@ +CLAUDE.md \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 9f752d7e0e0..f170541d482 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,6 +8,10 @@ FlashAttention-4 (FA4) — fast, memory-efficient exact attention kernels writte The repository also contains older generations (FA2 in top-level `csrc/`, FA3 in `hopper/`) but active development is on FA4 in `flash_attn/cute/`. +## Agent Scratch Space + +Use `agent_space/` for project-local scratch work such as lab notes, profiling outputs, temporary repro scripts, and experiment artifacts. Treat it as disposable workspace rather than product code. + ## Build & Install ```bash diff --git a/benchmarks/clc_bench.py b/benchmarks/clc_bench.py new file mode 100644 index 00000000000..18e7358a6d7 --- /dev/null +++ b/benchmarks/clc_bench.py @@ -0,0 +1,750 @@ +#!/usr/bin/env python3 +"""CLC benchmark for dense, varlen, and block-sparse FA4 sweeps. + +Run with benchmark against the yaml sweep: + python benchmarks/clc_bench.py --config benchmarks/configs/clc.yaml + +Useful overrides: + --workers 64 # compile parallelism + --case_filter q16_kv4 # run matching cases only +""" +from __future__ import annotations + +import csv +import json +import math +import os +import statistics +import subprocess +import sys +import types +from contextlib import nullcontext +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from itertools import product +from pathlib import Path +from typing import Literal + +try: + from jsonargparse import CLI +except ImportError as exc: + raise SystemExit( + "Missing jsonargparse. Install it with " + "uv pip install jsonargparse pyyaml" + ) from exc + + +REPO_ROOT = Path(__file__).resolve().parents[1] +RESULTS_ROOT = REPO_ROOT / "benchmarks" / "results" / "clc" +CSV_FLOAT_DIGITS = 3 +BLOCK_SIZE_Q = 256 +BLOCK_SIZE_K = 128 +INTERNAL_REQUEST_ENV = "CLC_BENCH_INTERNAL_REQUEST" + +DTypeName = Literal["bfloat16", "float16"] + + +@dataclass(frozen=True) +class DenseSweep: + enabled: bool = True + batches: list[int] = field(default_factory=lambda: [1, 4, 8, 16, 32]) + seqlen_pairs: list[list[int]] = field( + default_factory=lambda: [[1024, 1024], [2048, 2048], [4096, 4096], [8192, 8192], [16384, 16384]] + ) + head_dims: list[int] = field(default_factory=lambda: [64, 96, 128]) + head_pairs: list[list[int]] = field(default_factory=lambda: [[16, 16], [16, 8], [16, 4], [16, 2], [16, 1]]) + causal: bool | list[bool] = True + + +@dataclass(frozen=True) +class VarlenSweep: + enabled: bool = True + max_q_tokens: list[int] = field(default_factory=lambda: [2048, 4096, 8192, 16384, 32768]) + max_kv_tokens: list[int] = field(default_factory=lambda: [2048, 4096, 8192, 16384, 32768]) + batches: list[int] = field(default_factory=lambda: [4, 8, 16, 32]) + patterns: list[str] = field(default_factory=lambda: ["uniform", "longtail"]) + head_dims: list[int] = field(default_factory=lambda: [64, 96, 128]) + head_pairs: list[list[int]] = field(default_factory=lambda: [[16, 8], [16, 4], [16, 2], [16, 1]]) + causal: bool | list[bool] = False + + +@dataclass(frozen=True) +class BlockSparseSweep: + enabled: bool = False + batches: list[int] = field(default_factory=lambda: [1, 4, 8, 16, 32]) + seqlen_pairs: list[list[int]] = field( + default_factory=lambda: [[1024, 1024], [2048, 2048], [4096, 4096], [4096, 8192]] + ) + head_dims: list[int] = field(default_factory=lambda: [64, 128]) + head_pairs: list[list[int]] = field(default_factory=lambda: [[16, 16], [16, 4], [16, 1]]) + mask_names: list[str] = field(default_factory=lambda: ["block_diagonal"]) + sliding_window_sizes: list[int] = field(default_factory=lambda: [2048]) + + +@dataclass(frozen=True) +class Case: + name: str + mode: Literal["dense", "varlen", "block_sparse"] + q_heads: int + kv_heads: int + d: int + causal: bool + batch: int | None = None + seqlen_q: int | None = None + seqlen_k: int | None = None + seqlens_q: list[int] | None = None + seqlens_k: list[int] | None = None + pattern: str = "" + mask_name: str = "" + window_size: int | None = None + + +def utc_timestamp() -> str: + return datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + + +def default_out_dir() -> Path: + return RESULTS_ROOT / utc_timestamp() + + +def head_pair_label(q_heads: int, kv_heads: int) -> str: + return f"q{q_heads}_kv{kv_heads}" + + +def token_label(value: int) -> str: + return f"{value // 1024}k" if value >= 1024 and value % 1024 == 0 else str(value) + + +def dense_case_name(q_heads: int, kv_heads: int, causal: bool, d: int, batch: int, seqlen_q: int, seqlen_k: int) -> str: + causal_name = "causal" if causal else "noncausal" + pair = head_pair_label(q_heads, kv_heads) + if seqlen_q == seqlen_k: + return f"{pair}_{causal_name}_h{d}_{token_label(seqlen_q)}_b{batch}" + return f"{pair}_{causal_name}_q{seqlen_q}_k{seqlen_k}_h{d}_b{batch}" + + +def varlen_case_name( + pattern: str, + q_heads: int, + kv_heads: int, + causal: bool, + d: int, + batch: int, + max_q_tokens: int, + max_kv_tokens: int, +) -> str: + causal_name = "causal" if causal else "noncausal" + pair = head_pair_label(q_heads, kv_heads) + return ( + f"varlen_{pattern}_{pair}_{causal_name}_h{d}_" + f"b{batch}_q{token_label(max_q_tokens)}_kv{token_label(max_kv_tokens)}" + ) + + +def normalize_lengths(weights: list[float], total_tokens: int) -> list[int]: + if total_tokens < len(weights): + raise ValueError(f"total_tokens={total_tokens} is smaller than batch={len(weights)}") + scaled = [weight / sum(weights) * total_tokens for weight in weights] + lengths = [max(1, int(math.floor(value))) for value in scaled] + delta = total_tokens - sum(lengths) + order = sorted( + range(len(weights)), + key=lambda idx: scaled[idx] - math.floor(scaled[idx]), + reverse=delta > 0, + ) + cursor = 0 + while delta != 0: + idx = order[cursor % len(order)] + if delta > 0: + lengths[idx] += 1 + delta -= 1 + elif lengths[idx] > 1: + lengths[idx] -= 1 + delta += 1 + cursor += 1 + return lengths + + +def pattern_weights(pattern: str, batch: int) -> list[float]: + match pattern: + case "uniform": + return [1.0] * batch + case "spiky": + return [32.0] + [1.0] * (batch - 1) + case "longtail": + return [float(batch - idx) for idx in range(batch)] + case "bimodal": + split = max(1, batch // 2) + return [8.0] * split + [1.0] * (batch - split) + case "staircase": + return [1.0 + idx for idx in range(batch)] + case "loss_shape": + base = [130, 1, 1, 1, 1674, 68, 157, 1, 1, 1, 1, 1, 1, 9, 1, 5] + if batch == len(base): + return [float(value) for value in base] + return [float(base[idx % len(base)]) for idx in range(batch)] + case _: + raise ValueError(f"Unsupported varlen pattern: {pattern}") + + +def bool_values(value: bool | list[bool]) -> list[bool]: + return [value] if isinstance(value, bool) else value + + +def generate_cases( + dense: DenseSweep, + varlen: VarlenSweep, + block_sparse: BlockSparseSweep, + case_filter: str = "", +) -> list[Case]: + cases: list[Case] = [] + if dense.enabled: + for batch, seqlen_pair, d, (q_heads, kv_heads), causal in product( + dense.batches, + dense.seqlen_pairs, + dense.head_dims, + dense.head_pairs, + bool_values(dense.causal), + ): + seqlen_q, seqlen_k = seqlen_pair + cases.append( + Case( + name=dense_case_name(q_heads, kv_heads, causal, d, batch, seqlen_q, seqlen_k), + mode="dense", + q_heads=q_heads, + kv_heads=kv_heads, + d=d, + causal=causal, + batch=batch, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + ) + ) + if varlen.enabled: + for max_q_tokens, max_kv_tokens, batch, pattern, d, (q_heads, kv_heads), causal in product( + varlen.max_q_tokens, + varlen.max_kv_tokens, + varlen.batches, + varlen.patterns, + varlen.head_dims, + varlen.head_pairs, + bool_values(varlen.causal), + ): + weights = pattern_weights(pattern, batch) + lengths_q = normalize_lengths(weights, max_q_tokens) + lengths_k = normalize_lengths(weights, max(batch, max_kv_tokens)) + cases.append( + Case( + name=varlen_case_name(pattern, q_heads, kv_heads, causal, d, batch, max_q_tokens, max_kv_tokens), + mode="varlen", + q_heads=q_heads, + kv_heads=kv_heads, + d=d, + causal=causal, + batch=batch, + seqlens_q=lengths_q, + seqlens_k=lengths_k, + pattern=pattern, + ) + ) + if block_sparse.enabled: + for batch, seqlen_pair, d, (q_heads, kv_heads), mask_name in product( + block_sparse.batches, + block_sparse.seqlen_pairs, + block_sparse.head_dims, + block_sparse.head_pairs, + block_sparse.mask_names, + ): + seqlen_q, seqlen_k = seqlen_pair + if seqlen_q > seqlen_k: + continue + window_sizes = block_sparse.sliding_window_sizes if mask_name == "sliding_window" else [None] + for window_size in window_sizes: + window_label = f"_w{window_size}" if window_size is not None else "" + pair = head_pair_label(q_heads, kv_heads) + cases.append( + Case( + name=( + f"block_sparse_{mask_name}{window_label}_{pair}_" + f"h{d}_q{seqlen_q}_k{seqlen_k}_b{batch}" + ), + mode="block_sparse", + q_heads=q_heads, + kv_heads=kv_heads, + d=d, + causal=False, + batch=batch, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + mask_name=mask_name, + window_size=window_size, + ) + ) + if case_filter: + needle = case_filter.lower() + cases = [case for case in cases if needle in case.name.lower()] + return cases + + +def compile_q_stage(case: Case) -> int: + max_seqlen_q = max(case.seqlens_q) if case.seqlens_q is not None else case.seqlen_q + qhead_per_kvhead = case.q_heads // case.kv_heads + return 2 if max_seqlen_q is not None and max_seqlen_q * qhead_per_kvhead > 128 else 1 + + + +def compile_signature(case: Case) -> tuple: + q_stage = compile_q_stage(case) + if case.mode == "block_sparse": + return ( + case.mode, + case.q_heads, + case.kv_heads, + case.d, + case.mask_name, + case.window_size, + q_stage, + ) + return case.mode, case.q_heads, case.kv_heads, case.d, case.causal, q_stage + + +def select_compile_cases(cases: list[Case]) -> list[Case]: + selected: dict[tuple, Case] = {} + for case in cases: + selected.setdefault(compile_signature(case), case) + return list(selected.values()) + + +def benchmark_cuda_samples_in_microseconds(func, *args, **kwargs) -> list[float]: + num_iters = kwargs.pop("NUM_ITERS", 100) + warmup_iters = kwargs.pop("MEMORY_WARMUP_ITERS", 25) + is_vetted_benchmarking = kwargs.pop("IS_VETTED_BENCHMARKING", False) + from torch._inductor.runtime.benchmarking import benchmarker + + return [ + float(sample_ms) * 1e3 + for sample_ms in benchmarker.benchmark_gpu( + lambda: func(*args, **kwargs), + benchmark_iters=num_iters, + memory_warmup_iters=warmup_iters, + return_mode="all", + is_vetted_benchmarking=is_vetted_benchmarking, + ) + ] + + +def flash_attn_imports(): + if "flash_attn" not in sys.modules: + stub = types.ModuleType("flash_attn") + stub.__path__ = [str(REPO_ROOT / "flash_attn")] + sys.modules["flash_attn"] = stub + import torch + from torch._subclasses.fake_tensor import FakeTensorMode + from flash_attn.cute import utils as cute_utils + from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func + + return torch, FakeTensorMode, cute_utils, flash_attn_func, flash_attn_varlen_func + + +def block_sparse_imports(): + if "flash_attn" not in sys.modules: + stub = types.ModuleType("flash_attn") + stub.__path__ = [str(REPO_ROOT / "flash_attn")] + sys.modules["flash_attn"] = stub + if str(REPO_ROOT / "tests" / "cute") not in sys.path: + sys.path.insert(0, str(REPO_ROOT / "tests" / "cute")) + from flash_attn.cute.compute_block_sparsity import compute_block_sparsity + from mask_mod_definitions import get_mask_pair + + return compute_block_sparsity, get_mask_pair + + +def build_cu_seqlens(torch_mod, lengths: list[int]) -> torch_mod.Tensor: + cu_seqlens = torch_mod.zeros(len(lengths) + 1, device="cuda", dtype=torch_mod.int32) + cu_seqlens[1:] = torch_mod.tensor(lengths, device="cuda", dtype=torch_mod.int32).cumsum(0) + return cu_seqlens + + +def build_dense_inputs(torch_mod, flash_attn_func, case: Case, dtype, factory): + q = factory(case.batch, case.seqlen_q, case.q_heads, case.d, device="cuda", dtype=dtype) + k = factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + v = factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + return flash_attn_func, dict(q=q, k=k, v=v, causal=case.causal) + + +def build_varlen_inputs(torch_mod, flash_attn_varlen_func, case: Case, dtype, factory): + lengths_q = case.seqlens_q or [] + lengths_k = case.seqlens_k or lengths_q + total_q = sum(lengths_q) + total_k = sum(lengths_k) + q = factory(total_q, case.q_heads, case.d, device="cuda", dtype=dtype) + k = factory(total_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + v = factory(total_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + return flash_attn_varlen_func, dict( + q=q, + k=k, + v=v, + cu_seqlens_q=build_cu_seqlens(torch_mod, lengths_q), + cu_seqlens_k=build_cu_seqlens(torch_mod, lengths_k), + max_seqlen_q=max(lengths_q), + max_seqlen_k=max(lengths_k), + causal=case.causal, + ) + + +def build_block_sparse_compile_tensors(torch_mod, case: Case): + num_m_blocks = math.ceil((case.seqlen_q or 0) / BLOCK_SIZE_Q) + count_shape = (case.batch, 1, num_m_blocks) + index_shape = (*count_shape, 1) + return dict( + mask_block_cnt=torch_mod.zeros(count_shape, device="cuda", dtype=torch_mod.int32), + mask_block_idx=torch_mod.zeros(index_shape, device="cuda", dtype=torch_mod.int32), + full_block_cnt=torch_mod.zeros(count_shape, device="cuda", dtype=torch_mod.int32), + full_block_idx=torch_mod.zeros(index_shape, device="cuda", dtype=torch_mod.int32), + ) + + + +def build_block_sparse_inputs(torch_mod, flash_attn_func, case: Case, dtype, tensor_factory, fake_tensor: bool): + compute_block_sparsity, get_mask_pair = block_sparse_imports() + if case.mask_name in {"document", "ima"}: + raise ValueError(f"Aux-backed block-sparse masks are not supported by clc_bench.py: {case.mask_name}") + q = tensor_factory(case.batch, case.seqlen_q, case.q_heads, case.d, device="cuda", dtype=dtype) + k = tensor_factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + v = tensor_factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + cute_mask, _ = get_mask_pair( + case.mask_name, + seqlen_q=case.seqlen_q, + seqlen_k=case.seqlen_k, + window_size=case.window_size, + ) + if fake_tensor: + block_sparse_tensors = build_block_sparse_compile_tensors(torch_mod, case) + else: + _, sparse_tensors = compute_block_sparsity( + tile_m=BLOCK_SIZE_Q, + tile_n=BLOCK_SIZE_K, + batch_size=case.batch, + num_heads=1, + seqlen_q=case.seqlen_q, + seqlen_k=case.seqlen_k, + mask_mod=cute_mask, + aux_tensors=None, + device="cuda", + compute_full_blocks=True, + use_fast_sampling=False, + ) + block_sparse_tensors = dict( + mask_block_cnt=sparse_tensors.mask_block_cnt, + mask_block_idx=sparse_tensors.mask_block_idx, + full_block_cnt=sparse_tensors.full_block_cnt, + full_block_idx=sparse_tensors.full_block_idx, + ) + return flash_attn_func, dict( + q=q, + k=k, + v=v, + causal=False, + mask_mod=cute_mask, + **block_sparse_tensors, + block_size=(BLOCK_SIZE_Q, BLOCK_SIZE_K), + ) + + +def build_inputs(case: Case, dtype_name: DTypeName, fake_tensor: bool): + torch, FakeTensorMode, _, flash_attn_func, flash_attn_varlen_func = flash_attn_imports() + dtype = getattr(torch, dtype_name) + tensor_factory = torch.empty if fake_tensor else torch.randn + context = FakeTensorMode() if fake_tensor else nullcontext() + with context: + match case.mode: + case "block_sparse": + return build_block_sparse_inputs(torch, flash_attn_func, case, dtype, tensor_factory, fake_tensor) + case "dense": + return build_dense_inputs(torch, flash_attn_func, case, dtype, tensor_factory) + case "varlen": + return build_varlen_inputs(torch, flash_attn_varlen_func, case, dtype, tensor_factory) + + +def attended_pairs(seqlen_q: int, seqlen_k: int, causal: bool) -> float: + if not causal: + return float(seqlen_q * seqlen_k) + if seqlen_q <= seqlen_k: + return float(seqlen_q * (2 * seqlen_k - seqlen_q + 1) / 2) + return float(seqlen_q * seqlen_k - seqlen_k * (seqlen_k - 1) / 2) + + +def fwd_flops(case: Case, kwargs: dict | None = None) -> float: + if case.mode == "dense": + return (case.batch or 0) * case.q_heads * 2 * attended_pairs( + case.seqlen_q or 0, + case.seqlen_k or 0, + case.causal, + ) * (case.d + case.d) + if case.mode == "block_sparse": + if kwargs is None: + return 0.0 + total_blocks = kwargs["mask_block_cnt"].sum().item() + if kwargs["full_block_cnt"] is not None: + total_blocks += kwargs["full_block_cnt"].sum().item() + return float(total_blocks * BLOCK_SIZE_Q * BLOCK_SIZE_K * case.q_heads * 2 * (case.d + case.d)) + lengths_q = case.seqlens_q or [] + lengths_k = case.seqlens_k or lengths_q + total = 0.0 + for seqlen_q, seqlen_k in zip(lengths_q, lengths_k): + total += case.q_heads * 2 * attended_pairs(seqlen_q, seqlen_k, case.causal) * (case.d + case.d) + return total + + +def tflops(flop_count: float, time_us: float) -> float: + return 0.0 if time_us <= 0 else flop_count / time_us / 1e6 + + +def case_shape(case: Case) -> str: + match case.mode: + case "dense" | "block_sparse": + if case.seqlen_q == case.seqlen_k: + return token_label(case.seqlen_q or 0) + return f"q={token_label(case.seqlen_q or 0)} kv={token_label(case.seqlen_k or 0)}" + case "varlen": + lengths_q = case.seqlens_q or [] + lengths_k = case.seqlens_k or lengths_q + total_q = sum(lengths_q) + total_k = sum(lengths_k) + max_q = max(lengths_q, default=0) + max_k = max(lengths_k, default=0) + if total_q == total_k and max_q == max_k: + return f"total={token_label(total_q)} max={token_label(max_q)}" + return ( + f"q_total={token_label(total_q)} kv_total={token_label(total_k)} " + f"q_max={token_label(max_q)} kv_max={token_label(max_k)}" + ) + + +def case_metadata(case: Case) -> dict: + return { + "name": case.name, + "mode": case.mode, + "shape": case_shape(case), + "batch": case.batch, + "q_heads": case.q_heads, + "kv_heads": case.kv_heads, + "d": case.d, + "causal": case.causal, + "pattern": case.pattern, + "mask_name": case.mask_name, + "window_size": case.window_size, + } + + +def summarize_profile(case: Case, samples_off: list[float], samples_on: list[float], flop_count: float) -> dict: + mean_off = statistics.mean(samples_off) + mean_on = statistics.mean(samples_on) + paired_log_ratios = [math.log(off / on) for off, on in zip(samples_off, samples_on)] + mean_log_ratio = statistics.mean(paired_log_ratios) + stderr_log_ratio = ( + statistics.stdev(paired_log_ratios) / math.sqrt(len(paired_log_ratios)) + if len(paired_log_ratios) > 1 + else 0.0 + ) + ci95_low = math.exp(mean_log_ratio - 1.96 * stderr_log_ratio) + ci95_high = math.exp(mean_log_ratio + 1.96 * stderr_log_ratio) + return { + **case_metadata(case), + "samples_off_us": samples_off, + "samples_on_us": samples_on, + "mean_off_us": mean_off, + "mean_on_us": mean_on, + "median_off_us": statistics.median(samples_off), + "median_on_us": statistics.median(samples_on), + "mean_off_tflops": tflops(flop_count, mean_off), + "mean_on_tflops": tflops(flop_count, mean_on), + "speedup_on_vs_off": mean_off / mean_on, + "pct_change_on_vs_off": (mean_off / mean_on - 1.0) * 100.0, + "ci95_low_speedup": ci95_low, + "ci95_high_speedup": ci95_high, + "ci95_excludes_1x": ci95_low > 1.0 or ci95_high < 1.0, + } + + +def run_single_case( + case: Case, + clc: int, + fake_tensor: bool, + dtype_name: DTypeName, + bench_iters: int, + seed: int, +) -> dict: + os.environ["FA_CLC"] = str(clc) + os.environ["FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED"] = "1" + torch, _, cute_utils, _, _ = flash_attn_imports() + if not fake_tensor and not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for runtime profiling in clc_bench.py") + torch.manual_seed(seed) + fn, kwargs = build_inputs(case, dtype_name, fake_tensor) + cute_utils._fa_clc_enabled = bool(clc) + fn(**kwargs) + if fake_tensor: + return {"case": case.name, "clc": clc, "compiled": True} + torch.cuda.synchronize() + warmup_iters = min(25, max(10, bench_iters // 10)) + return { + "case": case.name, + "clc": clc, + "time_us": statistics.mean( + benchmark_cuda_samples_in_microseconds( + fn, + **kwargs, + NUM_ITERS=bench_iters, + MEMORY_WARMUP_ITERS=warmup_iters, + ) + ), + } + + +def run_single_subprocess(case: Case, clc: int, dtype_name: DTypeName, bench_iters: int, seed: int, script_path: Path) -> dict: + env = os.environ.copy() + env[INTERNAL_REQUEST_ENV] = json.dumps( + { + "case": asdict(case), + "clc": clc, + "fake_tensor": True, + "dtype_str": dtype_name, + "bench_iters": bench_iters, + "seed": seed, + } + ) + command = [sys.executable, str(script_path)] + try: + completed = subprocess.run(command, check=True, capture_output=True, text=True, env=env) + except subprocess.CalledProcessError as exc: + detail = (exc.stderr or exc.stdout).strip() + raise RuntimeError(f"Single-case compile failed for {case.name} clc={clc}:\n{detail}") from exc + for line in reversed(completed.stdout.splitlines()): + line = line.strip() + if line.startswith("{") and line.endswith("}"): + return json.loads(line) + raise RuntimeError(f"No JSON result found for {case.name} clc={clc}") + + +def run_compile(cases: list[Case], dtype_name: DTypeName, workers: int, bench_iters: int, seed: int, script_path: Path) -> list[dict]: + compile_cases = select_compile_cases(cases) + rows: list[dict] = [] + with ThreadPoolExecutor(max_workers=workers) as pool: + futures = { + pool.submit(run_single_subprocess, case, clc, dtype_name, bench_iters, seed, script_path): (case.name, clc) + for case in compile_cases + for clc in (0, 1) + } + for index, future in enumerate(as_completed(futures), start=1): + row = future.result() + print(f"[{index}/{len(futures)}] compiled {row['case']} clc={row['clc']}") + rows.append(row) + return sorted(rows, key=lambda row: (row["case"], row["clc"])) + + +def run_profile(cases: list[Case], dtype_name: DTypeName, profile_repeats: int, bench_iters: int, seed: int) -> list[dict]: + torch, _, cute_utils, _, _ = flash_attn_imports() + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for clc_bench.py") + torch.manual_seed(seed) + total_iters = profile_repeats * bench_iters + warmup_iters = min(25, max(10, total_iters // 10)) + rows: list[dict] = [] + for index, case in enumerate(cases, start=1): + fn, kwargs = build_inputs(case, dtype_name, fake_tensor=False) + samples: dict[int, list[float]] = {} + for clc in (0, 1): + cute_utils._fa_clc_enabled = bool(clc) + fn(**kwargs) + torch.cuda.synchronize() + samples[clc] = benchmark_cuda_samples_in_microseconds( + fn, + **kwargs, + NUM_ITERS=total_iters, + MEMORY_WARMUP_ITERS=warmup_iters, + ) + row = summarize_profile(case, samples[0], samples[1], fwd_flops(case, kwargs)) + print( + f"[{index}/{len(cases)}] {case.name}: " + f"off={row['mean_off_us']:.3f}us on={row['mean_on_us']:.3f}us " + f"speedup={row['speedup_on_vs_off']:.3f}x " + f"ci95=[{row['ci95_low_speedup']:.3f}, {row['ci95_high_speedup']:.3f}]" + ) + rows.append(row) + return rows + + +def round_scalar_row(row: dict) -> dict: + return { + key: round(value, CSV_FLOAT_DIGITS) if isinstance(value, float) else value + for key, value in row.items() + } + + +def write_csv(path: Path, rows: list[dict]) -> None: + if not rows: + return + path.parent.mkdir(parents=True, exist_ok=True) + scalar_rows = [ + round_scalar_row({key: value for key, value in row.items() if not isinstance(value, list)}) + for row in rows + ] + with path.open("w", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=list(scalar_rows[0].keys())) + writer.writeheader() + writer.writerows(scalar_rows) + + +def main( + out_dir: Path | None = None, + workers: int = 32, + profile_repeats: int = 3, + bench_iters: int = 64, + dtype_str: DTypeName = "bfloat16", + seed: int = 0, + case_filter: str = "", + dense: DenseSweep = DenseSweep(), + varlen: VarlenSweep = VarlenSweep(), + block_sparse: BlockSparseSweep = BlockSparseSweep(), +) -> None: + if (request_json := os.environ.get(INTERNAL_REQUEST_ENV)) is not None: + request = json.loads(request_json) + print( + json.dumps( + run_single_case( + Case(**request["case"]), + request["clc"], + request["fake_tensor"], + request["dtype_str"], + request["bench_iters"], + request["seed"], + ) + ) + ) + return + + os.environ["FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED"] = "1" + cases = generate_cases(dense, varlen, block_sparse, case_filter) + if not cases: + raise ValueError("No cases selected. Adjust the YAML sweep or case_filter.") + + run_dir = out_dir or default_out_dir() + print(f"cases={len(cases)}") + print(f"compile_cases={len(select_compile_cases(cases))}") + print(f"out_dir={run_dir}") + print(f"python={sys.executable}") + + script_path = Path(__file__).resolve() + run_compile(cases, dtype_str, workers, bench_iters, seed, script_path) + run_dir.mkdir(parents=True, exist_ok=True) + profile_rows = run_profile(cases, dtype_str, profile_repeats, bench_iters, seed) + profile_csv = run_dir / "profile.csv" + write_csv(profile_csv, profile_rows) + print("Profile written to:") + print(profile_csv) + + +if __name__ == "__main__": + CLI(main, as_positional=False) diff --git a/benchmarks/configs/clc.yaml b/benchmarks/configs/clc.yaml new file mode 100644 index 00000000000..b7bc5d4a949 --- /dev/null +++ b/benchmarks/configs/clc.yaml @@ -0,0 +1,35 @@ +dtype_str: bfloat16 +seed: 0 +workers: 64 +profile_repeats: 1 +bench_iters: 256 + +dense: + enabled: true + batches: [1, 4, 8, 16, 32] + seqlen_pairs: [[1024, 1024], [2048, 2048], [4096, 4096], [8192, 8192], [16384, 16384]] + head_dims: [64, 96, 128] + head_pairs: [[16, 16], [16, 8], [16, 4], [16, 2], [16, 1]] + causal: [true] + +varlen: + enabled: true + max_q_tokens: [2048, 4096, 8192, 16384, 32768] + max_kv_tokens: [2048, 4096, 8192, 16384, 32768] + batches: [4, 8, 16, 32] + # uniform: all sequences in the batch are similar length + # longtail: a few long sequences plus many shorter ones + patterns: [uniform, longtail] + head_dims: [64, 96, 128] + head_pairs: [[16, 8], [16, 4], [16, 2], [16, 1]] + causal: [false] + +block_sparse: + enabled: false + batches: [1, 4, 8, 16, 32] + seqlen_pairs: [[1024, 1024], [2048, 2048], [4096, 4096], [4096, 8192]] + head_dims: [64, 128] + head_pairs: [[16, 16], [16, 4], [16, 1]] + # supported mask_names: block_diagonal, sliding_window + mask_names: [block_diagonal] + sliding_window_sizes: [2048] diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 6c9c20d0b76..23a96a17b1a 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -187,6 +187,7 @@ def __init__( "Paged KV does not support irregular head dim" ) + # ClC does not compose with these other features, so disable even if requested self.use_clc_scheduler = ( use_clc_scheduler and self.use_tma_KV diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index b01376a4214..4be51d38839 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -553,6 +553,13 @@ def _flash_attn_fwd( or seqused_k is not None ) + # CLC regressed for varlen MHA and dense noncausal. Imbalanced varlen shapes + # keep more K/V blocks in flight and hurt L2; dense noncausal mostly just + # pays work-stealing overhead. + is_varlen_mha = is_varlen and qhead_per_kvhead == 1 + is_dense_noncausal = not is_varlen and not causal and not local + use_clc_scheduler = requested_use_clc_scheduler and not is_varlen_mha and not is_dense_noncausal + if mask_mod is not None: if is_varlen: raise NotImplementedError( @@ -791,7 +798,7 @@ def _flash_attn_fwd( is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, q_subtile_factor=q_subtile_factor, use_2cta_instrs=use_2cta_instrs, - use_clc_scheduler=requested_use_clc_scheduler, + use_clc_scheduler=use_clc_scheduler, ) elif arch // 10 == 12: # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity diff --git a/tests/cute/test_clc_fuzz.py b/tests/cute/test_clc_fuzz.py index 022276d3281..c988681da3b 100644 --- a/tests/cute/test_clc_fuzz.py +++ b/tests/cute/test_clc_fuzz.py @@ -17,7 +17,12 @@ from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func from flash_attn.cute.testing import attention_ref -from flash_attn.cute.tile_scheduler import SchedulingMode, SingleTileLPTScheduler, SingleTileVarlenScheduler +from flash_attn.cute.tile_scheduler import ( + SchedulingMode, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + StaticPersistentTileScheduler, +) if torch.cuda.is_available(): @@ -60,10 +65,19 @@ def check_output(q, k, v, *, causal=False, window_size=(None, None), num_splits= torch.cuda.synchronize() if assert_clc and _captured_schedulers: sched_cls, sched_mode, use_2cta = _captured_schedulers[-1] - assert sched_cls is SingleTileLPTScheduler, f"Expected SingleTileLPTScheduler, got {sched_cls.__name__}" - assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" - if assert_2cta: - assert use_2cta, "Expected use_2cta_instrs=True but got False" + is_local = window_size != (None, None) + if causal or is_local: + assert sched_cls is SingleTileLPTScheduler, f"Expected SingleTileLPTScheduler, got {sched_cls.__name__}" + assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + if assert_2cta: + assert use_2cta, "Expected use_2cta_instrs=True but got False" + else: + assert sched_cls is StaticPersistentTileScheduler, ( + f"Expected StaticPersistentTileScheduler for dense noncausal, got {sched_cls.__name__}" + ) + assert sched_mode == SchedulingMode.STATIC, ( + f"Expected STATIC scheduling mode for dense noncausal, got {sched_mode!r}" + ) out_ref, _ = attention_ref(q, k, v, causal=causal, window_size=window_size) out_pt, _ = attention_ref(q, k, v, causal=causal, window_size=window_size, upcast=False, reorder_ops=True) fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() @@ -249,7 +263,7 @@ def test_overlap_sO_sQ_fallback(self): class TestCLCFallback: - def test_varlen_uses_clc(self): + def test_varlen_mha_uses_static(self): _captured_schedulers.clear() batch, seqlen, heads, d = 4, 256, 4, 128 lens = torch.tensor([64, 128, 32, 32], dtype=torch.int32) @@ -271,7 +285,7 @@ def test_varlen_uses_clc(self): assert sched_cls is SingleTileVarlenScheduler, ( f"Expected SingleTileVarlenScheduler for varlen, got {sched_cls.__name__}" ) - assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + assert sched_mode == SchedulingMode.STATIC, f"Expected STATIC scheduling mode, got {sched_mode!r}" @pytest.mark.parametrize("sq,sk,wl,wr", [ (512, 512, 128, 128), @@ -311,7 +325,10 @@ def check_varlen_output(seqlens, heads, d, *, causal=False, kv_heads=None, num_s if _captured_schedulers: sched_cls, sched_mode, *_ = _captured_schedulers[-1] assert sched_cls is SingleTileVarlenScheduler, f"Expected SingleTileVarlenScheduler, got {sched_cls.__name__}" - assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + expected_sched_mode = SchedulingMode.CLC if heads != kv_heads else SchedulingMode.STATIC + assert sched_mode == expected_sched_mode, ( + f"Expected {expected_sched_mode.name} scheduling mode, got {sched_mode!r}" + ) for i in range(len(seqlens)): s = slice(cu_seqlens[i], cu_seqlens[i + 1]) @@ -355,7 +372,10 @@ def check_varlen_output_seqused(seqlens, heads, d, *, causal=False, kv_heads=Non if _captured_schedulers: sched_cls, sched_mode, *_ = _captured_schedulers[-1] assert sched_cls is SingleTileVarlenScheduler, f"Expected SingleTileVarlenScheduler, got {sched_cls.__name__}" - assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + expected_sched_mode = SchedulingMode.CLC if heads != kv_heads else SchedulingMode.STATIC + assert sched_mode == expected_sched_mode, ( + f"Expected {expected_sched_mode.name} scheduling mode, got {sched_mode!r}" + ) out_ref, _ = attention_ref(q, k, v, q_mask, k_mask, causal=causal) out_pt, _ = attention_ref(q, k, v, q_mask, k_mask, causal=causal, upcast=False, reorder_ops=True) From b322ae2675065ad96ad3c248fd6ef0f32252808f Mon Sep 17 00:00:00 2001 From: liangel-02 Date: Thu, 16 Apr 2026 13:41:21 -0400 Subject: [PATCH 16/44] expose num_splits for FA2 and add option for kernel blocksize alignment (#2448) --- csrc/flash_attn/flash_api.cpp | 8 ++++---- csrc/flash_attn/src/flash_fwd_launch_template.h | 12 ++++++++++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index c0c0e42176c..70270f40fff 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -532,7 +532,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s int window_size_right, const float softcap, const bool return_softmax, - std::optional gen_) { + std::optional gen_, + int num_splits = 0) { // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; @@ -697,12 +698,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s params.page_block_size = page_block_size; // Keep references to these tensors to extend their lifetime at::Tensor softmax_lse_accum, out_accum; - if (seqlenq_ngroups_swapped) { - // Only apply split-k for decoding + if (paged_KV || seqlenq_ngroups_swapped) { std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads, head_size, max_seqlen_k, max_seqlen_q, head_size_rounded, - p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts); + p_dropout, num_splits, get_num_sm(get_current_device()), opts); } if (leftpad_k_.has_value()) { diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 934e7b9114b..b7831c5e832 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -162,11 +162,19 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int kBlockM = 64; // Fixed for all head dimensions + constexpr static int kBlockM = 64; // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, // and for headdim 192 with block size 64 x 128. constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd, Is_causal>(params, stream); + // if user specifies num_splits=1, we assume they want bitwise identical + // numerics across the split KV and standard kernels so we align kBLockN to + // match + if (params.num_splits == 1) { + constexpr static int kBlockN_standard = Headdim <= 64 ? 128 : 64; + run_flash_splitkv_fwd, Is_causal>(params, stream); + } else { + run_flash_splitkv_fwd, Is_causal>(params, stream); + } } template From b65ae6b175f2438de55601695b6a21971fc5e429 Mon Sep 17 00:00:00 2001 From: David Wang <21328423+dcw02@users.noreply.github.com> Date: Fri, 17 Apr 2026 13:51:49 -0400 Subject: [PATCH 17/44] [Cute,Fwd,Sm100] fp8 e4m3 and e5m2 support (#2109) * fa4 benchmark and correctness test * initial fa4 fp8 e4m3 support * fix tmem p overlap/signaling for fp8 * kv_stage 4 for fp8 * fp8 e5m2 support * update benchmark * fix lint * descale named tuple * compile time gating * fix lint * fix rescale bug * fp8 register tuning * defensive restore default rescale threshold to 0, add fp8 override * load effective descales helper * uint8 workaround for fp8 note --- .../cute/benchmark_flash_attention_fp8.py | 434 ++++++++++++++++++ flash_attn/cute/blackwell_helpers.py | 68 ++- flash_attn/cute/block_sparse_utils.py | 7 +- flash_attn/cute/cute_dsl_utils.py | 17 +- flash_attn/cute/flash_fwd_sm100.py | 125 ++++- flash_attn/cute/interface.py | 111 ++++- flash_attn/cute/mma_sm100_desc.py | 4 +- flash_attn/cute/softmax.py | 7 +- 8 files changed, 718 insertions(+), 55 deletions(-) create mode 100644 flash_attn/cute/benchmark_flash_attention_fp8.py diff --git a/flash_attn/cute/benchmark_flash_attention_fp8.py b/flash_attn/cute/benchmark_flash_attention_fp8.py new file mode 100644 index 00000000000..c79e7687237 --- /dev/null +++ b/flash_attn/cute/benchmark_flash_attention_fp8.py @@ -0,0 +1,434 @@ +# Benchmark FP8 attention for FA4 (CuTe-DSL) on SM100. +# +# Run (recommended): +# python -m flash_attn.cute.benchmark_flash_attention_fp8 +# +# Notes: +# - This is intended to be used while bringing up FP8 support for SM100. +# - FP8 correctness depends on descales + max-offset scaling being implemented in the SM100 kernel. +# This script optionally checks output vs a BF16 PyTorch baseline on dequantized FP8 inputs. +# +# Adapted from: `hopper/benchmark_flash_attention_fp8.py` + +from __future__ import annotations + +import argparse +import inspect +import math +import time +from typing import Iterable + +import torch +from einops import rearrange + +from flash_attn.cute.benchmark import benchmark_forward +from flash_attn.cute.interface import _flash_attn_fwd as flash_attn_cute_fwd + +try: + import cudnn +except ImportError: + cudnn = None + + +def _torch_float8_dtype(name: str) -> torch.dtype: + if name in ("fp8", "fp8_e4m3", "fp8_e4m3fn"): + return torch.float8_e4m3fn + if name in ("fp8_e5m2", "fp8_e5m2fn"): + return torch.float8_e5m2 + raise ValueError(f"Unsupported fp8 dtype name: {name}") + + +def _parse_int_list(csv: str) -> list[int]: + out: list[int] = [] + for part in csv.split(","): + part = part.strip() + if not part: + continue + out.append(int(part)) + return out + + +def attention_pytorch(qkv: torch.Tensor, causal: bool) -> torch.Tensor: + """ + qkv: (batch, seqlen, 3, nheads, headdim) + out: (batch, seqlen, nheads, headdim) + """ + batch_size, seqlen, _, nheads, d = qkv.shape + q, k, v = qkv.unbind(dim=2) + q = rearrange(q, "b t h d -> (b h) t d") + k = rearrange(k, "b s h d -> (b h) d s") + softmax_scale = 1.0 / math.sqrt(d) + scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) + scores = rearrange( + torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), "(b h) t s -> b h t s", h=nheads + ) + if causal: + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1) + output = torch.einsum("bhts,bshd->bthd", attention, v) + return output.to(dtype=qkv.dtype) + + +def flops(batch: int, seqlen: int, headdim: int, nheads: int, causal: bool) -> int: + # Matches the hopper benchmark’s convention. + return 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) + + +def efficiency(flop: int, seconds: float) -> float: + return (flop / seconds / 1e12) if not math.isnan(seconds) else 0.0 + + +def time_fwd(fn, *args, repeats: int, **kwargs) -> float: + time.sleep(1) # reduce residual throttling effects between benchmarks + _, m = benchmark_forward(fn, *args, repeats=repeats, verbose=False, **kwargs) + return float(m.mean) + + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + if torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + if torch_type == torch.float32: + return cudnn.data_type.FLOAT + if torch_type == torch.int32: + return cudnn.data_type.INT32 + if torch_type == torch.int64: + return cudnn.data_type.INT64 + if torch_type == torch.float8_e4m3fn: + return cudnn.data_type.FP8_E4M3 + if torch_type == torch.float8_e5m2: + return cudnn.data_type.FP8_E5M2 + raise ValueError("Unsupported tensor data type.") + + +def cudnn_sdpa_fp8_setup(qkv: torch.Tensor, seqlen_q: int, seqlen_k: int, causal: bool): + """Minimal cudnn.fp8 sdpa runner (optional).""" + assert cudnn is not None, "cudnn python bindings not available" + b, _, _, nheads, headdim = qkv.shape + o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device) + o_gpu_transposed = torch.as_strided( + o_gpu, + [b, nheads, seqlen_q, headdim], + [nheads * seqlen_q * headdim, headdim, nheads * headdim, 1], + ) + amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) + amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) + + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(qkv.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + new_q = torch.as_strided( + qkv, + [b, nheads, seqlen_q, headdim], + [seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], + storage_offset=0, + ) + q = graph.tensor( + name="Q", + dim=list(new_q.shape), + stride=list(new_q.stride()), + data_type=convert_to_cudnn_type(qkv.dtype), + ) + + new_k = torch.as_strided( + qkv, + [b, nheads, seqlen_k, headdim], + [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], + storage_offset=nheads * headdim, + ) + k = graph.tensor( + name="K", + dim=list(new_k.shape), + stride=list(new_k.stride()), + data_type=convert_to_cudnn_type(qkv.dtype), + ) + + new_v = torch.as_strided( + qkv, + [b, nheads, seqlen_k, headdim], + [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], + storage_offset=nheads * headdim * 2, + ) + v = graph.tensor( + name="V", + dim=list(new_v.shape), + stride=list(new_v.stride()), + data_type=convert_to_cudnn_type(qkv.dtype), + ) + + def _scale_tensor(): + return graph.tensor(dim=[1, 1, 1, 1], stride=[1, 1, 1, 1], data_type=cudnn.data_type.FLOAT) + + default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda") + descale_q = _scale_tensor() + descale_k = _scale_tensor() + descale_v = _scale_tensor() + descale_s = _scale_tensor() + scale_s = _scale_tensor() + scale_o = _scale_tensor() + + o, _, amax_s, amax_o = graph.sdpa_fp8( + q=q, + k=k, + v=v, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_s=descale_s, + scale_s=scale_s, + scale_o=scale_o, + is_inference=True, + attn_scale=1.0 / math.sqrt(headdim), + use_causal_mask=causal, + name="sdpa", + ) + o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride()) + amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride()) + amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride()) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + variant_pack = { + q: new_q, + k: new_k, + v: new_v, + descale_q: default_scale_gpu, + descale_k: default_scale_gpu, + descale_v: default_scale_gpu, + descale_s: default_scale_gpu, + scale_s: default_scale_gpu, + scale_o: default_scale_gpu, + o: o_gpu_transposed, + amax_s: amax_s_gpu, + amax_o: amax_o_gpu, + } + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + def run(): + graph.execute(variant_pack, workspace) + return o_gpu + + return run + + +def _maybe_pass_descales(callable_, **kwargs): + sig = inspect.signature(callable_) + return {k: v for k, v in kwargs.items() if k in sig.parameters} + + +def main(argv: Iterable[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--repeats", type=int, default=30) + parser.add_argument("--dim", type=int, default=2048) + parser.add_argument("--headdims", default="64,128") + parser.add_argument("--dtype", default="fp8_e4m3fn") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--check", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable correctness checks vs BF16 PyTorch baseline.", + ) + parser.add_argument( + "--check-quantization-only", + action="store_true", + help="Check FP8 kernel vs dequantized-FP8 baseline (quantization error only).", + ) + parser.add_argument("--atol-bf16", type=float, default=0.10) + parser.add_argument("--rtol-bf16", type=float, default=0.10) + parser.add_argument("--atol-fp8", type=float, default=0.50) + parser.add_argument("--rtol-fp8", type=float, default=0.50) + parser.add_argument("--run-cudnn", action="store_true") + args = parser.parse_args(list(argv) if argv is not None else None) + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + major, minor = torch.cuda.get_device_capability() + if major != 10: + raise RuntimeError( + f"This benchmark is for SM100 (compute capability 10.x). Got {major}.{minor}." + ) + + torch.manual_seed(args.seed) + device = "cuda" + fp8_dtype = _torch_float8_dtype(args.dtype) + headdim_vals = _parse_int_list(args.headdims) + bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] + + methods = ["Pytorch", "FA4-CuTe-BF16", "FA4-CuTe-FP8"] + ( + ["cuDNN-FP8"] if args.run_cudnn and cudnn is not None else [] + ) + + fp8_failures = [] + + for headdim in headdim_vals: + for causal in (False, True): + for batch, seqlen in bs_seqlen_vals: + torch.cuda.empty_cache() + nheads = args.dim // headdim + if args.dim % headdim != 0: + raise ValueError(f"--dim must be divisible by headdim ({args.dim=} {headdim=})") + + q_bf16 = torch.randn( + batch, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16 + ) + k_bf16 = torch.randn( + batch, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16 + ) + v_bf16 = torch.randn( + batch, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16 + ) + qkv_bf16 = torch.stack([q_bf16, k_bf16, v_bf16], dim=2) + + times = {} + speeds = {} + + out_ref_bf16 = None + try: + out_ref_bf16 = attention_pytorch(qkv_bf16, causal=causal) # warmup / reference + t = time_fwd(attention_pytorch, qkv_bf16, causal=causal, repeats=args.repeats) + times["Pytorch"] = t + except RuntimeError as e: + if "out of memory" in str(e).lower(): + times["Pytorch"] = float("nan") + out_ref_bf16 = None + else: + raise + + # FA4 / CuTe BF16 baseline + try: + softmax_scale = headdim**-0.5 + out_fa4_bf16, _ = flash_attn_cute_fwd( + q_bf16, k_bf16, v_bf16, softmax_scale=softmax_scale, causal=causal + ) # warmup / compile + t = time_fwd( + flash_attn_cute_fwd, + q_bf16, + k_bf16, + v_bf16, + softmax_scale=softmax_scale, + causal=causal, + repeats=args.repeats, + ) + times["FA4-CuTe-BF16"] = t + if args.check and out_ref_bf16 is not None: + torch.testing.assert_close( + out_fa4_bf16, + out_ref_bf16, + atol=args.atol_bf16, + rtol=args.rtol_bf16, + ) + except Exception as e: + # Treat as fatal: BF16 kernel should be usable for basic sanity checking. + raise RuntimeError("FA4-CuTe BF16 baseline failed") from e + + # FA4 / CuTe FP8 + q_fp8 = q_bf16.to(fp8_dtype) + k_fp8 = k_bf16.to(fp8_dtype) + v_fp8 = v_bf16.to(fp8_dtype) + + # Placeholder descales (FA3-style: per-(batch, kv_head)). + q_descale = torch.ones(batch, nheads, device=device, dtype=torch.float32) + k_descale = torch.ones(batch, nheads, device=device, dtype=torch.float32) + v_descale = torch.ones(batch, nheads, device=device, dtype=torch.float32) + + # Optional: FP8 reference baseline (dequantized FP8 -> PyTorch) for quantization-error-only checks + out_ref_fp8 = None + if args.check and args.check_quantization_only: + try: + # Dequantize FP8 inputs back to BF16 (applying descales) + q_ref_fp8 = (q_fp8.to(torch.bfloat16) * q_descale[:, None, :, None]).to( + torch.bfloat16 + ) + k_ref_fp8 = (k_fp8.to(torch.bfloat16) * k_descale[:, None, :, None]).to( + torch.bfloat16 + ) + v_ref_fp8 = (v_fp8.to(torch.bfloat16) * v_descale[:, None, :, None]).to( + torch.bfloat16 + ) + qkv_ref_fp8 = torch.stack([q_ref_fp8, k_ref_fp8, v_ref_fp8], dim=2) + out_ref_fp8 = attention_pytorch(qkv_ref_fp8, causal=causal) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + out_ref_fp8 = None + else: + raise + + fa4_kwargs = dict(softmax_scale=softmax_scale, causal=causal) + fa4_kwargs.update( + _maybe_pass_descales( + flash_attn_cute_fwd, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + ) + + try: + # Warmup/compile (will raise until FP8 is implemented) + out_fa4_fp8, _ = flash_attn_cute_fwd(q_fp8, k_fp8, v_fp8, **fa4_kwargs) + t = time_fwd( + flash_attn_cute_fwd, + q_fp8, + k_fp8, + v_fp8, + repeats=args.repeats, + **fa4_kwargs, + ) + times["FA4-CuTe-FP8"] = t + if args.check: + # Choose baseline: quantization-only (dequantized FP8) or full (BF16) + if args.check_quantization_only: + ref_baseline = out_ref_fp8 + else: + ref_baseline = out_ref_bf16 + + if ref_baseline is not None: + torch.testing.assert_close( + out_fa4_fp8, + ref_baseline, + atol=args.atol_fp8, + rtol=args.rtol_fp8, + ) + except Exception as e: + fp8_failures.append((causal, headdim, batch, seqlen, repr(e))) + times["FA4-CuTe-FP8"] = float("nan") + + if args.run_cudnn and cudnn is not None: + qkv_fp8 = qkv_bf16.to(fp8_dtype) + runner = cudnn_sdpa_fp8_setup(qkv_fp8, seqlen, seqlen, causal=causal) + _ = runner() # warmup + t = time_fwd(lambda: runner(), repeats=args.repeats) + times["cuDNN-FP8"] = t + + print(f"### causal={causal}, headdim={headdim}, batch={batch}, seqlen={seqlen} ###") + for method in methods: + t = times.get(method, float("nan")) + speeds[method] = efficiency(flops(batch, seqlen, headdim, nheads, causal), t) + if math.isnan(t): + print(f"{method} fwd: (skipped)") + else: + print(f"{method} fwd: {speeds[method]:.2f} TFLOPs/s, {t * 1e3:.3f} ms") + if math.isnan(times.get("FA4-CuTe-FP8", float("nan"))): + print("FA4-CuTe-FP8 status: FAILED") + + if fp8_failures: + print(f"\nFP8 failures: {len(fp8_failures)} (showing first 5)") + for causal, headdim, batch, seqlen, err in fp8_failures[:5]: + print(f"- causal={causal} headdim={headdim} batch={batch} seqlen={seqlen}: {err}") + return 1 + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index 720778027b2..4caadce864a 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -10,6 +10,24 @@ import flash_attn.cute.mma_sm100_desc as sm100_desc +def _tcgen05_mma_kind(op: cute.nvgpu.tcgen05.mma.MmaOp) -> str: + if isinstance(op, tcgen05.mma.MmaF16BF16Op): + return "f16" + if isinstance(op, tcgen05.mma.MmaTF32Op): + return "tf32" + if isinstance(op, tcgen05.mma.MmaI8Op): + return "i8" + if isinstance(op, tcgen05.mma.MmaFP8Op): + return "f8f6f4" + if isinstance(op, tcgen05.mma.MmaMXF8Op): + return "mxf8f6f4" + if isinstance(op, tcgen05.mma.MmaMXF4Op): + return "mxf4" + if isinstance(op, tcgen05.mma.MmaMXF4NVF4Op): + return "mxf4nvf4" + raise TypeError(f"Unsupported tcgen05 MMA op kind: {type(op).__name__}") + + @cute.jit def gemm_w_idx( tiled_mma: cute.TiledMma, @@ -108,6 +126,7 @@ def gemm_ptx( sA_layout = sA.layout if sA is not None else None sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + kind = _tcgen05_mma_kind(op) if const_expr(not is_ts): sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( @@ -177,7 +196,7 @@ def gemm_ptx( f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t" f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" + f"tcgen05.mma.cta_group::1.kind::{kind} [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" "}\n", "r,r,r,r", has_side_effects=True, @@ -198,7 +217,7 @@ def gemm_ptx( ".reg .b64 smem_desc_b;\n\t" f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" + f"tcgen05.mma.cta_group::1.kind::{kind} [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" "}\n", "r,r,r,r", has_side_effects=True, @@ -223,6 +242,7 @@ def gemm_ptx_loop( sA_layout = sA.layout if sA is not None else tCrA.layout sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + kind = _tcgen05_mma_kind(op) if const_expr(not is_ts): sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( @@ -310,14 +330,14 @@ def gemm_ptx_loop( f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) ) @@ -351,14 +371,14 @@ def gemm_ptx_loop( f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) ) @@ -394,6 +414,7 @@ def gemm_ptx_partial( sA_layout = sA.layout if sA is not None else tCrA.layout sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + kind = _tcgen05_mma_kind(op) if const_expr(not is_ts): sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( @@ -477,7 +498,7 @@ def gemm_ptx_partial( f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" @@ -486,7 +507,7 @@ def gemm_ptx_partial( f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) for k in range(1, cute.size(tCrA.shape[2])) ) @@ -554,7 +575,7 @@ def gemm_ptx_partial( f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" @@ -562,7 +583,7 @@ def gemm_ptx_partial( f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range( 1, @@ -575,7 +596,7 @@ def gemm_ptx_partial( ( f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range(split_arrive_idx, cute.size(tCrA.shape[2])) ) @@ -613,6 +634,7 @@ def gemm_ptx_partial1( assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + kind = _tcgen05_mma_kind(op) if const_expr(not is_ts): smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( @@ -706,14 +728,14 @@ def gemm_ptx_partial1( f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $4, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + "".join( ( f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" ) for k in range(1, cute.size(tCrA.shape[2])) ) @@ -751,13 +773,13 @@ def gemm_ptx_partial1( f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + "".join( ( f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" ) for k in range(1, cute.size(tCrA.shape[2])) ) @@ -783,6 +805,7 @@ def gemm_ptx_precomputed( mbar_phase: Optional[Int32] = None, zero_init: bool | Boolean = False, cta_group: int = 1, + kind: str = "f16", ) -> None: # acc_tmem_addr += acc_offset is_ts = const_expr(smem_desc_base_a is None) @@ -842,7 +865,7 @@ def gemm_ptx_precomputed( f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" @@ -851,7 +874,7 @@ def gemm_ptx_precomputed( f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) for k in range(1, num_k_tile) ) @@ -911,7 +934,7 @@ def gemm_ptx_precomputed( f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" @@ -919,7 +942,7 @@ def gemm_ptx_precomputed( f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range( 1, @@ -933,7 +956,7 @@ def gemm_ptx_precomputed( # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range(num_k_tile // 4 * 3, num_k_tile) ) @@ -1019,6 +1042,7 @@ def gemm_ptx_precomputed_varname( smem_offset: int, zero_init: bool | Boolean = False, cta_group: int = 1, + kind: str = "f16", ) -> None: is_ts = False num_k_tile = cute.size(tCrB_layout.shape[2]) @@ -1067,7 +1091,7 @@ def gemm_ptx_precomputed_varname( ) + "setp.ne.b32 p, $1, 0;\n\t" # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t" + "".join( ( # f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" @@ -1077,7 +1101,7 @@ def gemm_ptx_precomputed_varname( # f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\n\t" # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t" ) for k in range(1, num_k_tile) ) diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index 52cb7e06044..b19dcd37cf8 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -667,6 +667,8 @@ def handle_block_sparse_empty_tile_correction_sm100( o_corr_consumer_phase: Int32, corr_epi_producer_phase: Int32, softmax_scale_log2: Float32, + max_offset: Float32, + max_offset_scale: Float32, mO_cur: Optional[cute.Tensor] = None, gO: Optional[cute.Tensor] = None, gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, @@ -706,10 +708,11 @@ def handle_block_sparse_empty_tile_correction_sm100( if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0): if row_max_value == -Float32.inf: row_max_value = sink_val * (LOG2_E / softmax_scale_log2) - row_sum_value = Float32(1.0) + row_sum_value = max_offset_scale else: row_sum_value = row_sum_value + cute.math.exp2( - sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True + sink_val * LOG2_E - row_max_value * softmax_scale_log2 + max_offset, + fastmath=True, ) if tidx < m_block_size: scale_row_idx = tidx + stage * m_block_size diff --git a/flash_attn/cute/cute_dsl_utils.py b/flash_attn/cute/cute_dsl_utils.py index 636f7de4de5..6dfad6606ef 100644 --- a/flash_attn/cute/cute_dsl_utils.py +++ b/flash_attn/cute/cute_dsl_utils.py @@ -26,6 +26,8 @@ torch.float16: cutlass.Float16, torch.bfloat16: cutlass.BFloat16, torch.float32: cutlass.Float32, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e5m2: cutlass.Float8E5M2, } @@ -59,7 +61,20 @@ def assume_tensor_aligned(t): def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True): """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1.""" - tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi) + # NOTE: torch 2.9.1 doesn't support fp8 via DLPack but 2.11.0 nightly does + # currently export raw bytes as uint8 and tell cutlass correct type + # can directly export as fp8 when torch supports it + if t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + tensor = from_dlpack( + t.view(torch.uint8).detach(), + assumed_align=assumed_align, + enable_tvm_ffi=enable_tvm_ffi, + ) + tensor.element_type = ( + cutlass.Float8E4M3FN if t.dtype == torch.float8_e4m3fn else cutlass.Float8E5M2 + ) + else: + tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi) if fully_dynamic: return tensor.mark_layout_dynamic() if leading_dim == -1: diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 23a96a17b1a..75e767cdc44 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -14,7 +14,7 @@ # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py import math -from typing import Tuple, Callable, Optional, Literal +from typing import Tuple, Callable, Optional, Literal, NamedTuple from functools import partial import cuda.bindings.driver as cuda @@ -87,9 +87,25 @@ (True, False, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, (False, True, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 72}, } +_FP8_TUNING_CONFIG = { + (True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 160, 'num_regs_correction': 72}, +} +_FP8_SMALL_HDIM_REGS = { + False: {"num_regs_softmax": 168, "num_regs_correction": 96, "num_regs_other": 80}, + True: {"num_regs_softmax": 152, "num_regs_correction": 96, "num_regs_other": 112}, +} # === END TUNING KNOBS === +class DescaleTensors(NamedTuple): + q_descale: Optional[cute.Tensor] = None + k_descale: Optional[cute.Tensor] = None + v_descale: Optional[cute.Tensor] = None + + def __new_from_mlir_values__(self, values): + return DescaleTensors(*((*values, None, None, None)[:3])) + + class FlashAttentionForwardSm100: def __init__( @@ -351,6 +367,7 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, + descale_tensors: Optional[DescaleTensors] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). @@ -406,6 +423,24 @@ def __call__( raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") if const_expr(self.q_dtype != self.v_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + if const_expr(self.q_dtype.width == 8): + paged_kv_non_tma = not self.use_tma_KV + if const_expr(self.head_dim_padded < 96): + fp8_regs = _FP8_SMALL_HDIM_REGS[paged_kv_non_tma] + self.num_regs_softmax = fp8_regs["num_regs_softmax"] + self.num_regs_correction = fp8_regs["num_regs_correction"] + self.num_regs_other = fp8_regs["num_regs_other"] + else: + fp8_tune = _FP8_TUNING_CONFIG.get( + (self.use_2cta_instrs, self.is_causal, self.head_dim_padded, self.is_sm103), {} + ) + if const_expr("ex2_emu_freq" in fp8_tune): + self._tune = {**self._tune, **fp8_tune} + self.enable_ex2_emu = self._tune["ex2_emu_freq"] > 0 + if const_expr(not paged_kv_non_tma and "num_regs_softmax" in fp8_tune): + self.num_regs_softmax = fp8_tune["num_regs_softmax"] + self.num_regs_correction = fp8_tune["num_regs_correction"] + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_correction self._setup_attributes() self.use_tma_O = ( self.arch >= Arch.sm_90 @@ -705,6 +740,7 @@ class SharedStorage: window_size_left, window_size_right, learnable_sink, + descale_tensors, blocksparse_tensors, sQ_layout, sK_layout, @@ -751,6 +787,7 @@ def kernel( window_size_left: Optional[Int32], window_size_right: Optional[Int32], learnable_sink: Optional[cute.Tensor], + descale_tensors: Optional[DescaleTensors], blocksparse_tensors: Optional[BlockSparseTensors], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, @@ -1179,6 +1216,7 @@ def kernel( self.softmax_loop, softmax_scale_log2=softmax_scale_log2, softmax_scale=softmax_scale, + descale_tensors=descale_tensors, thr_mma_qk=thr_mma_qk, sScale=sScale, mLSE=mLSE, @@ -1234,6 +1272,7 @@ def kernel( sm_stats_barrier, pipeline_o_epi, learnable_sink, + descale_tensors, gmem_tiled_copy_O, tma_atom_O, softmax_scale_log2, @@ -1499,6 +1538,7 @@ def mma( qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op qk_mma_idesc, pv_mma_idesc = sm100_desc.mma_op_to_idesc(qk_mma_op), sm100_desc.mma_op_to_idesc(pv_mma_op) + qk_mma_kind = sm100_utils._tcgen05_mma_kind(qk_mma_op) q_smem_base = sm100_desc.smem_desc_base_from_tensor(sQ, sm100_desc.Major.K) k_smem_base = sm100_desc.smem_desc_base_from_tensor(sK, sm100_desc.Major.K) v_smem_base = sm100_desc.smem_desc_base_from_tensor(sV, sm100_desc.Major.MN) @@ -1527,6 +1567,7 @@ def mma( tCrB_layout=tSrK[None, None, None, 0].layout, smem_var_name_prefix=f"fa_fwd_q_smem_desc", idesc_var_name=f"fa_fwd_qk_mma_idesc", + kind=qk_mma_kind, smem_offset=-sQ_stage_stride if stage == 0 else sQ_stage_stride, zero_init=True, cta_group=self.cta_group_size, @@ -1755,12 +1796,39 @@ def mma( # pipeline_o_acc.producer_acquire() inside the loop. # for both softmax0 and softmax1 warp group + @cute.jit + def _kv_head_idx(self, head_idx: Int32) -> Int32: + """Map query-head tile index -> KV-head index (FA3 descale semantics).""" + if cutlass.const_expr(self.pack_gqa): + return head_idx + return head_idx // self.qhead_per_kvhead + + @cute.jit + def _load_effective_descales( + self, + descale_tensors: Optional[DescaleTensors], + batch_idx: Int32, + kv_head_idx: Int32, + ) -> Tuple[Float32, Float32]: + """Load effective QK and V descales, defaulting unspecified tensors to identity.""" + qk_descale = Float32(1.0) + v_descale = Float32(1.0) + if cutlass.const_expr(descale_tensors is not None): + if cutlass.const_expr(descale_tensors.q_descale is not None): + qk_descale = qk_descale * Float32(descale_tensors.q_descale[batch_idx, kv_head_idx]) + if cutlass.const_expr(descale_tensors.k_descale is not None): + qk_descale = qk_descale * Float32(descale_tensors.k_descale[batch_idx, kv_head_idx]) + if cutlass.const_expr(descale_tensors.v_descale is not None): + v_descale = Float32(descale_tensors.v_descale[batch_idx, kv_head_idx]) + return qk_descale, v_descale + @cute.jit def softmax_loop( self, stage: int | Int32, softmax_scale_log2: Float32, - softmax_scale: Float32, + softmax_scale: Float32 | None, + descale_tensors: Optional[DescaleTensors], thr_mma_qk: cute.core.ThrMma, tStS: cute.Tensor, # ((TILE_M, TILE_N), 1, 1, q_stage) sScale: cute.Tensor, @@ -1826,7 +1894,10 @@ def softmax_loop( ) tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 + tcgen05.copy.St32x32bOp( + tcgen05.copy.Repetition(8 if const_expr(self.q_dtype.width == 8) else 16) + ), + Float32, ) thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) tStP_r2t = thr_tmem_store.partition_D(tStP) # (((16,32),1),1,4) @@ -1842,6 +1913,7 @@ def softmax_loop( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + kv_head_idx = self._kv_head_idx(head_idx) seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) @@ -1896,10 +1968,26 @@ def softmax_loop( else: mask_fn_none = None + qk_descale, _ = self._load_effective_descales(descale_tensors, batch_idx, kv_head_idx) + + max_offset = 8 if cutlass.const_expr(self.q_dtype.width == 8) else 0 + if const_expr(self.score_mod is None): + softmax_scale_log2_eff = softmax_scale_log2 * qk_descale + softmax_scale_eff = None + else: + softmax_scale_log2_eff = softmax_scale_log2 + softmax_scale_eff = softmax_scale * qk_descale + + rescale_threshold = ( + 8.0 if const_expr(self.q_dtype.width == 16) else + 4.0 if const_expr(self.q_dtype.width == 8) else + 0.0 + ) softmax = SoftmaxSm100.create( - softmax_scale_log2, - rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, - softmax_scale=softmax_scale, + softmax_scale_log2_eff, + rescale_threshold=rescale_threshold, + softmax_scale=softmax_scale_eff, + max_offset=max_offset, ) softmax.reset() @@ -2245,6 +2333,7 @@ def correction_loop( sm_stats_barrier: pipeline.NamedBarrier, pipeline_o_epi: pipeline.PipelineAsync, learnable_sink: Optional[cute.Tensor], + descale_tensors: Optional[DescaleTensors], gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: cute.CopyAtom, softmax_scale_log2: Float32, @@ -2285,6 +2374,17 @@ def correction_loop( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + kv_head_idx = self._kv_head_idx(head_idx) + qk_descale, v_descale = self._load_effective_descales(descale_tensors, batch_idx, kv_head_idx) + if const_expr(self.score_mod is None): + softmax_scale_log2_eff = softmax_scale_log2 * qk_descale + else: + softmax_scale_log2_eff = softmax_scale_log2 + + max_offset = Float32(8.0) if cutlass.const_expr(self.q_dtype.width == 8) else Float32(0.0) + max_offset_scale = ( + Float32(256.0) if cutlass.const_expr(self.q_dtype.width == 8) else Float32(1.0) + ) seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) @@ -2387,15 +2487,16 @@ def correction_loop( if const_expr(not self.is_split_kv) or split_idx == 0: if row_max == -Float32.inf: # It's possible to have an empty row with splitKV. - row_max = sink_val * (LOG2_E / softmax_scale_log2) - row_sum = Float32(1.0) + row_max = sink_val * (LOG2_E / softmax_scale_log2_eff) + row_sum = max_offset_scale else: row_sum += cute.math.exp2( - sink_val * LOG2_E - row_max * softmax_scale_log2, fastmath=True + sink_val * LOG2_E - row_max * softmax_scale_log2_eff + max_offset, fastmath=True ) acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) + scale = scale * v_descale # Wait for the last O to be ready from the MMA warp pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase) if const_expr(not self.use_correction_warps_for_epi): @@ -2459,7 +2560,9 @@ def correction_loop( sm_stats_consumer_phase, o_corr_consumer_phase, corr_epi_producer_phase, - softmax_scale_log2, + softmax_scale_log2_eff, + max_offset, + max_offset_scale, mO_cur, gO, gmem_tiled_copy_O_for_empty_tile, @@ -2486,7 +2589,7 @@ def correction_loop( # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) LN2 = math.log(2.0) lse = ( - (row_max * softmax_scale_log2 + cute.math.log2(row_sum, fastmath=True)) * LN2 + (row_max * softmax_scale_log2_eff + (cute.math.log2(row_sum, fastmath=True) - max_offset)) * LN2 if not acc_O_mn_row_is_zero_or_nan else -Float32.inf ) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 4be51d38839..5b9e382d217 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -34,7 +34,7 @@ ) from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 -from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100, DescaleTensors from flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 @@ -237,11 +237,12 @@ def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device): if not is_fake_mode(): assert t.is_cuda, f"{name} must be on CUDA" - torch2cute_dtype_map = { torch.float16: cutlass.Float16, torch.bfloat16: cutlass.BFloat16, torch.float32: cutlass.Float32, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e5m2: cutlass.Float8E5M2, } @@ -311,6 +312,9 @@ def _flash_attn_fwd( out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, aux_tensors: Optional[list[torch.Tensor]] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, gather_kv_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. @@ -327,6 +331,7 @@ def _flash_attn_fwd( aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. """ q, k, v = [maybe_contiguous(t) for t in (q, k, v)] + q_descale, k_descale, v_descale = [maybe_contiguous(t) for t in (q_descale, k_descale, v_descale)] num_head, head_dim = q.shape[-2:] if cu_seqlens_q is None: batch_size, seqlen_q = q.shape[:2] @@ -372,7 +377,9 @@ def _flash_attn_fwd( assert seqused_k is None or seqused_k.shape == (batch_size,), ( "seqused_k must have shape (batch_size,)" ) - assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" + assert q.dtype in [torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2], ( + "inputs must be float16, bfloat16, fp8 e4m3fn, or fp8 e5m2" + ) assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: if t is not None: @@ -393,6 +400,9 @@ def _flash_attn_fwd( q, k, v, + q_descale, + k_descale, + v_descale, cu_seqlens_q, cu_seqlens_k, seqused_q, @@ -415,7 +425,10 @@ def _flash_attn_fwd( if pack_gqa is None: pack_gqa = qhead_per_kvhead > 1 - out_torch_dtype = q.dtype + is_fp8 = q.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + if is_fp8 and (q.requires_grad or k.requires_grad or v.requires_grad): + raise NotImplementedError("FA4 CuTe FP8 backward is not supported yet (forward-only).") + out_torch_dtype = torch.bfloat16 if is_fp8 else q.dtype device = q.device q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) @@ -437,7 +450,18 @@ def _flash_attn_fwd( elif lse is not None: _validate_tensor(lse, "lse", lse_shape, torch.float32, device) + if is_fp8: + for t, name in ((q_descale, "q_descale"), (k_descale, "k_descale"), (v_descale, "v_descale")): + if t is not None: + _validate_tensor(t, name, (batch_size, num_head_kv), torch.float32, device) + else: + assert q_descale is None and k_descale is None and v_descale is None, ( + "q_descale/k_descale/v_descale are only supported for FP8 inputs" + ) + dtype = torch2cute_dtype_map[q.dtype] + if is_fp8: + assert arch // 10 == 10, "FP8 is only supported on SM100 (compute capability 10.x) for FA4 CuTe." use_block_sparsity = block_sparse_tensors is not None causal, local, window_size_left, window_size_right = _resolve_causal_local_window( @@ -611,6 +635,9 @@ def _flash_attn_fwd( assert head_dim == 64 and head_dim_v == 512, "only support MLA weight absorbed shape with qv" assert not local, "local not yet supported with qv" assert page_table is None, "page table not yet supported with qv" + assert q_descale is None and k_descale is None and v_descale is None, ( + "q_descale/k_descale/v_descale are not yet supported with qv" + ) assert not is_split_kv, "split kv not supported with qv" assert learnable_sink is None @@ -634,6 +661,7 @@ def _flash_attn_fwd( seqlen_k_boundary = min_seqlen_k disable_sparse_kv_bitmask = seqlen_k_boundary >= gather_kv_length else: + assert gather_kv_indices is None, "gather_kv_indices is only supported with qv" gather_kv_length = None sparse_kv = None disable_sparse_kv_bitmask = None @@ -658,6 +686,9 @@ def _flash_attn_fwd( window_size_left is not None, window_size_right is not None, learnable_sink is not None, + q_descale is not None, + k_descale is not None, + v_descale is not None, tile_m, tile_n, q_stage, @@ -705,6 +736,33 @@ def _flash_attn_fwd( else: lse_tensor = None + q_descale_tensor = ( + to_cute_tensor(q_descale, assumed_align=4, leading_dim=1) + if q_descale is not None + else None + ) + k_descale_tensor = ( + to_cute_tensor(k_descale, assumed_align=4, leading_dim=1) + if k_descale is not None + else None + ) + v_descale_tensor = ( + to_cute_tensor(v_descale, assumed_align=4, leading_dim=1) + if v_descale is not None + else None + ) + descale_tensors_tensor = ( + DescaleTensors( + q_descale=q_descale_tensor, + k_descale=k_descale_tensor, + v_descale=v_descale_tensor, + ) + if q_descale_tensor is not None + or k_descale_tensor is not None + or v_descale_tensor is not None + else None + ) + sparse_tensors = None if normalized_block_sparse_tensors is not None: sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) @@ -849,7 +907,7 @@ def _flash_attn_fwd( options="--enable-tvm-ffi", ) else: - _flash_attn_fwd.compile_cache[compile_key] = cute.compile( + compile_args = [ fa_fwd, q_tensor, k_tensor, @@ -868,20 +926,36 @@ def _flash_attn_fwd( sparse_tensors, cute_aux_tensors, current_stream, - options="--enable-tvm-ffi", - ) + ] + if arch // 10 in [10, 11]: + compile_args.insert(-3, descale_tensors_tensor) + _flash_attn_fwd.compile_cache[compile_key] = cute.compile(*compile_args, options="--enable-tvm-ffi") # In "fake mode", we will take torch fake tensors as input and the expected behaviors are: # - Use those fake metadata to populate compilation cache # - Return "fake" output tensors, which could be needed in follow-up fake operations # Thus, we skip the actual kernel invocation here. if not is_fake_mode(): + q_call, k_call, v_call = q.detach(), k.detach(), v.detach() + qv_call = qv.detach() if qv is not None else None + if is_fp8: + # need uint8 workaround until we pin torch >= 2.11.0 where fp8 export is supported + q_call = q_call.view(torch.uint8) + k_call = k_call.view(torch.uint8) + v_call = v_call.view(torch.uint8) + if qv_call is not None: + qv_call = qv_call.view(torch.uint8) + descale_tensors = ( + DescaleTensors(q_descale=q_descale, k_descale=k_descale, v_descale=v_descale) + if q_descale is not None or k_descale is not None or v_descale is not None + else None + ) if qv is not None: _flash_attn_fwd.compile_cache[compile_key]( - q.detach(), - qv.detach(), - k.detach(), - v.detach(), + q_call, + qv_call, + k_call, + v_call, out.detach(), lse, softmax_scale, @@ -895,10 +969,10 @@ def _flash_attn_fwd( window_size_right, ) else: - _flash_attn_fwd.compile_cache[compile_key]( - q.detach(), - k.detach(), - v.detach(), + call_args = [ + q_call, + k_call, + v_call, out.detach() if not is_split_kv else out_partial, lse_partial if is_split_kv else lse, softmax_scale, @@ -910,9 +984,14 @@ def _flash_attn_fwd( window_size_left, window_size_right, learnable_sink, + ] + if arch // 10 in [10, 11]: + call_args.append(descale_tensors) + call_args.extend([ normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, aux_tensors, - ) + ]) + _flash_attn_fwd.compile_cache[compile_key](*call_args) if is_split_kv: _flash_attn_fwd_combine( out_partial, diff --git a/flash_attn/cute/mma_sm100_desc.py b/flash_attn/cute/mma_sm100_desc.py index ab8dd098b92..9a2adfca8d7 100644 --- a/flash_attn/cute/mma_sm100_desc.py +++ b/flash_attn/cute/mma_sm100_desc.py @@ -83,9 +83,9 @@ def to_UMMA_format(cutlass_type) -> int: if cutlass_type is cutlass.TFloat32: return F16F32Format.TF32 # Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them - if cutlass_type is cutlass.FloatE4M3FN: + if cutlass_type is cutlass.Float8E4M3FN: return MXF8F6F4Format.E4M3 - if cutlass_type is cutlass.FloatE5M2: + if cutlass_type is cutlass.Float8E5M2: return MXF8F6F4Format.E5M2 raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}") diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 0565827b601..9369e0d49ca 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -169,12 +169,14 @@ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: @dataclass class SoftmaxSm100(Softmax): rescale_threshold: cutlass.Constexpr[float] = 0.0 + max_offset: cutlass.Constexpr[int] = 0 @staticmethod def create( scale_log2: Float32, rescale_threshold: cutlass.Constexpr[float] = 0.0, softmax_scale: Float32 | None = None, + max_offset: cutlass.Constexpr[int] = 0, ): num_rows = 1 arch = 100 @@ -188,6 +190,7 @@ def create( arch, softmax_scale, rescale_threshold=rescale_threshold, + max_offset=max_offset, ) @cute.jit @@ -258,11 +261,13 @@ def scale_subtract_rowmax( ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" row_max_scaled = row_max * self.scale_log2 + max_offset = Float32(self.max_offset) + bias = max_offset - row_max_scaled for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( (acc_S_row[i], acc_S_row[i + 1]), (self.scale_log2, self.scale_log2), - (-row_max_scaled, -row_max_scaled), + (bias, bias), ) @cute.jit From 881139b60c70dcf1406d397c5293585f5e1e5db3 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Sat, 18 Apr 2026 22:50:10 -0700 Subject: [PATCH 18/44] Expose --pack-gqa and --num-splits in benchmark_attn.py (#2473) Both flags override the kernel's internal heuristics so users can benchmark with forced settings instead of editing the script. Defaults are unchanged (num_splits=0 and pack_gqa=None both mean "kernel auto"). Useful for A/B comparisons. Example: MLA decode bs=32 seqlen_kv=65536 num_splits=2 on B200 gives 0.75 ms with --pack-gqa true and 50.32 ms with --pack-gqa false -- a 67x gap that confirms pack_gqa is the deciding factor for long-context MLA decode parallelism. Note: for non-MLA GQA with num_splits>1, interface.py may still force pack_gqa off regardless of --pack-gqa true (pending a separate fix). For MLA (qv is not None), the flag is honored. No new code path is exposed -- the flag only makes existing kernel options reachable from the CLI. --- benchmarks/benchmark_attn.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index 1f727461256..a8abdbd89d4 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -323,6 +323,13 @@ def parse_args(): help='kv sparsity length for MLA (hdim=64, hdim_v=512 only). ' 'When set, passes random kv indices (without repeats) to FA4 and uses gather-kv as ' 'the effective KV length for flops/bandwidth accounting.') + parser.add_argument('--num-splits', type=int, default=0, + help='Override kernel num_splits heuristic. 0 = auto (default). ' + '>1 forces SplitKV with that many splits.') + parser.add_argument('--pack-gqa', type=str.lower, choices=['auto', 'true', 'false'], default='auto', + help='Override kernel pack_gqa heuristic. auto = kernel default (default). ' + 'true/false force pack_gqa on or off. Note: for non-MLA GQA with ' + 'num_splits>1, interface.py may still force pack_gqa off pending a fix.') parser.add_argument('--warmup', type=int, default=5, help='Warmup iterations (default: 5)') parser.add_argument('--rep', type=int, default=10, @@ -406,10 +413,10 @@ def main(): has_qv = headdim == 64 and headdim_v == 512 sinks = None - num_splits = 0 + num_splits = args.num_splits window_size = (None, None) window_size_fa = (-1, -1) - pack_gqa = None + pack_gqa = None if args.pack_gqa == 'auto' else (args.pack_gqa == 'true') for seqlen, seqlen_q in zip(seqlen_list, seqlen_q_list): batch_size = args.batch_size if args.batch_size is not None else max(1, args.total_seqlen // seqlen) From bbda031f1cd1adf1a57a3b3e8cdc3db2b54d3994 Mon Sep 17 00:00:00 2001 From: Thomas Young Date: Mon, 20 Apr 2026 11:06:45 +0800 Subject: [PATCH 19/44] Fix: pass num_splits through varlen_fwd Python wrapper (fixes #2448 regression) (#2476) PR #2448 added `int num_splits = 0` as a trailing positional arg of `mha_varlen_fwd` in csrc/flash_attn/flash_api.cpp but did not update the Python wrapper nor the pybind11 binding to expose that default. Because the binding at the bottom of flash_api.cpp is just `m.def("varlen_fwd", &mha_varlen_fwd, ...)` (no `py::arg("num_splits") = 0`), pybind11 does not honour the C++ default value, so every call from `_flash_attn_varlen_forward` now fails with: TypeError: varlen_fwd(): incompatible function arguments. This patch plumbs `num_splits` through `_flash_attn_varlen_forward` (and its fake counterpart, to keep the torch.custom_op schemas in sync) and passes it to `flash_attn_gpu.varlen_fwd`, restoring the previous behaviour and exposing the knob to Python callers. Co-authored-by: Claude Opus 4.7 (1M context) --- flash_attn/flash_attn_interface.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 6c74e3af8cd..9fa5c873dd7 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -171,6 +171,7 @@ def _flash_attn_varlen_forward( leftpad_k: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, zero_tensors: bool = False, + num_splits: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( @@ -195,6 +196,7 @@ def _flash_attn_varlen_forward( softcap, return_softmax, None, + num_splits, ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() @@ -222,6 +224,7 @@ def _flash_attn_varlen_forward_fake( leftpad_k: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, zero_tensors: bool = False, + num_splits: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] paged_kv = block_table is not None From 03cd0651b469fe3390b6b17fed1a7e9c23d32c91 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Mon, 20 Apr 2026 09:40:43 -0700 Subject: [PATCH 20/44] [Cute,Fwd] Fix crash when K/V has seqlen == 0 (#2470) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue ----- flash_attn_func crashes when K/V is built with physical seqlen dim 0 (e.g. vLLM CUDA-graph capture with an all-padding batch): - causal=False: PTX IllegalInstruction (async CUDA error from kernel) - causal=True: host SIGFPE before kernel launch Root cause ---------- seqlen_k == 0 violates two downstream invariants: 1. TMA descriptor invariant (K physical seqlen > 0). On non-causal, StaticPersistentTileScheduler passes host-side setup and launches the kernel; the first TMA load over the 0-length K tensor goes OOB -> PTX IllegalInstruction. 2. LPT L2-swizzle heuristic invariant (size_one_head > 0). On causal, SingleTileLPTScheduler.Params.create in tile_scheduler.py computes size_one_head = seqlen_k * (headdim + headdim_v) * element_size which is 0, then evaluates "size_l2 // size_one_head" -> host integer divide by zero (SIGFPE) before the kernel launches. Fix --- Early return in interface._flash_attn_fwd when seqlen_k == 0 — zero the output, fill LSE with -inf, skip kernel launch. Guards both invariants at the boundary before either scheduler runs. Only affects the non-varlen path. Varlen is unchanged: its K tensor has physical seqlen > 0, and per-batch empty slots are already handled correctly by the kernel's fake-iteration path. Regression test: tests/cute/test_flash_attn.py::test_flash_attn_seqlen_k_zero covers both crash paths across seqlen_q in {1, 64, 128, 256} and d in {128, 192}. Co-authored-by: yunzhongOvO --- flash_attn/cute/interface.py | 6 ++++ tests/cute/test_flash_attn.py | 64 ++++++++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 5b9e382d217..17b703b3f21 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -450,6 +450,12 @@ def _flash_attn_fwd( elif lse is not None: _validate_tensor(lse, "lse", lse_shape, torch.float32, device) + if seqlen_k == 0: + out.zero_() + if lse is not None: + lse.fill_(float("-inf")) + return out, lse + if is_fp8: for t, name in ((q_descale, "q_descale"), (k_descale, "k_descale"), (v_descale, "v_descale")): if t is not None: diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index b551c01d7d6..e2fa31bdf88 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -2409,4 +2409,66 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # print(f"out2 max: {out2.abs().max().item()}, {iter=}") # print(f"out vs out2 max diff: {(out_cmp - out2).abs().max().item()}, {iter=}") # print(f"out vs out2 mean diff: {(out_cmp - out2).abs().mean().item()}, {iter=}") - assert torch.equal(out_cmp, out2), f"non-deterministic with max diff = {(out_cmp - out2).abs().max().item()} on {iter=}" \ No newline at end of file + assert torch.equal(out_cmp, out2), f"non-deterministic with max diff = {(out_cmp - out2).abs().max().item()} on {iter=}" + + +# --------------------------------------------------------------------------- +# Regression test: seqlen_k=0 must not crash (CUDA graph padding scenario) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("seqlen_q", [1, 64, 128, 256]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_seqlen_k_zero(seqlen_q, d, causal): + """K/V with physical seqlen dim == 0 must not crash. + + seqlen_k == 0 violates two downstream invariants, producing two + different crashes depending on the mask: + + causal=False -> TMA descriptor over a 0-length K tensor goes OOB + on first tile load -> PTX IllegalInstruction. + + causal=True -> SingleTileLPTScheduler's L2-swizzle heuristic in + tile_scheduler.py evaluates + size_l2 // (seqlen_k * (d + d_v) * elem_size) + -> host SIGFPE before the kernel launches. + + Varlen paths (cu_seqlens_k / seqused_k with K physical seqlen > 0) + are not exercised here: per-batch empty slots are already handled + by the kernel's fake-iteration path and do not hit either invariant. + """ + if IS_SM90: + pytest.skip("SM90 uses a different kernel path") + + device = "cuda" + dtype = torch.bfloat16 + dv = 128 if d == 192 else d + batch_size = 4 + nheads = 16 + nheads_kv = 16 + + torch.manual_seed(0) + + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + # K/V have physical seqlen dim == 0 — this is what crashes on unpatched FA4. + # causal=False hits GPU IllegalInstruction (TMA OOB on 0-length K tensor). + # causal=True hits host SIGFPE in tile_scheduler.py LPT L2-swizzle heuristic + # (size_l2 // size_one_head with size_one_head = seqlen_k*... = 0). + k = torch.empty(batch_size, 0, nheads_kv, d, device=device, dtype=dtype) + v = torch.empty(batch_size, 0, nheads_kv, dv, device=device, dtype=dtype) + + out, lse = flash_attn_func(q, k, v, causal=causal) + + if is_fake_mode(): + return + + # No crash above already validates the fix. Below validates the contract + # the early-return promises: zero output, -inf LSE. + assert out.shape == (batch_size, seqlen_q, nheads, dv), \ + f"Unexpected output shape: {out.shape}" + assert torch.all(out == 0).item(), \ + f"Expected all-zero output when seqlen_k=0, got max={out.abs().max().item():.6f}" + if lse is not None: + assert torch.all(torch.isinf(lse) & (lse < 0)).item(), \ + f"Expected all -inf LSE when seqlen_k=0, got: {lse}" From 41b2ef6cb5b372076802dd5202ca9441b8211d20 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Mon, 20 Apr 2026 11:08:03 -0700 Subject: [PATCH 21/44] fix causal calcs (#2463) * Fix script * Fix lower right causal bug and add clc parse viewer --- AI/CLC_TRACE_DEBUG.md | 82 ++++++++ AI/parse_clc_log.py | 396 ++++++++++++++++++++++++++++++++++++ CLAUDE.md | 2 +- benchmarks/clc_bench.py | 118 ++++++++--- benchmarks/configs/clc.yaml | 12 +- 5 files changed, 573 insertions(+), 37 deletions(-) create mode 100644 AI/CLC_TRACE_DEBUG.md create mode 100644 AI/parse_clc_log.py diff --git a/AI/CLC_TRACE_DEBUG.md b/AI/CLC_TRACE_DEBUG.md new file mode 100644 index 00000000000..9f1502aa57a --- /dev/null +++ b/AI/CLC_TRACE_DEBUG.md @@ -0,0 +1,82 @@ +# CLC Trace Debugging + +Use this when you suspect the CLC work scheduler is making surprising tile assignment decisions and you want a raw scheduler trace from the current kernel. + +## Current trace format + +SM100 forward kernels emit one trace line per scheduler-warp query at `FA_LOG_LEVEL=3`: + +```text +[CLC] query sm= cta= (m_blk=,h=,b=,s=) valid=<0|1> +``` + +Current emit sites: +- `flash_attn/cute/flash_fwd_sm100.py` +- `flash_attn/cute/flash_fwd_mla_sm100.py` + +## How to capture a trace + +Important: +- `FA_LOG_LEVEL=3` is needed for the `[CLC] query ...` device-side prints. +- `FA_CLC=1` only requests CLC; the kernel may still fall back if the shape/features disable it. + +Minimal repro pattern: + +```bash +FA_LOG_LEVEL=3 FA_CLC=1 CUDA_VISIBLE_DEVICES=0 python - <<'PY' \ + > agent_space/clc_trace.log 2>&1 +import torch +from flash_attn.cute.interface import flash_attn_func + +torch.manual_seed(0) +q = torch.randn(1, 512, 16, 128, device='cuda', dtype=torch.bfloat16) +k = torch.randn(1, 512, 1, 128, device='cuda', dtype=torch.bfloat16) +v = torch.randn(1, 512, 1, 128, device='cuda', dtype=torch.bfloat16) +flash_attn_func(q, k, v, causal=True) +torch.cuda.synchronize() +PY +``` + +If you want the run to say explicitly whether CLC was selected, keep the host log prefix too: + +```text +[FA] TileScheduler=SingleTileLPTScheduler, scheduling_mode=CLC, USE_2CTA=False +``` + +## What to look for + +- `scheduling_mode=CLC` in host logs confirms the shape actually used the CLC path. +- `valid=1` means the returned work tile is valid. +- `valid=0` means the scheduler is exhausted for that CTA/scheduler warp query. +- `m_blk`, `h`, `b`, `s` are the logical work coordinates after the scheduler mapping. +- `cta` is the physical `blockIdx.x`; for clustered launches multiple CTAs may participate in the same logical tile. + +## Parse the trace + +A lightweight parser lives in `AI/parse_clc_log.py`. + +Text summary: + +```bash +python AI/parse_clc_log.py agent_space/clc_trace.log +``` + +HTML view: + +```bash +python AI/parse_clc_log.py agent_space/clc_trace.log --html -o agent_space/clc_trace.html +``` + +## Suggested workflow + +1. Reproduce the surprising case with `FA_LOG_LEVEL=3 FA_CLC=1`. +2. Save stdout/stderr to `agent_space/clc_trace.log`. +3. Run `AI/parse_clc_log.py` on that log to get a compact per-SM / per-CTA summary. +4. If the trace still looks suspicious, attach or paste that log in the investigation thread / agent notes. +5. Compare against the relevant mapping logic in `flash_attn/cute/tile_scheduler.py`. + +## Caveats + +- The trace is noisy and expensive; use a single small shape first. +- Because the print happens on scheduler queries, many lines may be terminal `valid=0` queries after work is exhausted. +- Dense noncausal and varlen MHA may intentionally fall back away from CLC depending on the current heuristic in `flash_attn/cute/interface.py`. diff --git a/AI/parse_clc_log.py b/AI/parse_clc_log.py new file mode 100644 index 00000000000..c1b94543bf4 --- /dev/null +++ b/AI/parse_clc_log.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import html +import json +import re +import sys +from collections import Counter, defaultdict +from dataclasses import asdict, dataclass +from pathlib import Path + +TRACE_RE = re.compile( + r"\[CLC\]\s+query\s+sm=(?P\d+)\s+cta=(?P\d+)\s+" + r"\(m_blk=(?P-?\d+),h=(?P-?\d+),b=(?P-?\d+),s=(?P-?\d+)\)\s+" + r"valid=(?P[01])" +) + + +@dataclass(frozen=True) +class TraceRow: + sm: int + cta: int + m_blk: int + h: int + b: int + s: int + valid: int + + +def parse_rows(text: str) -> list[TraceRow]: + rows: list[TraceRow] = [] + for line in text.splitlines(): + match = TRACE_RE.search(line) + if match is None: + continue + rows.append(TraceRow(**{key: int(value) for key, value in match.groupdict().items()})) + return rows + + +def summarize(rows: list[TraceRow]) -> dict: + by_sm: dict[int, list[TraceRow]] = defaultdict(list) + by_cta: dict[int, list[TraceRow]] = defaultdict(list) + tile_counter: Counter[tuple[int, int, int, int, int]] = Counter() + for row in rows: + by_sm[row.sm].append(row) + by_cta[row.cta].append(row) + tile_counter[(row.m_blk, row.h, row.b, row.s, row.valid)] += 1 + + def encode_group(grouped: dict[int, list[TraceRow]]) -> dict[str, dict]: + out: dict[str, dict] = {} + for key, group_rows in sorted(grouped.items()): + out[str(key)] = { + "count": len(group_rows), + "valid_count": sum(row.valid for row in group_rows), + "invalid_count": sum(1 - row.valid for row in group_rows), + "first": asdict(group_rows[0]), + "last": asdict(group_rows[-1]), + "unique_tiles": len({(r.m_blk, r.h, r.b, r.s, r.valid) for r in group_rows}), + } + return out + + top_tiles = [ + { + "tile": { + "m_blk": tile[0], + "h": tile[1], + "b": tile[2], + "s": tile[3], + "valid": tile[4], + }, + "count": count, + } + for tile, count in tile_counter.most_common(20) + ] + + return { + "rows": len(rows), + "valid_rows": sum(row.valid for row in rows), + "invalid_rows": sum(1 - row.valid for row in rows), + "unique_sms": len(by_sm), + "unique_ctas": len(by_cta), + "by_sm": encode_group(by_sm), + "by_cta": encode_group(by_cta), + "top_tiles": top_tiles, + } + + +def format_summary(summary: dict) -> str: + lines = [ + f"rows={summary['rows']} valid={summary['valid_rows']} invalid={summary['invalid_rows']}", + f"unique_sms={summary['unique_sms']} unique_ctas={summary['unique_ctas']}", + "top_tiles:", + ] + for entry in summary["top_tiles"][:10]: + tile = entry["tile"] + lines.append( + f" count={entry['count']:>4} tile=(m_blk={tile['m_blk']}, h={tile['h']}, b={tile['b']}, s={tile['s']}, valid={tile['valid']})" + ) + lines.append("by_sm:") + for sm, sm_summary in summary["by_sm"].items(): + first = sm_summary["first"] + last = sm_summary["last"] + lines.append( + f" sm={sm:>3} count={sm_summary['count']:>4} valid={sm_summary['valid_count']:>4} invalid={sm_summary['invalid_count']:>4} " + f"first=(cta={first['cta']},m_blk={first['m_blk']},h={first['h']},b={first['b']},s={first['s']},v={first['valid']}) " + f"last=(cta={last['cta']},m_blk={last['m_blk']},h={last['h']},b={last['b']},s={last['s']},v={last['valid']})" + ) + return "\n".join(lines) + + +def visualize_html(rows: list[TraceRow], summary: dict) -> str: + by_sm: dict[int, list[TraceRow]] = defaultdict(list) + for row in rows: + by_sm[row.sm].append(row) + + data = [ + { + "sm": sm, + "tiles": [ + { + "id": r.m_blk, + "type": "INIT" if idx == 0 else "PULL", + "valid": bool(r.valid), + "m": r.m_blk, + "h": r.h, + "b": r.b, + "s": r.s, + "cta": r.cta, + } + for idx, r in enumerate(chain) + ], + } + for sm, chain in sorted(by_sm.items()) + ] + + total_tiles = sum(len(d["tiles"]) for d in data) + valid_pulls = sum(1 for d in data for t in d["tiles"] if t["type"] == "PULL" and t["valid"]) + work_per_sm = [sum(1 for t in d["tiles"] if t["valid"]) for d in data] + histogram = defaultdict(int) + for work in work_per_sm: + histogram[work] += 1 + histogram_data = [{"work": k, "count": v} for k, v in sorted(histogram.items())] + work_stats = { + "min": min(work_per_sm) if work_per_sm else 0, + "max": max(work_per_sm) if work_per_sm else 0, + "mean": (sum(work_per_sm) / len(work_per_sm)) if work_per_sm else 0.0, + "std": ( + sum((w - sum(work_per_sm) / len(work_per_sm)) ** 2 for w in work_per_sm) / len(work_per_sm) + ) ** 0.5 if work_per_sm else 0.0, + } + + return f""" + + + + +CLC Work Distribution Viewer + + + +

CLC Work Distribution Viewer

+
+ query-trace mode + SMs: {len(data)} + Total queries: {total_tiles} + Valid pulls: {valid_pulls} + Invalid queries: {summary['invalid_rows']} +
+
+ + + Press Esc to clear +
+
+
First query on SM
+
Later query / pull
+
Invalid / exhausted
+
+
+

Work Distribution Histogram min={work_stats['min']}, max={work_stats['max']}, mean={work_stats['mean']:.1f}, std={work_stats['std']:.2f}

+
+
+
+
+
+

SM

+
+
+
+ + + +""" + + +def read_text(path: str | None) -> str: + if path is None or path == "-": + return sys.stdin.read() + return Path(path).read_text() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Parse FlashAttention CLC trace lines.") + parser.add_argument("logfile", nargs="?", default="-", help="Trace log path or - for stdin") + parser.add_argument("--json", action="store_true", help="Emit JSON summary") + parser.add_argument("--rows", action="store_true", help="Emit parsed rows as JSON") + parser.add_argument("--html", action="store_true", help="Emit HTML view") + parser.add_argument("-o", "--output", help="Output path for --html") + args = parser.parse_args() + + rows = parse_rows(read_text(args.logfile)) + if args.rows: + print(json.dumps([asdict(row) for row in rows], indent=2)) + return + + summary = summarize(rows) + if args.html: + html_text = visualize_html(rows, summary) + if args.output is not None: + Path(args.output).write_text(html_text) + else: + print(html_text) + return + if args.json: + print(json.dumps(summary, indent=2)) + else: + print(format_summary(summary)) + + +if __name__ == "__main__": + main() diff --git a/CLAUDE.md b/CLAUDE.md index f170541d482..35d96195fd8 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -128,7 +128,7 @@ Env vars: `CUTE_CUBIN_PATH` (dump CUBIN/SASS), `CUTE_DSL_KEEP_PTX=1` (inspect PT ## Debugging GPU Kernels -See `AI/DEBUG_2CTA.md` for kernel hang/deadlock debugging (printf bisection, pipeline barrier analysis, 2CTA pitfalls). See `AI/RACECHECK_TMA_HAZARD.md` for `compute-sanitizer` false positives with `cp.async.bulk`. +See `AI/DEBUG_2CTA.md` for kernel hang/deadlock debugging (printf bisection, pipeline barrier analysis, 2CTA pitfalls). See `AI/RACECHECK_TMA_HAZARD.md` for `compute-sanitizer` false positives with `cp.async.bulk`. See `AI/CLC_TRACE_DEBUG.md` for visualization of CLC scheduling. Key tools: - `cute.printf` with thread guards (`tidx % 32 == 0`, `elect_one()`) for targeted output diff --git a/benchmarks/clc_bench.py b/benchmarks/clc_bench.py index 18e7358a6d7..46ee55980eb 100644 --- a/benchmarks/clc_bench.py +++ b/benchmarks/clc_bench.py @@ -50,9 +50,9 @@ class DenseSweep: enabled: bool = True batches: list[int] = field(default_factory=lambda: [1, 4, 8, 16, 32]) seqlen_pairs: list[list[int]] = field( - default_factory=lambda: [[1024, 1024], [2048, 2048], [4096, 4096], [8192, 8192], [16384, 16384]] + default_factory=lambda: [[32, 8192], [1024, 1024], [2048, 2048], [4096, 4096], [8192, 8192], [16384, 16384]] ) - head_dims: list[int] = field(default_factory=lambda: [64, 96, 128]) + head_dims: list[int | list[int]] = field(default_factory=lambda: [64, 96, 128, [192, 128]]) head_pairs: list[list[int]] = field(default_factory=lambda: [[16, 16], [16, 8], [16, 4], [16, 2], [16, 1]]) causal: bool | list[bool] = True @@ -64,7 +64,7 @@ class VarlenSweep: max_kv_tokens: list[int] = field(default_factory=lambda: [2048, 4096, 8192, 16384, 32768]) batches: list[int] = field(default_factory=lambda: [4, 8, 16, 32]) patterns: list[str] = field(default_factory=lambda: ["uniform", "longtail"]) - head_dims: list[int] = field(default_factory=lambda: [64, 96, 128]) + head_dims: list[int | list[int]] = field(default_factory=lambda: [64, 96, 128, [192, 128]]) head_pairs: list[list[int]] = field(default_factory=lambda: [[16, 8], [16, 4], [16, 2], [16, 1]]) causal: bool | list[bool] = False @@ -76,7 +76,7 @@ class BlockSparseSweep: seqlen_pairs: list[list[int]] = field( default_factory=lambda: [[1024, 1024], [2048, 2048], [4096, 4096], [4096, 8192]] ) - head_dims: list[int] = field(default_factory=lambda: [64, 128]) + head_dims: list[int | list[int]] = field(default_factory=lambda: [64, 128, [192, 128]]) head_pairs: list[list[int]] = field(default_factory=lambda: [[16, 16], [16, 4], [16, 1]]) mask_names: list[str] = field(default_factory=lambda: ["block_diagonal"]) sliding_window_sizes: list[int] = field(default_factory=lambda: [2048]) @@ -89,6 +89,7 @@ class Case: q_heads: int kv_heads: int d: int + dv: int causal: bool batch: int | None = None seqlen_q: int | None = None @@ -116,12 +117,36 @@ def token_label(value: int) -> str: return f"{value // 1024}k" if value >= 1024 and value % 1024 == 0 else str(value) -def dense_case_name(q_heads: int, kv_heads: int, causal: bool, d: int, batch: int, seqlen_q: int, seqlen_k: int) -> str: +def head_dim_label(d: int, dv: int) -> str: + return f"h{d}" if d == dv else f"h{d}_dv{dv}" + + +def head_dim_pairs(head_dims: list[int | list[int]]) -> list[tuple[int, int]]: + pairs: list[tuple[int, int]] = [] + invalid_pairs: list[int | list[int]] = [] + for dims in head_dims: + if isinstance(dims, int): + pairs.append((dims, dims)) + continue + if len(dims) == 1: + pairs.append((dims[0], dims[0])) + continue + if len(dims) == 2: + pairs.append((dims[0], dims[1])) + continue + invalid_pairs.append(dims) + if invalid_pairs: + raise ValueError(f"Expected d or [d] or [d, dv], got {invalid_pairs}") + return pairs + + +def dense_case_name(q_heads: int, kv_heads: int, causal: bool, d: int, dv: int, batch: int, seqlen_q: int, seqlen_k: int) -> str: causal_name = "causal" if causal else "noncausal" pair = head_pair_label(q_heads, kv_heads) + dims = head_dim_label(d, dv) if seqlen_q == seqlen_k: - return f"{pair}_{causal_name}_h{d}_{token_label(seqlen_q)}_b{batch}" - return f"{pair}_{causal_name}_q{seqlen_q}_k{seqlen_k}_h{d}_b{batch}" + return f"{pair}_{causal_name}_{dims}_{token_label(seqlen_q)}_b{batch}" + return f"{pair}_{causal_name}_q{seqlen_q}_k{seqlen_k}_{dims}_b{batch}" def varlen_case_name( @@ -130,14 +155,16 @@ def varlen_case_name( kv_heads: int, causal: bool, d: int, + dv: int, batch: int, max_q_tokens: int, max_kv_tokens: int, ) -> str: causal_name = "causal" if causal else "noncausal" pair = head_pair_label(q_heads, kv_heads) + dims = head_dim_label(d, dv) return ( - f"varlen_{pattern}_{pair}_{causal_name}_h{d}_" + f"varlen_{pattern}_{pair}_{causal_name}_{dims}_" f"b{batch}_q{token_label(max_q_tokens)}_kv{token_label(max_kv_tokens)}" ) @@ -200,21 +227,22 @@ def generate_cases( ) -> list[Case]: cases: list[Case] = [] if dense.enabled: - for batch, seqlen_pair, d, (q_heads, kv_heads), causal in product( + for batch, seqlen_pair, (d, dv), (q_heads, kv_heads), causal in product( dense.batches, dense.seqlen_pairs, - dense.head_dims, + head_dim_pairs(dense.head_dims), dense.head_pairs, bool_values(dense.causal), ): seqlen_q, seqlen_k = seqlen_pair cases.append( Case( - name=dense_case_name(q_heads, kv_heads, causal, d, batch, seqlen_q, seqlen_k), + name=dense_case_name(q_heads, kv_heads, causal, d, dv, batch, seqlen_q, seqlen_k), mode="dense", q_heads=q_heads, kv_heads=kv_heads, d=d, + dv=dv, causal=causal, batch=batch, seqlen_q=seqlen_q, @@ -222,12 +250,12 @@ def generate_cases( ) ) if varlen.enabled: - for max_q_tokens, max_kv_tokens, batch, pattern, d, (q_heads, kv_heads), causal in product( + for max_q_tokens, max_kv_tokens, batch, pattern, (d, dv), (q_heads, kv_heads), causal in product( varlen.max_q_tokens, varlen.max_kv_tokens, varlen.batches, varlen.patterns, - varlen.head_dims, + head_dim_pairs(varlen.head_dims), varlen.head_pairs, bool_values(varlen.causal), ): @@ -236,11 +264,12 @@ def generate_cases( lengths_k = normalize_lengths(weights, max(batch, max_kv_tokens)) cases.append( Case( - name=varlen_case_name(pattern, q_heads, kv_heads, causal, d, batch, max_q_tokens, max_kv_tokens), + name=varlen_case_name(pattern, q_heads, kv_heads, causal, d, dv, batch, max_q_tokens, max_kv_tokens), mode="varlen", q_heads=q_heads, kv_heads=kv_heads, d=d, + dv=dv, causal=causal, batch=batch, seqlens_q=lengths_q, @@ -249,10 +278,10 @@ def generate_cases( ) ) if block_sparse.enabled: - for batch, seqlen_pair, d, (q_heads, kv_heads), mask_name in product( + for batch, seqlen_pair, (d, dv), (q_heads, kv_heads), mask_name in product( block_sparse.batches, block_sparse.seqlen_pairs, - block_sparse.head_dims, + head_dim_pairs(block_sparse.head_dims), block_sparse.head_pairs, block_sparse.mask_names, ): @@ -263,16 +292,18 @@ def generate_cases( for window_size in window_sizes: window_label = f"_w{window_size}" if window_size is not None else "" pair = head_pair_label(q_heads, kv_heads) + dims = head_dim_label(d, dv) cases.append( Case( name=( f"block_sparse_{mask_name}{window_label}_{pair}_" - f"h{d}_q{seqlen_q}_k{seqlen_k}_b{batch}" + f"{dims}_q{seqlen_q}_k{seqlen_k}_b{batch}" ), mode="block_sparse", q_heads=q_heads, kv_heads=kv_heads, d=d, + dv=dv, causal=False, batch=batch, seqlen_q=seqlen_q, @@ -302,11 +333,12 @@ def compile_signature(case: Case) -> tuple: case.q_heads, case.kv_heads, case.d, + case.dv, case.mask_name, case.window_size, q_stage, ) - return case.mode, case.q_heads, case.kv_heads, case.d, case.causal, q_stage + return case.mode, case.q_heads, case.kv_heads, case.d, case.dv, case.causal, q_stage def select_compile_cases(cases: list[Case]) -> list[Case]: @@ -369,7 +401,7 @@ def build_cu_seqlens(torch_mod, lengths: list[int]) -> torch_mod.Tensor: def build_dense_inputs(torch_mod, flash_attn_func, case: Case, dtype, factory): q = factory(case.batch, case.seqlen_q, case.q_heads, case.d, device="cuda", dtype=dtype) k = factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) - v = factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + v = factory(case.batch, case.seqlen_k, case.kv_heads, case.dv, device="cuda", dtype=dtype) return flash_attn_func, dict(q=q, k=k, v=v, causal=case.causal) @@ -380,7 +412,7 @@ def build_varlen_inputs(torch_mod, flash_attn_varlen_func, case: Case, dtype, fa total_k = sum(lengths_k) q = factory(total_q, case.q_heads, case.d, device="cuda", dtype=dtype) k = factory(total_k, case.kv_heads, case.d, device="cuda", dtype=dtype) - v = factory(total_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + v = factory(total_k, case.kv_heads, case.dv, device="cuda", dtype=dtype) return flash_attn_varlen_func, dict( q=q, k=k, @@ -412,7 +444,7 @@ def build_block_sparse_inputs(torch_mod, flash_attn_func, case: Case, dtype, ten raise ValueError(f"Aux-backed block-sparse masks are not supported by clc_bench.py: {case.mask_name}") q = tensor_factory(case.batch, case.seqlen_q, case.q_heads, case.d, device="cuda", dtype=dtype) k = tensor_factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) - v = tensor_factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + v = tensor_factory(case.batch, case.seqlen_k, case.kv_heads, case.dv, device="cuda", dtype=dtype) cute_mask, _ = get_mask_pair( case.mask_name, seqlen_q=case.seqlen_q, @@ -468,11 +500,40 @@ def build_inputs(case: Case, dtype_name: DTypeName, fake_tensor: bool): def attended_pairs(seqlen_q: int, seqlen_k: int, causal: bool) -> float: + """Lower-right aligned causal: last query aligns with last key. + When M > N, only the bottom N query rows attend (triangle of size N), + so valid pairs = N*(N+1)/2, not the upper-left formula M*N - N*(N-1)/2. + """ if not causal: return float(seqlen_q * seqlen_k) if seqlen_q <= seqlen_k: return float(seqlen_q * (2 * seqlen_k - seqlen_q + 1) / 2) - return float(seqlen_q * seqlen_k - seqlen_k * (seqlen_k - 1) / 2) + return float(seqlen_k * (seqlen_k + 1) / 2) + + +def block_sparse_pairs(case: Case) -> float: + seqlen_q = case.seqlen_q or 0 + seqlen_k = case.seqlen_k or 0 + match case.mask_name: + case "block_diagonal": + total = 0 + for q_idx in range(seqlen_q): + block_start = (q_idx // BLOCK_SIZE_K) * BLOCK_SIZE_K + block_end = min(block_start + BLOCK_SIZE_K, seqlen_k) + total += max(0, block_end - block_start) + return float(total) + case "sliding_window": + window = case.window_size or 0 + offset = seqlen_k - seqlen_q + total = 0 + for q_idx in range(seqlen_q): + center = q_idx + offset + lower = max(0, center - window) + upper = min(seqlen_k - 1, center + window) + total += max(0, upper - lower + 1) + return float(total) + case _: + raise ValueError(f"Unsupported block-sparse FLOP mask: {case.mask_name}") def fwd_flops(case: Case, kwargs: dict | None = None) -> float: @@ -481,19 +542,15 @@ def fwd_flops(case: Case, kwargs: dict | None = None) -> float: case.seqlen_q or 0, case.seqlen_k or 0, case.causal, - ) * (case.d + case.d) + ) * (case.d + case.dv) if case.mode == "block_sparse": - if kwargs is None: - return 0.0 - total_blocks = kwargs["mask_block_cnt"].sum().item() - if kwargs["full_block_cnt"] is not None: - total_blocks += kwargs["full_block_cnt"].sum().item() - return float(total_blocks * BLOCK_SIZE_Q * BLOCK_SIZE_K * case.q_heads * 2 * (case.d + case.d)) + num_pairs = (case.batch or 0) * block_sparse_pairs(case) + return case.q_heads * 2 * num_pairs * (case.d + case.dv) lengths_q = case.seqlens_q or [] lengths_k = case.seqlens_k or lengths_q total = 0.0 for seqlen_q, seqlen_k in zip(lengths_q, lengths_k): - total += case.q_heads * 2 * attended_pairs(seqlen_q, seqlen_k, case.causal) * (case.d + case.d) + total += case.q_heads * 2 * attended_pairs(seqlen_q, seqlen_k, case.causal) * (case.d + case.dv) return total @@ -531,6 +588,7 @@ def case_metadata(case: Case) -> dict: "q_heads": case.q_heads, "kv_heads": case.kv_heads, "d": case.d, + "dv": case.dv, "causal": case.causal, "pattern": case.pattern, "mask_name": case.mask_name, diff --git a/benchmarks/configs/clc.yaml b/benchmarks/configs/clc.yaml index b7bc5d4a949..94daf11770d 100644 --- a/benchmarks/configs/clc.yaml +++ b/benchmarks/configs/clc.yaml @@ -7,9 +7,9 @@ bench_iters: 256 dense: enabled: true batches: [1, 4, 8, 16, 32] - seqlen_pairs: [[1024, 1024], [2048, 2048], [4096, 4096], [8192, 8192], [16384, 16384]] - head_dims: [64, 96, 128] - head_pairs: [[16, 16], [16, 8], [16, 4], [16, 2], [16, 1]] + seqlen_pairs: [[32, 8192], [2048, 2048], [4096, 4096], [8192, 8192], [16384, 16384]] + head_dims: [64, 96, 128, [192, 128]] + head_pairs: [[16, 16], [16, 8], [16, 4], [16, 1]] causal: [true] varlen: @@ -20,15 +20,15 @@ varlen: # uniform: all sequences in the batch are similar length # longtail: a few long sequences plus many shorter ones patterns: [uniform, longtail] - head_dims: [64, 96, 128] - head_pairs: [[16, 8], [16, 4], [16, 2], [16, 1]] + head_dims: [64, 128, [192, 128]] + head_pairs: [[16, 8], [16, 4], [16, 1]] causal: [false] block_sparse: enabled: false batches: [1, 4, 8, 16, 32] seqlen_pairs: [[1024, 1024], [2048, 2048], [4096, 4096], [4096, 8192]] - head_dims: [64, 128] + head_dims: [64, 128, [192, 128]] head_pairs: [[16, 16], [16, 4], [16, 1]] # supported mask_names: block_diagonal, sliding_window mask_names: [block_diagonal] From 3a7694c04748e3ba1cd954eb5dfdfd4b6646a234 Mon Sep 17 00:00:00 2001 From: geruome <85235464+geruome@users.noreply.github.com> Date: Wed, 22 Apr 2026 01:21:19 +0800 Subject: [PATCH 22/44] fix (#2481) Co-authored-by: wangziheng --- flash_attn/cute/flash_bwd_preprocess.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index d93ea5cc50b..48e590ecb67 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -241,6 +241,14 @@ def kernel( work_tile = tile_scheduler.initial_work_tile_info() m_block, head_idx, batch_idx, _ = work_tile.tile_idx + # This kernel is launched with use_pdl=True, so the GPU may start executing it in + # "prologue" mode while the previous stream kernel is still running. We must wait + # before touching any upstream GMEM outputs (mO, mdO, mLSE); otherwise we risk + # reading a partially-written dout, which silently corrupts dpsum = sum(O * dO) and + # propagates to dQ/dK via dS = P * (dP - dpsum). + if const_expr(self.use_pdl): + cute.arch.griddepcontrol_wait() + if work_tile.is_valid_tile: # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. From 27b4eb9822d35a6ab3a34e4d5335573e7212fbc5 Mon Sep 17 00:00:00 2001 From: wangsiyu Date: Fri, 24 Apr 2026 02:22:11 +0800 Subject: [PATCH 23/44] Feat([FA4][CUTE DSL]) Add head_dim=256 support (forward + backward) (#2412) * [Feat] Support flash-attention head_dim 256 in CuteDSL This PR adds head_dim=256 support to the FA4 FlashAttention implementation built with the CUTLASS CUTE DSL. * Forward: uses a 2-CTA design and introduces a new pipeline to better hide memory latency; includes a TMEM-based design for intermediate storage. * Backward: uses a 2-kernel approach and a 2-CTA design for the backward path. No API changes for existing head dimensions. But coding style should be adjusted step by step. This feature is authored by Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms, Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin. * Fix ruff lint errors in head_dim=256 changes Apply ruff check --fix and ruff format to bring the new hd256 files in line with the project's pre-commit config (flash_attn/cute/*.py, minus the excluded set in .pre-commit-config.yaml). Manual fixes: * mask.py: `Boolean(mask)` -> `cutlass.Boolean(mask)` (F821; other call sites in the file already use the qualified form). * sm100_hd256_2cta_fmha_backward_dkdvkernel.py: drop duplicate `SM100_TMEM_CAPACITY_COLUMNS = 512` local definition that shadowed the import from tile_scheduler (F811); the values were identical. * sm100_hd256_2cta_fmha_backward.py: both branches of the try/except ImportError imported the same two kernels once make_cotiled_copy/warp_reduction_sum were removed as unused; collapse to a single unconditional import. Auto-fixes: 41 unused imports (F401) + 2 f-strings without placeholders (F541) removed across sm100_hd256_2cta_fmha_{forward,backward,backward_dqkernel, backward_dkdvkernel}.py, tile_scheduler.py, mask.py. ruff format reformatted the 8 in-scope files touched by this PR. Verified: `ruff check` and `ruff format --check` both clean on flash_attn/cute/ (minus the pre-commit exclude list). Forward + varlen smoke tests on B200 pass (150 passed, 35 skipped, 0 failed across non-causal MHA, causal MHA, MQA/GQA, and varlen MHA at d=256). Backward kernels not yet test-exercised; change is imports/whitespace only and the kernels parse cleanly. --------- Co-authored-by: Johnsonms --- flash_attn/cute/flash_bwd_preprocess.py | 26 +- flash_attn/cute/interface.py | 224 +- flash_attn/cute/mask.py | 756 +++- .../cute/sm100_hd256_2cta_fmha_backward.py | 291 ++ ...100_hd256_2cta_fmha_backward_dkdvkernel.py | 3155 +++++++++++++++++ ...sm100_hd256_2cta_fmha_backward_dqkernel.py | 2145 +++++++++++ .../cute/sm100_hd256_2cta_fmha_forward.py | 1735 +++++++++ flash_attn/cute/tile_scheduler.py | 542 ++- flash_attn/cute/utils.py | 115 + tests/cute/test_flash_attn.py | 73 +- tests/cute/test_flash_attn_varlen.py | 2 +- 11 files changed, 8978 insertions(+), 86 deletions(-) create mode 100644 flash_attn/cute/sm100_hd256_2cta_fmha_backward.py create mode 100644 flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py create mode 100644 flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py create mode 100644 flash_attn/cute/sm100_hd256_2cta_fmha_forward.py diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index 48e590ecb67..8142def5ebb 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -43,6 +43,7 @@ def __init__( head_dim_v: int, tile_m: int = 128, num_threads: int = 256, + use_padded_offsets: bool = True, ): """ All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension @@ -64,6 +65,7 @@ def __init__( self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded self.num_threads = num_threads + self.use_padded_offsets = use_padded_offsets @staticmethod def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: @@ -258,9 +260,15 @@ def kernel( ) mO_cur = seqlen.offset_batch(mO, batch_idx, dim=0)[None, head_idx, None] mdO_cur = seqlen.offset_batch(mdO, batch_idx, dim=0)[None, head_idx, None] - mPdPsum_cur = seqlen.offset_batch(mPdPsum, batch_idx, dim=2, padded=True)[ - None, head_idx - ] + # Stats buffers (dpsum/lse_log2) are always consumed with padded q-offsets + # on the generic backward path (mdQaccum is present). Keep dedicated hd256 + # behavior controlled by self.use_padded_offsets. + stats_use_padded_offsets = self.use_padded_offsets + if const_expr(mdQaccum is not None): + stats_use_padded_offsets = True + mPdPsum_cur = seqlen.offset_batch( + mPdPsum, batch_idx, dim=2, padded=stats_use_padded_offsets + )[None, head_idx] headdim_v = mO_cur.shape[cute.rank(mO_cur) - 1] seqlen_q = seqlen.seqlen seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) @@ -338,7 +346,11 @@ def kernel( # Clear dQaccum if const_expr(mdQaccum is not None): mdQaccum_cur = seqlen.offset_batch( - mdQaccum, batch_idx, dim=2, padded=True, multiple=self.head_dim_padded + mdQaccum, + batch_idx, + dim=2, + padded=True, + multiple=self.head_dim_padded, )[None, head_idx] blkdQaccum_shape = (self.tile_m * self.head_dim_padded,) gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,)) @@ -349,9 +361,9 @@ def kernel( cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum) if const_expr(mLSE is not None): - mLSElog2_cur = seqlen.offset_batch(mLSElog2, batch_idx, dim=2, padded=True)[ - None, head_idx - ] + mLSElog2_cur = seqlen.offset_batch( + mLSElog2, batch_idx, dim=2, padded=stats_use_padded_offsets + )[None, head_idx] gLSElog2 = cute.local_tile(mLSElog2_cur, (self.tile_m,), (m_block,)) LOG2_E = math.log2(math.e) if tidx < seqlen_q_rounded - m_block * self.tile_m: diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 17b703b3f21..afdda11f351 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -45,6 +45,13 @@ from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine from flash_attn.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 +from cutlass import Int32 + +# SM100 head_dim=256 2CTA kernel imports +from flash_attn.cute.sm100_hd256_2cta_fmha_forward import BlackwellFusedMultiHeadAttentionForward +from flash_attn.cute.sm100_hd256_2cta_fmha_backward import BlackwellFusedMultiHeadAttentionBackward +from flash_attn.cute.mask import Sm100MaskEnum as MaskEnum + from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, get_sparse_q_block_size, @@ -86,6 +93,7 @@ def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, """Validate head dimension constraints based on compute capability.""" is_deepseek_shape = head_dim == 192 and head_dim_v == 128 is_deepseek_mla_absorbed_shape = head_dim == 64 and head_dim_v == 512 + is_dedicate_kernel_shape = head_dim == 256 and head_dim_v == 256 is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128 is_sm90_range = 8 <= head_dim <= 256 and 8 <= head_dim_v <= 256 @@ -95,9 +103,9 @@ def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, f"head_dim and head_dim_v must be between 8 and 256 and divisible by {alignment}." ) elif compute_capability in [10, 11]: - assert (is_standard_range or is_deepseek_shape or is_deepseek_mla_absorbed_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( + assert (is_standard_range or is_deepseek_shape or is_deepseek_mla_absorbed_shape or is_dedicate_kernel_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. " - f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek." + f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek, or (256, 256) for hd256." ) @@ -278,7 +286,6 @@ def _resolve_causal_local_window(causal, window_size_left, window_size_right, ma local = False return causal, local, window_size_left, window_size_right - def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, @@ -565,6 +572,10 @@ def _flash_attn_fwd( and (tile_m % qhead_per_kvhead == 0 or not pack_gqa) ) + # hd=256 2CTA forward uses dedicated kernel (SM100 only) + use_dedicated_hd256_kernel = arch // 10 == 10 and head_dim == 256 and head_dim_v == 256 + use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel + if softcap is not None: assert score_mod is None, "softcap and score_mod cannot be used together" score_mod = utils.create_softcap_scoremod(softcap) @@ -714,6 +725,7 @@ def _flash_attn_fwd( disable_sparse_kv_bitmask, fa_logging.get_fa_log_level(), ) + if compile_key not in _flash_attn_fwd.compile_cache: ( cu_seqlens_q_tensor, @@ -839,7 +851,25 @@ def _flash_attn_fwd( disable_bitmask=disable_sparse_kv_bitmask, ) else: - fa_fwd = FlashAttentionForwardSm100( + if use_dedicated_hd256_kernel: + # hd=256 2CTA forward: check for currently unsupported features + assert softcap is None, "SM100 forward with head_dim=256 does not support softcap" + assert not use_block_sparsity, \ + "SM100 forward with head_dim=256 does not support block sparsity" + assert learnable_sink is None, \ + "SM100 forward with head_dim=256 does not support learnable_sink" + assert seqused_q is None and seqused_k is None, \ + "SM100 forward with head_dim=256 does not support seqused_q/seqused_k" + # pack_gqa is an auto-selected optimization; disable it for hd256 kernel + pack_gqa = False + + flash_fwd_obj_cls = ( + BlackwellFusedMultiHeadAttentionForward + if use_dedicated_hd256_kernel + else FlashAttentionForwardSm100 + ) + + fa_fwd = flash_fwd_obj_cls( head_dim, head_dim_v, qhead_per_kvhead=qhead_per_kvhead, @@ -862,7 +892,7 @@ def _flash_attn_fwd( is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, q_subtile_factor=q_subtile_factor, use_2cta_instrs=use_2cta_instrs, - use_clc_scheduler=use_clc_scheduler, + use_clc_scheduler=requested_use_clc_scheduler, ) elif arch // 10 == 12: # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity @@ -1058,7 +1088,8 @@ def make_fake_bwd_tensors(dtype, has_gqa, varlen_q, varlen_k): def _compile_bwd_preprocess( - dtype, head_dim, head_dim_v, m_block_size, has_cuseqlens_q, has_seqused_q, has_dlse, + dtype, head_dim, head_dim_v, m_block_size, has_cuseqlens_q, has_seqused_q, has_dlse, has_dq_accum, + use_padded_offsets, ): """Compile bwd preprocess kernel using cute fake tensors (no real GPU tensors needed).""" mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( @@ -1069,7 +1100,10 @@ def _compile_bwd_preprocess( mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None mSequsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None mdLSE = fake_tensor(Float32, mLSE.shape, divisibility=1) if has_dlse else None - fa_bwd_pre = FlashAttentionBackwardPreprocess(dtype, head_dim, head_dim_v, m_block_size) + mdQaccum = mdQaccum if has_dq_accum else None + fa_bwd_pre = FlashAttentionBackwardPreprocess( + dtype, head_dim, head_dim_v, m_block_size, use_padded_offsets=use_padded_offsets + ) return cute.compile( fa_bwd_pre, mO, mdO, mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSequsedQ, mdLSE, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), @@ -1081,11 +1115,13 @@ def _bwd_preprocess( out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse, dtype, head_dim, head_dim_v, m_block_size, + use_padded_offsets=True, ): """Backward preprocess: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum.""" is_varlen = cu_seqlens_q is not None compile_key = ( - dtype, head_dim, head_dim_v, m_block_size, is_varlen, seqused_q is not None, dlse is not None, + dtype, head_dim, head_dim_v, m_block_size, is_varlen, seqused_q is not None, dlse is not None, dq_accum is not None, + use_padded_offsets, ) if compile_key not in _bwd_preprocess.compile_cache: _bwd_preprocess.compile_cache[compile_key] = _compile_bwd_preprocess(*compile_key) @@ -1198,6 +1234,7 @@ def _flash_attn_bwd( num_head, head_dim = q.shape[-2:] head_dim_v = v.shape[-1] + window_size = [window_size_left, window_size_right] causal, local, window_size_left, window_size_right = _resolve_causal_local_window( causal, window_size_left, window_size_right ) @@ -1273,6 +1310,9 @@ def _flash_attn_bwd( cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1 use_2cta_instrs = cluster_size==2 + use_dedicated_hd256_kernel = arch // 10 == 10 and head_dim == 256 and head_dim_v == 256 + use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel + q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ maybe_contiguous(t) for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -1389,12 +1429,16 @@ def _flash_attn_bwd( head_dim_rounded = (head_dim + 32 - 1) // 32 * 32 if cu_seqlens_q is None: - dq_accum = torch.empty( - batch_size, - num_head, - seqlen_q_rounded * head_dim_rounded, - dtype=torch.float32, - device=device, + dq_accum = ( + None + if use_dedicated_hd256_kernel + else torch.empty( + batch_size, + num_head, + seqlen_q_rounded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) ) dpsum = torch.empty( batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device @@ -1406,8 +1450,12 @@ def _flash_attn_bwd( total_q_rounded_padded = ( (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size ) - dq_accum = torch.empty( - num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device + dq_accum = ( + None + if use_dedicated_hd256_kernel + else torch.empty( + num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device + ) ) dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) @@ -1415,7 +1463,8 @@ def _flash_attn_bwd( # GQA (qhead_per_kvhead > 1) needs dK/dV accum+postprocess since multiple Q heads # accumulate into the same dK/dV. SM90 varlen_k with qhead_per_kvhead==1 now uses # ragged TMA tensors for direct store, so no longer needs accum+postprocess. - dKV_postprocess = qhead_per_kvhead > 1 + # hd=256 2CTA backward has its own internal postprocess for dK/dV. + dKV_postprocess = qhead_per_kvhead > 1 and not use_dedicated_hd256_kernel if dKV_postprocess: head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 if cu_seqlens_k is None: @@ -1467,10 +1516,12 @@ def _flash_attn_bwd( dV_semaphore = None # Preprocess kernel: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum. + # For hd=256 dedicated path, dq_accum is None so preprocess only fills dpsum/lse_log2. _bwd_preprocess( out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse, dtype, head_dim, head_dim_v, m_block_size, + use_padded_offsets=use_dedicated_hd256_kernel, ) # num_threads: SM90 derives from BwdConfig.num_wg, SM120 is set to 128 above, # SM100/SM110 uses default from function signature (384). @@ -1580,13 +1631,13 @@ def _flash_attn_bwd( (seqlen_q_rounded // m_block_size == 1), (seqlen_k_rounded // n_block_size == 1), ) + if compile_key not in _flash_attn_bwd.compile_cache: q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv) ] - dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ - to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2) - ] + lse_log2_tensor, dpsum_tensor = [to_cute_tensor(t) for t in (lse_log2, dpsum)] + dq_accum_tensor = to_cute_tensor(dq_accum) if dq_accum is not None else None if dKV_postprocess: dk_accum_tensor, dv_accum_tensor = [ to_cute_tensor(t) for t in (dk_accum, dv_accum) @@ -1654,28 +1705,61 @@ def _flash_attn_bwd( dQ_single_wg=dQ_single_wg, ) else: - fa_bwd_obj = FlashAttentionBackwardSm100( - head_dim, - head_dim_v, - is_causal=causal, - is_local=local, - qhead_per_kvhead=qhead_per_kvhead, - tile_m=m_block_size, - tile_n=n_block_size, - cluster_size=cluster_size, - use_2cta_instrs=use_2cta_instrs, - deterministic=deterministic, - score_mod=score_mod, - score_mod_bwd=score_mod_bwd, - mask_mod=mask_mod, - has_aux_tensors=aux_tensors is not None, - subtile_factor=subtile_factor, - ) + if use_dedicated_hd256_kernel: + assert softcap == 0.0, "SM100 backward with head_dim=256 does not support softcap" + assert block_sparse_tensors is None, \ + "SM100 backward with head_dim=256 does not support block sparsity" + assert dlse is None, \ + "SM100 backward with head_dim=256 does not support dlse" + assert seqused_q is None and seqused_k is None, \ + "SM100 backward with head_dim=256 does not support seqused_q/seqused_k" + + dq_tile_mn = (128, 128) + dkdv_tile_mn = (128, 64) + fa_bwd_obj = BlackwellFusedMultiHeadAttentionBackward( + head_dim, + head_dim_v, + is_causal=causal, + is_local=local, + qhead_per_kvhead=qhead_per_kvhead, + is_persistent=False, + deterministic=deterministic, + cluster_size=cluster_size, + use_2cta_instrs=use_2cta_instrs, + score_mod=score_mod, + score_mod_bwd=score_mod_bwd, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + subtile_factor=subtile_factor, + tile_m_dq=dq_tile_mn[0], + tile_n_dq=dq_tile_mn[1], + tile_m_dkdv=dkdv_tile_mn[0], + tile_n_dkdv=dkdv_tile_mn[1], + ) + else: + fa_bwd_obj = FlashAttentionBackwardSm100( + head_dim, + head_dim_v, + is_causal=causal, + is_local=local, + qhead_per_kvhead=qhead_per_kvhead, + tile_m=m_block_size, + tile_n=n_block_size, + cluster_size=cluster_size, + use_2cta_instrs=use_2cta_instrs, + deterministic=deterministic, + score_mod=score_mod, + score_mod_bwd=score_mod_bwd, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + subtile_factor=subtile_factor, + ) # Block sparse tensors for backward use Q-direction indexing (transposed from forward). sparse_tensors_compile = None if normalized_block_sparse_tensors is not None: sparse_tensors_compile = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) + dq_accum_tensor = dq_tensor if use_dedicated_hd256_kernel else dq_accum_tensor # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( @@ -1705,6 +1789,7 @@ def _flash_attn_bwd( options="--enable-tvm-ffi", ) if not is_fake_mode(): + dq_accum = dq if use_dedicated_hd256_kernel else dq_accum _flash_attn_bwd.compile_cache[compile_key]( q.detach(), k.detach(), @@ -1728,42 +1813,43 @@ def _flash_attn_bwd( aux_tensors, normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, ) - - if arch // 10 == 9: - # dQ postprocess: match main kernel's MMA WG count, unless dQ_single_wg - num_threads_post_dQ = 128 if dQ_single_wg else cfg.num_wg * 128 - num_threads_post_dKV = cfg.num_wg * 128 - else: - num_threads_post_dQ = 128 - num_threads_post_dKV = 128 - # Postprocess: convert dq_accum from float32 to dq in bf16/fp16 - _bwd_postprocess_convert( - dq_accum, dq, softmax_scale, - cu_seqlens_q, seqused_q, - arch, dtype, head_dim, m_block_size, num_threads_post_dQ, - AtomLayoutMdQ, dQ_swapAB, - use_2cta_instrs=use_2cta_instrs, cluster_size=1, - ) + # hd=256 2CTA backward has its own internal postprocess, skip here. + if not use_dedicated_hd256_kernel: + if arch // 10 == 9: + # dQ postprocess: match main kernel's MMA WG count, unless dQ_single_wg + num_threads_post_dQ = 128 if dQ_single_wg else cfg.num_wg * 128 + num_threads_post_dKV = cfg.num_wg * 128 + else: + num_threads_post_dQ = 128 + num_threads_post_dKV = 128 - if dKV_postprocess: - # Postprocess: convert dk_accum from float32 to dk in bf16/fp16 - _bwd_postprocess_convert( - dk_accum, dk, softmax_scale, - cu_seqlens_k, seqused_k, - arch, dtype, head_dim, n_block_size, num_threads_post_dKV, - AtomLayoutNdKV, dKV_swapAB, - cluster_size=cluster_size, - ) - # Postprocess: convert dv_accum from float32 to dv in bf16/fp16 _bwd_postprocess_convert( - dv_accum, dv, 1.0, - cu_seqlens_k, seqused_k, - arch, dtype, head_dim_v, n_block_size, num_threads_post_dKV, - AtomLayoutNdKV, dKV_swapAB, - cluster_size=cluster_size, + dq_accum, dq, softmax_scale, + cu_seqlens_q, seqused_q, + arch, dtype, head_dim, m_block_size, num_threads_post_dQ, + AtomLayoutMdQ, dQ_swapAB, + use_2cta_instrs=use_2cta_instrs, cluster_size=1, ) + if dKV_postprocess: + # Postprocess: convert dk_accum from float32 to dk in bf16/fp16 + _bwd_postprocess_convert( + dk_accum, dk, softmax_scale, + cu_seqlens_k, seqused_k, + arch, dtype, head_dim, n_block_size, num_threads_post_dKV, + AtomLayoutNdKV, dKV_swapAB, + cluster_size=cluster_size, + ) + # Postprocess: convert dv_accum from float32 to dv in bf16/fp16 + _bwd_postprocess_convert( + dv_accum, dv, 1.0, + cu_seqlens_k, seqused_k, + arch, dtype, head_dim_v, n_block_size, num_threads_post_dKV, + AtomLayoutNdKV, dKV_swapAB, + cluster_size=cluster_size, + ) + return dq, dk, dv @@ -2308,4 +2394,4 @@ def flash_attn_combine( seqused, varlen_batch_idx=varlen_batch_idx, ) - return out, lse + return out, lse \ No newline at end of file diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 99e7008ab82..a1971366183 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -1,14 +1,17 @@ # Copyright (c) 2025, Tri Dao. -from typing import Optional, Callable, TypeAlias +from typing import Optional, Callable, TypeAlias, Tuple from dataclasses import dataclass +import enum import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, Uint32, const_expr, Boolean +from cutlass import Float32, Int32, Uint32, const_expr +from cutlass.cutlass_dsl import min as dsl_min from quack import layout_utils import flash_attn.cute.utils as utils +from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.seqlen_info import SeqlenInfoQK MaskGenFn: TypeAlias = Callable[[int], Uint32] @@ -408,7 +411,7 @@ def apply_mask_sm100( for j in cutlass.range_constexpr(32): curr_col = col_start + j mask = (curr_mask_val >> j) & 1 - acc_S[curr_col] = acc_S[curr_col] if Boolean(mask) else -Float32.inf + acc_S[curr_col] = acc_S[curr_col] if cutlass.Boolean(mask) else -Float32.inf elif const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): @@ -721,3 +724,750 @@ def mask_gen_fn(s: int) -> Uint32: mask_gen_fn, rank1=True, ) + + +# ----------------------------------------------------------------------------- +# SM100 FMHA fused-mask policy layer (separate from generic mask primitives). +# ----------------------------------------------------------------------------- + + +class Sm100MaskEnum(enum.Enum): + """Enumeration of mask types for FMHA operations. + + - RESIDUAL_MASK: Residual mask for handling variable sequence lengths + - WINDOW_MASK: Window mask for attention which also includes causal and no mask + - WINDOW_MASK_INFERENCE: Same as the window mask, but has the limitation that the end of q is aligned with the end of k + - WINDOW_MASK_BWD: Window mask for backward pass + - WINDOW_MASK_BWD_INFERENCE: Same as the window mask for backward pass, but has the limitation that the end of q is aligned with the end of k + """ + + NO_MASK = enum.auto() + RESIDUAL_MASK = enum.auto() + CAUSAL_MASK = enum.auto() + WINDOW_MASK = enum.auto() + WINDOW_MASK_INFERENCE = enum.auto() + # Deprecated the following types + WINDOW_MASK_BWD = enum.auto() + WINDOW_MASK_BWD_INFERENCE = enum.auto() + RESIDUAL_MASK_BWD = enum.auto() + + +class Sm100FusedMask: + """A fused mask implementation for FMHA operations. + + This class handles different types of attention masks including no mask, + residual mask for variable sequence lengths, and causal mask for + autoregressive attention patterns. + + The class provides methods to: + - Calculate trip counts for different mask types + - Apply masks to attention scores + - Handle masked and unmasked trip calculations + """ + + def get_trip_count( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Int32: + """ + Calculate the number of trips needed for the current block. + + The trip count depends on the mask type and the block coordinates. + For causal masks, it considers the autoregressive constraint. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Number of trips needed. + :rtype: Int32 + """ + result = 0 + offset = 0 + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE): + offset = seqlen_k - seqlen_q + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE): + offset = seqlen_q - seqlen_k + if cutlass.const_expr(mask_type == Sm100MaskEnum.RESIDUAL_MASK): + result = cute.ceil_div(seqlen_k, tile_shape[1]) + if cutlass.const_expr(mask_type is Sm100MaskEnum.RESIDUAL_MASK_BWD): + result = cute.ceil_div(seqlen_q, tile_shape[0]) + if cutlass.const_expr( + mask_type == Sm100MaskEnum.WINDOW_MASK + or mask_type == Sm100MaskEnum.WINDOW_MASK_INFERENCE + ): + if cutlass.const_expr(window_size_right is None): + result = cute.ceil_div(seqlen_k, tile_shape[1]) + else: + max_idx_q = (blk_coord[0] + 1) * tile_shape[0] + idx_k = max_idx_q + offset + window_size_right + tmp_blocks_k = cute.ceil_div(idx_k, tile_shape[1]) + max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1]) + result = dsl_min(max_blocks_k, tmp_blocks_k) + if cutlass.const_expr( + mask_type == Sm100MaskEnum.WINDOW_MASK_BWD + or mask_type == Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE + ): + if cutlass.const_expr(window_size_left is None): + result = cute.ceil_div(seqlen_q, tile_shape[0]) + else: + max_idx_k = (blk_coord[1] + 1) * tile_shape[1] + idx_k = max_idx_k + offset + window_size_left + tmp_blocks_q = cute.ceil_div(idx_k, tile_shape[0]) + max_blocks_q = cute.ceil_div(seqlen_q, tile_shape[0]) + result = dsl_min(max_blocks_q, tmp_blocks_q) + start_block = Sm100FusedMask.get_trip_start( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + result = result - start_block + return result + + @cute.jit + def get_trip_start_count_via_block_info( + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + is_causal: cutlass.Constexpr[bool] = False, + is_local: cutlass.Constexpr[bool] = False, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Tuple[Int32, Int32]: + block_info = BlockInfo( + tile_m=tile_shape[0], + tile_n=tile_shape[1], + is_causal=is_causal, + is_local=is_local and not is_causal, + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + + seqlen_info = SeqlenInfoQK( + offset_q=Int32(0), + offset_k=Int32(0), + padded_offset_q=Int32(0), + padded_offset_k=Int32(0), + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + has_cu_seqlens_q=False, + has_cu_seqlens_k=False, + has_seqused_q=False, + has_seqused_k=False, + ) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen_info, blk_coord[0]) + return n_block_min, n_block_max - n_block_min + + @cute.jit + def get_trip_mask_bounds_via_block_info( + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + is_causal: cutlass.Constexpr[bool] = False, + is_local: cutlass.Constexpr[bool] = False, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Tuple[Int32, Int32]: + """Return SM100-style mask boundaries for dense iteration. + + Returns: + - n_block_min_causal_local_mask: right-side masked region start + - n_block_min_before_local_mask: start of fully unmasked middle region + """ + block_info = BlockInfo( + tile_m=tile_shape[0], + tile_n=tile_shape[1], + is_causal=is_causal, + is_local=is_local and not is_causal, + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + seqlen_info = SeqlenInfoQK( + offset_q=Int32(0), + offset_k=Int32(0), + padded_offset_q=Int32(0), + padded_offset_k=Int32(0), + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + has_cu_seqlens_q=False, + has_cu_seqlens_k=False, + has_seqused_q=False, + has_seqused_k=False, + ) + n_block_min, _ = block_info.get_n_block_min_max(seqlen_info, blk_coord[0]) + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen_info, blk_coord[0], n_block_min + ) + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen_info, blk_coord[0], n_block_min + ) + return n_block_min_causal_local_mask, n_block_min_before_local_mask + + @cute.jit + def get_trip_start( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Int32: + """ + Get the start of the trip for the current block. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + """ + result = 0 + offset = 0 + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE): + offset = seqlen_k - seqlen_q + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE): + offset = seqlen_q - seqlen_k + if cutlass.const_expr( + mask_type is Sm100MaskEnum.WINDOW_MASK + or mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE + ): + if cutlass.const_expr(window_size_left is not None): + min_idx_q = blk_coord[0] * tile_shape[0] + idx_k = min_idx_q + offset - window_size_left + tmp_blocks_k = idx_k // tile_shape[1] + result = max(tmp_blocks_k, result) + if cutlass.const_expr( + mask_type is Sm100MaskEnum.WINDOW_MASK_BWD + or mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE + ): + if cutlass.const_expr(window_size_right is not None): + min_idx_k = blk_coord[1] * tile_shape[1] + idx_q = min_idx_k + offset - window_size_right + tmp_blocks_q = idx_q // tile_shape[0] + result = max(tmp_blocks_q, result) + return result + + @cute.jit + def get_leading_mask_id( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Tuple[Int32, Int32]: + """ + Get the begin and end tile idx for the leading mask. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Tuple of (begin, end) tile idx for the leading mask. + :rtype: Tuple[Int32, Int32] + """ + offset = 0 + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE): + offset = seqlen_k - seqlen_q + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE): + offset = seqlen_q - seqlen_k + leading_mask_begin = Sm100FusedMask.get_trip_start( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + trip_count = Sm100FusedMask.get_trip_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + + leading_mask_end = leading_mask_begin + if cutlass.const_expr( + mask_type is Sm100MaskEnum.WINDOW_MASK + or mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE + ): + if cutlass.const_expr(window_size_left is not None): + min_idx_q = (blk_coord[0] + 1) * tile_shape[0] + offset - window_size_left + leading_mask_end = dsl_min( + cute.ceil_div(min_idx_q, tile_shape[1]) - 1, + trip_count + leading_mask_begin - 1, + ) + else: + leading_mask_end = leading_mask_begin - 1 + elif cutlass.const_expr( + mask_type is Sm100MaskEnum.WINDOW_MASK_BWD + or mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE + ): + if cutlass.const_expr(window_size_right is not None): + min_idx_k = (blk_coord[1] + 1) * tile_shape[1] + offset - window_size_right + leading_mask_end = cute.ceil_div(min_idx_k, tile_shape[0]) - 1 + else: + leading_mask_end = leading_mask_begin - 1 + return leading_mask_begin, leading_mask_end + + @cute.jit + def get_trailing_mask_id( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Tuple[Optional[Int32], Optional[Int32]]: + """ + Get the begin and end tile idx for the trailing mask. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Tuple of (begin, end) tile idx for the trailing mask. + :rtype: Tuple[Int32, Int32] + """ + offset = 0 + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE): + offset = seqlen_k - seqlen_q + if cutlass.const_expr(mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE): + offset = seqlen_q - seqlen_k + trip_start = Sm100FusedMask.get_trip_start( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + trip_count = Sm100FusedMask.get_trip_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + + trailing_mask_begin, trailing_mask_end = None, None + if cutlass.const_expr( + mask_type is Sm100MaskEnum.WINDOW_MASK + or mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE + ): + if cutlass.const_expr(window_size_right is not None): + min_idx_q = blk_coord[0] * tile_shape[0] + offset + window_size_right + trailing_mask_begin = dsl_min( + min_idx_q // tile_shape[1], trip_count + trip_start - 1 + ) + trailing_mask_end = trip_count + trip_start - 1 + else: + # last tile, we always apply mask on it regardless whether it's a residual tile + trailing_mask_begin = trip_count + trip_start - 1 + trailing_mask_end = trip_count + trip_start - 1 + else: + if cutlass.const_expr(window_size_left is not None): + min_idx_k = blk_coord[1] * tile_shape[1] + offset + window_size_left + 1 + max_idx_k = (blk_coord[1] + 1) * tile_shape[1] + offset + window_size_left + trailing_mask_begin = dsl_min( + cute.ceil_div(min_idx_k, tile_shape[0]) - 1, + trip_count + trip_start - 1, + ) + trailing_mask_end = dsl_min( + cute.ceil_div(max_idx_k, tile_shape[0]) - 1, + trip_count + trip_start - 1, + ) + else: + # last tile, we always apply mask on it regardless whether it's a residual tile + trailing_mask_begin = trip_count + trip_start - 1 + trailing_mask_end = trip_count + trip_start - 1 + + return trailing_mask_begin, trailing_mask_end + + @cute.jit + def get_masked_leading_count( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Int32: + """ + Calculate the number of masked trips for the leading mask. + + This is used for blocks that need special handling due to masking. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Number of masked trips. + :rtype: Int32 + """ + result = 0 + if cutlass.const_expr( + mask_type is not Sm100MaskEnum.RESIDUAL_MASK + and mask_type is not Sm100MaskEnum.RESIDUAL_MASK_BWD + ): + if cutlass.const_expr(window_size_left is not None or window_size_right is not None): + leading_mask_begin, leading_mask_end = Sm100FusedMask.get_leading_mask_id( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + result = max(leading_mask_end - leading_mask_begin + 1, 0) + + return result + + @cute.jit + def get_masked_trailing_count( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + rem_count: Optional[Int32] = 0, + ) -> Int32: + """ + Calculate the number of masked trips for the trailing mask. + + This is used for blocks that need special handling due to masking. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + :param rem_count: Remaining count from previous calculations. + :type rem_count: Int32 + + :return: Number of masked trips. + :rtype: Int32 + """ + result = 0 + + if cutlass.const_expr( + mask_type is not Sm100MaskEnum.RESIDUAL_MASK + and mask_type is not Sm100MaskEnum.RESIDUAL_MASK_BWD + ): + if cutlass.const_expr(window_size_left is not None or window_size_right is not None): + trailing_mask_begin, trailing_mask_end = Sm100FusedMask.get_trailing_mask_id( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + leading_mask_begin, leading_mask_end = Sm100FusedMask.get_leading_mask_id( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + if cutlass.const_expr( + trailing_mask_begin is not None and trailing_mask_end is not None + ): + if trailing_mask_begin <= leading_mask_end: + result = max(trailing_mask_end - leading_mask_end, 0) + else: + result = max(trailing_mask_end - trailing_mask_begin + 1, 0) + else: + if seqlen_k % tile_shape[1] != 0: + result = 1 + else: + result = 0 + + return result + rem_count + + @cute.jit + def get_unmasked_trip_count( + mask_type: Sm100MaskEnum, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + ) -> Int32: + """ + Calculate the number of unmasked trips for the current block. + + This represents the number of trips that don't require special + masking treatment. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param tile_shape: Shape of the tile. + :type tile_shape: cute.Shape + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Int32 + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[Int32] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[Int32] + + :return: Number of unmasked trips. + :rtype: Int32 + """ + result = ( + Sm100FusedMask.get_trip_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + - Sm100FusedMask.get_masked_leading_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + ) + - Sm100FusedMask.get_masked_trailing_count( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, + 0, + ) + ) + return result + + @cute.jit + def apply_mask( + mask_type: Sm100MaskEnum, + acc_qk: cute.Tensor, + index_qk: cute.Tensor, + seqlen_q: Int32, + seqlen_k: Int32, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + index_transform: cutlass.Constexpr = lambda index_q, index_k: ( + index_q, + index_k, + ), + ): + """ + Apply the appropriate mask to the attention scores. + + This method modifies the attention scores (acc_qk) based on the mask type + and the positions in the index tensor. + + :param mask_type: Type of mask to use + :type mask_type: utils.Sm100MaskEnum + :param acc_qk: Accumulated QK attention scores tensor. + :type acc_qk: cute.Tensor + :param index_qk: Index tensor containing position information. + :type index_qk: cute.Tensor + :param seqlen_k: Key sequence length for attention computation. + :type seqlen_k: Int32 + :param seqlen_q: Query sequence length for attention computation. + :type seqlen_q: Optional[int] + :param window_size_left: Left-side sliding window size for attention masking. + :type window_size_left: Optional[int] + :param window_size_right: Right-side sliding window size for attention masking. + :type window_size_right: Optional[int] + """ + + tidx, tidy, tidx = cute.arch.thread_idx() + offset = 0 + # NOTE: causal masking in this repo aligns the *end* of Q with the *end* of K + # when seqlen_k != seqlen_q (same as the test/reference implementation): + # k_index <= q_index + (seqlen_k - seqlen_q) + window_right + # In our kernels, causal is represented by (window_left is None, window_right is not None). + if cutlass.const_expr(window_size_left is None and window_size_right is not None): + offset = seqlen_k - seqlen_q + elif cutlass.const_expr( + mask_type is Sm100MaskEnum.WINDOW_MASK_INFERENCE + or mask_type is Sm100MaskEnum.WINDOW_MASK_BWD_INFERENCE + ): + offset = seqlen_k - seqlen_q + for i in cutlass.range_constexpr(cute.size(acc_qk), unroll_full=True): + index_q, index_k = index_transform(*index_qk[i]) + if cutlass.const_expr(window_size_left is not None or window_size_right is not None): + if cutlass.const_expr(window_size_left is None): + if index_q + offset + window_size_right < index_k: + acc_qk[i] = -Float32.inf + if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask + acc_qk[i] = -Float32.inf + elif cutlass.const_expr(window_size_right is None): + if index_q + offset - window_size_left > index_k: + acc_qk[i] = -Float32.inf + if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask + acc_qk[i] = -Float32.inf + else: + max_K_index = dsl_min(index_q + offset + window_size_right, seqlen_k) + min_K_index = max(0, index_q + offset - window_size_left) + if index_k > max_K_index or index_k < min_K_index: + acc_qk[i] = -Float32.inf + if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask + acc_qk[i] = -Float32.inf + + if cutlass.const_expr( + mask_type == Sm100MaskEnum.RESIDUAL_MASK + or mask_type == Sm100MaskEnum.RESIDUAL_MASK_BWD + ): + if index_k >= seqlen_k or index_q >= seqlen_q: + acc_qk[i] = -Float32.inf + + @cute.jit + def apply_mask_via_causal_local( + acc_qk: cute.Tensor, + index_qk: cute.Tensor, + seqlen_q: Int32, + seqlen_k: Int32, + apply_semantic_window: cutlass.Constexpr[bool] = True, + is_causal: cutlass.Constexpr[bool] = False, + is_local: cutlass.Constexpr[bool] = False, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + index_transform: cutlass.Constexpr = lambda index_q, index_k: ( + index_q, + index_k, + ), + ): + """Apply forward mask without mask_type. + + - If apply_semantic_window=True, apply causal/local window constraints. + - Always apply residual OOB masking (index_k>=seqlen_k or index_q>=seqlen_q). + """ + tidx, tidy, tidx = cute.arch.thread_idx() + offset = 0 + if cutlass.const_expr(apply_semantic_window): + # Match WINDOW_MASK_INFERENCE semantics: end-align Q/K when lengths differ. + offset = seqlen_k - seqlen_q + for i in cutlass.range_constexpr(cute.size(acc_qk), unroll_full=True): + index_q, index_k = index_transform(*index_qk[i]) + if cutlass.const_expr(apply_semantic_window): + if cutlass.const_expr(is_causal and not is_local): + # Pure causal; tolerate both external forms: + # - (None, None) from interface + # - (None, 0) from fused-mask-style callers + right = 0 if const_expr(window_size_right is None) else window_size_right + if index_q + offset + right < index_k: + acc_qk[i] = -Float32.inf + elif cutlass.const_expr( + is_local or window_size_left is not None or window_size_right is not None + ): + if cutlass.const_expr(window_size_left is None): + if index_q + offset + window_size_right < index_k: + acc_qk[i] = -Float32.inf + elif cutlass.const_expr(window_size_right is None): + if index_q + offset - window_size_left > index_k: + acc_qk[i] = -Float32.inf + else: + max_K_index = dsl_min(index_q + offset + window_size_right, seqlen_k) + min_K_index = max(0, index_q + offset - window_size_left) + if index_k > max_K_index or index_k < min_K_index: + acc_qk[i] = -Float32.inf + # Residual mask is always needed for boundary protection. + if index_k >= seqlen_k or index_q >= seqlen_q: + acc_qk[i] = -Float32.inf diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py new file mode 100644 index 00000000000..c07e3e94176 --- /dev/null +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py @@ -0,0 +1,291 @@ +# Copyright (c) 2025, Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms, Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin. + + +"""Fused multi-head attention (FMHA) backward for the SM100 architecture using CUTE DSL. + +Constraints: +* Supported head dimensions: 256 only +* mma_tiler_mn must be 64,64 +* Batch size must be the same for Q, K, and V tensors +""" + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Int32 + +from flash_attn.cute.sm100_hd256_2cta_fmha_backward_dqkernel import ( + BlackwellFusedMultiHeadAttentionBackwardDQKernel, +) +from flash_attn.cute.sm100_hd256_2cta_fmha_backward_dkdvkernel import ( + BlackwellFusedMultiHeadAttentionBackwardDKDVKernel, +) + + +def _as_bshkrd_tensor( + tensor: cute.Tensor, + h_k: Int32, + h_r: Int32, + varlen: bool, +) -> cute.Tensor: + """Normalize (B,S,H,D)/(S,H,D) tensors to (B,S,H_k,H_r,D) view.""" + if cutlass.const_expr(cute.rank(tensor.layout) == 5): + return tensor + if cutlass.const_expr(cute.rank(tensor.layout) == 4): + return cute.make_tensor( + tensor.iterator, + cute.make_layout( + (tensor.shape[0], tensor.shape[1], h_k, h_r, tensor.shape[3]), + stride=( + tensor.stride[0], + tensor.stride[1], + tensor.stride[2] * h_r, + tensor.stride[2], + tensor.stride[3], + ), + ), + ) + assert cutlass.const_expr(cute.rank(tensor.layout) == 3), "Expected rank-3 varlen tensor" + assert cutlass.const_expr(varlen), "Rank-3 input is only valid for varlen backward" + return cute.make_tensor( + tensor.iterator, + cute.make_layout( + (1, tensor.shape[0], h_k, h_r, tensor.shape[2]), + stride=( + 0, + tensor.stride[0], + tensor.stride[1] * h_r, + tensor.stride[1], + tensor.stride[2], + ), + ), + ) + + +def _as_shhb_tensor( + tensor: cute.Tensor, + h_k: Int32, + h_r: Int32, + b: Int32, + varlen: bool, +) -> cute.Tensor: + """Normalize (B,H,S)/(H,S) tensors to (S, ((H_r, H_k), B)) view.""" + if cutlass.const_expr(cute.rank(tensor.layout) == 3): + return cute.make_tensor( + tensor.iterator, + cute.make_layout( + (tensor.shape[2], ((h_r, h_k), tensor.shape[0])), + stride=( + tensor.stride[2], + ((tensor.stride[1], tensor.stride[1] * h_r), tensor.stride[0]), + ), + ), + ) + assert cutlass.const_expr(cute.rank(tensor.layout) == 2), "Expected rank-2 varlen tensor" + assert cutlass.const_expr(varlen), "Rank-2 input is only valid for varlen backward" + return cute.make_tensor( + tensor.iterator, + cute.make_layout( + (tensor.shape[1], ((h_r, h_k), b)), + stride=( + tensor.stride[1], + ((tensor.stride[0], tensor.stride[0] * h_r), 0), + ), + ), + ) + + +class BlackwellFusedMultiHeadAttentionBackward: + """FMHA backward class for executing CuTeDSL kernel.""" + + def __init__( + self, + head_dim: int, + head_dim_v: int | None = None, + is_causal: bool = False, + is_local: bool = False, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, + is_persistent: bool = False, + deterministic: bool = False, + cluster_size: int = 1, + use_2cta_instrs: bool = False, + score_mod: cutlass.Constexpr | None = None, + score_mod_bwd: cutlass.Constexpr | None = None, + mask_mod: cutlass.Constexpr | None = None, + has_aux_tensors: cutlass.Constexpr = False, + subtile_factor: cutlass.Constexpr[int] = 1, + tile_m_dq: int = 128, + tile_n_dq: int = 128, + tile_m_dkdv: int = 128, + tile_n_dkdv: int = 64, + window_size_left: int | None = None, + window_size_right: int | None = None, + use_clc_scheduler: bool = False, + ): + """Initialization.""" + head_dim_v = head_dim if head_dim_v is None else head_dim_v + assert head_dim == 256 and head_dim_v == 256, ( + "SM100 dedicated backward kernel only supports (head_dim, head_dim_v) = (256, 256)" + ) + assert not is_local, "SM100 backward with head_dim=256 does not support local attention" + assert tile_m_dq == 128 and tile_n_dq == 128, ( + "SM100 dedicated backward kernel only supports tile_m_dq=128 and tile_n_dq=128" + ) + assert tile_m_dkdv == 128 and tile_n_dkdv == 64, ( + "SM100 dedicated backward kernel only supports tile_m_dkdv=128 and tile_n_dkdv=64" + ) + assert score_mod is None and score_mod_bwd is None and mask_mod is None, ( + "SM100 backward with head_dim=256 does not support score_mod/mask_mod" + ) + assert not deterministic, ( + "SM100 backward with head_dim=256 does not support deterministic mode" + ) + assert not has_aux_tensors, "SM100 backward with head_dim=256 does not support aux_tensors" + assert cluster_size in (1, 2), ( + "SM100 backward with head_dim=256 only supports cluster_size in {1, 2}" + ) + assert use_2cta_instrs, "SM100 backward with head_dim=256 requires use_2cta_instrs=True" + # subtile_factor is accepted for interface parity with FlashAttentionBackwardSm100, + # but this dedicated kernel uses fixed internal behavior. + + self.acc_dtype = cutlass.Float32 + self.is_causal = is_causal + self.window_size_left = ( + None if (window_size_left is None or window_size_left < 0) else window_size_left + ) + self.window_size_right = ( + None if (window_size_right is None or window_size_right < 0) else window_size_right + ) + self.tile_m_dq = tile_m_dq + self.tile_n_dq = tile_n_dq + self.tile_m_dkdv = tile_m_dkdv + self.tile_n_dkdv = tile_n_dkdv + self.use_clc_scheduler = use_clc_scheduler + + self.dq_kernel = BlackwellFusedMultiHeadAttentionBackwardDQKernel( + self.acc_dtype, + (self.tile_m_dq, self.tile_n_dq, 256), + self.is_causal, + self.window_size_left, + self.window_size_right, + False, # is_persistent + False, # split_head + use_clc_scheduler=self.use_clc_scheduler, + ) + self.dkdv_kernel = BlackwellFusedMultiHeadAttentionBackwardDKDVKernel( + self.acc_dtype, + (self.tile_m_dkdv, self.tile_n_dkdv, 256), + self.is_causal, + self.window_size_left, + self.window_size_right, + use_clc_scheduler=self.use_clc_scheduler, + ) + + @cute.jit + def __call__( + self, + Q: cute.Tensor, + K: cute.Tensor, + V: cute.Tensor, + dO: cute.Tensor, + lse_log2: cute.Tensor, + dpsum: cute.Tensor, + dQ_accum: cute.Tensor | None, + dK: cute.Tensor, + dV: cute.Tensor, + scale_softmax: cutlass.Float32, + cumulative_s_q: cute.Tensor | None, + cumulative_s_k: cute.Tensor | None, + seqused_q: cute.Tensor | None = None, + seqused_k: cute.Tensor | None = None, + window_size_left: Int32 | None = None, + window_size_right: Int32 | None = None, + dQ_semaphore: cute.Tensor | None = None, + dK_semaphore: cute.Tensor | None = None, + dV_semaphore: cute.Tensor | None = None, + aux_tensors: tuple[cute.Tensor] | None = None, + block_sparse_tensors: cute.Tensor | None = None, + stream: cuda.CUstream = None, + ): + """Host function to launch CuTeDSL kernel.""" + assert seqused_q is None and seqused_k is None, ( + "SM100 backward with head_dim=256 does not support seqused_q/seqused_k" + ) + assert window_size_left is None and window_size_right is None, ( + "SM100 backward with head_dim=256 uses constructor-provided window sizes" + ) + assert dQ_semaphore is None and dK_semaphore is None and dV_semaphore is None, ( + "SM100 backward with head_dim=256 does not use semaphores" + ) + assert block_sparse_tensors is None, ( + "SM100 backward with head_dim=256 does not support block sparse tensors" + ) + assert aux_tensors is None or len(aux_tensors) == 0, ( + "SM100 backward with head_dim=256 does not support aux_tensors" + ) + assert dQ_accum is not None, ( + "SM100 backward with head_dim=256 expects dQ tensor at dQ_accum slot" + ) + dQ = dQ_accum + varlen = cumulative_s_q is not None or cumulative_s_k is not None + q_rank = cute.rank(Q.layout) + k_rank = cute.rank(K.layout) + if cutlass.const_expr(q_rank == 5): + h_q = Q.shape[2] * Q.shape[3] + elif cutlass.const_expr(q_rank == 4): + h_q = Q.shape[2] + else: + h_q = Q.shape[1] + if cutlass.const_expr(k_rank == 5): + h_k = K.shape[2] + elif cutlass.const_expr(k_rank == 4): + h_k = K.shape[2] + else: + h_k = K.shape[1] + h_r = h_q // h_k + if cutlass.const_expr(cumulative_s_q is not None): + b = cumulative_s_q.shape[0] - 1 + elif cutlass.const_expr(cumulative_s_k is not None): + b = cumulative_s_k.shape[0] - 1 + else: + b = Q.shape[0] + + Q = _as_bshkrd_tensor(Q, h_k, h_r, varlen) + K = _as_bshkrd_tensor(K, h_k, 1, varlen) + V = _as_bshkrd_tensor(V, h_k, 1, varlen) + dQ = _as_bshkrd_tensor(dQ, h_k, h_r, varlen) + dK = _as_bshkrd_tensor(dK, h_k, 1, varlen) + dV = _as_bshkrd_tensor(dV, h_k, 1, varlen) + dO = _as_bshkrd_tensor(dO, h_k, h_r, varlen) + scaled_LSE = _as_shhb_tensor(lse_log2, h_k, h_r, b, varlen) + sum_OdO = _as_shhb_tensor(dpsum, h_k, h_r, b, varlen) + + # Keep original order: dQ first, then dKdV. + self.dq_kernel( + Q, + K, + V, + dQ, + dO, + scaled_LSE, + sum_OdO, + cumulative_s_q, + cumulative_s_k, + scale_softmax, + stream, + ) + self.dkdv_kernel( + Q, + K, + V, + dK, + dV, + dO, + scaled_LSE, + sum_OdO, + cumulative_s_q, + cumulative_s_k, + scale_softmax, + stream, + ) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py new file mode 100644 index 00000000000..e413d6f638a --- /dev/null +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py @@ -0,0 +1,3155 @@ +# Copyright (c) 2025, Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms, Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin. + +"""Fused multi-head attention (FMHA) backward for the SM100 architecture using CUTE DSL. + +Constraints: +* Supported head dimensions: 256 only +* cta_tiler_mn must be 64,128 +* Batch size must be the same for Q, K, and V tensors +""" + +import enum +import math + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.typing import Int32 +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +from cutlass.utils import ClcDynamicPersistentTileScheduler +from flash_attn.cute.tile_scheduler import ( + ClcState, + SM100_TMEM_CAPACITY_COLUMNS, + make_sm100_thread_cooperative_group as make_thread_cooperative_group, + Sm100FmhaClcDynamicTileSchedulerParams as FmhaClcDynamicTileSchedulerParams, + Sm100FmhaClcDynamicTileScheduler as FmhaClcDynamicTileScheduler, + Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams, +) + + +LAYOUT_RANK_CONSTANT = 3 + + +@cute.jit +def split_wg( + t: cute.Tensor, + num_warp_groups: Int32, + wg_idx: Int32, +) -> cute.Tensor: + """Split warp group.""" + # dishengbin, TODO:need to double check if more efficient to split in other dimensions + ret = None + if cutlass.const_expr(cute.rank(t.layout) == LAYOUT_RANK_CONSTANT): + p = cute.composition( + t, + cute.make_layout( + ( + t.shape[0], + t.shape[1], + (num_warp_groups, cute.size(t, mode=[2]) // num_warp_groups), + ) + ), + ) + ret = p[None, None, (wg_idx, None)] + else: + p = cute.composition( + t, + cute.make_layout( + ( + t.shape[0], + t.shape[1], + t.shape[2], + (num_warp_groups, cute.size(t, mode=[3]) // num_warp_groups), + ) + ), + ) + ret = p[None, None, None, (wg_idx, None)] + return ret + + +class MaskType(enum.Enum): + """Mask type used in FMHA backward.""" + + NO_MASK = enum.auto() + RESIDUAL_MASK_FOR_BACKWARD = enum.auto() + CAUSAL_MASK_FOR_BACKWARD = enum.auto() + + +def Tmemory_offset(lane, col): + """Tensor memory offset.""" + return (lane << 16) + col + + +permute_order = (0, 1, 2, 3, 4) + + +class BlackwellFusedMultiHeadAttentionBackwardDKDVKernel: + """FMHA backward class for executing CuTeDSL kernel.""" + + def __init__( + self, + acc_dtype: type[cutlass.Numeric], + cta_tiler: tuple[int, int, int], + is_causal: bool, + window_size_left: int | None, + window_size_right: int | None, + use_clc_scheduler: bool = False, + ): + """Initialization.""" + self.acc_dtype = acc_dtype + self.cta_tiler = cta_tiler + self.use_clc_scheduler = use_clc_scheduler + self.sched_warp_id = 10 if use_clc_scheduler else None + # TODO: need check, not sure whether need to *2 if 2cta + self.tile_shape_Q = cta_tiler[0] + self.tile_shape_K = cta_tiler[1] + self.tile_shape_dQ_K = cta_tiler[2] + self.tile_shape_dV_dO = cta_tiler[2] + # For S + self.KQ_mma_tiler = ( + cta_tiler[1] * 2, + cta_tiler[0], + cta_tiler[2], + ) + # For dP + self.VdO_mma_tiler = ( + cta_tiler[1] * 2, + cta_tiler[0], + cta_tiler[2], + ) + # For dV + self.PdO_mma_tiler = ( + cta_tiler[1] * 2, + cta_tiler[2], + cta_tiler[0], + ) + # For dK + self.dSQ_mma_tiler = ( + cta_tiler[1] * 2, + cta_tiler[2], + cta_tiler[0], + ) + # For dQ, dishengbin, need to remove + self.dSK_mma_tiler = ( + cta_tiler[0] * 2, + cta_tiler[2], + cta_tiler[1], + ) + self.cluster_shape_mn = (2, 1) + self.is_causal = is_causal + self.window_size_left: int = -1 if window_size_left is None else window_size_left + self.window_size_right: int = -1 if window_size_right is None else window_size_right + self.has_sliding_window = False + if self.window_size_left > 0 or self.window_size_right > 0: + self.has_sliding_window = True + if self.is_causal: + self.window_size_right = 0 + + self.compute_warp_id = (0, 1, 2, 3, 4, 5, 6, 7) + self.mma_warp_id = 8 + self.load_warp_id = 9 + self.empty_warp_id = 10 + + self.num_compute_warps = 8 + + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * (self.num_compute_warps + 4) + + self.cta_sync_bar_id = 0 + self.tmem_alloc_sync_bar_id = 1 + self.compute_sync_bar_id = 2 + self.epilogue_sync_bar_id = 3 + self.reduce_sync_bar_id = 4 + + self.tmem_dK_offset = 0 + self.tmem_dV_offset = Tmemory_offset(0, cta_tiler[2] // 2) + self.tmem_dP_offset = Tmemory_offset(0, cta_tiler[2] + cta_tiler[0] // 2) + self.tmem_S_offset = Tmemory_offset(0, cta_tiler[2]) + + self.num_regs_reduce = 152 + self.num_regs_compute = 128 + self.num_regs_mma = 128 + self.num_regs_empty = 96 + self.num_regs_load = 96 + + self.buffer_align_bytes = 128 + + def _setup_attributes(self): + """Settings for pipeline stage.""" + self.load_mma_Q_stage = 1 + self.load_mma_K_stage = 1 + self.load_mma_V_stage = 1 + self.load_mma_QT_stage = 1 + self.load_mma_dO_stage = 1 + self.load_compute_LSE_stage = 1 + self.load_compute_sum_OdO_stage = 1 + self.mma_compute_S_stage = 1 + self.mma_compute_dP_stage = 1 + self.compute_mma_P_stage = 1 + self.compute_mma_dS_stage = 1 + self.mma_compute_dKdV_stage = 2 + + if cutlass.const_expr(self.use_clc_scheduler): + self.num_clc_stage = 1 + self.num_clc_response_bytes = 16 + + @cute.jit + def __call__( + self, + Q: cute.Tensor, + K: cute.Tensor, + V: cute.Tensor, + dK: cute.Tensor, + dV: cute.Tensor, + dO: cute.Tensor, + scaled_LSE: cute.Tensor, + sum_OdO: cute.Tensor, + cumulative_s_q: cute.Tensor | None, + cumulative_s_k: cute.Tensor | None, + scale_softmax: cutlass.Float32, + stream: cuda.CUstream, + ): + """Host function to launch CuTeDSL kernel.""" + varlen = cumulative_s_q is not None or cumulative_s_k is not None + # Infer shape metadata from normalized 5D tensors (B, S, H_k, H_r, D). + h_r = Q.shape[3] + h_k = Q.shape[2] + if cutlass.const_expr(cumulative_s_q is not None): + b = cumulative_s_q.shape[0] - 1 + elif cutlass.const_expr(cumulative_s_k is not None): + b = cumulative_s_k.shape[0] - 1 + else: + b = Q.shape[0] + problem_shape = ( + Q.shape[1], + K.shape[1], + Q.shape[4], + ((h_r, h_k), b), + ) + hb = ((h_r, h_k), b) + # (b, s, h_k, h_r, d) -> (s, d, ((h_r, h_k), b)) + Q = cute.make_tensor( + Q.iterator, + cute.make_layout( + (Q.shape[1], Q.shape[4], hb), + stride=( + cute.assume(Q.stride[1], divby=64), + Q.stride[4], + ( + (Q.shape[4], Q.shape[4] * Q.shape[3]), + ( + 0 + if varlen + else cute.assume(Q.shape[1] * Q.shape[4] * h_r * h_k, divby=64) + ), + ), + ), + ), + ) + # (b, s, h_k, 1, d) -> (s, d, ((1, h_k), b)) + K = cute.make_tensor( + K.iterator, + cute.make_layout( + (K.shape[1], K.shape[4], hb), + stride=( + cute.assume(K.stride[1], divby=64), + K.stride[4], + ( + (0, K.shape[4]), + (0 if varlen else cute.assume(K.shape[1] * K.shape[4] * 1 * h_k, divby=64)), + ), + ), + ), + ) + # (b, s, h_k, 1, d) -> (s, d, ((1, h_k), b)) + V = cute.make_tensor( + V.iterator, + cute.make_layout( + (V.shape[1], V.shape[4], hb), + stride=( + cute.assume(V.stride[1], divby=64), + V.stride[4], + ( + (0, V.shape[4]), + (0 if varlen else cute.assume(V.shape[1] * V.shape[4] * 1 * h_k, divby=64)), + ), + ), + ), + ) + # (s, d, ((h_r, h_k), b)) -> (d, s, ((h_r, h_k), b)) + QT = cute.make_tensor( + Q.iterator, + cute.make_layout( + (Q.shape[1], Q.shape[0], Q.shape[2]), + stride=( + Q.stride[1], + Q.stride[0], + Q.stride[2], + ), + ), + ) + dK = cute.make_tensor(dK.iterator, K.layout) + dV = cute.make_tensor(dV.iterator, V.layout) + # (s, d, ((h_r, h_k), b)) + dO = cute.make_tensor(dO.iterator, Q.layout) + + # (s, d, ((h_r, h_k), b)) -> (d, s, ((h_r, h_k), b)) + dOT = cute.make_tensor( + dO.iterator, + cute.make_layout( + (dO.shape[1], dO.shape[0], dO.shape[2]), + stride=( + dO.stride[1], + dO.stride[0], + dO.stride[2], + ), + ), + ) + + self.Q_major_mode = utils.LayoutEnum.from_tensor(Q).mma_major_mode() + self.K_major_mode = utils.LayoutEnum.from_tensor(K).mma_major_mode() + self.dK_major_mode = utils.LayoutEnum.from_tensor(dK).mma_major_mode() + self.V_major_mode = utils.LayoutEnum.from_tensor(V).mma_major_mode() + self.dV_major_mode = utils.LayoutEnum.from_tensor(dV).mma_major_mode() + self.dO_major_mode = utils.LayoutEnum.from_tensor(dO).mma_major_mode() + + if cutlass.const_expr(self.Q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError(f"The layout of q is not supported: {self.Q_major_mode}") + if cutlass.const_expr(self.K_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of k is not supported") + if cutlass.const_expr(self.dK_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of dk is not supported") + if cutlass.const_expr(self.V_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of v is not supported") + if cutlass.const_expr(self.dV_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of dv is not supported") + + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.TWO + PT_source = tcgen05.OperandSource.SMEM + + # compute S + KQ_tiled_mma = sm100_utils.make_trivial_tiled_mma( + K.element_type, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + cta_group, + self.KQ_mma_tiler[:2], + ) + # compute dP + VdO_tiled_mma = sm100_utils.make_trivial_tiled_mma( + V.element_type, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + cta_group, + self.VdO_mma_tiler[:2], + ) + # compute dV + PdO_tiled_mma = sm100_utils.make_trivial_tiled_mma( + dO.element_type, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.MN, + self.acc_dtype, + cta_group, + self.PdO_mma_tiler[:2], + PT_source, + ) + # compute dK + dSQ_tiled_mma = sm100_utils.make_trivial_tiled_mma( + Q.element_type, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.MN, + self.acc_dtype, + cta_group, + self.dSQ_mma_tiler[:2], + ) + # compute dQ + # dishengbin, need to remove, but used in dS_mem_layout_staged + dSK_tiled_mma = sm100_utils.make_trivial_tiled_mma( + K.element_type, + tcgen05.OperandMajorMode.MN, + tcgen05.OperandMajorMode.MN, + self.acc_dtype, + cta_group, + self.dSK_mma_tiler[:2], + ) + + atom_thr_size = cute.size(KQ_tiled_mma.thr_id.shape) + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) # type: ignore[assignment] + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (atom_thr_size,), + ) + + K_smem_layout_staged = sm100_utils.make_smem_layout_a( + KQ_tiled_mma, + self.KQ_mma_tiler, + K.element_type, + 1, + ) + Q_smem_layout_staged = sm100_utils.make_smem_layout_b( + KQ_tiled_mma, + self.KQ_mma_tiler, + Q.element_type, + self.load_mma_Q_stage, + ) + V_smem_layout_staged = sm100_utils.make_smem_layout_a( + VdO_tiled_mma, + self.VdO_mma_tiler, + V.element_type, + 1, + ) + dO_smem_layout_staged = sm100_utils.make_smem_layout_b( + VdO_tiled_mma, + self.VdO_mma_tiler, + dO.element_type, + self.load_mma_dO_stage, + ) + # dishengbin, need to remove, but used for sPT, need to double check + dS_smem_layout_staged = sm100_utils.make_smem_layout_a( + dSK_tiled_mma, + self.dSK_mma_tiler, + Q.element_type, + self.compute_mma_dS_stage, + ) + dST_smem_layout_staged = sm100_utils.make_smem_layout_a( + dSQ_tiled_mma, + self.dSQ_mma_tiler, + Q.element_type, + self.compute_mma_dS_stage, + ) + tiled_mma = dSQ_tiled_mma + is_k_major = tiled_mma.op.a_major_mode == tcgen05.OperandMajorMode.K + a_major_mode = tcgen05.OperandMajorMode.K if is_k_major else tcgen05.OperandMajorMode.MN + tmp = cute.dice(self.dSQ_mma_tiler, (1, None, 1)) + a_smem_shape = tiled_mma.partition_shape_A( + cute.dice(self.dSQ_mma_tiler, (1, None, 1)), + ) + a_smem_shape_mn_k = ( + cute.size(a_smem_shape[0][0]) * a_smem_shape[1], + cute.size(a_smem_shape[0][1]) * a_smem_shape[2], + ) + smem_layout_atom_kind = sm100_utils.get_smem_layout_atom_ab( + a_major_mode, + K.element_type, + a_smem_shape_mn_k, + ) + a_smem_layout_atom = sm100_utils.make_smem_layout_atom( + smem_layout_atom_kind, + K.element_type, + ) + + a_smem_shape = cute.append( + a_smem_shape, + self.compute_mma_dS_stage, + ) + order = (2, 1, 3) if not is_k_major else (1, 2, 3) + dST_smem_layout_staged_tmp = sm100_utils.tile_to_mma_shape( + a_smem_layout_atom, + a_smem_shape, + order=order, + ) + QT_smem_layout_staged = sm100_utils.make_smem_layout_b( + dSQ_tiled_mma, + self.dSQ_mma_tiler, + Q.element_type, + self.load_mma_QT_stage, + ) + P_smem_layout_staged = sm100_utils.make_smem_layout_a( + PdO_tiled_mma, + self.PdO_mma_tiler, + Q.element_type, + self.compute_mma_P_stage, + ) + dOT_smem_layout_staged = sm100_utils.make_smem_layout_b( + PdO_tiled_mma, + self.PdO_mma_tiler, + dO.element_type, + self.load_mma_dO_stage, + ) + LSE_smem_layout = cute.make_layout((self.cta_tiler[0], self.load_compute_LSE_stage)) + sum_OdO_smem_layout = cute.make_layout((self.cta_tiler[0], self.load_compute_sum_OdO_stage)) + + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_reduce_op = cpasync.CopyReduceBulkTensorTileS2GOp() + + K_smem_layout = cute.select(K_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + K, + K_smem_layout, + self.KQ_mma_tiler, + KQ_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + V_smem_layout = cute.select(V_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + V, + V_smem_layout, + self.VdO_mma_tiler, + VdO_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + Q_smem_layout = cute.select(Q_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + Q, + Q_smem_layout, + self.KQ_mma_tiler, + KQ_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + QT_smem_layout = cute.select(QT_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_QT, tma_tensor_QT = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + QT, + QT_smem_layout, + self.dSQ_mma_tiler, + dSQ_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + dO_smem_layout = cute.select(dO_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + dO, + dO_smem_layout, + self.VdO_mma_tiler, + VdO_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + dOT_smem_layout = cute.select(dOT_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_dOT, tma_tensor_dOT = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + dOT, + dOT_smem_layout, + self.PdO_mma_tiler, + PdO_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # for 2cta, tma_copy_QT_bytes is same as the tma_copy_Q_bytes + self.tma_copy_Q_bytes = cute.size_in_bytes(Q.element_type, Q_smem_layout) * atom_thr_size + self.tma_copy_K_bytes = cute.size_in_bytes(K.element_type, K_smem_layout) * atom_thr_size + self.tma_copy_V_bytes = cute.size_in_bytes(V.element_type, V_smem_layout) * atom_thr_size + self.tma_copy_dO_bytes = cute.size_in_bytes(dO.element_type, dO_smem_layout) * atom_thr_size + + @cute.struct + class SharedStorage: + # Pipeline barriers + load_mma_Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_Q_stage * 2] + load_mma_K_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_K_stage * 2] + load_mma_V_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_V_stage * 2] + load_mma_QT_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_QT_stage * 2] + load_mma_dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_dO_stage * 2] + load_mma_dOT_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_dO_stage * 2] + load_compute_lse_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_compute_LSE_stage * 2 + ] + load_compute_sum_OdO_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_compute_sum_OdO_stage * 2 + ] + mma_compute_S_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.mma_compute_S_stage * 2 + ] + mma_compute_dP_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.mma_compute_dP_stage * 2 + ] + compute_mma_P_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.compute_mma_P_stage * 2 + ] + compute_mma_dS_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.compute_mma_dS_stage * 2 + ] + mma_compute_dKdV_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.mma_compute_dKdV_stage * 2 + ] + tmem_holding_buf: cutlass.Int32 + tmem_dealloc_mbar_ptr: cutlass.Int64 + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + clc_response: cute.struct.MemRange[Int32, 4] + # Smem tensors + sK: cute.struct.Align[ + cute.struct.MemRange[K.element_type, cute.cosize(K_smem_layout_staged)], + self.buffer_align_bytes, + ] + # only used in 2cta + sV: cute.struct.Align[ + cute.struct.MemRange[V.element_type, cute.cosize(V_smem_layout_staged)], + self.buffer_align_bytes, + ] + sQ: cute.struct.Align[ + cute.struct.MemRange[Q.element_type, cute.cosize(Q_smem_layout_staged)], + self.buffer_align_bytes, + ] + sQT: cute.struct.Align[ + cute.struct.MemRange[Q.element_type, cute.cosize(QT_smem_layout_staged)], + self.buffer_align_bytes, + ] + sdO: cute.struct.Align[ + cute.struct.MemRange[dO.element_type, cute.cosize(dO_smem_layout_staged)], + self.buffer_align_bytes, + ] + sdOT: cute.struct.Align[ + cute.struct.MemRange[dO.element_type, cute.cosize(dOT_smem_layout_staged)], + self.buffer_align_bytes, + ] + # only used in 2cta + # dishengbin checked whether we need sP + sP: cute.struct.Align[ + cute.struct.MemRange[Q.element_type, cute.cosize(P_smem_layout_staged)], + self.buffer_align_bytes, + ] + sdST: cute.struct.Align[ + cute.struct.MemRange[Q.element_type, cute.cosize(dST_smem_layout_staged)], + self.buffer_align_bytes, + ] + + sLSE: cute.struct.Align[ + cute.struct.MemRange[self.acc_dtype, cute.cosize(LSE_smem_layout)], + self.buffer_align_bytes, + ] + sSum_OdO: cute.struct.Align[ + cute.struct.MemRange[self.acc_dtype, cute.cosize(sum_OdO_smem_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # =============================== bwd =============================== + K_val = problem_shape[1] + _, H_K = problem_shape[3][0] + B = problem_shape[3][1] + problem_shape_mbh = ( + cute.ceil_div(K_val, self.cta_tiler[1]), + cute.size(B), + cute.size(H_K), + ) + if cutlass.const_expr(self.use_clc_scheduler): + self.tile_sched_params = FmhaClcDynamicTileSchedulerParams( + problem_shape_mbh, + (*self.cluster_shape_mn, 1), + ) + bwd_grid = FmhaClcDynamicTileScheduler.get_grid_shape(self.tile_sched_params) + else: + self.tile_sched_params = FmhaStaticTileSchedulerParams( + is_persistent=False, + problem_shape_mbh=problem_shape_mbh, + ) + bwd_grid = self._compute_bwd_grid(problem_shape, self.cta_tiler[1]) + bwd_grid = cute.round_up(bwd_grid, self.cluster_shape_mnk) + + self.dkdv_bwd( + KQ_tiled_mma, + VdO_tiled_mma, + PdO_tiled_mma, + dSQ_tiled_mma, + tma_atom_K, + tma_tensor_K, + K, + tma_atom_V, + tma_tensor_V, + tma_atom_Q, + tma_tensor_Q, + Q, + tma_atom_QT, + tma_tensor_QT, + tma_atom_dO, + tma_tensor_dO, + tma_atom_dOT, + tma_tensor_dOT, + dK, + dV, + scaled_LSE, + scale_softmax, + sum_OdO, + problem_shape, + cumulative_s_q, + cumulative_s_k, + self.cluster_layout_vmnk, + K_smem_layout_staged, + Q_smem_layout_staged, + V_smem_layout_staged, + dO_smem_layout_staged, + dS_smem_layout_staged, + dST_smem_layout_staged, + QT_smem_layout_staged, + dOT_smem_layout_staged, + P_smem_layout_staged, + LSE_smem_layout, + sum_OdO_smem_layout, + self.tile_sched_params, + ).launch( + grid=bwd_grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=self.shared_storage.size_in_bytes(), # type: ignore [attr-defined] + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def dkdv_bwd( + self, + KQ_tiled_mma: cute.TiledMma, + VdO_tiled_mma: cute.TiledMma, + PdO_tiled_mma: cute.TiledMma, + dSQ_tiled_mma: cute.TiledMma, + tma_atom_K: cute.CopyAtom, + K_in: cute.Tensor, + K_ref: cute.Tensor, + tma_atom_V: cute.CopyAtom, + V_in: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + Q_in: cute.Tensor, + Q_ref: cute.Tensor, + tma_atom_QT: cute.CopyAtom, + QT_in: cute.Tensor, + tma_atom_dO: cute.CopyAtom, + dO_in: cute.Tensor, + tma_atom_dOT: cute.CopyAtom, + dOT_in: cute.Tensor, + dK: cute.Tensor, + dV: cute.Tensor, + LSE: cute.Tensor, + scale_softmax: cutlass.Float32, + sum_OdO: cute.Tensor, + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + cumulative_s_q: cute.Tensor | None, + cumulative_s_k: cute.Tensor | None, + cluster_layout_vmnk: cute.Layout, + K_smem_layout_staged: cute.ComposedLayout, + Q_smem_layout_staged: cute.ComposedLayout, + V_smem_layout_staged: cute.ComposedLayout, + dO_smem_layout_staged: cute.ComposedLayout, + dS_smem_layout_staged: cute.ComposedLayout, + dST_smem_layout_staged: cute.ComposedLayout, + QT_smem_layout_staged: cute.ComposedLayout, + dOT_smem_layout_staged: cute.ComposedLayout, + P_smem_layout_staged: cute.ComposedLayout, + LSE_smem_layout: cute.Layout, + sum_OdO_smem_layout: cute.Layout, + tile_sched_params: FmhaStaticTileSchedulerParams | FmhaClcDynamicTileSchedulerParams, + ): + """Core CuTeDSL backward kernel.""" + bidx, bidy, bidz = cute.arch.block_idx() + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + varlen = cumulative_s_q is not None or cumulative_s_k is not None + + mma_tile_coord_v = bidx % cute.size(KQ_tiled_mma.thr_id.shape) + + if warp_idx == self.load_warp_id: + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_QT) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_dO) + cpasync.prefetch_descriptor(tma_atom_dOT) + + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + load_mma_Q_producer, load_mma_Q_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.load_mma_Q_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_Q_bytes, + barrier_storage=storage.load_mma_Q_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_mma_K_producer, load_mma_K_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.load_mma_K_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_K_bytes, + barrier_storage=storage.load_mma_K_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_mma_V_producer, load_mma_V_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.load_mma_V_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_V_bytes, + barrier_storage=storage.load_mma_V_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_mma_QT_producer, load_mma_QT_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.load_mma_QT_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_Q_bytes, + barrier_storage=storage.load_mma_QT_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_mma_dO_producer, load_mma_dO_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.load_mma_dO_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_dO_bytes, + barrier_storage=storage.load_mma_dO_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_mma_dOT_producer, load_mma_dOT_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.load_mma_dO_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_dO_bytes, + barrier_storage=storage.load_mma_dOT_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_compute_LSE_producer, load_compute_LSE_consumer = pipeline.PipelineCpAsync.create( + num_stages=self.load_compute_LSE_stage, + producer_group=make_thread_cooperative_group(self.threads_per_warp), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * self.num_compute_warps + ), + barrier_storage=storage.load_compute_lse_mbar_ptr.data_ptr(), + ).make_participants() + load_compute_sum_OdO_producer, load_compute_sum_OdO_consumer = ( + pipeline.PipelineCpAsync.create( + num_stages=self.load_compute_sum_OdO_stage, + producer_group=make_thread_cooperative_group(self.threads_per_warp), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * self.num_compute_warps + ), + barrier_storage=storage.load_compute_sum_OdO_mbar_ptr.data_ptr(), + ).make_participants() + ) + mma_compute_S_producer, mma_compute_S_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_compute_S_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0] + ), + barrier_storage=storage.mma_compute_S_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + mma_compute_dP_producer, mma_compute_dP_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_compute_dP_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0] + ), + barrier_storage=storage.mma_compute_dP_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + compute_mma_P_producer, compute_mma_P_consumer = pipeline.PipelineAsyncUmma.create( + num_stages=self.compute_mma_P_stage, + producer_group=make_thread_cooperative_group( + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0] + ), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.compute_mma_P_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + compute_mma_dS_producer, compute_mma_dS_consumer = pipeline.PipelineAsyncUmma.create( + num_stages=self.compute_mma_dS_stage, + producer_group=make_thread_cooperative_group( + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0] + ), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.compute_mma_dS_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + mma_compute_dKdV_producer, mma_compute_dKdV_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_compute_dKdV_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0] + ), + barrier_storage=storage.mma_compute_dKdV_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + + cute.arch.barrier(barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta) + + # setup mma + sQ = storage.sQ.get_tensor(Q_smem_layout_staged.outer, swizzle=Q_smem_layout_staged.inner) + sK = storage.sK.get_tensor(K_smem_layout_staged.outer, swizzle=K_smem_layout_staged.inner) + sV = storage.sV.get_tensor(V_smem_layout_staged.outer, swizzle=V_smem_layout_staged.inner) + sdO = storage.sdO.get_tensor( + dO_smem_layout_staged.outer, swizzle=dO_smem_layout_staged.inner + ) + sLSE = storage.sLSE.get_tensor(LSE_smem_layout) + sSum_OdO = storage.sSum_OdO.get_tensor(sum_OdO_smem_layout) + tmem_holding_buf = storage.tmem_holding_buf + # for 2cta, QT use different mem from Q + + sQT = storage.sQT.get_tensor( + QT_smem_layout_staged.outer, swizzle=QT_smem_layout_staged.inner + ) + sdST = storage.sdST.get_tensor( + dST_smem_layout_staged.outer, swizzle=dST_smem_layout_staged.inner + ) + tP_fake_ptr = cute.make_ptr(sQ.element_type, 0, cute.AddressSpace.tmem) + tP = cute.make_tensor(tP_fake_ptr, P_smem_layout_staged.outer) + + sP = storage.sP.get_tensor(P_smem_layout_staged.outer, swizzle=P_smem_layout_staged.inner) + + sdOT = storage.sdOT.get_tensor( + dOT_smem_layout_staged.outer, swizzle=dOT_smem_layout_staged.inner + ) + + # tSTrK shape : (MMA, MMA_M, MMA_K, STAGE) + tSTrK = KQ_tiled_mma.make_fragment_A(sK) + # tSTrQ shape : (MMA, MMA_N, MMA_K, STAGE) + tSTrQ = KQ_tiled_mma.make_fragment_B(sQ) + + # tdPTrV shape : (MMA, MMA_M, MMA_K, STAGE) + tdPTrV = VdO_tiled_mma.make_fragment_A(sV) + # tdPTrdO shape : (MMA, MMA_N, MMA_K, STAGE) + tdPTrdO = VdO_tiled_mma.make_fragment_B(sdO) + + # tdKrdST shape: (MMA, MMA_M, MMA_K, STAGE) + tdKrdST = dSQ_tiled_mma.make_fragment_A(sdST) + # tdKrQT shape : (MMA, MMA_N, MMA_K, STAGE) + tdKrQT = dSQ_tiled_mma.make_fragment_B(sQT) + + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=self.threads_per_cta, + ) + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.load_warp_id, + is_two_cta=True, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + tmem.allocate(self.tmem_alloc_cols) + + # wait for tmem allocation and retrieve the pointer + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + # Cluster arrive after barrier init + # is_relaxed=False has memory consistency guarantee + pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=False) + + if cutlass.const_expr(self.use_clc_scheduler): + clc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + cluster_size = cute.size(self.cluster_shape_mnk) + num_clc_consumer_threads = self.threads_per_warp * ( + 1 # sched_warp (CTA 0 only) + + cluster_size + * ( + len(self.compute_warp_id) + + 1 # mma_warp + + 1 # load_warp + ) + ) + clc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_clc_consumer_threads + ) + clc_response_ptr = storage.clc_response.data_ptr() + clc = ClcState.create( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_sched_params.clc_hw_params(), + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + ), + pipeline=pipeline.PipelineClcFetchAsync.create( + barrier_storage=storage.clc_mbar_ptr.data_ptr(), + num_stages=self.num_clc_stage, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=self.num_clc_response_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ), + consumer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_clc_stage + ), + producer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_clc_stage + ), + ) + tile_sched = FmhaClcDynamicTileScheduler.create( + tile_sched_params, + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + clc, + ) + work_tile = tile_sched.initial_work_tile_info() + else: + clc = None + clc_response_ptr = None + + tSTtST_shape = KQ_tiled_mma.partition_shape_C(cute.select(self.KQ_mma_tiler, mode=[0, 1])) + tSTtST = KQ_tiled_mma.make_fragment_C(tSTtST_shape) + # tSTtST shape : (MMA, MMA_M, MMA_N) + tSTtST = cute.make_tensor(tmem_ptr + self.tmem_S_offset, tSTtST.layout) + + # tdVrP shape : (MMA, MMA_M, MMA_K, STAGE) + tdVrP = PdO_tiled_mma.make_fragment_A(sP) + # tdVrdOT shape : (MMA, MMA_N, MMA_K, STAGE) + tdVrdOT = PdO_tiled_mma.make_fragment_B(sdOT) + + tdPTtdPT_shape = VdO_tiled_mma.partition_shape_C( + cute.select(self.VdO_mma_tiler, mode=[0, 1]) + ) + tdPTtdPT = VdO_tiled_mma.make_fragment_C(tdPTtdPT_shape) + # tdPTtdPT shape : (MMA, MMA_M, MMA_N) + tdPTtdPT = cute.make_tensor(tmem_ptr + self.tmem_dP_offset, tdPTtdPT.layout) + + tdKtdK_shape = dSQ_tiled_mma.partition_shape_C(cute.select(self.dSQ_mma_tiler, mode=[0, 1])) + tdKtdK = dSQ_tiled_mma.make_fragment_C(tdKtdK_shape) + # tdKtdK shape : (MMA, MMA_M, MMA_N) + tdKtdK = cute.make_tensor(tmem_ptr + self.tmem_dK_offset, tdKtdK.layout) + + tdVtdV_shape = PdO_tiled_mma.partition_shape_C(cute.select(self.PdO_mma_tiler, mode=[0, 1])) + tdVtdV = PdO_tiled_mma.make_fragment_C(tdVtdV_shape) + # tdVtdV shape : (MMA, MMA_M, MMA_N) + tdVtdV = cute.make_tensor(tmem_ptr + self.tmem_dV_offset, tdVtdV.layout) + + # get the current batch problem shape + + if cutlass.const_expr(self.use_clc_scheduler): + # ===== CLC PERSISTENT PATH: per-warp while loops ===== + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_first_cta_in_cluster = cta_rank_in_cluster == 0 + is_sched_warp = warp_idx == self.sched_warp_id and is_first_cta_in_cluster + + # Register allocation ONCE (before any loop) + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_load) + elif warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_alloc(self.num_regs_mma) + elif warp_idx >= self.compute_warp_id[0] and warp_idx <= self.compute_warp_id[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_compute) + else: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + # Cluster wait + pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk) + + # ===== SCHEDULER WARP ===== + if is_sched_warp: + while work_tile.is_valid_tile: + tile_sched.prefetch_next_work() + work_tile = tile_sched.advance_to_next_work() + tile_sched.producer_tail() + + # ===== LOAD WARP ===== + elif warp_idx == self.load_warp_id: + while work_tile.is_valid_tile: + # Decode coordinates from CLC work tile + blk_coord_k = work_tile.tile_idx[0] + blk_coord_b = work_tile.tile_idx[2][0] + blk_coord_h_k = work_tile.tile_idx[2][1] + blk_coord = ( + Int32(0), + blk_coord_k, + Int32(0), + ((Int32(0), blk_coord_h_k), blk_coord_b), + ) + seqlen_q_cur_batch = Q_ref.shape[0] + seqlen_k_cur_batch = K_ref.shape[0] + blk_offset = (Int32(0), Int32(0), Int32(0), ((Int32(0), Int32(0)), Int32(0))) + if cutlass.const_expr(varlen): + assert isinstance(cumulative_s_q, cute.Tensor) + assert isinstance(cumulative_s_k, cute.Tensor) + seqlen_q_cur_batch = ( + cumulative_s_q[blk_coord_b + 1] - cumulative_s_q[blk_coord_b] + ) + seqlen_k_cur_batch = ( + cumulative_s_k[blk_coord_b + 1] - cumulative_s_k[blk_coord_b] + ) + blk_offset = ( + cumulative_s_q[blk_coord_b], + cumulative_s_k[blk_coord_b], + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + iter_start, iter_end = self.get_Q_block_min_max( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + blk_coord_k, + is_2cta=True, + ) + iter_count = (iter_end - iter_start) * problem_shape[3][0][0] + if iter_count <= 0: + if blk_coord_k * self.tile_shape_K < seqlen_k_cur_batch: + problem_shape_cur_batch = ( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + problem_shape[2], + problem_shape[3], + ) + self.epilogue_clear( + blk_coord, blk_offset, problem_shape_cur_batch, dK, dV + ) + else: + problem_shape_cur_batch = ( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + problem_shape[2], + problem_shape[3], + ) + ( + load_mma_Q_producer, + load_mma_K_producer, + load_mma_V_producer, + load_compute_LSE_producer, + load_mma_dO_producer, + load_mma_dOT_producer, + load_compute_sum_OdO_producer, + load_mma_QT_producer, + ) = self.load( + K_in, + V_in, + Q_in, + QT_in, + dO_in, + dOT_in, + LSE, + sum_OdO, + sK, + sQ, + sQT, + sV, + sdO, + sdOT, + sLSE, + sSum_OdO, + KQ_tiled_mma, + VdO_tiled_mma, + PdO_tiled_mma, + dSQ_tiled_mma, + tma_atom_K, + tma_atom_Q, + tma_atom_QT, + tma_atom_V, + tma_atom_dO, + tma_atom_dOT, + blk_offset, + problem_shape_cur_batch, + varlen, + iter_count, + iter_start, + iter_end, + load_mma_Q_producer, + load_mma_Q_consumer, + load_mma_K_producer, + load_mma_K_consumer, + load_mma_V_producer, + load_mma_V_consumer, + load_compute_LSE_producer, + load_compute_LSE_consumer, + load_mma_dO_producer, + load_mma_dO_consumer, + load_mma_dOT_producer, + load_mma_dOT_consumer, + load_compute_sum_OdO_producer, + load_compute_sum_OdO_consumer, + load_mma_QT_producer, + load_mma_QT_consumer, + blk_coord_k, + blk_coord_h_k, + blk_coord_b, + ) + # CLC advance + work_tile = tile_sched.advance_to_next_work() + # producer_tail after loop + load_mma_K_producer.tail() + load_mma_V_producer.tail() + load_mma_Q_producer.tail() + load_compute_LSE_producer.tail() + load_mma_dO_producer.tail() + load_mma_dOT_producer.tail() + load_compute_sum_OdO_producer.tail() + load_mma_QT_producer.tail() + + # ===== MMA WARP ===== + elif warp_idx == self.mma_warp_id: + while work_tile.is_valid_tile: + blk_coord_k = work_tile.tile_idx[0] + blk_coord_b = work_tile.tile_idx[2][0] + blk_coord_h_k = work_tile.tile_idx[2][1] + blk_coord = ( + Int32(0), + blk_coord_k, + Int32(0), + ((Int32(0), blk_coord_h_k), blk_coord_b), + ) + seqlen_q_cur_batch = Q_ref.shape[0] + seqlen_k_cur_batch = K_ref.shape[0] + blk_offset = (Int32(0), Int32(0), Int32(0), ((Int32(0), Int32(0)), Int32(0))) + if cutlass.const_expr(varlen): + assert isinstance(cumulative_s_q, cute.Tensor) + assert isinstance(cumulative_s_k, cute.Tensor) + seqlen_q_cur_batch = ( + cumulative_s_q[blk_coord_b + 1] - cumulative_s_q[blk_coord_b] + ) + seqlen_k_cur_batch = ( + cumulative_s_k[blk_coord_b + 1] - cumulative_s_k[blk_coord_b] + ) + blk_offset = ( + cumulative_s_q[blk_coord_b], + cumulative_s_k[blk_coord_b], + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + iter_start, iter_end = self.get_Q_block_min_max( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + blk_coord_k, + is_2cta=True, + ) + iter_count = (iter_end - iter_start) * problem_shape[3][0][0] + if iter_count <= 0: + if blk_coord_k * self.tile_shape_K < seqlen_k_cur_batch: + problem_shape_cur_batch = ( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + problem_shape[2], + problem_shape[3], + ) + self.epilogue_clear( + blk_coord, blk_offset, problem_shape_cur_batch, dK, dV + ) + else: + ( + mma_compute_S_producer, + mma_compute_dP_producer, + mma_compute_dKdV_producer, + load_mma_Q_consumer, + load_mma_K_consumer, + load_mma_V_consumer, + load_mma_dO_consumer, + load_mma_dOT_consumer, + compute_mma_P_consumer, + compute_mma_dS_consumer, + load_mma_QT_consumer, + ) = self.mma_2cta( + KQ_tiled_mma, + VdO_tiled_mma, + PdO_tiled_mma, + dSQ_tiled_mma, + tSTtST, + tSTrQ, + tSTrK, + tdPTtdPT, + tdPTrV, + tdPTrdO, + tdVtdV, + tdVrP, + tdVrdOT, + tdKrdST, + tdKtdK, + tdKrQT, + iter_count, + load_mma_Q_consumer, + load_mma_K_consumer, + load_mma_V_consumer, + mma_compute_S_producer, + load_mma_dO_consumer, + mma_compute_dP_producer, + load_mma_dOT_consumer, + compute_mma_P_consumer, + compute_mma_dS_consumer, + load_mma_QT_consumer, + mma_compute_dKdV_producer, + ) + # CLC advance + work_tile = tile_sched.advance_to_next_work() + # producer_tail after loop + mma_compute_S_producer.tail() + mma_compute_dP_producer.tail() + mma_compute_dKdV_producer.tail() + + # ===== COMPUTE WARPS ===== + elif warp_idx >= self.compute_warp_id[0] and warp_idx <= self.compute_warp_id[-1]: + while work_tile.is_valid_tile: + blk_coord_k = work_tile.tile_idx[0] + blk_coord_b = work_tile.tile_idx[2][0] + blk_coord_h_k = work_tile.tile_idx[2][1] + blk_coord = ( + Int32(0), + blk_coord_k, + Int32(0), + ((Int32(0), blk_coord_h_k), blk_coord_b), + ) + seqlen_q_cur_batch = Q_ref.shape[0] + seqlen_k_cur_batch = K_ref.shape[0] + blk_offset = (Int32(0), Int32(0), Int32(0), ((Int32(0), Int32(0)), Int32(0))) + if cutlass.const_expr(varlen): + assert isinstance(cumulative_s_q, cute.Tensor) + assert isinstance(cumulative_s_k, cute.Tensor) + seqlen_q_cur_batch = ( + cumulative_s_q[blk_coord_b + 1] - cumulative_s_q[blk_coord_b] + ) + seqlen_k_cur_batch = ( + cumulative_s_k[blk_coord_b + 1] - cumulative_s_k[blk_coord_b] + ) + blk_offset = ( + cumulative_s_q[blk_coord_b], + cumulative_s_k[blk_coord_b], + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + iter_start, iter_end = self.get_Q_block_min_max( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + blk_coord_k, + is_2cta=True, + ) + iter_count = (iter_end - iter_start) * problem_shape[3][0][0] + if iter_count <= 0: + if blk_coord_k * self.tile_shape_K < seqlen_k_cur_batch: + problem_shape_cur_batch = ( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + problem_shape[2], + problem_shape[3], + ) + self.epilogue_clear( + blk_coord, blk_offset, problem_shape_cur_batch, dK, dV + ) + else: + problem_shape_cur_batch = ( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + problem_shape[2], + problem_shape[3], + ) + ( + compute_mma_P_producer, + compute_mma_dS_producer, + mma_compute_S_consumer, + compute_mma_P_consumer, + load_compute_LSE_consumer, + load_compute_sum_OdO_consumer, + mma_compute_dP_consumer, + compute_mma_dS_consumer, + mma_compute_dKdV_consumer, + ) = self.compute( + tSTtST, + tdPTtdPT, + tdVrP, + sP, + sLSE, + sdST, + sSum_OdO, + dK, + dV, + tdKtdK, + tdVtdV, + PdO_tiled_mma, + dSQ_tiled_mma, + blk_coord, + blk_offset, + problem_shape_cur_batch, + iter_count, + iter_start, + iter_end, + scale_softmax, + mma_compute_S_producer, + mma_compute_S_consumer, + compute_mma_P_producer, + compute_mma_P_consumer, + load_compute_LSE_producer, + load_compute_LSE_consumer, + load_compute_sum_OdO_producer, + load_compute_sum_OdO_consumer, + mma_compute_dP_producer, + mma_compute_dP_consumer, + compute_mma_dS_producer, + compute_mma_dS_consumer, + mma_compute_dKdV_producer, + mma_compute_dKdV_consumer, + varlen, + sK, + seqlen_k_cur_batch, + ) + cute.arch.barrier( + barrier_id=self.epilogue_sync_bar_id, + number_of_threads=self.num_compute_warps * self.threads_per_warp, + ) + # CLC advance + work_tile = tile_sched.advance_to_next_work() + # producer_tail after loop + compute_mma_P_producer.tail() + compute_mma_dS_producer.tail() + + else: + # ===== STATIC PATH: original non-persistent code ===== + blk_coord = (Int32(0), bidx, Int32(0), ((Int32(0), bidy), bidz)) + seqlen_q_cur_batch = Q_ref.shape[0] + seqlen_k_cur_batch = K_ref.shape[0] + blk_offset = (Int32(0), Int32(0), Int32(0), ((Int32(0), Int32(0)), Int32(0))) + if cutlass.const_expr(varlen): + assert isinstance(cumulative_s_q, cute.Tensor) + assert isinstance(cumulative_s_k, cute.Tensor) + seqlen_q_cur_batch = cumulative_s_q[bidz + 1] - cumulative_s_q[bidz] + seqlen_k_cur_batch = cumulative_s_k[bidz + 1] - cumulative_s_k[bidz] + blk_offset = ( + cumulative_s_q[bidz], + cumulative_s_k[bidz], + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + + iter_start, iter_end = self.get_Q_block_min_max( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + blk_coord[1], + is_2cta=True, + ) + + # Cluster wait + pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk) + + iter_count = (iter_end - iter_start) * problem_shape[3][0][0] + problem_shape_cur_batch = ( + seqlen_q_cur_batch, + seqlen_k_cur_batch, + problem_shape[2], + problem_shape[3], + ) + if iter_count <= 0: + if bidx * self.tile_shape_K < seqlen_k_cur_batch: + self.epilogue_clear( + blk_coord, + blk_offset, + problem_shape_cur_batch, + dK, + dV, + ) + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + elif warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_load) + + self.load( + K_in, + V_in, + Q_in, + QT_in, + dO_in, + dOT_in, + LSE, + sum_OdO, + sK, + sQ, + sQT, + sV, + sdO, + sdOT, + sLSE, + sSum_OdO, + KQ_tiled_mma, + VdO_tiled_mma, + PdO_tiled_mma, + dSQ_tiled_mma, + tma_atom_K, + tma_atom_Q, + tma_atom_QT, + tma_atom_V, + tma_atom_dO, + tma_atom_dOT, + blk_offset, + problem_shape_cur_batch, + varlen, + iter_count, + iter_start, + iter_end, + load_mma_Q_producer, + load_mma_Q_consumer, + load_mma_K_producer, + load_mma_K_consumer, + load_mma_V_producer, + load_mma_V_consumer, + load_compute_LSE_producer, + load_compute_LSE_consumer, + load_mma_dO_producer, + load_mma_dO_consumer, + load_mma_dOT_producer, + load_mma_dOT_consumer, + load_compute_sum_OdO_producer, + load_compute_sum_OdO_consumer, + load_mma_QT_producer, + load_mma_QT_consumer, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + elif warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_alloc(self.num_regs_mma) + + self.mma_2cta( + KQ_tiled_mma, + VdO_tiled_mma, + PdO_tiled_mma, + dSQ_tiled_mma, + tSTtST, + tSTrQ, + tSTrK, + tdPTtdPT, + tdPTrV, + tdPTrdO, + tdVtdV, + tdVrP, + tdVrdOT, + tdKrdST, + tdKtdK, + tdKrQT, + iter_count, + load_mma_Q_consumer, + load_mma_K_consumer, + load_mma_V_consumer, + mma_compute_S_producer, + load_mma_dO_consumer, + mma_compute_dP_producer, + load_mma_dOT_consumer, + compute_mma_P_consumer, + compute_mma_dS_consumer, + load_mma_QT_consumer, + mma_compute_dKdV_producer, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Compute + # /////////////////////////////////////////////////////////////////////////////// + elif warp_idx >= self.compute_warp_id[0] and warp_idx <= self.compute_warp_id[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_compute) + + self.compute( + tSTtST, + tdPTtdPT, + tdVrP, + sP, + sLSE, + sdST, + sSum_OdO, + dK, + dV, + tdKtdK, + tdVtdV, + PdO_tiled_mma, + dSQ_tiled_mma, + blk_coord, + blk_offset, + problem_shape_cur_batch, + iter_count, + iter_start, + iter_end, + scale_softmax, + mma_compute_S_producer, + mma_compute_S_consumer, + compute_mma_P_producer, + compute_mma_P_consumer, + load_compute_LSE_producer, + load_compute_LSE_consumer, + load_compute_sum_OdO_producer, + load_compute_sum_OdO_consumer, + mma_compute_dP_producer, + mma_compute_dP_consumer, + compute_mma_dS_producer, + compute_mma_dS_consumer, + mma_compute_dKdV_producer, + mma_compute_dKdV_consumer, + varlen, + sK, + seqlen_k_cur_batch, + ) + + cute.arch.barrier( + barrier_id=self.epilogue_sync_bar_id, + number_of_threads=self.num_compute_warps * self.threads_per_warp, + ) + + else: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + cute.arch.cluster_arrive() + cute.arch.cluster_wait() + # dishengbin Deallocate tmem for early exit + # Dealloc the tensor memory + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + @cute.jit + def get_Q_block_min_max( + self, + seq_Q: Int32, + seq_K: Int32, + blk_coord_k: Int32, + is_2cta: bool, + ): + """Get Q tiles range.""" + Q_block_max = cute.ceil_div(seq_Q, self.tile_shape_Q) + Q_block_min = cutlass.Int32(0) + if cutlass.const_expr(self.has_sliding_window): + # For 2cta, use the last K block in the cluster so both CTAs get the same Q_block_max + blk_coord_k_for_max = (blk_coord_k // 2) * 2 + 1 + Q_block_max_tmp = cute.ceil_div( + (blk_coord_k_for_max + 1) * self.tile_shape_K + + seq_Q + - seq_K + + self.window_size_left, + self.tile_shape_Q, + ) + Q_block_max = min(Q_block_max, Q_block_max_tmp) + if cutlass.const_expr(self.is_causal or self.has_sliding_window): + # For 2cta, use the first K block in the cluster so both CTAs get the same Q_block_min. + # This ensures both CTAs in a cluster run the same number of pipeline iterations, + # avoiding hang from mismatched producer_commit / consumer_wait counts. + blk_coord_k_for_min = (blk_coord_k // 2) * 2 + Q_block_min_tmp = ( + blk_coord_k_for_min * self.tile_shape_K + seq_Q - seq_K - self.window_size_right + ) // self.tile_shape_Q + # Consider the case of 2cta, we need to ensure the K block is aligned to 2 + Q_block_min_tmp = Q_block_min_tmp - Q_block_min_tmp % 2 + Q_block_min = max(Q_block_min_tmp, Q_block_min) + return Q_block_min, Q_block_max + + @cute.jit + def load( + self, + K_in: cute.Tensor, + V_in: cute.Tensor, + Q_in: cute.Tensor, + QT_in: cute.Tensor, + dO_in: cute.Tensor, + dOT_in: cute.Tensor, + LSE_in: cute.Tensor, + sum_OdO_in: cute.Tensor, + sK: cute.Tensor, + sQ: cute.Tensor, + sQT: cute.Tensor, + sV: cute.Tensor, + sdO: cute.Tensor, + sdOT: cute.Tensor, + sLSE: cute.Tensor, + sSum_OdO: cute.Tensor, + KQ_tiled_mma: cute.TiledMma, + VdO_tiled_mma: cute.TiledMma, + PdO_tiled_mma: cute.TiledMma, + dSQ_tiled_mma: cute.TiledMma, + tma_atom_K: cute.CopyAtom, + tma_atom_Q: cute.CopyAtom, + tma_atom_QT: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + tma_atom_dOT: cute.CopyAtom, + blk_offset: cute.Shape, + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + varlen: bool, + iter_count: Int32, + iter_start: Int32, + iter_end: Int32, + load_mma_Q_producer, + load_mma_Q_consumer, + load_mma_K_producer, + load_mma_K_consumer, + load_mma_V_producer, + load_mma_V_consumer, + load_compute_LSE_producer, + load_compute_LSE_consumer, + load_mma_dO_producer, + load_mma_dO_consumer, + load_mma_dOT_producer, + load_mma_dOT_consumer, + load_compute_sum_OdO_producer, + load_compute_sum_OdO_consumer, + load_mma_QT_producer, + load_mma_QT_consumer, + blk_coord_k_override: Int32 = Int32(-1), + blk_coord_h_k_override: Int32 = Int32(-1), + blk_coord_b_override: Int32 = Int32(-1), + ): + """TMA load.""" + tidx, _, _ = cute.arch.thread_idx() + if cutlass.const_expr(self.use_clc_scheduler): + blk_coord_k = blk_coord_k_override + blk_coord_h_k = blk_coord_h_k_override + blk_coord_b = blk_coord_b_override + else: + blk_coord_k, blk_coord_h_k, blk_coord_b = cute.arch.block_idx() + blk_coord_h_r = Int32(0) + blk_coord_h = (blk_coord_h_r, blk_coord_h_k) + iter_index = iter_start + mma_tile_coord_v = blk_coord_k % cute.size(KQ_tiled_mma.thr_id.shape) + mma_tile_coord_m = blk_coord_k // cute.size(KQ_tiled_mma.thr_id.shape) + + K = cute.domain_offset(cute.select(blk_offset, mode=[1, 2, 3]), K_in) + V = cute.domain_offset(cute.select(blk_offset, mode=[1, 2, 3]), V_in) + Q = cute.domain_offset(cute.select(blk_offset, mode=[0, 2, 3]), Q_in) + QT = cute.domain_offset(cute.select(blk_offset, mode=[2, 0, 3]), QT_in) + dO = cute.domain_offset(cute.select(blk_offset, mode=[0, 2, 3]), dO_in) + dOT = cute.domain_offset(cute.select(blk_offset, mode=[2, 0, 3]), dOT_in) + blk_offset_stats = blk_offset + if cutlass.const_expr(varlen): + cuseqlen_q_stats = cute.assume( + (blk_offset[0] + blk_coord_b * self.tile_shape_Q) + // self.tile_shape_Q + * self.tile_shape_Q, + divby=self.tile_shape_Q, + ) + blk_offset_stats = ( + cuseqlen_q_stats, + blk_offset[1], + blk_offset[2], + blk_offset[3], + ) + LSE = cute.domain_offset(cute.select(blk_offset_stats, mode=[0, 3]), LSE_in) + sum_OdO = cute.domain_offset(cute.select(blk_offset_stats, mode=[0, 3]), sum_OdO_in) + + gK = cute.local_tile(K, cute.select(self.KQ_mma_tiler, mode=[0, 2]), (None, None, None)) + gQ = cute.local_tile(Q, cute.select(self.KQ_mma_tiler, mode=[1, 2]), (None, None, None)) + gQT = cute.local_tile(QT, cute.select(self.dSQ_mma_tiler, mode=[1, 2]), (None, None, None)) + gV = cute.local_tile(V, cute.select(self.VdO_mma_tiler, mode=[0, 2]), (None, None, None)) + gdO = cute.local_tile(dO, cute.select(self.VdO_mma_tiler, mode=[1, 2]), (None, None, None)) + gdOT = cute.local_tile( + dOT, cute.select(self.PdO_mma_tiler, mode=[1, 2]), (None, None, None) + ) + + KQ_thr_mma = KQ_tiled_mma.get_slice(mma_tile_coord_v) + VdO_thr_mma = VdO_tiled_mma.get_slice(mma_tile_coord_v) + PdO_thr_mma = PdO_tiled_mma.get_slice(mma_tile_coord_v) + dSQ_thr_mma = dSQ_tiled_mma.get_slice(mma_tile_coord_v) + + tSTgK = KQ_thr_mma.partition_A(gK) + tSTgQ = KQ_thr_mma.partition_B(gQ) + tdKgQT = dSQ_thr_mma.partition_B(gQT) + tdPTgV = VdO_thr_mma.partition_A(gV) + tdPTgdO = VdO_thr_mma.partition_B(gdO) + tdVgdOT = PdO_thr_mma.partition_B(gdOT) + + cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk) + cta_layout_vmnk = cute.tiled_divide(cta_layout_mnk, (KQ_tiled_mma.thr_id,)) + cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cute.arch.block_idx_in_cluster()) + + tKsK, tKgK_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_K, + cta_in_cluster_coord_vmnk[2], + cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSTgK, 0, 3), + ) + tQsQ, tQgQ_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_Q, + cta_in_cluster_coord_vmnk[1], + cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])), + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSTgQ, 0, 3), + ) + tQTsQT, tQTgQT_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_QT, + cta_in_cluster_coord_vmnk[1], + cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])), + cute.group_modes(sQT, 0, 3), + cute.group_modes(tdKgQT, 0, 3), + ) + tVsV, tVgV_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_V, + cta_in_cluster_coord_vmnk[2], + cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])), + cute.group_modes(sV, 0, 3), + cute.group_modes(tdPTgV, 0, 3), + ) + tdOsdO, tdOgdO_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_dO, + cta_in_cluster_coord_vmnk[1], + cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])), + cute.group_modes(sdO, 0, 3), + cute.group_modes(tdPTgdO, 0, 3), + ) + tdOTsdOT, tdOTgdOT_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_dOT, + cta_in_cluster_coord_vmnk[1], + cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])), + cute.group_modes(sdOT, 0, 3), + cute.group_modes(tdVgdOT, 0, 3), + ) + + k_handle = load_mma_K_producer.acquire_and_advance() + cute.copy( + tma_atom_K, + tKgK_mkl[(None, mma_tile_coord_m, 0, (blk_coord_h, blk_coord_b))], + tKsK[None, 0], + tma_bar_ptr=k_handle.barrier, + ) + + q_handle = load_mma_Q_producer.acquire_and_advance() + cute.copy( + tma_atom_Q, + tQgQ_mkl[(None, iter_index, 0, (blk_coord_h, blk_coord_b))], + tQsQ[None, q_handle.index], + tma_bar_ptr=q_handle.barrier, + ) + + lse_handle = load_compute_LSE_producer.acquire_and_advance() + thread_idx = tidx % self.threads_per_warp + async_copy_num_elts = sLSE.shape[0] // self.threads_per_warp + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + self.acc_dtype, + num_bits_per_copy=self.acc_dtype.width, + ) + sLSE_for_copy = cute.flat_divide(sLSE, (1,)) + LSE_for_copy = cute.flat_divide(LSE, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + LSE_idx = self.tile_shape_Q * iter_index + thread_idx * async_copy_num_elts + if cute.elem_less(LSE_idx + i, problem_shape[0]): + cute.copy( + atom_async_copy, + LSE_for_copy[None, LSE_idx + i, (blk_coord_h, blk_coord_b)], + sLSE_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + lse_handle.index, + ], + ) + else: + sLSE_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + lse_handle.index, + ].fill(0.0) + lse_handle.commit() + + v_handle = load_mma_V_producer.acquire_and_advance() + cute.copy( + tma_atom_V, + tVgV_mkl[(None, mma_tile_coord_m, 0, (blk_coord_h, blk_coord_b))], + tVsV[(None, 0)], + tma_bar_ptr=v_handle.barrier, + ) + + do_handle = load_mma_dO_producer.acquire_and_advance() + cute.copy( + tma_atom_dO, + tdOgdO_mkl[(None, iter_index, 0, (blk_coord_h, blk_coord_b))], + tdOsdO[(None, do_handle.index)], + tma_bar_ptr=do_handle.barrier, + ) + + sum_odo_handle = load_compute_sum_OdO_producer.acquire_and_advance() + sSum_OdO_for_copy = cute.flat_divide(sSum_OdO, (1,)) + sum_OdO_for_copy = cute.flat_divide(sum_OdO, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + sum_OdO_idx = self.tile_shape_Q * iter_index + thread_idx * async_copy_num_elts + if cute.elem_less(sum_OdO_idx + i, problem_shape[0]): + cute.copy( + atom_async_copy, + sum_OdO_for_copy[None, sum_OdO_idx + i, (blk_coord_h, blk_coord_b)], + sSum_OdO_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + sum_odo_handle.index, + ], + ) + else: + sSum_OdO_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + sum_odo_handle.index, + ].fill(0.0) + sum_odo_handle.commit() + + dot_handle = load_mma_dOT_producer.acquire_and_advance() + cute.copy( + tma_atom_dOT, + tdOTgdOT_mkl[(None, 0, iter_index, (blk_coord_h, blk_coord_b))], + tdOTsdOT[None, dot_handle.index], + tma_bar_ptr=dot_handle.barrier, + ) + + qt_handle = load_mma_QT_producer.acquire_and_advance() + cute.copy( + tma_atom_QT, + tQTgQT_mkl[(None, 0, iter_index, (blk_coord_h, blk_coord_b))], + tQTsQT[None, qt_handle.index], + tma_bar_ptr=qt_handle.barrier, + ) + + iter_count -= 1 + iter_index += 1 + + while iter_count > 0: + if iter_index == iter_end: + iter_index = iter_start + blk_coord_h_r += 1 + blk_coord_h = (blk_coord_h_r, blk_coord_h_k) + + q_handle = load_mma_Q_producer.acquire_and_advance() + cute.copy( + tma_atom_Q, + tQgQ_mkl[(None, iter_index, 0, (blk_coord_h, blk_coord_b))], + tQsQ[None, q_handle.index], + tma_bar_ptr=q_handle.barrier, + ) + + lse_handle = load_compute_LSE_producer.acquire_and_advance() + sLSE_for_copy = cute.flat_divide(sLSE, (1,)) + LSE_for_copy = cute.flat_divide(LSE, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + LSE_idx = self.tile_shape_Q * iter_index + thread_idx * async_copy_num_elts + if cute.elem_less(LSE_idx + i, problem_shape[0]): + cute.copy( + atom_async_copy, + LSE_for_copy[None, LSE_idx + i, (blk_coord_h, blk_coord_b)], + sLSE_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + lse_handle.index, + ], + ) + else: + sLSE_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + lse_handle.index, + ].fill(0.0) + lse_handle.commit() + + do_handle = load_mma_dO_producer.acquire_and_advance() + cute.copy( + tma_atom_dO, + tdOgdO_mkl[(None, iter_index, 0, (blk_coord_h, blk_coord_b))], + tdOsdO[None, do_handle.index], + tma_bar_ptr=do_handle.barrier, + ) + + sum_odo_handle = load_compute_sum_OdO_producer.acquire_and_advance() + sSum_OdO_for_copy = cute.flat_divide(sSum_OdO, (1,)) + sum_OdO_for_copy = cute.flat_divide(sum_OdO, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + sum_OdO_idx = self.tile_shape_Q * iter_index + thread_idx * async_copy_num_elts + if cute.elem_less(sum_OdO_idx + i, problem_shape[0]): + cute.copy( + atom_async_copy, + sum_OdO_for_copy[None, sum_OdO_idx + i, (blk_coord_h, blk_coord_b)], + sSum_OdO_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + sum_odo_handle.index, + ], + ) + else: + sSum_OdO_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + sum_odo_handle.index, + ].fill(0.0) + sum_odo_handle.commit() + + dot_handle = load_mma_dOT_producer.acquire_and_advance() + cute.copy( + tma_atom_dOT, + tdOTgdOT_mkl[(None, 0, iter_index, (blk_coord_h, blk_coord_b))], + tdOTsdOT[None, dot_handle.index], + tma_bar_ptr=dot_handle.barrier, + ) + + qt_handle = load_mma_QT_producer.acquire_and_advance() + cute.copy( + tma_atom_QT, + tQTgQT_mkl[(None, 0, iter_index, (blk_coord_h, blk_coord_b))], + tQTsQT[None, qt_handle.index], + tma_bar_ptr=qt_handle.barrier, + ) + + iter_count -= 1 + iter_index += 1 + + if not cutlass.const_expr(self.use_clc_scheduler): + load_mma_K_producer.tail() + load_mma_V_producer.tail() + load_mma_Q_producer.tail() + load_compute_LSE_producer.tail() + load_mma_dO_producer.tail() + load_mma_dOT_producer.tail() + load_compute_sum_OdO_producer.tail() + load_mma_QT_producer.tail() + + return ( + load_mma_Q_producer, + load_mma_K_producer, + load_mma_V_producer, + load_compute_LSE_producer, + load_mma_dO_producer, + load_mma_dOT_producer, + load_compute_sum_OdO_producer, + load_mma_QT_producer, + ) + + @cute.jit + def mma_2cta( + self, + KQ_tiled_mma: cute.TiledMma, + VdO_tiled_mma: cute.TiledMma, + PdO_tiled_mma: cute.TiledMma, + dSQ_tiled_mma: cute.TiledMma, + tSTtST: cute.Tensor, + tSTrQ: cute.Tensor, + tSTrK: cute.Tensor, + tdPTtdPT: cute.Tensor, + tdPTrV: cute.Tensor, + tdPTrdO: cute.Tensor, + tdVtdV: cute.Tensor, + tdVrP: cute.Tensor, + tdVrdOT: cute.Tensor, + tdKrdST: cute.Tensor, + tdKtdK: cute.Tensor, + tdKrQT: cute.Tensor, + iter_count: Int32, + load_mma_Q_consumer, + load_mma_K_consumer, + load_mma_V_consumer, + mma_compute_S_producer, + load_mma_dO_consumer, + mma_compute_dP_producer, + load_mma_dOT_consumer, + compute_mma_P_consumer, + compute_mma_dS_consumer, + load_mma_QT_consumer, + mma_compute_dKdV_producer, + ): + """CuTeDSL kernel for mma pipeline.""" + load_mma_Q_releaser = load_mma_Q_consumer.clone() + load_mma_K_releaser = load_mma_K_consumer.clone() + load_mma_V_releaser = load_mma_V_consumer.clone() + + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + + if is_leader_cta: + s_handle = mma_compute_S_producer.acquire_and_advance() + k_handle = load_mma_K_consumer.wait_and_advance() + q_handle = load_mma_Q_consumer.wait_and_advance() + + # Compute S = K * Q + for k_block in cutlass.range(0, cute.size(tSTrQ, mode=[2]), unroll_full=True): + KQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, k_block != 0) + cute.gemm( + KQ_tiled_mma, + tSTtST, + tSTrK[None, None, k_block, 0], + tSTrQ[None, None, k_block, q_handle.index], + tSTtST, + ) + q_handle.release() + + cute.arch.fence_view_async_tmem_store() + s_handle.commit() + + do_handle = load_mma_dO_consumer.wait_and_advance() + v_handle = load_mma_V_consumer.wait_and_advance() + + dp_handle = mma_compute_dP_producer.acquire_and_advance() + + # Compute dP = V * dO + VdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_block in cutlass.range(0, cute.size(tdPTrV, mode=[2]), unroll_full=True): + cute.gemm( + VdO_tiled_mma, + tdPTtdPT, + tdPTrV[None, None, k_block, 0], + tdPTrdO[None, None, k_block, do_handle.index], + tdPTtdPT, + ) + VdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + dp_handle.commit() + do_handle.release() + # V only produced once by load(); hold v_handle until end, release there via releaser + + p_handle = compute_mma_P_consumer.wait_and_advance() + dot_handle = load_mma_dOT_consumer.wait_and_advance() + + # Compute dV = P * dO (First iteration) + PdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_block in cutlass.range(0, cute.size(tdVrP, mode=[2]), unroll_full=True): + cute.gemm( + PdO_tiled_mma, + tdVtdV, + tdVrP[None, None, k_block, 0], + tdVrdOT[None, None, k_block, dot_handle.index], + tdVtdV, + ) + PdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + dot_handle.release() + p_handle.release() + + iter_count -= 1 + + dSQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + while iter_count > 0: + if is_leader_cta: + q_handle = load_mma_Q_consumer.wait_and_advance() + s_handle = mma_compute_S_producer.acquire_and_advance() + + # Compute S = K * Q + for k_block in cutlass.range(0, cute.size(tSTrQ, mode=[2]), unroll_full=True): + KQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, k_block != 0) + cute.gemm( + KQ_tiled_mma, + tSTtST, + tSTrK[None, None, k_block, 0], + tSTrQ[None, None, k_block, q_handle.index], + tSTtST, + ) + q_handle.release() + s_handle.commit() + + if is_leader_cta: + qt_handle = load_mma_QT_consumer.wait_and_advance() + ds_handle = compute_mma_dS_consumer.wait_and_advance() + + # Compute dK = dS * QT + for k_block in cutlass.range(0, cute.size(tdKrdST, mode=[2]), unroll_full=True): + cute.gemm( + dSQ_tiled_mma, + tdKtdK, + tdKrdST[ + None, + None, + k_block, + ds_handle.index, + ], + tdKrQT[None, None, k_block, qt_handle.index], + tdKtdK, + ) + dSQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + qt_handle.release() + ds_handle.release() + + if is_leader_cta: + dp_handle = mma_compute_dP_producer.acquire_and_advance() + do_handle = load_mma_dO_consumer.wait_and_advance() + # V only produced once by load(); reuse same V (index 0) for all loop iterations + VdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_block in cutlass.range(0, cute.size(tdPTrV, mode=[2]), unroll_full=True): + cute.gemm( + VdO_tiled_mma, + tdPTtdPT, + tdPTrV[None, None, k_block, 0], + tdPTrdO[None, None, k_block, do_handle.index], + tdPTtdPT, + ) + VdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + dp_handle.commit() + do_handle.release() + + if is_leader_cta: + p_handle = compute_mma_P_consumer.wait_and_advance() + dot_handle = load_mma_dOT_consumer.wait_and_advance() + + # Compute dV = P * dO (Loop iterations) + for k_block in cutlass.range(0, cute.size(tdVrP, mode=[2]), unroll_full=True): + cute.gemm( + PdO_tiled_mma, + tdVtdV, + tdVrP[None, None, k_block, 0], + tdVrdOT[None, None, k_block, dot_handle.index], + tdVtdV, + ) + PdO_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + p_handle.release() + dot_handle.release() + + iter_count -= 1 + + if is_leader_cta: + dkdv_handle = mma_compute_dKdV_producer.acquire_and_advance() + dkdv_handle.commit() + + load_mma_K_releaser.release() + load_mma_K_releaser.advance() + load_mma_V_releaser.release() + load_mma_V_releaser.advance() + + if is_leader_cta: + dkdv_handle = mma_compute_dKdV_producer.acquire_and_advance() + + ds_handle = compute_mma_dS_consumer.wait_and_advance() + qt_handle = load_mma_QT_consumer.wait_and_advance() + + # Compute dK = dS * Q + for k_block in cutlass.range(0, cute.size(tdKrdST, mode=[2]), unroll_full=True): + cute.gemm( + dSQ_tiled_mma, + tdKtdK, + tdKrdST[None, None, k_block, ds_handle.index], + tdKrQT[None, None, k_block, qt_handle.index], + tdKtdK, + ) + dSQ_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + dkdv_handle.commit() + qt_handle.release() + ds_handle.release() + + if not cutlass.const_expr(self.use_clc_scheduler): + mma_compute_S_producer.tail() + mma_compute_dP_producer.tail() + mma_compute_dKdV_producer.tail() + + return ( + mma_compute_S_producer, + mma_compute_dP_producer, + mma_compute_dKdV_producer, + load_mma_Q_consumer, + load_mma_K_consumer, + load_mma_V_consumer, + load_mma_dO_consumer, + load_mma_dOT_consumer, + compute_mma_P_consumer, + compute_mma_dS_consumer, + load_mma_QT_consumer, + ) + + @cute.jit + def reg_to_smem_mma64x64( + self, + regs: cute.Tensor, + smem: cute.Tensor, + index: Int32, + tiler_mn: tuple[Int32, Int32], + dp_idx: Int32, + wg_idx: Int32, + ): + smem_slice = smem[None, None, None, index] + # TODO: double check the layout of the data in reg. + # TODO: this may introduce additional smem transpose. + thread_layout = cute.make_ordered_layout( + tiler_mn, + (0, 1), # TODO: (0,1) or (1,0) ??? + ) + smem_slice_tmp = cute.composition(smem_slice, thread_layout) + + # TODO: temporary code for tile 64 x 64. + tmp_shape = ((8, 2, 4), (2, 4, 2, 2, 2)) + tmp_stride = ((64, 512, 1024), (1, 2, 8, 16, 32)) + smem_copy = cute.composition(smem_slice_tmp, cute.make_layout(tmp_shape, stride=tmp_stride)) + + # TODO: the following code is only for tile 64 x 64. + # TODO: need to modify the code for other tile sizes. + lane_idx = dp_idx % 32 + reg_shape = regs.shape + atom_loops = reg_shape[0][0][2] + block_loops = reg_shape[2] + # | 00 ~ 07 | 08 ~ 15 | 16 ~ 23 | 24 ~ 31 | 32 ~ 39 | 40 ~ 47 | 48 ~ 55 | 56 ~ 63 | + # |---- atom size ----|---- atom size ----|---- atom size ----|---- atom size ----| + # |---- wg0 ----|---- wg1 ----|---- wg0 ----|---- wg1 ----| + for ia in cutlass.range(atom_loops): + for ib in cutlass.range(block_loops): + # the lower 8 lines + regs_copy = regs[((None, 0, ia), 0), 0, ib] # two elements + smem_copy_slice = smem_copy[ + (lane_idx // 4, 0, dp_idx // 32), + (None, lane_idx % 4, ia, wg_idx, ib), + ] + cute.autovec_copy(regs_copy, smem_copy_slice) + # the upper 8 lines + regs_copy = regs[((None, 1, ia), 0), 0, ib] + smem_copy_slice = smem_copy[ + (lane_idx // 4, 1, dp_idx // 32), + (None, lane_idx % 4, ia, wg_idx, ib), + ] + cute.autovec_copy(regs_copy, smem_copy_slice) + + @cute.jit + def reg_to_smem_mma128x128_2cta( + self, + regs: cute.Tensor, + smem: cute.Tensor, + index: Int32, + tiler_mn: tuple[Int32, Int32], + dp_idx: Int32, + wg_idx: Int32, + smem_RowMajor: bool = True, + ): + smem_slice = smem[None, None, None, index] + # K>> smem_slice: tensor, S<3,4,3>> o ((64,16),1,(4,2)):((64,1),0,(16,4096))> + thread_layout = cute.make_ordered_layout( + # (tileN, tileM) + tiler_mn, + (0, 1), + ) + # K>> thread_layout: (64,128):(128,1) + smem_slice_tmp = cute.composition(smem_slice, thread_layout) + + # NOTE: hardcode for tcgen05.ld.32x32b.x8 & mma128x64+2cta + # tmp_shape = ((32, 2), (8, 2, 2, 2)) # for 64x64 tile + # tmp_stride = ((64, 32*64), (1, 8, 16, 32)) + # NOTE: hardcode for tcgen05.ld.32x32b.x16 & mma128x64+2cta + tmp_shape = ((32, 2), (16, 2, 2, 2)) # for 128x64 tile + tmp_stride = ((64, 32 * 64), (1, 16, 32, 64 * 64)) + # smem_copy = cute.composition(smem_slice_tmp, cute.make_layout(tmp_shape, stride=tmp_stride)) + smem_copy = cute.make_tensor( + smem_slice_tmp.iterator, cute.make_layout(tmp_shape, stride=tmp_stride) + ) + + warp_idx = dp_idx // 32 + warp_row_idx = warp_idx % 2 + warp_col_idx = warp_idx // 2 # corresponding to the second 64 cols in smem + lane_idx = dp_idx % 32 + reg_shape = ( + regs.shape + ) # ((8,1),1,2):((1,0),0,8) for 64x64, ((16,1),1,2):((1,0),0,16) for 128x64 + block_loops = reg_shape[2] + + # TODO: maybe can use cp.async for optimization + for ib in cutlass.range(block_loops): + regs_copy = regs[(None, 0), 0, ib] + smem_copy_slice = smem_copy[(lane_idx, warp_row_idx), (None, wg_idx, ib, warp_col_idx)] + cute.autovec_copy(regs_copy, smem_copy_slice) + + @cute.jit + def reg_to_smem_mma128x64_2cta( + self, + regs: cute.Tensor, + smem: cute.Tensor, + index: Int32, + tiler_mn: tuple[Int32, Int32], + dp_idx: Int32, + wg_idx: Int32, + smem_RowMajor: bool = True, + ): + smem_slice = smem[None, None, None, index] + thread_layout = cute.make_ordered_layout( + # (tileN, tileM) + tiler_mn, + (1, 0) if smem_RowMajor else (0, 1), + ) + smem_slice_tmp = cute.composition(smem_slice, thread_layout) + # NOTE: hardcode for tcgen05.ld.32x32b.x8 & mma128x64+2cta + tmp_shape = ((32, 2), (8, 2, 2, 2)) + tmp_stride = ((64, 32 * 64), (1, 8, 16, 32)) + smem_copy = cute.composition(smem_slice_tmp, cute.make_layout(tmp_shape, stride=tmp_stride)) + + warp_idx = dp_idx // 32 + warp_row_idx = warp_idx % 2 + warp_col_idx = warp_idx // 2 + lane_idx = dp_idx % 32 + reg_shape = regs.shape # ((8,1),1,2):((1,0),0,8) + block_loops = reg_shape[2] + + # TODO: maybe can use cp.async for optimization + for ib in cutlass.range(block_loops): + regs_copy = regs[(None, 0), 0, ib] + smem_copy_slice = smem_copy[(lane_idx, warp_row_idx), (None, wg_idx, ib, warp_col_idx)] + cute.autovec_copy(regs_copy, smem_copy_slice) + + @cute.jit + def compute( + self, + tSTtST: cute.Tensor, + tdPTtdPT: cute.Tensor, + tdVrP: cute.Tensor, + sP: cute.Tensor, + sLSE: cute.Tensor, + # sdS: cute.Tensor, + sdST: cute.Tensor, + sSum_OdO: cute.Tensor, + dK: cute.Tensor, + dV: cute.Tensor, + tdKtdK: cute.Tensor, + tdVtdV: cute.Tensor, + PdO_tiled_mma: cute.TiledMma, + dSQ_tiled_mma: cute.TiledMma, + blk_coord: cute.Coord, + blk_offset: cute.Shape, + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + iter_count: Int32, + iter_start: Int32, + iter_end: Int32, + scale_softmax: cutlass.Float32, + mma_compute_S_producer, + mma_compute_S_consumer, + compute_mma_P_producer, + compute_mma_P_consumer, + load_compute_LSE_producer, + load_compute_LSE_consumer, + load_compute_sum_OdO_producer, + load_compute_sum_OdO_consumer, + mma_compute_dP_producer, + mma_compute_dP_consumer, + compute_mma_dS_producer, + compute_mma_dS_consumer, + mma_compute_dKdV_producer, + mma_compute_dKdV_consumer, + varlen: bool, + sK: cute.Tensor, + problem_shape_k_cur_batch: Int32, + ): + """CuTeDSL kernel for recomputing softmax and producing dk and dv.""" + tidx, _, _ = cute.arch.thread_idx() + Q, K, _, _ = problem_shape + _, blk_coord_k, _, _ = blk_coord + + iter_index = iter_start + + # adi: TMEM_ST, TMEM_DPT + tmem_load_op = tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)) + tmem_load_atom = cute.make_copy_atom( + tmem_load_op, + self.acc_dtype, + ) + + tSTtST = tSTtST[(None, None), 0, 0] + tdPTtdPT = tdPTtdPT[(None, None), 0, 0] + + cST = cute.make_identity_tensor(cute.select(self.cta_tiler, mode=[1, 0])) + cdPT = cute.make_identity_tensor(cute.select(self.cta_tiler, mode=[1, 0])) + + num_warp_groups = self.num_compute_warps // 4 + dp_idx = tidx % 128 + wg_idx = (tidx % (self.num_compute_warps * self.threads_per_warp)) // 128 + tiled_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tSTtST) + thr_t2r = tiled_t2r.get_slice(dp_idx) + + tTR_cST = thr_t2r.partition_D(cST) + tTR_cST = split_wg(tTR_cST, num_warp_groups, wg_idx) + tTR_rST = cute.make_rmem_tensor(tTR_cST.shape, self.acc_dtype) + + tTR_tST = thr_t2r.partition_S(tSTtST) + tTR_tST = split_wg(tTR_tST, num_warp_groups, wg_idx) + + tTR_cdPT_p = thr_t2r.partition_D(cdPT) + tTR_cdPT = split_wg(tTR_cdPT_p, num_warp_groups, wg_idx) + tTR_rdPT = cute.make_rmem_tensor(tTR_cdPT.shape, self.acc_dtype) + + tTR_tdPT = thr_t2r.partition_S(tdPTtdPT) + tTR_tdPT = split_wg(tTR_tdPT, num_warp_groups, wg_idx) + + tdVcST = PdO_tiled_mma.get_slice(0).partition_A(cST) + + is_residual_k = blk_coord_k * self.tile_shape_K + self.tile_shape_K > K + last_iter = iter_end - 1 + + while iter_count > 0: + s_handle = mma_compute_S_consumer.wait_and_advance() + p_handle = compute_mma_P_producer.acquire_and_advance() + lse_handle = load_compute_LSE_consumer.wait_and_advance() + + leading_causal_masking = cutlass.Boolean(False) + if cutlass.const_expr(self.is_causal): + # TODO: could be optimized by specify an exact iter_index + # + # NOTE (causal + 2CTA correctness): + # `iter_start` can be rounded down to an even index for 2CTA. When the true + # causal boundary Q tile is odd, it becomes (iter_start + 1). In that case, + # we must treat both (iter_start, iter_start + 1) as "masked tiles" so the + # per-element causal mask is applied on the boundary tile too. + leading_causal_masking = iter_index == iter_start + q_block_min_unaligned = ( + blk_coord_k * self.tile_shape_K + Q - K - self.window_size_right + ) // self.tile_shape_Q + boundary_in_second = (q_block_min_unaligned % 2) == 1 + leading_causal_masking = leading_causal_masking or ( + boundary_in_second and (iter_index == iter_start + 1) + ) + offset_partial_tile = (K - Q) % self.tile_shape_K + need_additional_mask = offset_partial_tile and iter_index == iter_start + 1 + leading_causal_masking = leading_causal_masking or need_additional_mask + leading_causal_masking = cute.arch.shuffle_sync(leading_causal_masking, 0) + + trailing_residual_masking = cutlass.Boolean(False) + trailing_residual_masking = iter_index == last_iter or is_residual_k + trailing_residual_masking = cute.arch.shuffle_sync(trailing_residual_masking, 0) + + # For causal, every tile may contain (q,k) with k > q; we must apply per-element mask for all Q tiles. + is_masked_tile = ( + leading_causal_masking + or trailing_residual_masking + or self.has_sliding_window + or cutlass.const_expr(self.is_causal) + ) + + # Compute P = softmax(S, LSE) + cute.copy(tiled_t2r, tTR_tST, tTR_rST) + + if is_masked_tile: + for i in cutlass.range(cute.size(tTR_rST), unroll_full=True): + c_transpose = tTR_cST[i] + pos = ( + cute.get(c_transpose, mode=[1]) + iter_index * self.tile_shape_Q, + cute.get(c_transpose, mode=[0]) + blk_coord_k * self.tile_shape_K, + ) + if cutlass.const_expr(self.has_sliding_window): + if cutlass.const_expr(self.window_size_left < 0): + tTR_rST[i] = ( + -cutlass.Float32.inf + if pos[1] > pos[0] + K - Q + self.window_size_right + else tTR_rST[i] + ) + else: + max_K_index = min(pos[0] + K - Q + self.window_size_right, K) + min_K_index = max(0, pos[0] + K - Q - self.window_size_left) + tTR_rST[i] = ( + -cutlass.Float32.inf + if pos[1] > max_K_index or pos[1] < min_K_index + else tTR_rST[i] + ) + if cutlass.const_expr(self.is_causal) and ( + pos[0] + K - Q < pos[1] or not cute.elem_less(pos, (Q, K)) + ): + tTR_rST[i] = -cutlass.Float32.inf + if not cute.elem_less(pos, (Q, K)): + tTR_rST[i] = -cutlass.Float32.inf + + log2_e = cutlass.Float32(math.log2(math.e)) + softmax_scale_log2_e = scale_softmax * log2_e + + for i in cutlass.range(0, cute.size(tTR_rST), 2, unroll_full=True): + lse = ( + -sLSE[ + cute.get(tTR_cST[i], mode=[1]), + lse_handle.index, + ], + -sLSE[ + cute.get(tTR_cST[i + 1], mode=[1]), + lse_handle.index, + ], + ) + tTR_rST[i], tTR_rST[i + 1] = cute.arch.fma_packed_f32x2( + (tTR_rST[i], tTR_rST[i + 1]), + (softmax_scale_log2_e, softmax_scale_log2_e), + lse, + ) + tTR_rST[i] = cute.math.exp2(tTR_rST[i], fastmath=True) + tTR_rST[i + 1] = cute.math.exp2(tTR_rST[i + 1], fastmath=True) + + # convert fp32 P to fp16 P which will be used in the PdO + tTR_rPT = self.quantize(tTR_rST, dV.element_type) # tTR_rST is ST in fp32 in RF. + self.reg_to_smem_mma128x128_2cta( + tTR_rPT, + sP, + p_handle.index, + (self.tile_shape_K, self.tile_shape_Q), + dp_idx, + wg_idx, + ) + cute.arch.fence_view_async_shared() + cute.arch.barrier( + barrier_id=self.compute_sync_bar_id, + number_of_threads=self.num_compute_warps * self.threads_per_warp, + ) + + p_handle.commit() + + s_handle.release() + lse_handle.release() + + sum_odo_handle = load_compute_sum_OdO_consumer.wait_and_advance() + dp_handle = mma_compute_dP_consumer.wait_and_advance() + ds_handle = compute_mma_dS_producer.acquire_and_advance() + + # Compute dS = dsoftmax(P, dP, sum_OdO) + cute.copy(tiled_t2r, tTR_tdPT, tTR_rdPT) + + for i in cutlass.range(0, cute.size(tTR_rdPT), 2, unroll_full=True): + dpsum_0 = -sSum_OdO[ + cute.get(tTR_cdPT[i], mode=[1]), + sum_odo_handle.index, + ] + dpsum_1 = -sSum_OdO[ + cute.get(tTR_cdPT[i + 1], mode=[1]), + sum_odo_handle.index, + ] + if cutlass.const_expr(varlen): + if not cute.elem_less(cute.get(tTR_cdPT[i], mode=[1]), Q): + dpsum_0 = 0.0 + if not cute.elem_less(cute.get(tTR_cdPT[i + 1], mode=[1]), Q): + dpsum_1 = 0.0 + tTR_rdPT[i], tTR_rdPT[i + 1] = cute.arch.add_packed_f32x2( + (tTR_rdPT[i], tTR_rdPT[i + 1]), + (dpsum_0, dpsum_1), + ) + tTR_rdPT[i], tTR_rdPT[i + 1] = cute.arch.mul_packed_f32x2( + (tTR_rdPT[i], tTR_rdPT[i + 1]), (tTR_rST[i], tTR_rST[i + 1]) + ) + # For causal, force dS to zero at masked (q,k) so dK/dV accumulation is correct + if cutlass.const_expr(self.is_causal): + for i in cutlass.range(cute.size(tTR_rdPT), unroll_full=True): + c_transpose = tTR_cdPT[i] + pos = ( + cute.get(c_transpose, mode=[1]) + iter_index * self.tile_shape_Q, + cute.get(c_transpose, mode=[0]) + blk_coord_k * self.tile_shape_K, + ) + if pos[0] + K - Q < pos[1] or not cute.elem_less(pos, (Q, K)): + tTR_rdPT[i] = cutlass.Float32(0.0) + # convert fp32 dS to fp16 dS which will be used in the computation of dK and DQ + tTR_rdST = self.quantize(tTR_rdPT, dV.element_type) + + cute.arch.fence_view_async_tmem_load() + dp_handle.release() + + self.reg_to_smem_mma128x128_2cta( + tTR_rdST, + sdST, + ds_handle.index, + (self.tile_shape_K, self.tile_shape_Q), + dp_idx, + wg_idx, + ) + cute.arch.fence_view_async_shared() + cute.arch.barrier( + barrier_id=self.compute_sync_bar_id, + number_of_threads=self.num_compute_warps * self.threads_per_warp, + ) + + ds_handle.commit() + sum_odo_handle.release() + + iter_count -= 1 + iter_index += 1 + if iter_index == iter_end: + iter_index = iter_start + + # Epilogue + mma_compute_dKdV_consumer = self.epilogue( + blk_coord, + blk_offset, + problem_shape, + dK, + dV, + tdKtdK, + tdVtdV, + scale_softmax, + mma_compute_dKdV_producer, + mma_compute_dKdV_consumer, + problem_shape_k_cur_batch, + ) + + if not cutlass.const_expr(self.use_clc_scheduler): + compute_mma_P_producer.tail() + compute_mma_dS_producer.tail() + + return ( + compute_mma_P_producer, + compute_mma_dS_producer, + mma_compute_S_consumer, + compute_mma_P_consumer, + load_compute_LSE_consumer, + load_compute_sum_OdO_consumer, + mma_compute_dP_consumer, + compute_mma_dS_consumer, + mma_compute_dKdV_consumer, + ) + + @cute.jit + def quantize( + self, + input_t: cute.Tensor, + element_dtype: type[cutlass.Numeric], + ) -> cute.Tensor: + """Convert Float32 to element dtype.""" + output = cute.make_rmem_tensor(input_t.shape, element_dtype) + output.store(input_t.load().to(element_dtype)) + return output + + @cute.jit + def store( + self, + gmem: cute.Tensor, + regs: cute.Tensor, + coord: cute.Tensor, + tensor_shape: cute.Shape, + ): + for i in cutlass.range(cute.size(coord, mode=[2]), unroll_full=True): + coord_i = coord[None, 0, i] + gmem_i = gmem[None, 0, i] + regs_i = regs[None, 0, i] + if cute.elem_less(coord_i[0], tensor_shape): + gmem_i.store(regs_i.load().to(gmem.element_type)) + + @cute.jit + def epilogue_clear( + self, + blk_coord: cute.Coord, + blk_offset: cute.Shape, + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + dK: cute.Tensor, + dV: cute.Tensor, + ): + """Early stopping needs to clear dK and dV.""" + tidx, _, _ = cute.arch.thread_idx() + block_dim_x, _, _ = cute.arch.block_dim() + _, K, _, HB = problem_shape + _, blk_coord_k, _, blk_coord_batch = blk_coord + + mdK_offset = cute.assume(blk_offset[1] * dK.stride[0], divby=64) + mdK = cute.make_tensor( + dK.iterator + mdK_offset, + cute.make_layout((K, self.tile_shape_dQ_K, HB), stride=dK.stride), + ) + gdK = cute.local_tile( + mdK, (self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1]), (None, None, None) + ) + gdK = gdK[None, None, blk_coord_k, 0, blk_coord_batch] + cdK = cute.domain_offset( + (blk_coord_k * self.tile_shape_K, 0), + cute.make_identity_tensor((self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1])), + ) + + mdV_offset = cute.assume(blk_offset[1] * dV.stride[0], divby=64) + mdV = cute.make_tensor( + dV.iterator + mdV_offset, + cute.make_layout((K, self.tile_shape_dV_dO, HB), stride=dV.stride), + ) + gdV = cute.local_tile( + mdV, (self.PdO_mma_tiler[0], self.PdO_mma_tiler[1]), (None, None, None) + ) + gdV = gdV[None, None, blk_coord_k, 0, blk_coord_batch] + cdV = cute.domain_offset( + (blk_coord_k * self.tile_shape_K, 0), + cute.make_identity_tensor((self.PdO_mma_tiler[0], self.PdO_mma_tiler[1])), + ) + + for i in cutlass.range(tidx * 8, cute.size(gdK), block_dim_x * 8): + if cute.elem_less(cdK[i], cute.select(problem_shape, mode=[1, 2])): + gdK_i = cute.make_tensor(gdK.iterator + cute.assume(i, divby=8), (8)) + gdK_i.fill(0) + + for i in cutlass.range(tidx * 8, cute.size(gdV), block_dim_x * 8): + if cute.elem_less(cdV[i], cute.select(problem_shape, mode=[1, 2])): + gdV_i = cute.make_tensor(gdV.iterator + cute.assume(i, divby=8), (8)) + gdV_i.fill(0) + + @cute.jit + def epilogue( + self, + blk_coord: cute.Coord, + blk_offset: cute.Shape, + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + dK: cute.Tensor, + dV: cute.Tensor, + tdKtdK: cute.Tensor, + tdVtdV: cute.Tensor, + scale_softmax: cutlass.Float32, + mma_compute_dKdV_producer, + mma_compute_dKdV_consumer, + problem_shape_k_cur_batch: Int32, + ): + """Epilogue phase to store result from tensor memory to register, then global memory.""" + tidx, _, _ = cute.arch.thread_idx() + _, K, D, HB = problem_shape + _, blk_coord_k, _, blk_coord_batch = blk_coord + + # adi: TMEM_DK, TMEM_DV + tmem_copy_op = tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)) + load_op = cute.make_copy_atom( + tmem_copy_op, + self.acc_dtype, + ) + + tdKtdK = tdKtdK[(None, None), 0, 0] + mdK_offset = cute.assume(blk_offset[1] * dK.stride[0], divby=64) + mdK = cute.make_tensor( + dK.iterator + mdK_offset, + cute.make_layout((K, self.tile_shape_dQ_K, HB), stride=dK.stride), + ) + gdK = cute.local_tile(mdK, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) + gdK = gdK[None, None, blk_coord_k, 0, blk_coord_batch] + cdK = cute.domain_offset( + (blk_coord_k * self.tile_shape_K, 0), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), + ) + + num_warp_groups = self.num_compute_warps // 4 + dp_idx = tidx % 128 + wg_idx = (tidx % (self.num_compute_warps * self.threads_per_warp)) // 128 + + tiled_t2r_dK = tcgen05.make_tmem_copy(load_op, tdKtdK) + thread_t2r_dK = tiled_t2r_dK.get_slice(dp_idx) + + tTR_cdK = thread_t2r_dK.partition_D(cdK) + tTR_cdK = split_wg(tTR_cdK, num_warp_groups, wg_idx) + tTR_gdK = thread_t2r_dK.partition_D(gdK) + tTR_gdK = split_wg(tTR_gdK, num_warp_groups, wg_idx) + tTR_rdK = cute.make_rmem_tensor(tTR_cdK.shape, self.acc_dtype) + tTR_tdK = thread_t2r_dK.partition_S(tdKtdK) + tTR_tdK = split_wg(tTR_tdK, num_warp_groups, wg_idx) + + mdV_in = cute.make_tensor( + dV.iterator, cute.make_layout((K, self.cta_tiler[2], HB), stride=dV.stride) + ) + offset_mdV = cute.assume(blk_offset[1] * mdV_in.stride[0], divby=64) + mdV = cute.make_tensor(mdV_in.iterator + offset_mdV, mdV_in.layout) + gdV = cute.local_tile(mdV, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) + gdV = gdV[None, None, blk_coord_k, 0, blk_coord_batch] + + cdV = cute.domain_offset( + (blk_coord_k * self.cta_tiler[1], 0), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), + ) + + tdVtdV = tdVtdV[(None, None), 0, 0] + + tiled_t2r_dV = tcgen05.make_tmem_copy(load_op, tdVtdV) + thread_t2r_dV = tiled_t2r_dV.get_slice(dp_idx) + + tTR_cdV = thread_t2r_dV.partition_D(cdV) + tTR_cdV = split_wg(tTR_cdV, num_warp_groups, wg_idx) + tTR_gdV = thread_t2r_dV.partition_D(gdV) + tTR_gdV = split_wg(tTR_gdV, num_warp_groups, wg_idx) + tTR_rdV = cute.make_rmem_tensor(tTR_cdV.shape, self.acc_dtype) + tTR_tdV = thread_t2r_dV.partition_S(tdVtdV) + tTR_tdV = split_wg(tTR_tdV, num_warp_groups, wg_idx) + + dkdv_handle = mma_compute_dKdV_consumer.wait_and_advance() + + if blk_coord_k * self.tile_shape_K < problem_shape_k_cur_batch: + cute.copy(tiled_t2r_dV, tTR_tdV, tTR_rdV) + self.store(tTR_gdV, tTR_rdV, tTR_cdV, (K, D)) + + cute.arch.fence_view_async_tmem_load() + dkdv_handle.release() + + dkdv_handle = mma_compute_dKdV_consumer.wait_and_advance() + + if blk_coord_k * self.tile_shape_K < problem_shape_k_cur_batch: + cute.copy(tiled_t2r_dK, tTR_tdK, tTR_rdK) + + for i in cutlass.range(cute.size(tTR_rdK), unroll_full=True): + tTR_rdK[i] = scale_softmax * tTR_rdK[i] + + self.store(tTR_gdK, tTR_rdK, tTR_cdK, (K, D)) + + cute.arch.fence_view_async_tmem_load() + dkdv_handle.release() + + return mma_compute_dKdV_consumer + + def get_workspace_tensor( + self, + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + workspace: cute.Tensor, + acc_dtype: type[cutlass.Numeric], + varlen: bool, + ) -> tuple[cute.Tensor, cute.Tensor, cute.Tensor]: + """Get workspace tensor.""" + D = problem_shape[2] + H, B = cute.size(problem_shape[3][0]), cute.size(problem_shape[3][1]) + H_r, H_k = problem_shape[3][0] + D = cute.round_up(D, 8) + + # b = 1 for varlen, else batch_size + b = workspace.shape[0] + # s_q_sum for varlen, else s_q_max, already rounded to 8 + S_Q = workspace.shape[1] + + acc_bytes = acc_dtype.width // 8 + sum_OdO_bytes = cute.assume(b * H * S_Q * acc_bytes, divby=acc_bytes) + scaled_lse_bytes = cute.assume(b * H * S_Q * acc_bytes, divby=acc_bytes) + + sum_OdO_iter = workspace.iterator + scaled_lse_iter = sum_OdO_iter + sum_OdO_bytes + + sum_OdO_iter = cute.recast_ptr(sum_OdO_iter, dtype=self.acc_dtype) + scaled_lse_iter = cute.recast_ptr(scaled_lse_iter, dtype=self.acc_dtype) + + sum_OdO = cute.make_tensor( + sum_OdO_iter, + cute.make_layout( + (S_Q, ((H_r, H_k), B)), + stride=(1, ((S_Q, S_Q * H_r), 0 if varlen else S_Q * H)), + ), + ) + scaled_lse = cute.make_tensor( + scaled_lse_iter, + cute.make_layout( + (S_Q, ((H_r, H_k), B)), + stride=(1, ((S_Q, S_Q * H_r), 0 if varlen else S_Q * H)), + ), + ) + + return sum_OdO, scaled_lse + + @staticmethod + def _compute_sum_OdO_grid( + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + block_q: int, + ) -> tuple[Int32, Int32, Int32]: + """Compute grid shape for sum_OdO kernel.""" + return ( + cute.ceil_div(cute.size(problem_shape[0]), block_q), + cute.size(problem_shape[3][0]), # H + cute.size(problem_shape[3][1]), # B + ) + + @staticmethod + def _compute_bwd_grid( + problem_shape: tuple[Int32, Int32, Int32, tuple[tuple[Int32, Int32], Int32]], + block_k: int, + ) -> tuple[Int32, Int32, Int32]: + """Compute grid shape for bwd kernel.""" + K = problem_shape[1] + _, H_K = problem_shape[3][0] + B = problem_shape[3][1] + return (cute.ceil_div(K, block_k), cute.size(H_K), cute.size(B)) + + @staticmethod + def get_workspace_size(s_q: int, d: int, h: int, b: int, acc_dtype: type[cutlass.Numeric]): + """Get workspace size.""" + d = (d + 7) // 8 * 8 # round up to 8 + s_q = (s_q + 7) // 8 * 8 # round up to 8 + workspace_bytes = 0 + # OdO vector + workspace_bytes += acc_dtype.width // 8 + # scaled LSE vector + workspace_bytes += acc_dtype.width // 8 + # FP32 versions of outputs that are churned (start off with Q only) + workspace_bytes += d * acc_dtype.width // 8 + return (b, s_q, h, workspace_bytes) + + def make_and_init_load_mma_K_pipeline(self, load_mma_K_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for load mma Q.""" + load_mma_K_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_mma_K_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_mma_K_mbar_ptr, + num_stages=self.load_mma_K_stage, + producer_group=load_mma_K_producer_group, + consumer_group=load_mma_K_consumer_group, + tx_count=self.tma_copy_K_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_load_mma_V_pipeline(self, load_mma_V_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for load mma Q.""" + load_mma_V_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_mma_V_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_mma_V_mbar_ptr, + num_stages=self.load_mma_V_stage, + producer_group=load_mma_V_producer_group, + consumer_group=load_mma_V_consumer_group, + tx_count=self.tma_copy_V_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_load_mma_Q_pipeline(self, load_mma_Q_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for load mma Q.""" + load_mma_Q_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_mma_Q_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_mma_Q_mbar_ptr, + num_stages=self.load_mma_Q_stage, + producer_group=load_mma_Q_producer_group, + consumer_group=load_mma_Q_consumer_group, + tx_count=self.tma_copy_Q_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_load_mma_QT_pipeline(self, load_mma_QT_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for load mma QT.""" + load_mma_QT_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_mma_QT_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_mma_QT_mbar_ptr, + num_stages=self.load_mma_QT_stage, + producer_group=load_mma_QT_producer_group, + consumer_group=load_mma_QT_consumer_group, + tx_count=self.tma_copy_Q_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_load_mma_dO_pipeline(self, load_mma_dO_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for load mma dO.""" + load_mma_dO_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_mma_dO_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_mma_dO_mbar_ptr, + num_stages=self.load_mma_dO_stage, + producer_group=load_mma_dO_producer_group, + consumer_group=load_mma_dO_consumer_group, + tx_count=self.tma_copy_dO_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_load_compute_LSE_pipeline(self, load_compute_lse_mbar_ptr): + """Create and initialize barrier for load compute lse.""" + load_compute_lse_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp, + # self.threads_per_warp, + ) + load_compute_lse_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * self.num_compute_warps, + # self.threads_per_warp * self.num_compute_warps, + ) + return pipeline.PipelineCpAsync.create( + barrier_storage=load_compute_lse_mbar_ptr, + num_stages=self.load_compute_LSE_stage, + producer_group=load_compute_lse_producer_group, + consumer_group=load_compute_lse_consumer_group, + ) + + def make_and_init_load_compute_sum_OdO_pipeline(self, load_compute_sum_OdO_mbar_ptr): + """Create and initialize barrier for load sum OdO.""" + load_compute_sum_OdO_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp, + # self.threads_per_warp, + ) + load_compute_sum_OdO_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * self.num_compute_warps, + # self.threads_per_warp * self.num_compute_warps, + ) + return pipeline.PipelineCpAsync.create( + barrier_storage=load_compute_sum_OdO_mbar_ptr, + num_stages=self.load_compute_sum_OdO_stage, + producer_group=load_compute_sum_OdO_producer_group, + consumer_group=load_compute_sum_OdO_consumer_group, + ) + + def make_and_init_mma_compute_S_pipeline(self, mma_compute_S_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for mma S.""" + mma_compute_S_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + mma_compute_S_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0], + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_compute_S_mbar_ptr, + num_stages=self.mma_compute_S_stage, + producer_group=mma_compute_S_producer_group, + consumer_group=mma_compute_S_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Barrier to between dP = v * dO and consume of dP in compute() + def make_and_init_mma_compute_dP_pipeline(self, mma_compute_dP_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for mma Q.""" + mma_compute_dP_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + mma_compute_dP_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0], + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_compute_dP_mbar_ptr, + num_stages=self.mma_compute_dP_stage, + producer_group=mma_compute_dP_producer_group, + consumer_group=mma_compute_dP_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_compute_mma_P_pipeline(self, compute_mma_P_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for mma P.""" + compute_mma_P_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0], + # self.num_compute_warps * self.threads_per_warp, + ) + compute_mma_P_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + return pipeline.PipelineAsyncUmma.create( + barrier_storage=compute_mma_P_mbar_ptr, + num_stages=self.compute_mma_P_stage, + producer_group=compute_mma_P_producer_group, + consumer_group=compute_mma_P_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_compute_mma_dS_pipeline(self, compute_mma_dS_mbar_ptr, cluster_layout_vmnk): + """Create and initialize barrier for mma dS.""" + + compute_mma_dS_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0], + # self.num_compute_warps * self.threads_per_warp, + ) + compute_mma_dS_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + + return pipeline.PipelineAsyncUmma.create( + barrier_storage=compute_mma_dS_mbar_ptr, + num_stages=self.compute_mma_dS_stage, + producer_group=compute_mma_dS_producer_group, + consumer_group=compute_mma_dS_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_and_init_mma_compute_dKdV_pipeline( + self, mma_compute_dKdV_mbar_ptr, cluster_layout_vmnk + ): + """Create and initialize barrier for mma dKdV.""" + mma_compute_dKdV_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_warp_id]), + ) + mma_compute_dKdV_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_compute_warps * self.threads_per_warp * cluster_layout_vmnk.shape[0][0], + # self.num_compute_warps * self.threads_per_warp, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_compute_dKdV_mbar_ptr, + num_stages=self.mma_compute_dKdV_stage, + producer_group=mma_compute_dKdV_producer_group, + consumer_group=mma_compute_dKdV_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py new file mode 100644 index 00000000000..346f3c1b90d --- /dev/null +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py @@ -0,0 +1,2145 @@ +# Copyright (c) 2025, Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms, Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin. + +from typing import Type, Tuple, Optional + +import cuda.bindings.driver as cuda + +import math +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +from cutlass.cute.nvgpu import cpasync +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.typing import Int32, Int64, Float32 + +from cutlass.utils import ClcDynamicPersistentTileScheduler +from flash_attn.cute.tile_scheduler import ( + ClcState, + compute_sm100_fmha_grid as compute_grid, + compute_sm100_fmha_grid_clc as compute_grid_clc, + make_sm100_thread_cooperative_group as make_thread_cooperative_group, + Sm100FmhaStaticTileScheduler as FmhaStaticTileScheduler, + Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams, + Sm100FmhaClcDynamicTileScheduler as FmhaClcDynamicTileScheduler, + Sm100FmhaClcDynamicTileSchedulerParams as FmhaClcDynamicTileSchedulerParams, +) +from flash_attn.cute.mask import ( + Sm100FusedMask as FusedMask, +) +from flash_attn.cute.tile_scheduler import SM100_TMEM_CAPACITY_COLUMNS + + +class BlackwellFusedMultiHeadAttentionBackwardDQKernel: + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + mma_tiler: Tuple[int, int, int], + is_causal: bool, + window_size_left: int | None, + window_size_right: int | None, + is_persistent: bool, + split_head: bool, + use_clc_scheduler: bool = False, + ): + self.acc_dtype = acc_dtype + self.mma_tiler = mma_tiler + self.is_causal = is_causal + self.window_size_left = window_size_left + # Keep original behavior (known-good in this repo) + window_size_left = ( + None + if (window_size_left is None or window_size_left < 0) + else cutlass.Int32(window_size_left) + ) + window_size_right = ( + None + if (window_size_right is None or window_size_right < 0) + else cutlass.Int32(window_size_right) + ) + self.window_size_left = None if self.is_causal else window_size_left + self.window_size_right = cutlass.Int32(0) if self.is_causal else window_size_right + self.is_local = (not self.is_causal) and ( + self.window_size_left is not None or self.window_size_right is not None + ) + assert mma_tiler[0] == 128 and mma_tiler[1] == 128, "Only 128x128 tile impl is supported" + assert mma_tiler[2] == 256, "Only 256 is supported for 128x128 tile impl" + self.cta_tiler = ( + mma_tiler[0], + mma_tiler[1], + mma_tiler[2], + ) + self.qk_mma_tiler = ( + 2 * mma_tiler[0], + mma_tiler[1], + min(self.cta_tiler[2], 128) if split_head else self.cta_tiler[2], + ) + self.dov_mma_tiler = self.qk_mma_tiler + self.dsk_mma_tiler = ( + 2 * mma_tiler[0], + min(self.cta_tiler[2], 128) if split_head else self.cta_tiler[2], + mma_tiler[1], + ) + + self.dsk_block_tiler = ( + self.dsk_mma_tiler[0] // 2, + self.dsk_mma_tiler[1], + self.dsk_mma_tiler[2], + ) + self.iterations_qk = self.cta_tiler[2] // self.qk_mma_tiler[2] + self.iterations_dov = self.cta_tiler[2] // self.dov_mma_tiler[2] + self.iterations_dsk = self.cta_tiler[2] // self.dsk_mma_tiler[1] + self.cluster_shape_mn = (2, 1) + self.tmem_warp_shape_mn = (4, 1) + self.is_persistent = is_persistent + self.use_clc_scheduler = use_clc_scheduler + self.use_semantic_trip_range = self.is_causal or self.is_local + + self.compute_warp_ids = (0, 1, 2, 3) + self.epilogue_warp_ids = (4, 5, 6, 7) + self.mma_warp_id = 8 + self.load_warp_id = 9 + self.empty_warp_id = (10, 11) + self.sched_warp_id = self.empty_warp_id[0] if use_clc_scheduler else None + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + self.num_compute_warps = len(self.compute_warp_ids) + + self.cta_sync_bar_id = 0 + self.tmem_alloc_sync_bar_id = 1 + self.compute_sync_bar_id = 2 + + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.compute_warp_ids, # this is to get a round num threads + *self.epilogue_warp_ids, + self.mma_warp_id, + self.load_warp_id, + *self.empty_warp_id, + ) + ) + + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + + self.tmem_s_offset = 0 + self.tmem_dp_offset = 128 + self.tmem_dq_offset = 256 + + self.num_regs_compute = 256 + self.num_regs_epilogue = 160 + self.num_regs_other = 32 + + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + self.q_stage = self.iterations_qk + self.k_stage = self.iterations_qk + self.do_stage = self.iterations_dov + self.v_stage = self.iterations_dov + self.kt_stage = 1 + self.qk_acc_stage = 1 + self.dov_acc_stage = 1 + self.dsk_acc_stage = 1 + self.epi_stage = 1 + self.load_compute_LSE_stage = 1 + self.load_compute_sum_OdO_stage = 1 + if cutlass.const_expr(self.use_clc_scheduler): + self.num_clc_stage = 1 + self.num_clc_response_bytes = 16 + + @cute.jit + def __call__( + self, + q_tensor: cute.Tensor, + k_tensor: cute.Tensor, + v_tensor: cute.Tensor, + dq_tensor: cute.Tensor, + do_tensor: cute.Tensor, + lse_tensor: cute.Tensor, + sum_odo_tensor: cute.Tensor, + cum_seqlen_q: Optional[cute.Tensor], + cum_seqlen_k: Optional[cute.Tensor], + scale_softmax: cutlass.Float32, + stream: cuda.CUstream, + ): + varlen = cum_seqlen_q is not None or cum_seqlen_k is not None + # Infer shape metadata from normalized 5D tensors (B, S, H_k, H_r, D), + # similar to the dedicated hd256 forward path. + s_q = q_tensor.shape[1] + s_k = k_tensor.shape[1] + d = q_tensor.shape[4] + h_k = q_tensor.shape[2] + h_r = q_tensor.shape[3] + if cutlass.const_expr(cum_seqlen_q is not None): + b = cum_seqlen_q.shape[0] - 1 + elif cutlass.const_expr(cum_seqlen_k is not None): + b = cum_seqlen_k.shape[0] - 1 + else: + b = q_tensor.shape[0] + # `lse_tensor` / `sum_odo_tensor` are preallocated FP32 buffers (lse_log2 / dpsum) + # whose sequence dimension is padded (rounded up) by the caller. + # Use their leading dimension as the LSE length to ensure correct batch strides. + s_lse = lse_tensor.shape[0] + s_q64 = Int64(s_q) + s_k64 = Int64(s_k) + s_lse64 = Int64(s_lse) + d64 = cute.assume(Int64(d), divby=128) + h_r64 = Int64(h_r) + h_k64 = Int64(h_k) + b64 = Int64(b) + # Packed-varlen representation uses batch-dim = 1 and sequence-dim = total_{q,k}. + # Keep the *physical* sequence extent in the tensor layouts so that applying + # `cuseqlen_*` offsets stays within the tensor domain. + s_q_total = q_tensor.shape[1] if cum_seqlen_q is not None else s_q64 + s_k_total = k_tensor.shape[1] if cum_seqlen_k is not None else s_k64 + stride_b_qo = h_r64 * h_k64 * s_q64 * d64 if cum_seqlen_q is None else 0 + stride_b_kv = h_k64 * s_k64 * d64 if cum_seqlen_k is None else 0 + b_lse = b64 if cum_seqlen_q is None else 1 + stride_b_lse = h_r64 * h_k64 * s_lse64 if cum_seqlen_q is None else 0 + + # (s, d, ((h_r, h_k), b)) + q_layout = cute.make_layout( + (s_q_total, d, ((h_r, h_k), b)), + stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + ) + q = cute.make_tensor(q_tensor.iterator, q_layout) + # (s, d, ((h_r, h_k), b)) + do_layout = cute.make_layout( + (s_q_total, d, ((h_r, h_k), b)), + stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + ) + do = cute.make_tensor(do_tensor.iterator, do_layout) + # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast + k_layout = cute.make_layout( + (s_k_total, d, ((h_r, h_k), b)), + stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)), + ) + k = cute.make_tensor(k_tensor.iterator, k_layout) + # (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast + kt_layout = cute.make_layout( + (d, s_k_total, ((h_r, h_k), b)), + stride=(1, d64 * h_k64, ((0, d64), stride_b_kv)), + ) + kt = cute.make_tensor(k_tensor.iterator, kt_layout) + # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast + v_layout = cute.make_layout( + (s_k_total, d, ((h_r, h_k), b)), + stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)), + ) + v = cute.make_tensor(v_tensor.iterator, v_layout) + # (s, ((h_r, h_k), b)) + lse = cute.make_tensor(lse_tensor.iterator, lse_tensor.layout) + # (s, ((h_r, h_k), b)) + sum_odo_layout = cute.make_layout( + (s_lse64, ((h_r, h_k), b_lse)), + stride=(1, ((s_lse64, h_r64 * s_lse64), stride_b_lse)), + ) + sum_odo = cute.make_tensor(sum_odo_tensor.iterator, sum_odo_layout) + # (s, d, ((h_r, h_k), b)) + dq_layout = cute.make_layout( + (s_q_total, d, ((h_r, h_k), b)), + stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + ) + dq = cute.make_tensor(dq_tensor.iterator, dq_layout) + + # setup static attributes before smem/grid/tma computation + self.q_dtype = q.element_type + self.k_dtype = k.element_type + self.v_dtype = v.element_type + self.do_dtype = do.element_type + self.dq_dtype = dq.element_type + self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.q_dtype.width + + if cutlass.const_expr(self.use_clc_scheduler): + self.tile_sched_params, grid = compute_grid_clc( + (s_q, dq.shape[1], dq.shape[2]) if cum_seqlen_q is not None else dq.shape, + self.cta_tiler, + (*self.cluster_shape_mn, 1), + ) + else: + self.tile_sched_params, grid = compute_grid( + (s_q, dq.shape[1], dq.shape[2]) if cum_seqlen_q is not None else dq.shape, + self.cta_tiler, + self.is_persistent, + ) + + self.q_major_mode = utils.LayoutEnum.from_tensor(q).mma_major_mode() + self.do_major_mode = utils.LayoutEnum.from_tensor(do).mma_major_mode() + self.k_major_mode = utils.LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major_mode = utils.LayoutEnum.from_tensor(v).mma_major_mode() + self.dq_layout = utils.LayoutEnum.from_tensor(dq) + + if cutlass.const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of q is not supported") + if cutlass.const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of k is not supported") + if cutlass.const_expr(self.v_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of v is not supported") + if cutlass.const_expr(self.do_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of v is not supported") + + # check type consistency + if cutlass.const_expr(self.q_dtype != self.k_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") + if cutlass.const_expr(self.q_dtype != self.v_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + if cutlass.const_expr(self.q_dtype != self.do_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.do_dtype}") + + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.TWO + # the intermediate tensor p is from tmem & k-major + ds_source = tcgen05.OperandSource.TMEM + ds_major_mode = tcgen05.OperandMajorMode.K + k_trans_major_mode = tcgen05.OperandMajorMode.MN + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.acc_dtype, + cta_group, + self.qk_mma_tiler[:2], + ) + dov_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.do_dtype, + self.do_major_mode, + self.v_major_mode, + self.acc_dtype, + cta_group, + self.dov_mma_tiler[:2], + ) + dsk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + ds_major_mode, + k_trans_major_mode, + self.acc_dtype, + cta_group, + self.dsk_mma_tiler[:2], + ds_source, + ) + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + + self.epi_tile = self.dsk_block_tiler[:2] + + q_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.qk_mma_tiler, + self.q_dtype, + self.q_stage, + ) + k_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + self.qk_mma_tiler, + self.k_dtype, + self.k_stage, + ) + do_smem_layout_staged = sm100_utils.make_smem_layout_a( + dov_tiled_mma, + self.dov_mma_tiler, + self.do_dtype, + self.do_stage, + ) + v_smem_layout_staged = sm100_utils.make_smem_layout_b( + dov_tiled_mma, + self.dov_mma_tiler, + self.v_dtype, + self.v_stage, + ) + ds_tmem_layout_staged = sm100_utils.make_smem_layout_a( + dsk_tiled_mma, + self.dsk_mma_tiler, + self.q_dtype, + self.qk_acc_stage, + ) + ds_tmem_layout = cute.select(ds_tmem_layout_staged, mode=[0, 1, 2]) + kt_smem_layout_staged = sm100_utils.make_smem_layout_b( + dsk_tiled_mma, + self.dsk_mma_tiler, + self.k_dtype, + self.dsk_acc_stage, + ) + + # TMA load for Q + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) + + q_smem_layout = cute.select(q_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q, tma_tensor_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q, + q_smem_layout, + self.qk_mma_tiler, + qk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + # TMA load for K + k_smem_layout = cute.select(k_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_k, tma_tensor_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k, + k_smem_layout, + self.qk_mma_tiler, + qk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + # TMA load for dO + do_smem_layout = cute.select(do_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_do, tma_tensor_do = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + do, + do_smem_layout, + self.dov_mma_tiler, + dov_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + # TMA load for V + v_smem_layout = cute.select(v_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_v, tma_tensor_v = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + v, + v_smem_layout, + self.dov_mma_tiler, + dov_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + # TMA load for KT + kt_smem_layout = cute.select(kt_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_kt, tma_tensor_kt = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + kt, + kt_smem_layout, + self.dsk_mma_tiler, + dsk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + lse_smem_layout = cute.make_layout((self.cta_tiler[0], self.load_compute_LSE_stage)) + sum_odo_smem_layout = cute.make_layout((self.cta_tiler[0], self.load_compute_sum_OdO_stage)) + + q_copy_size = cute.size_in_bytes(self.q_dtype, q_smem_layout) + k_copy_size = cute.size_in_bytes(self.k_dtype, k_smem_layout) + do_copy_size = cute.size_in_bytes(self.do_dtype, do_smem_layout) + v_copy_size = cute.size_in_bytes(self.v_dtype, v_smem_layout) + kt_copy_size = cute.size_in_bytes(self.k_dtype, kt_smem_layout) + self.tma_copy_q_bytes = q_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + self.tma_copy_k_bytes = k_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + self.tma_copy_do_bytes = do_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + self.tma_copy_v_bytes = v_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + self.tma_copy_kt_bytes = kt_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + + @cute.struct + class SharedStorage: + # TMA G2S load barriers: LOAD warp (producer) -> MMA warp (consumer) + load_q_mbar_ptr: cute.struct.MemRange[ + Int64, self.q_stage * 2 + ] # load_q_{producer,consumer} + load_do_mbar_ptr: cute.struct.MemRange[ + Int64, self.do_stage * 2 + ] # load_do_{producer,consumer} + load_k_mbar_ptr: cute.struct.MemRange[ + Int64, self.k_stage * 2 + ] # load_k_{producer,consumer} + load_kt_mbar_ptr: cute.struct.MemRange[ + Int64, self.kt_stage * 2 + ] # load_kt_{producer,consumer} + load_v_mbar_ptr: cute.struct.MemRange[ + Int64, self.v_stage * 2 + ] # load_v_{producer,consumer} + mma_s_mbar_ptr: cute.struct.MemRange[Int64, self.qk_acc_stage * 2] + mma_dp_mbar_ptr: cute.struct.MemRange[Int64, self.dov_acc_stage * 2] + mma_dq_mbar_ptr: cute.struct.MemRange[Int64, self.epi_stage * 2] + ds_mma_mbar_ptr: cute.struct.MemRange[Int64, self.dsk_acc_stage * 2] + lse_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_compute_LSE_stage * 2] + sum_odo_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_compute_sum_OdO_stage * 2 + ] + # A CTA-wide "TMEM lifetime" barrier used to safely deallocate TMEM after all users finish. + tmem_dealloc_mbar_ptr: Int64 + # Tmem holding buffer + tmem_holding_buf: Int32 + # CLC pipeline barriers and response buffer + clc_mbar_ptr: cute.struct.MemRange[Int64, 2] + clc_response: cute.struct.MemRange[Int32, 4] + + self.shared_storage = SharedStorage + + grid = cute.round_up(grid, self.cluster_shape_mnk) + # Launch the kernel synchronously + self.kernel( + qk_tiled_mma, + dov_tiled_mma, + dsk_tiled_mma, + tma_atom_q, + tma_tensor_q, + tma_atom_k, + tma_tensor_k, + tma_atom_v, + tma_tensor_v, + tma_atom_do, + tma_tensor_do, + tma_atom_kt, + tma_tensor_kt, + lse, + sum_odo, + dq, + cum_seqlen_q, + cum_seqlen_k, + scale_softmax, + self.window_size_left, + self.window_size_right, + self.cluster_layout_vmnk, + q_smem_layout_staged, + k_smem_layout_staged, + v_smem_layout_staged, + do_smem_layout_staged, + kt_smem_layout_staged, + ds_tmem_layout, + lse_smem_layout, + sum_odo_smem_layout, + self.tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + min_blocks_per_mp=1, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + qk_tiled_mma: cute.TiledMma, + dov_tiled_mma: cute.TiledMma, + dsk_tiled_mma: cute.TiledMma, + tma_atom_q: cute.CopyAtom, + mQ_qdl: cute.Tensor, + tma_atom_k: cute.CopyAtom, + mK_kdl: cute.Tensor, + tma_atom_v: cute.CopyAtom, + mV_dkl: cute.Tensor, + tma_atom_do: cute.CopyAtom, + mdO_qdl: cute.Tensor, + tma_atom_kt: cute.CopyAtom, + mK_dkl: cute.Tensor, + mLSE: cute.Tensor, + mSum_OdO: cute.Tensor, + mdQ_qdl: cute.Tensor, + cum_seqlen_q: Optional[cute.Tensor], + cum_seqlen_k: Optional[cute.Tensor], + scale_softmax: Float32, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + cluster_layout_vmnk: cute.Layout, + q_smem_layout_staged: cute.ComposedLayout, + k_smem_layout_staged: cute.ComposedLayout, + v_smem_layout_staged: cute.ComposedLayout, + do_smem_layout_staged: cute.ComposedLayout, + kt_smem_layout_staged: cute.ComposedLayout, + ds_tmem_layout_staged: cute.ComposedLayout, + lse_smem_layout: cute.Layout, + sum_odo_smem_layout: cute.Layout, + tile_sched_params: FmhaStaticTileSchedulerParams | FmhaClcDynamicTileSchedulerParams, + ): + # llvm.inline_asm( + # None, + # [], + # '.pragma "global knob CommonIntoMultiBlockLoop=1";', + # "", + # has_side_effects=True, + # is_align_stack=False, + # asm_dialect=llvm.AsmDialect.AD_ATT, + # ) + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # + # Prefetch tma desc + # + if warp_idx == self.load_warp_id: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_q) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_k) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_v) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_do) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_kt) + + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + varlen = cum_seqlen_q is not None or cum_seqlen_k is not None + mma_tile_coord_v = bidx % cute.size(qk_tiled_mma.thr_id.shape) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + + # Alloc + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + load_q_producer, load_q_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.q_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_q_bytes, + barrier_storage=storage.load_q_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_k_producer, load_k_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.k_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_k_bytes, + barrier_storage=storage.load_k_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_v_producer, load_v_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.v_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_v_bytes, + barrier_storage=storage.load_v_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_do_producer, load_do_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.do_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_do_bytes, + barrier_storage=storage.load_do_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_kt_producer, load_kt_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.kt_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_kt_bytes, + barrier_storage=storage.load_kt_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + mma_s_producer, mma_s_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.qk_acc_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + len(self.compute_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + barrier_storage=storage.mma_s_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + mma_dp_producer, mma_dp_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.dov_acc_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + len(self.compute_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + barrier_storage=storage.mma_dp_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + ds_mma_producer, ds_mma_consumer = pipeline.PipelineAsyncUmma.create( + num_stages=self.dsk_acc_stage, + producer_group=make_thread_cooperative_group( + len(self.compute_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.ds_mma_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + mma_dq_producer, mma_dq_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.epi_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + len(self.epilogue_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + barrier_storage=storage.mma_dq_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + + load_lse_producer, load_lse_consumer = pipeline.PipelineCpAsync.create( + num_stages=self.load_compute_LSE_stage, + producer_group=make_thread_cooperative_group(self.threads_per_warp), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * self.num_compute_warps + ), + barrier_storage=storage.lse_mbar_ptr.data_ptr(), + ).make_participants() + load_sum_odo_producer, load_sum_odo_consumer = pipeline.PipelineCpAsync.create( + num_stages=self.load_compute_sum_OdO_stage, + producer_group=make_thread_cooperative_group(self.threads_per_warp), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * self.num_compute_warps + ), + barrier_storage=storage.sum_odo_mbar_ptr.data_ptr(), + ).make_participants() + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilogue_warp_ids[0], + is_two_cta=True, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + tmem.allocate(self.tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + # Initialize CLC state if using dynamic scheduler + if cutlass.const_expr(self.use_clc_scheduler): + clc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + cluster_size = cute.size(self.cluster_shape_mnk) + num_clc_consumer_threads = self.threads_per_warp * ( + 1 # sched_warp (CTA 0 only) + + cluster_size + * ( + len(self.compute_warp_ids) + + len(self.epilogue_warp_ids) + + 1 # mma_warp + + 1 # load_warp + ) + ) + clc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_clc_consumer_threads + ) + clc_response_ptr = storage.clc_response.data_ptr() + clc = ClcState.create( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_sched_params.clc_hw_params(), + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + ), + pipeline=pipeline.PipelineClcFetchAsync.create( + barrier_storage=storage.clc_mbar_ptr.data_ptr(), + num_stages=self.num_clc_stage, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=self.num_clc_response_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ), + consumer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_clc_stage + ), + producer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_clc_stage + ), + ) + else: + clc = None + clc_response_ptr = None + + # Cluster arrive after barrier init + pipeline.pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True) + + sQ = smem.allocate_tensor( + element_type=self.q_dtype, + layout=q_smem_layout_staged.outer, + swizzle=q_smem_layout_staged.inner, + byte_alignment=128, + ) + sK = smem.allocate_tensor( + element_type=self.k_dtype, + layout=k_smem_layout_staged.outer, + swizzle=k_smem_layout_staged.inner, + byte_alignment=128, + ) + # K and V now use separate memory since we removed the transform stage + sV = smem.allocate_tensor( + element_type=self.v_dtype, + layout=v_smem_layout_staged.outer, + swizzle=v_smem_layout_staged.inner, + byte_alignment=128, + ) + sdO = smem.allocate_tensor( + element_type=self.do_dtype, + layout=do_smem_layout_staged.outer, + swizzle=do_smem_layout_staged.inner, + byte_alignment=128, + ) + sKT = smem.allocate_tensor( + element_type=self.k_dtype, + layout=kt_smem_layout_staged.outer, + swizzle=kt_smem_layout_staged.inner, + byte_alignment=128, + ) + sLSE = smem.allocate_tensor( + element_type=self.acc_dtype, + layout=lse_smem_layout, + byte_alignment=128, + ) + sSum_OdO = smem.allocate_tensor( + element_type=self.acc_dtype, + layout=sum_odo_smem_layout, + byte_alignment=128, + ) + qk_thr_mma = qk_tiled_mma.get_slice(mma_tile_coord_v) # default 1sm + dov_thr_mma = dov_tiled_mma.get_slice(mma_tile_coord_v) # default 1sm + dsk_thr_mma = dsk_tiled_mma.get_slice(mma_tile_coord_v) # default 1sm + tSrQ = qk_thr_mma.make_fragment_A(sQ) + tSrK = qk_thr_mma.make_fragment_B(sK) + tdPrdO = dov_thr_mma.make_fragment_A(sdO) + tdPrV = dov_thr_mma.make_fragment_B(sV) + tdQrKT = dsk_thr_mma.make_fragment_B(sKT) + qk_acc_shape = qk_thr_mma.partition_shape_C((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tStS = qk_thr_mma.make_fragment_C(cute.append(qk_acc_shape, self.qk_acc_stage)) + dov_acc_shape = dov_thr_mma.partition_shape_C( + (self.dov_mma_tiler[0], self.dov_mma_tiler[1]) + ) + tdPtdP = dov_thr_mma.make_fragment_C(cute.append(dov_acc_shape, self.dov_acc_stage)) + dsk_acc_shape = dsk_thr_mma.partition_shape_C( + (self.dsk_mma_tiler[0], self.dsk_mma_tiler[1]) + ) + tdQtdQ = dsk_thr_mma.make_fragment_C(dsk_acc_shape) + tdQtdQ_layout = cute.append( + tdQtdQ.layout, + cute.make_layout( + self.iterations_dsk, + stride=self.dsk_mma_tiler[1] // self.tmem_warp_shape_mn[1], + ), + ) + tStS = cute.make_tensor(tStS.iterator + self.tmem_s_offset, tStS.layout) + tdPtdP = cute.make_tensor(tdPtdP.iterator + self.tmem_dp_offset, tdPtdP.layout) + tdQtdQ_staged = cute.make_tensor(tdQtdQ.iterator + self.tmem_dq_offset, tdQtdQ_layout) + + # /////////////////////////////////////////////////////////////////////////////// + # EMPTY + # /////////////////////////////////////////////////////////////////////////////// + for _i in cutlass.range_constexpr(len(self.empty_warp_id)): + if warp_idx == self.empty_warp_id[_i]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + if cutlass.const_expr(self.use_clc_scheduler): + tile_sched = FmhaClcDynamicTileScheduler.create( + tile_sched_params, + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + clc, + ) + else: + blk_idx = cute.arch.block_idx() + tile_sched = FmhaStaticTileScheduler( + tile_sched_params, blk_idx[0], blk_idx, cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + # Cluster wait + pipeline.pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk) + + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + continue_cond = False + batch_coord = curr_block_coord[2][1] + seqlen_q = mQ_qdl.shape[0] + seqlen_k = mK_kdl.shape[0] + cuseqlen_q = Int32(0) + cuseqlen_k = Int32(0) + block_offset = ( + Int32(0), + Int32(0), + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + block_offset = ( + cuseqlen_q, + cuseqlen_k, + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if not continue_cond: + mQ_qdl_ = cute.domain_offset(cute.select(block_offset, mode=[0, 2, 3]), mQ_qdl) + mK_kdl_ = cute.domain_offset(cute.select(block_offset, mode=[1, 2, 3]), mK_kdl) + mdO_qdl_ = cute.domain_offset( + cute.select(block_offset, mode=[0, 2, 3]), mdO_qdl + ) + mV_dkl_ = cute.domain_offset(cute.select(block_offset, mode=[1, 2, 3]), mV_dkl) + mK_dkl_ = cute.domain_offset(cute.select(block_offset, mode=[2, 1, 3]), mK_dkl) + block_offset_stats = block_offset + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q_stats = cute.assume( + (cuseqlen_q + batch_coord * self.cta_tiler[0]) + // self.cta_tiler[0] + * self.cta_tiler[0], + divby=self.cta_tiler[0], + ) + block_offset_stats = ( + cuseqlen_q_stats, + block_offset[1], + block_offset[2], + block_offset[3], + ) + LSE = cute.domain_offset(cute.select(block_offset_stats, mode=[0, 3]), mLSE) + sum_OdO = cute.domain_offset( + cute.select(block_offset_stats, mode=[0, 3]), mSum_OdO + ) + + # Local tile partition global tensors + q_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # (bM, bK, loopM, loopK, loopL) + gQ_qdl = cute.flat_divide(mQ_qdl_, cute.select(self.qk_mma_tiler, mode=[0, 2])) + tSgQ_qdl = qk_thr_mma.partition_A(gQ_qdl) + tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_q, + block_in_cluster_coord_vmnk[2], + q_cta_layout, + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ_qdl, 0, 3), + ) + k_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + gK_kdl = cute.flat_divide(mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2])) + tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_k, + block_in_cluster_coord_vmnk[1], + k_cta_layout, + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_kdl, 0, 3), + ) + do_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # (bM, bK, loopM, loopK, loopL) + gdO_qdl = cute.flat_divide( + mdO_qdl_, cute.select(self.dov_mma_tiler, mode=[0, 2]) + ) + tdPgdO_qdl = dov_thr_mma.partition_A(gdO_qdl) + tdOsdO, tdOgdO_qdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_do, + block_in_cluster_coord_vmnk[2], + do_cta_layout, + cute.group_modes(sdO, 0, 3), + cute.group_modes(tdPgdO_qdl, 0, 3), + ) + v_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + gV_dkl = cute.flat_divide(mV_dkl_, cute.select(self.dov_mma_tiler, mode=[1, 2])) + tSgV_dkl = dov_thr_mma.partition_B(gV_dkl) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_v, + block_in_cluster_coord_vmnk[1], + v_cta_layout, + cute.group_modes(sV, 0, 3), + cute.group_modes(tSgV_dkl, 0, 3), + ) + # kt layout + kt_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + gK_dkl = cute.flat_divide(mK_dkl_, cute.select(self.dsk_mma_tiler, mode=[1, 2])) + tdQgK_dkl = dsk_thr_mma.partition_B(gK_dkl) + tKTsKT, tKgK_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_kt, + block_in_cluster_coord_vmnk[1], + kt_cta_layout, + cute.group_modes(sKT, 0, 3), + cute.group_modes(tdQgK_dkl, 0, 3), + ) + # ((atom_v, rest_v), RestK) + tQgQ = tQgQ_qdl[None, mma_block_coord[0], None, mma_block_coord[2]] + # ((atom_v, rest_v), RestK) + tdOgdO = tdOgdO_qdl[None, mma_block_coord[0], None, mma_block_coord[2]] + # ((atom_v, rest_v), RestN, RestK) + tKgK = tKgK_kdl[None, None, None, mma_block_coord[2]] + # ((atom_v, rest_v), RestN, RestK) + tVgV = tVgV_dkl[None, None, None, mma_block_coord[2]] + # ((atom_v, rest_v), RestN, RestK) + tKTgKT = tKgK_dkl[None, None, None, mma_block_coord[2]] + + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + # LSE + lse_handle = load_lse_producer.acquire_and_advance() + # 32 threads loading 128 values of 32b each + # so 4*32b = 128b + thread_idx = tidx % self.threads_per_warp + async_copy_num_elts = sLSE.shape[0] // self.threads_per_warp + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + self.acc_dtype, + num_bits_per_copy=self.acc_dtype.width, + ) + sLSE_for_copy = cute.flat_divide(sLSE, (1,)) + LSE_for_copy = cute.flat_divide(LSE, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + LSE_idx = ( + self.cta_tiler[0] * curr_block_coord[0] + + thread_idx * async_copy_num_elts + ) + if cute.elem_less(LSE_idx + i, seqlen_q): + cute.copy( + atom_async_copy, + LSE_for_copy[None, LSE_idx + i, curr_block_coord[2]], + sLSE_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + lse_handle.index, + ], + ) + else: + sLSE_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + lse_handle.index, + ].fill(0.0) + lse_handle.commit() + + sum_odo_handle = load_sum_odo_producer.acquire_and_advance() + sSum_OdO_for_copy = cute.flat_divide(sSum_OdO, (1,)) + sum_OdO_for_copy = cute.flat_divide(sum_OdO, (1,)) + for i in cutlass.range_constexpr(async_copy_num_elts): + sum_OdO_idx = ( + self.cta_tiler[0] * curr_block_coord[0] + + thread_idx * async_copy_num_elts + ) + if cute.elem_less(sum_OdO_idx + i, seqlen_q): + cute.copy( + atom_async_copy, + sum_OdO_for_copy[None, sum_OdO_idx + i, curr_block_coord[2]], + sSum_OdO_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + sum_odo_handle.index, + ], + ) + else: + sSum_OdO_for_copy[ + None, + thread_idx * async_copy_num_elts + i, + sum_odo_handle.index, + ].fill(0.0) + sum_odo_handle.commit() + + # Q + for iter in cutlass.range(self.iterations_qk, unroll=1): + q_handle = load_q_producer.acquire_and_advance() + cute.copy( + tma_atom_q, + tQgQ[None, iter], + tQsQ[None, q_handle.index], + tma_bar_ptr=q_handle.barrier, + ) + # dO + for iter in cutlass.range(self.iterations_dov, unroll=1): + do_handle = load_do_producer.acquire_and_advance() + cute.copy( + tma_atom_do, + tdOgdO[None, iter], + tdOsdO[None, do_handle.index], + tma_bar_ptr=do_handle.barrier, + ) + + kv_coord = seqlen_kv_loop_start + for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # Ki + for iter in cutlass.range(self.iterations_qk, unroll=1): + k_handle = load_k_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[None, kv_coord, iter], + tKsK[None, k_handle.index], + tma_bar_ptr=k_handle.barrier, + ) + # Vi + for iter in cutlass.range(self.iterations_dov, unroll=1): + v_handle = load_v_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[None, kv_coord, iter], + tVsV[None, v_handle.index], + tma_bar_ptr=v_handle.barrier, + ) + # KTi + for iter in cutlass.range(self.iterations_dsk, unroll=1): + kt_handle = load_kt_producer.acquire_and_advance() + cute.copy( + tma_atom_kt, + tKTgKT[None, iter, kv_coord], + tKTsKT[None, kt_handle.index], + tma_bar_ptr=kt_handle.barrier, + ) + kv_coord += 1 + + work_tile = tile_sched.advance_to_next_work() + # End of persistent scheduler loop + load_k_producer.tail() + load_v_producer.tail() + load_kt_producer.tail() + load_q_producer.tail() + load_do_producer.tail() + load_lse_producer.tail() + load_sum_odo_producer.tail() + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + continue_cond = False + seqlen_q = mQ_qdl.shape[0] + seqlen_k = mK_kdl.shape[0] + batch_coord = curr_block_coord[2][1] + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + # dq_handle = mma_dq_producer.acquire_and_advance() + load_q_releaser = load_q_consumer.clone() + load_do_releaser = load_do_consumer.clone() + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + num_innerloop = 8 + + if is_leader_cta: + dq_handle = mma_dq_producer.acquire_and_advance() + if seqlen_kv_loop_steps > 1: + # QK0 + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + load_q_consumer.wait_and_advance() + tSrQ_slice = tSrQ[None, None, None, iter] + + k_handle = load_k_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + cute.arch.fence_view_async_tmem_store() + s_handle.commit() + + # dOV0 + dp_handle = mma_dp_producer.acquire_and_advance() + tdPtdP_slice = tdPtdP[None, None, None, dp_handle.index] + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_dov, unroll=1): + load_do_consumer.wait_and_advance() + tdPrdO_slice = tdPrdO[None, None, None, iter] + v_handle = load_v_consumer.wait_and_advance() + tdPrV_trans_slice = tdPrV[None, None, None, v_handle.index] + num_kphases = cute.size(tdPrdO_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + cute.arch.fence_view_async_tmem_store() + dp_handle.commit() + + for i in cutlass.range(1, seqlen_kv_loop_steps - 1, 1, unroll=1): + # QKi + s_handle = mma_s_producer.acquire_and_advance() + + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_k_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range( + num_kphases, unroll_full=True + ): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + s_handle.commit() + + # dSKTi + ds_handle = ds_mma_consumer.wait_and_advance() + dsk_whether_acc = dsk_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_dsk, unroll=1): + kt_handle = load_kt_consumer.wait_and_advance() + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, dsk_whether_acc) + tdQtdQ_slice = tdQtdQ_staged[None, None, None, iter] + tdStdS_slice = tdPtdP[None, None, None, ds_handle.index] + tdS = cute.make_tensor( + tdStdS_slice.iterator, ds_tmem_layout_staged.outer + ) + tdQrdS = dsk_thr_mma.make_fragment_A(tdS) + tdQrdS_slice = cute.make_tensor( + cute.recast_ptr(tdStdS_slice.iterator, dtype=self.q_dtype), + tdQrdS.layout, + ) + + tdQrKT_slice = tdQrKT[None, None, None, kt_handle.index] + num_kphases = cute.size(tdQrKT_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range( + num_kphases, unroll_full=True + ): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + kt_handle.release() + ds_handle.release() + + # dOVi + dp_handle = mma_dp_producer.acquire_and_advance() + tdPtdP_slice = tdPtdP[None, None, None, dp_handle.index] + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_dov, unroll=1): + tdPrdO_slice = tdPrdO[None, None, None, iter] + v_handle = load_v_consumer.wait_and_advance() + tdPrV_trans_slice = tdPrV[None, None, None, v_handle.index] + num_kphases = cute.size(tdPrdO_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range( + num_kphases, unroll_full=True + ): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + dp_handle.commit() + + # QKend + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_k_consumer.wait_and_advance() + + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + + num_kphases = cute.size(tSrQ_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + load_q_releaser.release() + load_q_releaser.advance() + s_handle.commit() + + # dSKTend - 1 + ds_handle = ds_mma_consumer.wait_and_advance() + dsk_whether_acc = dsk_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_dsk, unroll=1): + kt_handle = load_kt_consumer.wait_and_advance() + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, dsk_whether_acc) + tdQtdQ_slice = tdQtdQ_staged[None, None, None, iter] + tdStdS_slice = tdPtdP[None, None, None, ds_handle.index] + tdS = cute.make_tensor( + tdStdS_slice.iterator, ds_tmem_layout_staged.outer + ) + tdQrdS = dsk_thr_mma.make_fragment_A(tdS) + tdQrdS_slice = cute.make_tensor( + cute.recast_ptr(tdStdS_slice.iterator, dtype=self.q_dtype), + tdQrdS.layout, + ) + + tdQrKT_slice = tdQrKT[None, None, None, kt_handle.index] + num_kphases = cute.size(tdQrKT_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + kt_handle.release() + ds_handle.release() + + # dOVend + dp_handle = mma_dp_producer.acquire_and_advance() + tdPtdP_slice = tdPtdP[None, None, None, dp_handle.index] + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_dov, unroll=1): + tdPrdO_slice = tdPrdO[None, None, None, iter] + v_handle = load_v_consumer.wait_and_advance() + tdPrV_trans_slice = tdPrV[None, None, None, v_handle.index] + num_kphases = cute.size(tdPrdO_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + load_do_releaser.release() + load_do_releaser.advance() + dp_handle.commit() + # dSKTend + ds_handle = ds_mma_consumer.wait_and_advance() + dsk_whether_acc = dsk_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_dsk, unroll=1): + kt_handle = load_kt_consumer.wait_and_advance() + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, dsk_whether_acc) + tdQtdQ_slice = tdQtdQ_staged[None, None, None, iter] + tdStdS_slice = tdPtdP[None, None, None, ds_handle.index] + tdS = cute.make_tensor( + tdStdS_slice.iterator, ds_tmem_layout_staged.outer + ) + tdQrdS = dsk_thr_mma.make_fragment_A(tdS) + tdQrdS_slice = cute.make_tensor( + cute.recast_ptr(tdStdS_slice.iterator, dtype=self.q_dtype), + tdQrdS.layout, + ) + + tdQrKT_slice = tdQrKT[None, None, None, kt_handle.index] + num_kphases = cute.size(tdQrKT_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + kt_handle.release() + ds_handle.release() + else: + # QK0 + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + for iter in cutlass.range(self.iterations_qk, unroll=1): + load_q_consumer.wait_and_advance() + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_k_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + load_q_releaser.release() + load_q_releaser.advance() + s_handle.commit() + + # dOV0 + dp_handle = mma_dp_producer.acquire_and_advance() + tdPtdP_slice = tdPtdP[None, None, None, dp_handle.index] + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_dov, unroll=1): + load_do_consumer.wait_and_advance() + tdPrdO_slice = tdPrdO[None, None, None, iter] + v_handle = load_v_consumer.wait_and_advance() + tdPrV_trans_slice = tdPrV[None, None, None, v_handle.index] + num_kphases = cute.size(tdPrdO_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dov_tiled_mma, + tdPtdP_slice, + tdPrdO_slice[kphase_coord], + tdPrV_trans_slice[kphase_coord], + tdPtdP_slice, + ) + dov_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + load_do_releaser.release() + load_do_releaser.advance() + dp_handle.commit() + + # dSKT0 + ds_handle = ds_mma_consumer.wait_and_advance() + dsk_whether_acc = dsk_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_dsk, unroll=1): + kt_handle = load_kt_consumer.wait_and_advance() + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, dsk_whether_acc) + tdQtdQ_slice = tdQtdQ_staged[None, None, None, iter] + tdStdS_slice = tdPtdP[None, None, None, ds_handle.index] + tdS = cute.make_tensor( + tdStdS_slice.iterator, ds_tmem_layout_staged.outer + ) + tdQrdS = dsk_thr_mma.make_fragment_A(tdS) + tdQrdS_slice = cute.make_tensor( + cute.recast_ptr(tdStdS_slice.iterator, dtype=self.q_dtype), + tdQrdS.layout, + ) + + tdQrKT_slice = tdQrKT[None, None, None, kt_handle.index] + num_kphases = cute.size(tdQrKT_slice, mode=[2]) + if cutlass.const_expr(num_kphases % num_innerloop == 0): + num_outer_iter = num_kphases // num_innerloop + for outer_iter in cutlass.range(num_outer_iter, unroll=1): + for kphase_idx in cutlass.range( + num_innerloop, unroll_full=True + ): + kphase_coord = ( + None, + None, + outer_iter * num_innerloop + kphase_idx, + ) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + dsk_tiled_mma, + tdQtdQ_slice, + tdQrdS_slice[kphase_coord], + tdQrKT_slice[kphase_coord], + tdQtdQ_slice, + ) + dsk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + kt_handle.release() + ds_handle.release() + dq_handle.commit() + work_tile = tile_sched.advance_to_next_work() + # End of persistent scheduler loop + mma_s_producer.tail() + mma_dp_producer.tail() + mma_dq_producer.tail() + + # Softmax and dSoftmax warp + if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]: + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.num_regs_compute) + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + batch_coord = curr_block_coord[2][1] + continue_cond = False + seqlen_q = mQ_qdl.shape[0] + seqlen_k = mK_kdl.shape[0] + cuseqlen_q = Int32(0) + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + end_count = start_count + trip_count + if cutlass.const_expr(self.use_semantic_trip_range): + n_block_min_causal_local_mask, n_block_min_before_local_mask = ( + FusedMask.get_trip_mask_bounds_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + + cS_base = cute.make_identity_tensor( + (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) + ) + cS = cute.domain_offset((mma_block_coord[0] * self.qk_mma_tiler[0], 0), cS_base) + tScS = qk_thr_mma.partition_C(cS) + + cdP_base = cute.make_identity_tensor( + (self.dov_mma_tiler[0], self.dov_mma_tiler[1]) + ) + cdP = cute.domain_offset( + (mma_block_coord[0] * self.dov_mma_tiler[0], 0), cdP_base + ) + tdPcdP = dov_thr_mma.partition_C(cdP) + + lse_handle = load_lse_consumer.wait_and_advance() + sum_odo_handle = load_sum_odo_consumer.wait_and_advance() + for step in cutlass.range(start_count, end_count, 1, unroll=1): + cS_iter = cute.domain_offset((0, step * self.qk_mma_tiler[1]), cS) + tScS_iter = qk_thr_mma.partition_C(cS_iter) + + cdP_iter = cute.domain_offset((0, step * self.dov_mma_tiler[1]), cdP) + + tdPcdP_iter = dov_thr_mma.partition_C(cdP_iter) + + # Si, dPi -> dSi + if cutlass.const_expr(self.use_semantic_trip_range): + need_apply_mask = ( + step >= n_block_min_causal_local_mask + or step < n_block_min_before_local_mask + ) + else: + need_apply_mask = step == end_count - 1 + mma_s_consumer, mma_dp_consumer, ds_mma_producer = self.compute_step( + (need_apply_mask, window_size_left, window_size_right), + ( + seqlen_q, + seqlen_k, + scale_softmax, + batch_coord, + curr_block_coord[0], + varlen, + ), + (tStS, tScS_iter, tdPtdP, tdPcdP_iter, sLSE, sSum_OdO), + ( + mma_s_consumer, + mma_dp_consumer, + ds_mma_producer, + lse_handle, + sum_odo_handle, + ), + step, + ) + lse_handle.release() + sum_odo_handle.release() + work_tile = tile_sched.advance_to_next_work() + ds_mma_producer.tail() + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_epilogue) + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + batch_coord = curr_block_coord[2][1] + # cute.printf("batch_coord={}", batch_coord) + seqlen_q = mQ_qdl.shape[0] + seqlen_k = mK_kdl.shape[0] + continue_cond = False + cuseqlen_q = Int32(0) + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + mdQ_qdl_eff = mdQ_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + block_offset_dQ = ( + cuseqlen_q, + Int32(0), + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + mdQ_qdl_eff = cute.domain_offset( + cute.select(block_offset_dQ, mode=[0, 2, 3]), mdQ_qdl + ) + + # (bM, bN, loopM, loopN, loopL) + gdQ_qdl = cute.flat_divide( + mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + cdQ_qdl = cute.flat_divide( + cute.make_identity_tensor(mdQ_qdl_eff.shape), + cute.select(self.dsk_block_tiler, mode=[0, 1]), + ) + + gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + + # dQ TMEM to GMEM + mma_dq_consumer = self.dQ_epilogue( + (seqlen_q, cuseqlen_q, mQ_qdl.shape[0], batch_coord), + (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged), + self.epi_tile, + ) + work_tile = tile_sched.advance_to_next_work() + # NOTE: tmem.free() moved to kernel end to enable cluster-wide sync + + # /////////////////////////////////////////////////////////////////////////////// + # Scheduler Warp (only for CLC dynamic scheduler) + # /////////////////////////////////////////////////////////////////////////////// + if cutlass.const_expr(self.use_clc_scheduler): + is_first_cta_in_cluster = cta_rank_in_cluster == 0 + + if warp_idx == self.sched_warp_id and is_first_cta_in_cluster: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + while work_tile.is_valid_tile: + tile_sched.prefetch_next_work() + work_tile = tile_sched.advance_to_next_work() + tile_sched.producer_tail() + + # /////////////////////////////////////////////////////////////////////////////// + # Empty warps reg dealloc + # /////////////////////////////////////////////////////////////////////////////// + if cutlass.const_expr(self.use_clc_scheduler): + if warp_idx > self.load_warp_id: + if not (warp_idx == self.sched_warp_id and is_first_cta_in_cluster): + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + else: + if warp_idx > self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + # /////////////////////////////////////////////////////////////////////////////// + # Cooperative TMEM Deallocation (2CTA) + # /////////////////////////////////////////////////////////////////////////////// + # All warps (including scheduler) have finished by this point. + # Cluster-wide sync ensures both CTAs reach here before dealloc. + cute.arch.cluster_arrive() + cute.arch.cluster_wait() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + return + + @cute.jit + def compute_step( + self, + mask_args: Tuple, + value_args: Tuple, + tensor_args: Tuple, + pipeline_args: Tuple, + step: Int32, + ) -> Tuple[Float32, Float32, pipeline.PipelineConsumer, pipeline.PipelineProducer]: + need_apply_mask, window_size_left, window_size_right = mask_args + seqlen_q, seqlen_k, scale_softmax, batch_coord, block_m_idx, varlen = value_args + tStS, tScS, tdPtdP, tdPcdP, sLSE, sSum_OdO = tensor_args + mma_s_consumer, mma_dp_consumer, ds_mma_producer, lse_handle, sum_odo_handle = pipeline_args + + bidx = block_m_idx + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.compute_warp_ids)) + s_handle = mma_s_consumer.wait_and_advance() + tStS_slice = tStS[(None, None), 0, 0, s_handle.index] + tScS_slice = tScS[(None, None), 0, 0] + tmem_load_atom = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition(16)), self.acc_dtype + ) + tmem_tiled_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS_slice) + thr_load = tmem_tiled_load.get_slice(thread_idx) + tTMEM_LOADtS = thr_load.partition_S(tStS_slice) + tTMEM_LOADcS = thr_load.partition_D(tScS_slice) + tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.acc_dtype) + cute.copy(tmem_tiled_load, tTMEM_LOADtS, tTMEM_LOADrS) + cute.arch.fence_view_async_tmem_load() + s_handle.release() + if need_apply_mask: + FusedMask.apply_mask_via_causal_local( + tTMEM_LOADrS, + tTMEM_LOADcS, + seqlen_q, + seqlen_k, + self.use_semantic_trip_range, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + + log2_e = cutlass.Float32(math.log2(math.e)) + softmax_scale_log2_e = scale_softmax * log2_e + tTMEM_STORErP = cute.make_rmem_tensor(tTMEM_LOADrS.shape, self.q_dtype) + for k in cutlass.range(0, cute.size(tTMEM_LOADrS), 2, unroll_full=True): + lse = ( + -sLSE[ + cute.get(tTMEM_LOADcS[k], mode=[0]) - bidx * self.cta_tiler[0], + lse_handle.index, + ], + -sLSE[ + cute.get(tTMEM_LOADcS[k + 1], mode=[0]) - bidx * self.cta_tiler[0], + lse_handle.index, + ], + ) + tTMEM_LOADrS[k], tTMEM_LOADrS[k + 1] = cute.arch.fma_packed_f32x2( + (tTMEM_LOADrS[k], tTMEM_LOADrS[k + 1]), + (softmax_scale_log2_e, softmax_scale_log2_e), + lse, + ) + tTMEM_LOADrS[k] = cute.math.exp2(tTMEM_LOADrS[k], fastmath=True) + tTMEM_LOADrS[k + 1] = cute.math.exp2(tTMEM_LOADrS[k + 1], fastmath=True) + + dp_handle = mma_dp_consumer.wait_and_advance() + tdPtdP_slice = tdPtdP[(None, None), 0, 0, dp_handle.index] + tdPcdP_slice = tdPcdP[(None, None), 0, 0] + thr_load = tmem_tiled_load.get_slice(thread_idx) + tTMEM_LOADtdP = thr_load.partition_S(tdPtdP_slice) + tTMEM_LOADcdP = thr_load.partition_D(tdPcdP_slice) + tTMEM_LOADrdP = cute.make_rmem_tensor(tTMEM_LOADcdP.shape, self.acc_dtype) + cute.copy(tmem_tiled_load, tTMEM_LOADtdP, tTMEM_LOADrdP) + cute.arch.fence_view_async_tmem_load() + dp_handle.release() + tTMEM_STORErdP = cute.make_rmem_tensor(tTMEM_LOADrdP.shape, self.q_dtype) + + for k in cutlass.range(0, cute.size(tTMEM_LOADrdP), 2, unroll_full=True): + dpsum_0 = -sSum_OdO[ + cute.get(tTMEM_LOADcdP[k], mode=[0]) - bidx * self.cta_tiler[0], + sum_odo_handle.index, + ] + dpsum_1 = -sSum_OdO[ + cute.get(tTMEM_LOADcdP[k + 1], mode=[0]) - bidx * self.cta_tiler[0], + sum_odo_handle.index, + ] + if cutlass.const_expr(varlen): + if not cute.elem_less(cute.get(tTMEM_LOADcdP[k], mode=[0]), seqlen_q): + dpsum_0 = 0.0 + if not cute.elem_less(cute.get(tTMEM_LOADcdP[k + 1], mode=[0]), seqlen_q): + dpsum_1 = 0.0 + tTMEM_LOADrdP[k], tTMEM_LOADrdP[k + 1] = cute.arch.add_packed_f32x2( + (tTMEM_LOADrdP[k], tTMEM_LOADrdP[k + 1]), + (dpsum_0, dpsum_1), + ) + tTMEM_LOADrdP[k], tTMEM_LOADrdP[k + 1] = cute.arch.mul_packed_f32x2( + (tTMEM_LOADrdP[k], tTMEM_LOADrdP[k + 1]), (tTMEM_LOADrS[k], tTMEM_LOADrS[k + 1]) + ) + tTMEM_LOADrdP[k], tTMEM_LOADrdP[k + 1] = cute.arch.mul_packed_f32x2( + (tTMEM_LOADrdP[k], tTMEM_LOADrdP[k + 1]), (scale_softmax, scale_softmax) + ) + dp_vec = tTMEM_LOADrdP.load() + tTMEM_STORErdP.store(dp_vec.to(self.q_dtype)) + + ds_handle = ds_mma_producer.acquire_and_advance() + tmem_store_atom = cute.make_copy_atom( + tcgen05.St32x32bOp(tcgen05.Repetition(32)), self.acc_dtype + ) + tilePlikeFP32 = tdPtdP_slice.shape[1] // Float32.width * self.q_dtype.width + tdPtdP_dS_layout = cute.composition( + tdPtdP_slice.layout, cute.make_layout((tdPtdP_slice.shape[0], tilePlikeFP32)) + ) + tdPtdP_dS = cute.make_tensor(tdPtdP_slice.iterator, tdPtdP_dS_layout) + tdPcdP_dS_layout = cute.composition( + tdPcdP_slice.layout, cute.make_layout((tdPcdP_slice.shape[0], tilePlikeFP32)) + ) + tdPcdP_dS = cute.make_tensor(tdPcdP_slice.iterator, tdPcdP_dS_layout) + tmem_tiled_store = tcgen05.make_tmem_copy(tmem_store_atom, tdPtdP_dS) + + thr_store = tmem_tiled_store.get_slice(thread_idx) + tTMEM_STOREtdS = thr_store.partition_D(tdPtdP_dS) + tTMEM_STOREcdP = thr_store.partition_S(tdPcdP_dS) + tTMEM_STORErdS_ = cute.make_tensor( + cute.recast_ptr(tTMEM_STORErdP.iterator, dtype=self.acc_dtype), + tTMEM_STOREcdP.shape, + ) + cute.copy(tmem_tiled_store, tTMEM_STORErdS_, tTMEM_STOREtdS) + cute.arch.fence_view_async_tmem_store() + ds_handle.commit() + return mma_s_consumer, mma_dp_consumer, ds_mma_producer + + @cute.jit + def dQ_epilogue( + self, + value_args: Tuple, + dq_args: Tuple, + epi_tile: cute.Tile, + ) -> Tuple[pipeline.PipelineConsumer, pipeline.PipelineProducer]: + seqlen_q, cuseqlen_q, total_q, batch_coord = value_args + (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged) = dq_args + dq_handle = mma_dq_consumer.wait_and_advance() + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + cute.arch.fence_view_async_shared() + + for iter in cutlass.range(self.iterations_dsk): + gdQ = gdQ_staged[None, None, iter] + cdQ = cdQ_staged[None, None, iter] + tdQtdQ = tdQtdQ_staged[(None, None), 0, 0, iter] + tdQtdQ_epi = cute.zipped_divide(tdQtdQ, epi_tile) + cdQ_epi = cute.zipped_divide(cdQ, epi_tile) + gdQ_epi = cute.zipped_divide(gdQ, epi_tile) + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.epilogue_warp_ids)) + tmem_copy_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tdQtdQ_epi) + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + tTMEM_LOADtdQ = thr_tmem_load.partition_S(tdQtdQ_epi) + tTMEM_LOADgdQ = thr_tmem_load.partition_D(gdQ_epi) + tTMEM_LOADcdQ = thr_tmem_load.partition_D(cdQ_epi) + + for i in cutlass.range(cute.size(tTMEM_LOADtdQ, mode=[1]), unroll_full=True): + tTMEM_LOADtdQ_i = tTMEM_LOADtdQ[None, i, 0] + tTMEM_LOADgdQ_i = tTMEM_LOADgdQ[None, i, 0] + tTMEM_LOADcdQ_i = tTMEM_LOADcdQ[None, i, 0] + tTMrdQ = cute.make_rmem_tensor(tTMEM_LOADcdQ[None, 0, i].shape, self.acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtdQ_i, tTMrdQ) + tSMrdQ = cute.make_rmem_tensor(tTMrdQ.shape, self.q_dtype) + dq_vec = tTMrdQ.load() + tSMrdQ.store(dq_vec.to(self.q_dtype)) + if cute.elem_less(tTMEM_LOADcdQ_i[0][0], seqlen_q): + cute.autovec_copy(tSMrdQ, tTMEM_LOADgdQ_i) + dq_handle.release() + return mma_dq_consumer diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py new file mode 100644 index 00000000000..03108e25a3d --- /dev/null +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py @@ -0,0 +1,1735 @@ +# Copyright (c) 2025, Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms, Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin. + +import math +from typing import Tuple, Optional + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.typing import Int32, Int64, Float32 + +from cutlass.utils import ClcDynamicPersistentTileScheduler +from flash_attn.cute.tile_scheduler import ( + ClcState, + compute_sm100_fmha_grid as compute_grid, + compute_sm100_fmha_grid_clc as compute_grid_clc, + make_sm100_thread_cooperative_group as make_thread_cooperative_group, + Sm100FmhaStaticTileScheduler as FmhaStaticTileScheduler, + Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams, + Sm100FmhaClcDynamicTileScheduler as FmhaClcDynamicTileScheduler, + Sm100FmhaClcDynamicTileSchedulerParams as FmhaClcDynamicTileSchedulerParams, +) +from flash_attn.cute.mask import ( + Sm100FusedMask as FusedMask, +) +from flash_attn.cute.tile_scheduler import SM100_TMEM_CAPACITY_COLUMNS +from flash_attn.cute.flash_fwd_sm100 import DescaleTensors + + +class BlackwellFusedMultiHeadAttentionForward: + def __init__( + self, + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: int = 1, + is_causal: bool = False, + is_local: bool = False, + is_split_kv: bool = False, + pack_gqa: bool = False, + q_subtile_factor: int | None = None, + m_block_size: int = 128, + n_block_size: int = 128, + q_stage: int = 2, + is_persistent: bool = True, + score_mod=None, + mask_mod=None, + has_aux_tensors: bool = False, + paged_kv_non_tma: bool = False, + is_varlen_q: bool = False, + use_2cta_instrs: bool = False, + use_clc_scheduler: bool = False, + ): + head_dim_v = head_dim if head_dim_v is None else head_dim_v + assert head_dim == 256 and head_dim_v == 256, ( + "SM100 dedicated kernel only supports (head_dim, head_dim_v) = (256, 256)" + ) + assert score_mod is None, "SM100 forward with head_dim=256 does not support score_mod" + assert mask_mod is None, "SM100 forward with head_dim=256 does not support mask_mod" + assert not has_aux_tensors, "SM100 forward with head_dim=256 does not support aux tensors" + assert not paged_kv_non_tma, "SM100 forward with head_dim=256 does not support paged KV" + assert not pack_gqa, "SM100 forward with head_dim=256 does not support pack_gqa" + assert not is_split_kv, "SM100 forward with head_dim=256 does not support SplitKV" + assert q_subtile_factor is None, ( + "SM100 forward with head_dim=256 does not support q_subtile_factor" + ) + assert m_block_size == 128 and n_block_size == 128, ( + "SM100 dedicated kernel only supports tile_m=128 and tile_n=128" + ) + # q_stage / persistence / scheduler knobs are accepted for interface parity, + # but this dedicated kernel uses fixed internal settings. + + qk_acc_dtype = cutlass.Float32 + pv_acc_dtype = cutlass.Float32 + mma_tiler = (128, 128, head_dim) + self.qk_acc_dtype = qk_acc_dtype + self.pv_acc_dtype = pv_acc_dtype + self.mma_tiler = mma_tiler + assert mma_tiler[0] == 128 and mma_tiler[1] == 128, "Only 128x128 tile impl is supported" + assert mma_tiler[2] == 256, "Only 256 is supported for 128x128 tile impl" + self.cta_tiler = ( + mma_tiler[0], + mma_tiler[1], + mma_tiler[2], + ) + self.qk_mma_tiler = ( + 2 * mma_tiler[0], + mma_tiler[1], + min(self.cta_tiler[2], 128), + ) + self.pv_mma_tiler = self.qk_mma_tiler + self.pv_block_tiler = ( + self.pv_mma_tiler[0] // 2, + self.pv_mma_tiler[1], + self.pv_mma_tiler[2], + ) + self.iterations_qk = self.cta_tiler[2] // self.qk_mma_tiler[2] + self.iterations_pv = self.cta_tiler[2] // self.pv_mma_tiler[1] + self.cluster_shape_mn = (2, 1) + self.tmem_warp_shape_mn = (4, 1) + # Dedicated hd256 kernel uses fixed scheduling policy. + self.is_persistent = False + self.is_causal = is_causal + self.is_local = is_local + self.use_semantic_trip_range = is_causal or is_local + self.use_clc_scheduler = False + + self.softmax_warp_ids = (0, 1, 2, 3) + self.correction_warp_ids = (4, 5, 6, 7) + self.mma_warp_id = 8 + self.load_warp_id = 9 + self.empty_warp_id = (10, 11) + self.sched_warp_id = self.empty_warp_id[0] if use_clc_scheduler else None + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.softmax_warp_ids, # this is to get a round num threads + *self.correction_warp_ids, + self.mma_warp_id, + self.load_warp_id, + *self.empty_warp_id, + ) + ) + + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + + self.tmem_s_offset = 0 + self.tmem_o_offset = 256 + self.tmem_p_offset = self.tmem_s_offset + + self.num_regs_softmax = 256 + self.num_regs_correction = 160 + self.num_regs_other = 32 + + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + self.q_stage = self.iterations_qk + self.kv_stage = 4 + self.qk_acc_stage = 2 + self.mma_corr_stage = 1 + if cutlass.const_expr(self.use_clc_scheduler): + self.num_clc_stage = 1 + self.num_clc_response_bytes = 16 + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + learnable_sink: Optional[cute.Tensor] = None, + descale_tensors: Optional[DescaleTensors] = None, + blocksparse_tensors: Optional[cute.Tensor] = None, + aux_tensors: Optional[list] = None, + stream: cuda.CUstream = None, + ): + # Keep parity with FlashAttentionForwardSm100.__call__ interface. + # (TODO@wangsiyu) Implement these features. + assert mSeqUsedQ is None and mSeqUsedK is None, ( + "SM100 forward with head_dim=256 does not support seqused_q/seqused_k" + ) + assert mPageTable is None, "SM100 forward with head_dim=256 does not support paged KV" + assert learnable_sink is None, ( + "SM100 forward with head_dim=256 does not support learnable_sink" + ) + assert blocksparse_tensors is None, ( + "SM100 forward with head_dim=256 does not support block sparsity" + ) + assert aux_tensors is None, "SM100 forward with head_dim=256 does not support aux_tensors" + assert not self.is_local, ( + "SM100 forward with head_dim=256 does not support local attention yet" + ) + assert window_size_left is None and window_size_right is None, ( + "SM100 forward with head_dim=256 does not support runtime window_size overrides" + ) + assert descale_tensors is None, ( + "SM100 forward with head_dim=256 does not support descale_tensors" + ) + + q_tensor, k_tensor, v_tensor, o_tensor = mQ, mK, mV, mO + lse_tensor = mLSE + cum_seqlen_q = mCuSeqlensQ + cum_seqlen_k = mCuSeqlensK + + q_rank = len(mQ.shape) + k_rank = len(mK.shape) + if cutlass.const_expr(cum_seqlen_q is not None): + # Varlen path accepts either legacy 5D tensors or standard 3D tensors. + if cutlass.const_expr(q_rank == 5): + s_q = mQ.shape[1] + h_q = mQ.shape[2] * mQ.shape[3] + d = mQ.shape[4] + elif cutlass.const_expr(q_rank == 3): + s_q = mQ.shape[0] + h_q = mQ.shape[1] + d = mQ.shape[2] + else: + raise RuntimeError(f"hd256 forward varlen expects q rank 3 or 5, got rank {q_rank}") + else: + # Non-varlen path accepts either legacy 5D tensors or standard 4D tensors. + if cutlass.const_expr(q_rank == 5): + s_q = mQ.shape[1] + h_q = mQ.shape[2] * mQ.shape[3] + d = mQ.shape[4] + elif cutlass.const_expr(q_rank == 4): + s_q = mQ.shape[1] + h_q = mQ.shape[2] + d = mQ.shape[3] + else: + raise RuntimeError( + f"hd256 forward non-varlen expects q rank 4 or 5, got rank {q_rank}" + ) + + if cutlass.const_expr(cum_seqlen_k is not None): + if cutlass.const_expr(k_rank == 5): + s_k = mK.shape[1] + h_k = mK.shape[2] + elif cutlass.const_expr(k_rank == 3): + s_k = mK.shape[0] + h_k = mK.shape[1] + else: + raise RuntimeError(f"hd256 forward varlen expects k rank 3 or 5, got rank {k_rank}") + else: + if cutlass.const_expr(k_rank == 5): + s_k = mK.shape[1] + h_k = mK.shape[2] + elif cutlass.const_expr(k_rank == 4): + s_k = mK.shape[1] + h_k = mK.shape[2] + else: + raise RuntimeError( + f"hd256 forward non-varlen expects k rank 4 or 5, got rank {k_rank}" + ) + if cutlass.const_expr(cum_seqlen_q is not None): + b = mCuSeqlensQ.shape[0] - 1 + elif cutlass.const_expr(cum_seqlen_k is not None): + b = mCuSeqlensK.shape[0] - 1 + else: + b = mQ.shape[0] + + scale_softmax = softmax_scale + scale_softmax_log2 = softmax_scale * math.log2(math.exp(1.0)) + scale_output = 1.0 + s_lse = s_q + h_r = h_q // h_k + s_q64 = Int64(s_q) + s_k64 = Int64(s_k) + s_lse64 = Int64(s_lse) + d64 = cute.assume(Int64(d), divby=128) + h_r64 = Int64(h_r) + h_k64 = Int64(h_k) + b64 = Int64(b) + s_q_total = ( + q_tensor.shape[1] + if cum_seqlen_q is not None and q_rank == 5 + else (q_tensor.shape[0] if cum_seqlen_q is not None else s_q64) + ) + s_k_total = ( + k_tensor.shape[1] + if cum_seqlen_k is not None and k_rank == 5 + else (k_tensor.shape[0] if cum_seqlen_k is not None else s_k64) + ) + stride_b_qo = h_r64 * h_k64 * s_q64 * d64 if cum_seqlen_q is None else 0 + stride_b_kv = h_k64 * s_k64 * d64 if cum_seqlen_k is None else 0 + b_lse = b64 if cum_seqlen_q is None else 1 + stride_b_lse = h_r64 * h_k64 * s_lse64 if cum_seqlen_q is None else 0 + + # (s, d, ((h_r, h_k), b)) + q_layout = cute.make_layout( + (s_q_total, d, ((h_r, h_k), b)), + stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + ) + q = cute.make_tensor(q_tensor.iterator, q_layout) + # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast + k_layout = cute.make_layout( + (s_k_total, d, ((h_r, h_k), b)), + stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)), + ) + k = cute.make_tensor(k_tensor.iterator, k_layout) + # (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast + v_layout = cute.make_layout( + (d, s_k_total, ((h_r, h_k), b)), + stride=(1, d64 * h_k64, ((0, d64), stride_b_kv)), + ) + v = cute.make_tensor(v_tensor.iterator, v_layout) + # (s, d, ((h_r, h_k), b)) + o_layout = cute.make_layout( + (s_q_total, d, ((h_r, h_k), b)), + stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + ) + o = cute.make_tensor(o_tensor.iterator, o_layout) + if cutlass.const_expr(lse_tensor is not None): + # (s, ((h_r, h_k), b)) + lse_layout = cute.make_layout( + (s_lse64, ((h_r, h_k), b_lse)), + stride=(1, ((s_lse64, h_r64 * s_lse64), stride_b_lse)), + ) + lse = cute.make_tensor(lse_tensor.iterator, lse_layout) + else: + lse = None + + # setup static attributes before smem/grid/tma computation + self.q_dtype = q.element_type + self.k_dtype = k.element_type + self.v_dtype = v.element_type + self.o_dtype = o.element_type + self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.q_dtype.width + + if cutlass.const_expr(self.use_clc_scheduler): + self.tile_sched_params, grid = compute_grid_clc( + (s_q, o.shape[1], o.shape[2]) if cum_seqlen_q is not None else o.shape, + self.cta_tiler, + (*self.cluster_shape_mn, 1), + ) + else: + self.tile_sched_params, grid = compute_grid( + (s_q, o.shape[1], o.shape[2]) if cum_seqlen_q is not None else o.shape, + self.cta_tiler, + self.is_persistent, + ) + + self.q_major_mode = utils.LayoutEnum.from_tensor(q).mma_major_mode() + self.k_major_mode = utils.LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major_mode = utils.LayoutEnum.from_tensor(v).mma_major_mode() + self.o_layout = utils.LayoutEnum.from_tensor(o) + + if cutlass.const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of q is not supported") + if cutlass.const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of k is not supported") + if cutlass.const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN): + raise RuntimeError("The layout of v is not supported") + + # check type consistency + if cutlass.const_expr(self.q_dtype != self.k_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") + if cutlass.const_expr(self.q_dtype != self.v_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.TWO + # the intermediate tensor p is from tmem & k-major + p_source = tcgen05.OperandSource.TMEM + p_major_mode = tcgen05.OperandMajorMode.K + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.qk_acc_dtype, + cta_group, + self.qk_mma_tiler[:2], + ) + pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.pv_acc_dtype, + cta_group, + self.pv_mma_tiler[:2], + p_source, + ) + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + + self.epi_tile = self.pv_block_tiler[:2] + + q_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.qk_mma_tiler, + self.q_dtype, + self.q_stage, + ) + k_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + self.qk_mma_tiler, + self.k_dtype, + self.kv_stage, + ) + p_tmem_layout_staged = sm100_utils.make_smem_layout_a( + pv_tiled_mma, + self.pv_mma_tiler, + self.q_dtype, + self.qk_acc_stage, + ) + p_tmem_layout = cute.select(p_tmem_layout_staged, mode=[0, 1, 2]) + v_smem_layout_staged = sm100_utils.make_smem_layout_b( + pv_tiled_mma, + self.pv_mma_tiler, + self.v_dtype, + self.kv_stage, + ) + # TMA load for Q + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) + + q_smem_layout = cute.select(q_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q, tma_tensor_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q, + q_smem_layout, + self.qk_mma_tiler, + qk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # TMA load for K + k_smem_layout = cute.select(k_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_k, tma_tensor_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k, + k_smem_layout, + self.qk_mma_tiler, + qk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + # TMA load for V + v_smem_layout = cute.select(v_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_v, tma_tensor_v = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + v, + v_smem_layout, + self.pv_mma_tiler, + pv_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + q_copy_size = cute.size_in_bytes(self.q_dtype, q_smem_layout) + k_copy_size = cute.size_in_bytes(self.k_dtype, k_smem_layout) + self.tma_copy_q_bytes = q_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + self.tma_copy_kv_bytes = k_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + + @cute.struct + class SharedStorage: + # TMA G2S load barriers: LOAD warp (producer) -> MMA warp (consumer) + load_q_mbar_ptr: cute.struct.MemRange[ + Int64, self.q_stage * 2 + ] # load_q_{producer,consumer} + load_kv_mbar_ptr: cute.struct.MemRange[ + Int64, self.kv_stage * 2 + ] # load_kv_{producer,consumer} + mma_s_mbar_ptr: cute.struct.MemRange[Int64, self.qk_acc_stage * 2] + p_mma_mbar_ptr: cute.struct.MemRange[Int64, self.qk_acc_stage * 2] + # Softmax -> Correction signaling barriers (row_max/row_sum vec ready) + s_corr_mbar_ptr: cute.struct.MemRange[ + Int64, self.qk_acc_stage * 2 + ] # s_corr_{producer,consumer} + sum_mbar_ptr: cute.struct.MemRange[Int64, 2] + # MMA -> Correction ownership barriers for O_partial tokens (online rescale/finalize) + mma_corr_mbar_ptr: cute.struct.MemRange[ + Int64, self.mma_corr_stage * 2 + ] # mma_corr_{producer,consumer} + # A CTA-wide "TMEM lifetime" barrier used to safely deallocate TMEM after all users finish. + tmem_dealloc_mbar_ptr: Int64 + # Tmem holding buffer + tmem_holding_buf: Int32 + # CLC pipeline barriers and response buffer + clc_mbar_ptr: cute.struct.MemRange[Int64, 2] + clc_response: cute.struct.MemRange[Int32, 4] + + self.shared_storage = SharedStorage + + grid = cute.round_up(grid, self.cluster_shape_mnk) + # Launch the kernel synchronously + self.kernel( + qk_tiled_mma, + pv_tiled_mma, + tma_atom_q, + tma_tensor_q, + tma_atom_k, + tma_tensor_k, + tma_atom_v, + tma_tensor_v, + o, + cum_seqlen_q, + cum_seqlen_k, + lse, + scale_softmax_log2, + scale_softmax, + scale_output, + window_size_left, + window_size_right, + self.cluster_layout_vmnk, + q_smem_layout_staged, + k_smem_layout_staged, + p_tmem_layout, + v_smem_layout_staged, + self.tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + min_blocks_per_mp=1, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + qk_tiled_mma: cute.TiledMma, + pv_tiled_mma: cute.TiledMma, + tma_atom_q: cute.CopyAtom, + mQ_qdl: cute.Tensor, + tma_atom_k: cute.CopyAtom, + mK_kdl: cute.Tensor, + tma_atom_v: cute.CopyAtom, + mV_dkl: cute.Tensor, + mO_qdl: cute.Tensor, + cum_seqlen_q: Optional[cute.Tensor], + cum_seqlen_k: Optional[cute.Tensor], + mLSE: Optional[cute.Tensor], + scale_softmax_log2: Float32, + scale_softmax: Float32, + scale_output: Float32, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + cluster_layout_vmnk: cute.Layout, + q_smem_layout_staged: cute.ComposedLayout, + k_smem_layout_staged: cute.ComposedLayout, + p_tmem_layout_staged: cute.ComposedLayout, + v_smem_layout_staged: cute.ComposedLayout, + tile_sched_params: FmhaStaticTileSchedulerParams | FmhaClcDynamicTileSchedulerParams, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # + # Prefetch tma desc + # + if warp_idx == self.load_warp_id: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_q) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_k) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_v) + + bidx, _, _ = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(qk_tiled_mma.thr_id.shape) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + + # Alloc + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + load_q_producer, load_q_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.q_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_q_bytes, + barrier_storage=storage.load_q_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + load_kv_producer, load_kv_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.kv_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_kv_bytes, + barrier_storage=storage.load_kv_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + mma_s_producer, mma_s_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.qk_acc_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + len(self.softmax_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + barrier_storage=storage.mma_s_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + p_mma_producer, p_mma_consumer = pipeline.PipelineAsyncUmma.create( + num_stages=self.qk_acc_stage, + producer_group=make_thread_cooperative_group( + len(self.softmax_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.p_mma_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + s_corr_producer, s_corr_consumer = pipeline.PipelineAsync.create( + num_stages=self.qk_acc_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.s_corr_mbar_ptr.data_ptr(), + defer_sync=True, + ).make_participants() + sum_producer, sum_consumer = pipeline.PipelineAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.sum_mbar_ptr.data_ptr(), + defer_sync=True, + ).make_participants() + mma_corr_producer, mma_corr_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_corr_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + len(self.correction_warp_ids) * self.threads_per_warp * self.cluster_shape_mnk[0], + ), + barrier_storage=storage.mma_corr_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.correction_warp_ids[0], + is_two_cta=True, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + tmem.allocate(self.tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + # Initialize CLC state if using dynamic scheduler + if cutlass.const_expr(self.use_clc_scheduler): + clc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + cluster_size = cute.size(self.cluster_shape_mnk) + num_clc_consumer_threads = self.threads_per_warp * ( + 1 # sched_warp (CTA 0 only) + + cluster_size + * ( + len(self.softmax_warp_ids) + + len(self.correction_warp_ids) + + 1 # mma_warp + + 1 # load_warp + ) + ) + clc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_clc_consumer_threads + ) + clc_response_ptr = storage.clc_response.data_ptr() + clc = ClcState.create( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_sched_params.clc_hw_params(), + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + ), + pipeline=pipeline.PipelineClcFetchAsync.create( + barrier_storage=storage.clc_mbar_ptr.data_ptr(), + num_stages=self.num_clc_stage, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=self.num_clc_response_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ), + consumer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_clc_stage + ), + producer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_clc_stage + ), + ) + else: + clc = None + clc_response_ptr = None + + # Cluster arrive after barrier init + pipeline.pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True) + + sQ = smem.allocate_tensor( + element_type=self.q_dtype, + layout=q_smem_layout_staged.outer, + swizzle=q_smem_layout_staged.inner, + byte_alignment=128, + ) + sK = smem.allocate_tensor( + element_type=self.k_dtype, + layout=k_smem_layout_staged.outer, + swizzle=k_smem_layout_staged.inner, + byte_alignment=128, + ) + # K and V now use separate memory since we removed the transform stage + sV = smem.allocate_tensor( + element_type=self.v_dtype, + layout=v_smem_layout_staged.outer, + swizzle=v_smem_layout_staged.inner, + byte_alignment=128, + ) + + sSum = smem.allocate_tensor( + element_type=self.qk_acc_dtype, + layout=cute.make_layout(len(self.softmax_warp_ids) * self.threads_per_warp), + byte_alignment=128, + ) + qk_thr_mma = qk_tiled_mma.get_slice(mma_tile_coord_v) # default 1sm + pv_thr_mma = pv_tiled_mma.get_slice(mma_tile_coord_v) # default 1sm + tSrQ = qk_thr_mma.make_fragment_A(sQ) + tSrK = qk_thr_mma.make_fragment_B(sK) + tOrV = pv_thr_mma.make_fragment_B(sV) + qk_acc_shape = qk_thr_mma.partition_shape_C((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tStS = qk_thr_mma.make_fragment_C(cute.append(qk_acc_shape, self.qk_acc_stage)) + pv_acc_shape = pv_thr_mma.partition_shape_C((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOtO = pv_thr_mma.make_fragment_C(pv_acc_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + self.iterations_pv, + stride=self.pv_mma_tiler[1] // self.tmem_warp_shape_mn[1], + ), + ) + tStS = cute.make_tensor(tStS.iterator + self.tmem_s_offset, tStS.layout) + tOtO_staged = cute.make_tensor(tOtO.iterator + self.tmem_o_offset, tOtO_layout) + + # /////////////////////////////////////////////////////////////////////////////// + # EMPTY + # /////////////////////////////////////////////////////////////////////////////// + for _i in cutlass.range_constexpr(len(self.empty_warp_id)): + if warp_idx == self.empty_warp_id[_i]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + if cutlass.const_expr(self.use_clc_scheduler): + tile_sched = FmhaClcDynamicTileScheduler.create( + tile_sched_params, + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + clc, + ) + else: + blk_idx = cute.arch.block_idx() + tile_sched = FmhaStaticTileScheduler( + tile_sched_params, blk_idx[0], blk_idx, cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + # Cluster wait + pipeline.pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk) + + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx # (q_tile_idx, 0, (head_idx, batch_idx)) + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + continue_cond = False + batch_coord = curr_block_coord[2][1] + seqlen_q = mQ_qdl.shape[0] + seqlen_k = mK_kdl.shape[0] + cuseqlen_q = Int32(0) + cuseqlen_k = Int32(0) + block_offset = ( + Int32(0), + Int32(0), + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + block_offset = ( + cuseqlen_q, + cuseqlen_k, + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if not continue_cond: + mQ_qdl_ = cute.domain_offset(cute.select(block_offset, mode=[0, 2, 3]), mQ_qdl) + mK_kdl_ = cute.domain_offset(cute.select(block_offset, mode=[1, 2, 3]), mK_kdl) + mV_dkl_ = cute.domain_offset(cute.select(block_offset, mode=[2, 1, 3]), mV_dkl) + # Local tile partition global tensors + q_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # (bM, bK, loopM, loopK, loopL) + gQ_qdl = cute.flat_divide(mQ_qdl_, cute.select(self.qk_mma_tiler, mode=[0, 2])) + tSgQ_qdl = qk_thr_mma.partition_A(gQ_qdl) + tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_q, + block_in_cluster_coord_vmnk[2], + q_cta_layout, + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ_qdl, 0, 3), + ) + kv_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + gK_kdl = cute.flat_divide(mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2])) + tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_k, + block_in_cluster_coord_vmnk[1], + kv_cta_layout, + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_kdl, 0, 3), + ) + + gV_dkl = cute.flat_divide(mV_dkl_, cute.select(self.pv_mma_tiler, mode=[1, 2])) + tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_v, + block_in_cluster_coord_vmnk[1], + kv_cta_layout, + cute.group_modes(sV, 0, 3), + cute.group_modes(tSgV_dkl, 0, 3), + ) + # ((atom_v, rest_v), RestK) + tQgQ = tQgQ_qdl[None, mma_block_coord[0], None, mma_block_coord[2]] + # ((atom_v, rest_v), RestN, RestK) + tKgK = tKgK_kdl[None, None, None, mma_block_coord[2]] + # ((atom_v, rest_v), RestN, RestK) + tVgV = tVgV_dkl[None, None, None, mma_block_coord[2]] + + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + seqlen_kv_loop_end = seqlen_kv_loop_start + seqlen_kv_loop_steps + # Q + for iter in cutlass.range(self.iterations_qk, unroll=1): + q_handle = load_q_producer.acquire_and_advance() + cute.copy( + tma_atom_q, + tQgQ[None, iter], + tQsQ[None, q_handle.index], + tma_bar_ptr=q_handle.barrier, + ) + + # K0 + kv_coord = seqlen_kv_loop_start + for iter in cutlass.range(self.iterations_qk, unroll=1): + k_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[None, kv_coord, iter], + tKsK[None, k_handle.index], + tma_bar_ptr=k_handle.barrier, + ) + kv_coord += 1 + + for i in cutlass.range(1, seqlen_kv_loop_steps, 1, unroll=1): + # Ki + for iter in cutlass.range(self.iterations_qk, unroll=1): + k_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[None, kv_coord, iter], + tKsK[None, k_handle.index], + tma_bar_ptr=k_handle.barrier, + ) + # Vi-1 + for iter in cutlass.range(self.iterations_pv, unroll=1): + v_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[None, iter, kv_coord - 1], + tVsV[None, v_handle.index], + tma_bar_ptr=v_handle.barrier, + ) + kv_coord += 1 + # Vend + for iter in cutlass.range(self.iterations_pv, unroll=1): + v_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[None, iter, seqlen_kv_loop_end - 1], + tVsV[None, v_handle.index], + tma_bar_ptr=v_handle.barrier, + ) + + work_tile = tile_sched.advance_to_next_work() + # End of persistent scheduler loop + load_kv_producer.tail() + load_q_producer.tail() + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + continue_cond = False + seqlen_q = mQ_qdl.shape[0] + seqlen_k = mK_kdl.shape[0] + batch_coord = curr_block_coord[2][1] + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + seqlen_kv_loop_end = seqlen_kv_loop_start + seqlen_kv_loop_steps + + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + load_q_releaser = load_q_consumer.clone() + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + if seqlen_kv_loop_steps > 1: + # QK0 + if is_leader_cta: + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + load_q_consumer.wait_and_advance() + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_kv_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + s_handle.commit() + for i in cutlass.range(1, seqlen_kv_loop_steps - 1, 1, unroll=1): + # QKi + if is_leader_cta: + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_kv_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + s_handle.commit() + + # PVi-1 + p_handle = p_mma_consumer.wait_and_advance() + o_handle = mma_corr_producer.acquire_and_advance() + pv_whether_acc = pv_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_pv, unroll=1): + v_handle = load_kv_consumer.wait_and_advance() + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, pv_whether_acc) + tOtO_slice = tOtO_staged[None, None, None, iter] + tStS_slice = tStS[None, None, None, p_handle.index] + tP = cute.make_tensor( + tStS_slice.iterator, p_tmem_layout_staged.outer + ) + tOrP = pv_thr_mma.make_fragment_A(tP) + tOrP_slice = cute.make_tensor( + cute.recast_ptr(tStS_slice.iterator, dtype=self.q_dtype), + tOrP.layout, + ) + tOrV_slice = tOrV[None, None, None, v_handle.index] + num_kphases = cute.size(tOrV_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + pv_tiled_mma, + tOtO_slice, + tOrP_slice[kphase_coord], + tOrV_slice[kphase_coord], + tOtO_slice, + ) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + o_handle.commit() + p_handle.release() + if is_leader_cta: + # QKend + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_kv_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + load_q_releaser.release() + load_q_releaser.advance() + s_handle.commit() + + # PVend-1 + p_handle = p_mma_consumer.wait_and_advance() + o_handle = mma_corr_producer.acquire_and_advance() + pv_whether_acc = pv_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_pv, unroll=1): + v_handle = load_kv_consumer.wait_and_advance() + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, pv_whether_acc) + tOtO_slice = tOtO_staged[None, None, None, iter] + tStS_slice = tStS[None, None, None, p_handle.index] + tP = cute.make_tensor( + tStS_slice.iterator, p_tmem_layout_staged.outer + ) + tOrP = pv_thr_mma.make_fragment_A(tP) + tOrP_slice = cute.make_tensor( + cute.recast_ptr(tStS_slice.iterator, dtype=self.q_dtype), + tOrP.layout, + ) + tOrV_slice = tOrV[None, None, None, v_handle.index] + num_kphases = cute.size(tOrV_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + pv_tiled_mma, + tOtO_slice, + tOrP_slice[kphase_coord], + tOrV_slice[kphase_coord], + tOtO_slice, + ) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + o_handle.commit() + p_handle.release() + else: + if is_leader_cta: + # QK0 + s_handle = mma_s_producer.acquire_and_advance() + tStS_slice = tStS[None, None, None, s_handle.index] + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for iter in cutlass.range(self.iterations_qk, unroll=1): + load_q_consumer.wait_and_advance() + tSrQ_slice = tSrQ[None, None, None, iter] + k_handle = load_kv_consumer.wait_and_advance() + tSrK_trans_slice = tSrK[None, None, None, k_handle.index] + num_kphases = cute.size(tSrQ_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + qk_tiled_mma, + tStS_slice, + tSrQ_slice[kphase_coord], + tSrK_trans_slice[kphase_coord], + tStS_slice, + ) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + k_handle.release() + load_q_releaser.release() + load_q_releaser.advance() + s_handle.commit() + + if is_leader_cta: + # PVend + p_handle = p_mma_consumer.wait_and_advance() + o_handle = mma_corr_producer.acquire_and_advance() + pv_whether_acc = pv_tiled_mma.get(tcgen05.Field.ACCUMULATE) + for iter in cutlass.range(self.iterations_pv, unroll=1): + v_handle = load_kv_consumer.wait_and_advance() + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, pv_whether_acc) + tOtO_slice = tOtO_staged[None, None, None, iter] + tStS_slice = tStS[None, None, None, p_handle.index] + tP = cute.make_tensor(tStS_slice.iterator, p_tmem_layout_staged.outer) + tOrP = pv_thr_mma.make_fragment_A(tP) + tOrP_slice = cute.make_tensor( + cute.recast_ptr(tStS_slice.iterator, dtype=self.q_dtype), + tOrP.layout, + ) + tOrV_slice = tOrV[None, None, None, v_handle.index] + num_kphases = cute.size(tOrV_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = (None, None, kphase_idx) + cute.gemm( + pv_tiled_mma, + tOtO_slice, + tOrP_slice[kphase_coord], + tOrV_slice[kphase_coord], + tOtO_slice, + ) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + v_handle.release() + o_handle.commit() + p_handle.release() + work_tile = tile_sched.advance_to_next_work() + # End of persistent scheduler loop + mma_s_producer.tail() + mma_corr_producer.tail() + + if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax_warp_ids[0]: + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + batch_coord = curr_block_coord[2][1] + continue_cond = False + seqlen_q = mQ_qdl.shape[0] + seqlen_k = mK_kdl.shape[0] + cuseqlen_q = Int32(0) + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + row_max = -Float32.inf + row_max_prev = -Float32.inf + row_sum = 0.0 + + start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + end_count = start_count + trip_count + if cutlass.const_expr(self.use_semantic_trip_range): + n_block_min_causal_local_mask, n_block_min_before_local_mask = ( + FusedMask.get_trip_mask_bounds_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + cS_base = cute.make_identity_tensor( + (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) + ) + cS = cute.domain_offset((mma_block_coord[0] * self.qk_mma_tiler[0], 0), cS_base) + tScS = qk_thr_mma.partition_C(cS) + + for step in cutlass.range(start_count, end_count, 1, unroll=1): + cS_iter = cute.domain_offset((0, step * self.qk_mma_tiler[1]), cS) + tScS_iter = qk_thr_mma.partition_C(cS_iter) + if cutlass.const_expr(self.use_semantic_trip_range): + need_apply_mask = ( + step >= n_block_min_causal_local_mask + or step < n_block_min_before_local_mask + ) + else: + # Residual path only needs seqlen masking on the last K tile. + need_apply_mask = step == end_count - 1 + # Si -> Pi + ( + row_max, + row_sum, + mma_s_consumer, + p_mma_producer, + s_corr_producer, + ) = self.softmax_step( + (need_apply_mask, window_size_left, window_size_right), + ( + row_max_prev, + row_sum, + seqlen_q, + seqlen_k, + scale_softmax_log2, + ), + (tStS, tScS_iter), + (mma_s_consumer, p_mma_producer, s_corr_producer), + ) + row_max_prev = row_max + sum_producer = self.store_sum_max( + row_max, + mLSE, + row_sum, + sSum, + sum_producer, + curr_block_coord, + seqlen_q, + cum_seqlen_q, + cuseqlen_q, + scale_softmax, + ) + work_tile = tile_sched.advance_to_next_work() + p_mma_producer.tail() + s_corr_producer.tail() + + # /////////////////////////////////////////////////////////////////////////////// + # Correction + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_correction) + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + mma_block_coord = ( + curr_block_coord[0] // cute.size(qk_tiled_mma.thr_id.shape), + curr_block_coord[1], + curr_block_coord[2], + ) + batch_coord = curr_block_coord[2][1] + seqlen_q = mQ_qdl.shape[0] + seqlen_k = mK_kdl.shape[0] + continue_cond = False + cuseqlen_q = Int32(0) + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + mO_qdl_eff = mO_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + block_offset_o = ( + cuseqlen_q, + Int32(0), + Int32(0), + ((Int32(0), Int32(0)), Int32(0)), + ) + mO_qdl_eff = cute.domain_offset( + cute.select(block_offset_o, mode=[0, 2, 3]), mO_qdl + ) + + # (bM, bN, loopM, loopN, loopL) + gO_qdl = cute.flat_divide( + mO_qdl_eff, cute.select(self.pv_block_tiler, mode=[0, 1]) + ) + cO_qdl = cute.flat_divide( + cute.make_identity_tensor(mO_qdl_eff.shape), + cute.select(self.pv_block_tiler, mode=[0, 1]), + ) + + _, seqlen_kv_loop_steps = FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + gO_staged = gO_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + cO_staged = cO_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr_mma.partition_C(cS) + + # Empty step as the first step is no need for correction + stats_handle = s_corr_consumer.wait_and_advance() + stats_handle.release() + for step in cutlass.range(1, seqlen_kv_loop_steps, 1, unroll=1): + # Oi-1 -> Oi + mma_corr_consumer, s_corr_consumer = self.correction_rescale( + scale_softmax_log2, + (s_corr_consumer, tStS, tScS), + (mma_corr_consumer, tOtO_staged, cO_staged), + self.epi_tile, + ) + # O_partial -> O_final + mma_corr_consumer, sum_consumer = self.correction_epilog( + (seqlen_q, scale_output), + (sum_consumer, sSum), + (mma_corr_consumer, gO_staged, cO_staged, tOtO_staged), + self.epi_tile, + ) + work_tile = tile_sched.advance_to_next_work() + # NOTE: tmem.free() moved to kernel end to enable cluster-wide sync + + # /////////////////////////////////////////////////////////////////////////////// + # Scheduler Warp (only for CLC dynamic scheduler) + # /////////////////////////////////////////////////////////////////////////////// + if cutlass.const_expr(self.use_clc_scheduler): + is_first_cta_in_cluster = cta_rank_in_cluster == 0 + + if warp_idx == self.sched_warp_id and is_first_cta_in_cluster: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + while work_tile.is_valid_tile: + tile_sched.prefetch_next_work() + work_tile = tile_sched.advance_to_next_work() + tile_sched.producer_tail() + + # /////////////////////////////////////////////////////////////////////////////// + # Empty warps reg dealloc + # /////////////////////////////////////////////////////////////////////////////// + if cutlass.const_expr(self.use_clc_scheduler): + if warp_idx > self.load_warp_id: + if not (warp_idx == self.sched_warp_id and is_first_cta_in_cluster): + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + else: + if warp_idx > self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + # /////////////////////////////////////////////////////////////////////////////// + # Cooperative TMEM Deallocation (2CTA) + # /////////////////////////////////////////////////////////////////////////////// + # All warps (including scheduler) have finished by this point. + # Cluster-wide sync ensures both CTAs reach here before dealloc. + cute.arch.cluster_arrive() + cute.arch.cluster_wait() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + return + + @cute.jit + def softmax_step( + self, + mask_args: Tuple, + value_args: Tuple, + tensor_args: Tuple, + pipeline_args: Tuple, + ) -> Tuple[Float32, Float32, pipeline.PipelineConsumer, pipeline.PipelineProducer]: + need_apply_mask, window_size_left, window_size_right = mask_args + row_max, row_sum, seqlen_q, seqlen_k, scale_softmax_log2 = value_args + tStS, tScS = tensor_args + mma_s_consumer, p_mma_producer, s_corr_producer = pipeline_args + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.softmax_warp_ids)) + s_handle = mma_s_consumer.wait_and_advance() + tStS_slice = tStS[(None, None), 0, 0, s_handle.index] + tScS_slice = tScS[(None, None), 0, 0] + tmem_load_atom = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition(32)), self.qk_acc_dtype + ) + tmem_tiled_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS_slice) + thr_load = tmem_tiled_load.get_slice(thread_idx) + tTMEM_LOADtS = thr_load.partition_S(tStS_slice) + tTMEM_LOADcS = thr_load.partition_D(tScS_slice) + tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype) + cute.copy(tmem_tiled_load, tTMEM_LOADtS, tTMEM_LOADrS) + + cute.arch.fence_view_async_tmem_load() + s_handle.release() + if need_apply_mask: + FusedMask.apply_mask_via_causal_local( + tTMEM_LOADrS, + tTMEM_LOADcS, + seqlen_q, + seqlen_k, + self.use_semantic_trip_range, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + old_row_max = row_max + row_max = tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0) + row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + row_max_safe = 0.0 + + stats_handle = s_corr_producer.acquire_and_advance() + stats_layout = cute.composition( + tStS_slice.layout, cute.make_layout((tStS_slice.shape[0], 2)) + ) + stats_c_layout = cute.composition( + tScS_slice.layout, cute.make_layout((tScS_slice.shape[0], 2)) + ) + tOtStats = cute.make_tensor(tStS_slice.iterator + self.tilePlikeFP32, stats_layout) + tOcStats = cute.make_tensor(tScS_slice.iterator, stats_c_layout) + tmem_store_stats_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + tiled_tmem_store_stats = tcgen05.make_tmem_copy(tmem_store_stats_atom, tOtStats) + thr_tmem_store_stats = tiled_tmem_store_stats.get_slice(thread_idx) + tTMEM_STOREcStats = thr_tmem_store_stats.partition_S(tOcStats) + tTMEM_STORErStats = cute.make_rmem_tensor(tTMEM_STOREcStats.shape, self.qk_acc_dtype) + tTMEM_STORErStats[0] = old_row_max + tTMEM_STORErStats[1] = row_max_safe + tTMEM_STOREtStats = thr_tmem_store_stats.partition_D(tOtStats) + cute.copy(tiled_tmem_store_stats, tTMEM_STORErStats, tTMEM_STOREtStats) + cute.arch.fence_view_async_tmem_store() + stats_handle.commit() + + scale = scale_softmax_log2 + minus_row_max_scale = (0.0 - row_max_safe) * scale + tTMEM_STORErP = cute.make_rmem_tensor(tTMEM_LOADrS.shape, self.q_dtype) + for k in range(0, cute.size(tTMEM_LOADrS), 2): + tTMEM_LOADrS[k], tTMEM_LOADrS[k + 1] = cute.arch.fma_packed_f32x2( + (tTMEM_LOADrS[k], tTMEM_LOADrS[k + 1]), + (scale, scale), + (minus_row_max_scale, minus_row_max_scale), + ) + tTMEM_LOADrS[k] = cute.math.exp2(tTMEM_LOADrS[k], fastmath=True) + tTMEM_LOADrS[k + 1] = cute.math.exp2(tTMEM_LOADrS[k + 1], fastmath=True) + s_vec = tTMEM_LOADrS.load() + tTMEM_STORErP.store(s_vec.to(self.q_dtype)) + + p_handle = p_mma_producer.acquire_and_advance() + tmem_store_atom = cute.make_copy_atom( + tcgen05.St32x32bOp(tcgen05.Repetition(32)), self.qk_acc_dtype + ) + tilePlikeFP32 = tStS_slice.shape[1] // Float32.width * self.q_dtype.width + tStS_P_layout = cute.composition( + tStS_slice.layout, cute.make_layout((tStS_slice.shape[0], tilePlikeFP32)) + ) + tStS_P = cute.make_tensor(tStS_slice.iterator, tStS_P_layout) + tScS_P_layout = cute.composition( + tScS_slice.layout, cute.make_layout((tScS_slice.shape[0], tilePlikeFP32)) + ) + tScS_P = cute.make_tensor(tScS_slice.iterator, tScS_P_layout) + tmem_tiled_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P) + thr_store = tmem_tiled_store.get_slice(thread_idx) + tTMEM_STOREtP = thr_store.partition_D(tStS_P) + tTMEM_STOREcS = thr_store.partition_S(tScS_P) + tTMEM_STORErP_ = cute.make_tensor( + cute.recast_ptr(tTMEM_STORErP.iterator, dtype=self.qk_acc_dtype), + tTMEM_STOREcS.shape, + ) + cute.copy(tmem_tiled_store, tTMEM_STORErP_, tTMEM_STOREtP) + cute.arch.fence_view_async_tmem_store() + + p_handle.commit() + acc_scale_ = scale * (old_row_max - row_max_safe) + acc_scale = cute.math.exp2(acc_scale_, fastmath=True) * 0.5 + # TODO: calc row sum with TensorSSA + row_sum *= acc_scale + local_row_sum_0 = (row_sum, row_sum) + local_row_sum_1 = (0.0, 0.0) + local_row_sum_2 = (0.0, 0.0) + local_row_sum_3 = (0.0, 0.0) + reduction_unroll = 4 + frg_tile = cute.size(tTMEM_LOADrS) // reduction_unroll + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + for j in cutlass.range_constexpr(0, cute.size(tTMEM_LOADrS_frg, mode=[0]), 2): + local_row_sum_0 = cute.arch.add_packed_f32x2( + local_row_sum_0, (tTMEM_LOADrS_frg[j, 0], tTMEM_LOADrS_frg[j + 1, 0]) + ) + local_row_sum_1 = cute.arch.add_packed_f32x2( + local_row_sum_1, (tTMEM_LOADrS_frg[j, 1], tTMEM_LOADrS_frg[j + 1, 1]) + ) + local_row_sum_2 = cute.arch.add_packed_f32x2( + local_row_sum_2, (tTMEM_LOADrS_frg[j, 2], tTMEM_LOADrS_frg[j + 1, 2]) + ) + local_row_sum_3 = cute.arch.add_packed_f32x2( + local_row_sum_3, (tTMEM_LOADrS_frg[j, 3], tTMEM_LOADrS_frg[j + 1, 3]) + ) + local_row_sum_0 = cute.arch.add_packed_f32x2(local_row_sum_0, local_row_sum_1) + local_row_sum_2 = cute.arch.add_packed_f32x2(local_row_sum_2, local_row_sum_3) + local_row_sum_0 = cute.arch.add_packed_f32x2(local_row_sum_0, local_row_sum_2) + row_sum = local_row_sum_0[0] + local_row_sum_0[1] + return row_max, row_sum, mma_s_consumer, p_mma_producer, s_corr_producer + + @cute.jit + def correction_rescale( + self, + scale_softmax_log2: Float32, + stats_args: tuple, + o_args: tuple, + epi_tile: cute.Tile, + ) -> pipeline.PipelineConsumer: + (s_corr_consumer, tStS, tScS) = stats_args + (mma_o_consumer, tOtO_staged, cO_staged) = o_args + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.softmax_warp_ids)) + + stats_handle = s_corr_consumer.wait_and_advance() + tStS_slice = tStS[(None, None), 0, 0, stats_handle.index] + tScS_slice = tScS[(None, None), 0, 0] + stats_layout = cute.composition( + tStS_slice.layout, cute.make_layout((tStS_slice.shape[0], 2)) + ) + stats_c_layout = cute.composition( + tScS_slice.layout, cute.make_layout((tScS_slice.shape[0], 2)) + ) + tOtStats = cute.make_tensor(tStS_slice.iterator + self.tilePlikeFP32, stats_layout) + tOcStats = cute.make_tensor(tScS_slice.iterator, stats_c_layout) + tmem_load_stats_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + tiled_tmem_load_stats = tcgen05.make_tmem_copy(tmem_load_stats_atom, tOtStats) + thr_tmem_load_stats = tiled_tmem_load_stats.get_slice(thread_idx) + tTMEM_LOADtStats = thr_tmem_load_stats.partition_S(tOtStats) + tTMEM_LOADcStats = thr_tmem_load_stats.partition_D(tOcStats) + tTMEM_LOADrStats = cute.make_rmem_tensor(tTMEM_LOADcStats.shape, self.qk_acc_dtype) + cute.copy(tiled_tmem_load_stats, tTMEM_LOADtStats, tTMEM_LOADrStats) + + scale = scale_softmax_log2 * (tTMEM_LOADrStats[0] - tTMEM_LOADrStats[1]) + scale = cute.math.exp2(scale, fastmath=True) + stats_handle.release() + o_handle = mma_o_consumer.wait_and_advance() + for iter in cutlass.range(self.iterations_pv, unroll_full=True): + tOtO = tOtO_staged[(None, None), 0, 0, iter] + cO = cO_staged[None, None, iter] + tOtO_epi = cute.zipped_divide(tOtO, epi_tile) + cO_epi = cute.zipped_divide(cO, epi_tile) + tmem_load_atom = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition(16)), + self.pv_acc_dtype, + ) + tmem_tiled_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_epi) + thr_load = tmem_tiled_load.get_slice(thread_idx) + tmem_store_atom = cute.make_copy_atom( + tcgen05.St32x32bOp(tcgen05.Repetition(16)), + self.pv_acc_dtype, + ) + tmem_store_atom = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_epi) + thr_store = tmem_store_atom.get_slice(thread_idx) + tTMEM_LOADtO = thr_load.partition_S(tOtO_epi) + tTMEM_LOADcO = thr_load.partition_D(cO_epi) + tTMEM_STOREtO = thr_store.partition_D(tOtO_epi) + tTMrO = cute.make_rmem_tensor_like( + cute.append( + cute.make_layout(tTMEM_LOADcO[None, 0, 0].shape), + cute.make_layout(2, stride=cute.size(tTMEM_LOADcO[None, 0, 0].shape)), + ), + self.pv_acc_dtype, + ) + tTMEM_LOADtO_0 = tTMEM_LOADtO[None, 0, 0] + cute.copy(tmem_tiled_load, tTMEM_LOADtO_0, tTMrO[None, 0]) + iter_num = cute.size(tTMEM_LOADtO, mode=[1]) + for i in cutlass.range(1, iter_num, unroll_full=True): + tTMEM_LOADtO_i = tTMEM_LOADtO[None, i, 0] + cute.copy(tmem_tiled_load, tTMEM_LOADtO_i, tTMrO[None, i % 2]) + for j in cutlass.range(0, cute.size(tTMrO, mode=[0]), 2, unroll_full=True): + tTMrO[j, (i - 1) % 2], tTMrO[j + 1, (i - 1) % 2] = cute.arch.mul_packed_f32x2( + (tTMrO[j, (i - 1) % 2], tTMrO[j + 1, (i - 1) % 2]), + (scale, scale), + ) + tTMEM_STOREtO_prev_i = tTMEM_STOREtO[None, i - 1, 0] + cute.copy(tmem_store_atom, tTMrO[None, (i - 1) % 2], tTMEM_STOREtO_prev_i) + + for j in cutlass.range(0, cute.size(tTMrO, mode=[0]), 2, unroll_full=True): + tTMrO[j, (iter_num - 1) % 2], tTMrO[j + 1, (iter_num - 1) % 2] = ( + cute.arch.mul_packed_f32x2( + ( + tTMrO[j, (iter_num - 1) % 2], + tTMrO[j + 1, (iter_num - 1) % 2], + ), + (scale, scale), + ) + ) + cute.copy( + tmem_store_atom, + tTMrO[None, (iter_num - 1) % 2], + tTMEM_STOREtO[None, iter_num - 1, 0], + ) + cute.arch.fence_view_async_tmem_store() + o_handle.release() + return mma_o_consumer, s_corr_consumer + + @cute.jit + def correction_epilog( + self, + value_args: Tuple, + sum_args: Tuple, + o_args: Tuple, + epi_tile: cute.Tile, + ) -> Tuple[pipeline.PipelineConsumer, pipeline.PipelineProducer]: + (seqlen_q, scale_output) = value_args + (sum_consumer, sSum) = sum_args + (mma_o_consumer, gO_staged, cO_staged, tOtO_staged) = o_args + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.softmax_warp_ids)) + sum_handle = sum_consumer.wait_and_advance() + row_sum = sSum[thread_idx] + cute.arch.fence_view_async_shared() + sum_handle.release() + scale = scale_output / row_sum + o_handle = mma_o_consumer.wait_and_advance() + for iter in cutlass.range(self.iterations_pv): + gO = gO_staged[None, None, iter] + cO = cO_staged[None, None, iter] + tOtO = tOtO_staged[(None, None), 0, 0, iter] + tOtO_epi = cute.zipped_divide(tOtO, epi_tile) + cO_epi = cute.zipped_divide(cO, epi_tile) + gO_epi = cute.zipped_divide(gO, epi_tile) + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.softmax_warp_ids)) + tmem_copy_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.pv_acc_dtype + ) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_epi) + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_epi) + tTMEM_LOADgO = thr_tmem_load.partition_D(gO_epi) + tTMEM_LOADcO = thr_tmem_load.partition_D(cO_epi) + for i in cutlass.range(cute.size(tTMEM_LOADtO, mode=[1]), unroll_full=True): + tTMEM_LOADtO_i = tTMEM_LOADtO[None, i, 0] + tTMEM_LOADgO_i = tTMEM_LOADgO[None, i, 0] + tTMEM_LOADcO_i = tTMEM_LOADcO[None, i, 0] + tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO[None, 0, i].shape, self.pv_acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO) + for j in cutlass.range(0, cute.size(tTMrO), 2, unroll_full=True): + tTMrO[j], tTMrO[j + 1] = cute.arch.mul_packed_f32x2( + (tTMrO[j], tTMrO[j + 1]), + (scale, scale), + ) + tSMrO = cute.make_rmem_tensor(tTMrO.shape, self.o_dtype) + o_vec = tTMrO.load() + tSMrO.store(o_vec.to(self.o_dtype)) + if cute.elem_less(tTMEM_LOADcO_i[0][0], seqlen_q): + cute.autovec_copy(tSMrO, tTMEM_LOADgO_i) + o_handle.release() + return mma_o_consumer, sum_consumer + + @cute.jit + def store_sum_max( + self, + row_max, + mLSE, + row_sum, + sSum, + sum_producer, + current_block_coord, + seqlen_q, + cum_seqlen_q, + cuseqlen_q, + scale_softmax, + ): + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.softmax_warp_ids)) + sum_handle = sum_producer.acquire_and_advance() + sSum[thread_idx] = row_sum + cute.arch.fence_view_async_shared() + sum_handle.commit() + + if cutlass.const_expr(mLSE is not None): + q_idx = current_block_coord[0] * self.cta_tiler[0] + tidx + hb_idx = ( + (current_block_coord[2][0], Int32(0)) + if cutlass.const_expr(cum_seqlen_q is not None) + else current_block_coord[2] + ) + lse_value = scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + if cute.elem_less(q_idx, seqlen_q): + global_q_idx = ( + q_idx + cuseqlen_q if cutlass.const_expr(cum_seqlen_q is not None) else q_idx + ) + mLSE[global_q_idx, hb_idx] = lse_value + return sum_producer diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index ae57858acd5..ff820e59626 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Tri Dao. +# Copyright (c) 2025, Tri Dao, Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms, Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin. from enum import IntEnum, auto from typing import Optional, Tuple, Protocol, runtime_checkable @@ -16,6 +16,13 @@ from cutlass import Int32, const_expr from cutlass.cute import FastDivmodDivisor from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams +from cutlass.cute.typing import Boolean +from cutlass.cutlass_dsl import ( + min as dsl_min, + extract_mlir_values, + new_from_mlir_values, +) +from cutlass.utils.hardware_info import HardwareInfo from quack.cute_dsl_utils import ParamsBase @@ -1094,3 +1101,536 @@ def __new_from_mlir_values__(self, values): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return self.__class__(*obj_list, loc=self._loc) + + +# ----------------------------------------------------------------------------- +# SM100 FMHA-specific schedulers (kept separate from generic schedulers). +# ----------------------------------------------------------------------------- + + +class Sm100FmhaStaticTileSchedulerParams: + """A class to represent parameters for the FMHA (Fused Multi-Head Attention) static tile scheduler. + + This class holds the configuration parameters needed to initialize and configure + the tile scheduler for FMHA operations. + + :ivar is_persistent: Whether to use persistent kernel mode. + :type is_persistent: bool + :ivar problem_shape_mbh: Problem shape in (M, B, H) format. + :type problem_shape_mbh: cute.Shape + """ + + def __init__( + self, + is_persistent: bool, + problem_shape_mbh: cute.Shape, + *, + loc=None, + ip=None, + ): + """ + Initializes the Sm100FmhaStaticTileSchedulerParams with the given parameters. + + :param is_persistent: Whether to use persistent kernel mode. + :type is_persistent: bool + :param problem_shape_mbh: Problem shape in (M, B, H) format. + :type problem_shape_mbh: cute.Shape + """ + self.is_persistent = is_persistent + self.problem_shape_mbh = problem_shape_mbh + self._loc = loc + self._ip = ip + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.problem_shape_mbh]: + obj_values = extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.problem_shape_mbh], self._values_pos): + obj_list.append(new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return Sm100FmhaStaticTileSchedulerParams( + self.is_persistent, *(tuple(obj_list)), loc=self._loc + ) + + +class Sm100FmhaStaticTileScheduler: + """A static tile scheduler for FMHA (Fused Multi-Head Attention) operations. + + This class manages the scheduling of work tiles for FMHA kernels, supporting + both persistent and non-persistent kernel modes. It tracks the current work + position and advances through the problem space efficiently. + + :ivar _params: Scheduler parameters. + :type _params: Sm100FmhaStaticTileSchedulerParams + :ivar _blk_coord: Block coordinates. + :type _blk_coord: cute.Coord + :ivar _grid_shape: Grid shape for the kernel. + :type _grid_shape: cute.Shape + :ivar _is_persistent: Whether to use persistent kernel mode. + :type _is_persistent: bool + :ivar _current_work_linear_idx: Current linear work index. + :type _current_work_linear_idx: Int32 + :ivar _problem_shape_mbh: Problem shape in (M, B, H) format. + :type _problem_shape_mbh: cute.Layout + :ivar _num_blocks: Number of blocks in the problem. + :type _num_blocks: Int32 + :ivar _is_first_block: Whether this is the first block. + :type _is_first_block: bool + :ivar num_persistent_sm: Number of persistent SMs. + :type num_persistent_sm: Int32 + """ + + def __init__( + self, + params: Sm100FmhaStaticTileSchedulerParams, + current_work_linear_idx: Int32, + blk_coord: cute.Coord, + grid_shape: cute.Shape, + *, + loc=None, + ip=None, + ): + """ + Initializes the Sm100FmhaStaticTileScheduler with the given parameters. + + :param params: Scheduler parameters. + :type params: Sm100FmhaStaticTileSchedulerParams + :param current_work_linear_idx: Current linear work index. + :type current_work_linear_idx: Int32 + :param blk_coord: Block coordinates. + :type blk_coord: cute.Coord + :param grid_shape: Grid shape for the kernel. + :type grid_shape: cute.Shape + """ + self._params = params + self._blk_coord = blk_coord + self._grid_shape = grid_shape + self._is_persistent = params.is_persistent + self._current_work_linear_idx = current_work_linear_idx + self._problem_shape_mbh = cute.make_layout(params.problem_shape_mbh, loc=loc, ip=ip) + self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip) + self._is_first_block = True + self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) + self._loc = loc + self._ip = ip + + # called by host + @staticmethod + def get_grid_shape( + params: Sm100FmhaStaticTileSchedulerParams, + *, + loc=None, + ip=None, + ) -> cute.Shape: + """ + Determine the grid shape for the FMHA kernel. + + For persistent kernels, the grid shape is limited by the number of SMs + (Streaming Multiprocessors) available on the device. For non-persistent + kernels, the grid shape matches the problem shape. + + :param params: Scheduler parameters. + :type params: Sm100FmhaStaticTileSchedulerParams + + :return: Grid shape as (M, B, H) tuple. + :rtype: cute.Shape + """ + if params.is_persistent: + hardware_info = HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + return ( + dsl_min(sm_count, cute.size(params.problem_shape_mbh, loc=loc, ip=ip)), + 1, + 1, + ) + else: + return params.problem_shape_mbh + + @staticmethod + def check_valid_work_for_seqlen_q( + q_tiler: int, + current_idx: Int32, + seqlen_q: Int32, + ) -> Boolean: + """ + Check if the current work index is valid for the given query sequence length. + + This method verifies that the current work tile index multiplied by the + query tiler size is within the bounds of the query sequence length. + + :param q_tiler: Query tiler size. + :type q_tiler: int + :param current_idx: Current work index. + :type current_idx: Int32 + :param seqlen_q: Query sequence length. + :type seqlen_q: Int32 + + :return: True if the work is valid, False otherwise. + :rtype: Boolean + """ + return current_idx * q_tiler < seqlen_q + + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + """ + Get information about the current work tile. + + Determines if the current work is valid and computes the tile coordinates + based on whether the kernel is persistent or non-persistent. + + :return: WorkTileInfo containing tile coordinates and validity flag. + :rtype: WorkTileInfo + """ + is_valid = ( + self._current_work_linear_idx < self._num_blocks + if self._is_persistent + else self._is_first_block + ) + + blk_coord = (0, 0, 0) + if self._is_persistent: + blk_coord = self._problem_shape_mbh.get_hier_coord( + self._current_work_linear_idx, loc=loc, ip=ip + ) + else: + blk_coord = self._blk_coord + + # cur_tile_coord is (mid, 0, (bid, hid)) + cur_tile_coord = ( + blk_coord[0], + 0, + (blk_coord[1], blk_coord[2]), + ) + + return cutlass.utils.WorkTileInfo(cur_tile_coord, is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + """ + Get the initial work tile information. + + :return: Initial WorkTileInfo. + :rtype: WorkTileInfo + """ + return self.get_current_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): + """ + Advance to the next work tile and return it. + + For persistent kernels, advances by the number of persistent SMs. + For non-persistent kernels, marks that the first block has been processed. + """ + if self._is_persistent: + self._current_work_linear_idx += advance_count * self.num_persistent_sm + self._is_first_block = False + return self.get_current_work() + + def prefetch_next_work(self, *, loc=None, ip=None): + """No-op for static scheduler.""" + pass + + def producer_tail(self, *, loc=None, ip=None): + """No-op for static scheduler.""" + pass + + def __extract_mlir_values__(self): + values = extract_mlir_values(self._params) + values.extend(extract_mlir_values(self._current_work_linear_idx)) + values.extend(extract_mlir_values(self._blk_coord)) + values.extend(extract_mlir_values(self._grid_shape)) + return values + + def __new_from_mlir_values__(self, values): + assert len(values) == 10 + new_params = new_from_mlir_values(self._params, values[0:3]) + new_current_work_linear_idx = new_from_mlir_values( + self._current_work_linear_idx, [values[3]] + ) + new_blk_coord = new_from_mlir_values(self._blk_coord, values[4:7]) + new_grid_shape = new_from_mlir_values(self._grid_shape, values[7:]) + return Sm100FmhaStaticTileScheduler( + new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape + ) + + +def compute_sm100_fmha_grid( + o_shape: cute.Shape, + cta_tiler: Tuple[int, int, int], + is_persistent: bool, +) -> Tuple[Sm100FmhaStaticTileSchedulerParams, Tuple[int, int, int]]: + """Compute grid parameters for FMHA (static scheduler). + + The output tensor o has shape (s, d, ((h_r, h_k), b)). + """ + tile_sched_params = Sm100FmhaStaticTileSchedulerParams( + is_persistent, + ( + cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), + cute.size(o_shape[2][0]), + cute.size(o_shape[2][1]), + ), + ) + grid = Sm100FmhaStaticTileScheduler.get_grid_shape(tile_sched_params) + return tile_sched_params, grid + + +############################################################################## +# Fmha CLC dynamic tile scheduler +############################################################################## + + +class Sm100FmhaClcDynamicTileSchedulerParams: + """Parameters for FMHA CLC dynamic persistent tile scheduler. + + This class manages the layout of tiles for CLC (Cluster Launch Control) + based dynamic scheduling, adapted for FMHA's (M, B, H) problem shape. + + :ivar problem_shape_mbh: Problem shape in (M, B, H) format. + :type problem_shape_mbh: cute.Shape + :ivar cluster_shape_mnk: Cluster shape in (M, N, K) format. + :type cluster_shape_mnk: cute.Shape + """ + + def __init__( + self, + problem_shape_mbh: cute.Shape, + cluster_shape_mnk: cute.Shape, + *, + loc=None, + ip=None, + ): + self.problem_shape_mbh = problem_shape_mbh + self._cluster_shape_mnk = cluster_shape_mnk + self.cluster_shape_mn = cluster_shape_mnk[:2] + self._loc = loc + self._ip = ip + + # FMHA uses linear indexing over (M, B, H), convert to (M, N, L) style + # For FMHA: M dim is tile count along sequence, N=1, L=(B*H) + self.problem_shape_ntile_mnl = ( + problem_shape_mbh[0], # M tiles + 1, # N tiles (always 1 for FMHA) + problem_shape_mbh[1] * problem_shape_mbh[2], # L = B * H + ) + + # Create layout for cluster-to-tile mapping + self.problem_layout_ncluster_mnl = cute.make_layout( + cute.ceil_div(self.problem_shape_ntile_mnl, cluster_shape_mnk[:2]), + loc=loc, + ip=ip, + ) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [ + self.problem_shape_mbh, + self._cluster_shape_mnk, + ]: + obj_values = extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + values_copy = list(values) + for obj, n_items in zip( + [self.problem_shape_mbh, self._cluster_shape_mnk], + self._values_pos, + ): + obj_list.append(new_from_mlir_values(obj, values_copy[:n_items])) + values_copy = values_copy[n_items:] + return Sm100FmhaClcDynamicTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + + def get_grid_shape(self, *, loc=None, ip=None) -> Tuple[int, int, int]: + """Compute grid shape aligned with cluster shape.""" + return cute.round_up(self.problem_shape_ntile_mnl, self._cluster_shape_mnk) + + def clc_hw_params(self) -> ClcDynamicPersistentTileSchedulerParams: + """Return params for the upstream CLC hardware scheduler.""" + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=self.problem_shape_ntile_mnl, + cluster_shape_mnk=self._cluster_shape_mnk, + ) + + +class Sm100FmhaClcDynamicTileScheduler: + """CLC dynamic persistent tile scheduler for FMHA. + + This scheduler uses Blackwell's Cluster Launch Control hardware mechanism + for dynamic tile distribution, providing automatic load balancing. + Adapted for FMHA's (M, B, H) problem shape. + """ + + def __init__( + self, + params: Sm100FmhaClcDynamicTileSchedulerParams, + cta_id_in_cluster: cute.Coord, + num_tiles_executed: Int32, + clc_response_ptr: cute.Pointer, + block_idx: Tuple, + clc: ClcState = None, + *, + loc=None, + ip=None, + ): + self.params = params + self.cta_id_in_cluster = cta_id_in_cluster + self._num_tiles_executed = num_tiles_executed + self._clc_response_ptr = clc_response_ptr + self._block_idx = block_idx + self.clc = clc + self._loc = loc + self._ip = ip + + def __extract_mlir_values__(self): + values = extract_mlir_values(self.cta_id_in_cluster) + values.extend(extract_mlir_values(self._num_tiles_executed)) + values.extend(extract_mlir_values(self._clc_response_ptr)) + values.extend(extract_mlir_values(self._block_idx)) + if self.clc is not None: + values.extend(extract_mlir_values(self.clc)) + return values + + def __new_from_mlir_values__(self, values): + new_cta_id_in_cluster = new_from_mlir_values(self.cta_id_in_cluster, values[0:3]) + new_num_tiles_executed = new_from_mlir_values(self._num_tiles_executed, [values[3]]) + new_clc_response_ptr = new_from_mlir_values(self._clc_response_ptr, [values[4]]) + new_block_idx = new_from_mlir_values(self._block_idx, values[5:8]) + new_clc = None + if self.clc is not None: + new_clc = new_from_mlir_values(self.clc, values[8:]) + return Sm100FmhaClcDynamicTileScheduler( + self.params, + new_cta_id_in_cluster, + new_num_tiles_executed, + new_clc_response_ptr, + new_block_idx, + new_clc, + ) + + @staticmethod + def create( + params: Sm100FmhaClcDynamicTileSchedulerParams, + block_idx: Tuple, + grid_dim: Tuple, + clc_response_ptr: cute.Pointer, + clc: ClcState = None, + *, + loc=None, + ip=None, + ): + """Create a CLC dynamic tile scheduler instance.""" + bidx, bidy, bidz = block_idx + + # CTA id in cluster + cta_id_in_cluster = ( + Int32(bidx % params.cluster_shape_mn[0]), + Int32(bidy % params.cluster_shape_mn[1]), + Int32(0), + ) + + num_tiles_executed = Int32(0) + + return Sm100FmhaClcDynamicTileScheduler( + params, + cta_id_in_cluster, + num_tiles_executed, + clc_response_ptr, + block_idx, + clc, + ) + + @staticmethod + def get_grid_shape( + params: Sm100FmhaClcDynamicTileSchedulerParams, + *, + loc=None, + ip=None, + ) -> Tuple[int, int, int]: + """Get grid shape for kernel launch.""" + return params.get_grid_shape(loc=loc, ip=ip) + + def work_tile_info_from_clc_response(self, result_addr: cute.Pointer, *, loc=None, ip=None): + """Parse CLC response and convert to FMHA tile coordinates.""" + m_idx, n_idx, l_idx, vld = cute.arch.clc_response(result_addr, loc=loc, ip=ip) + cute.arch.fence_proxy("async.shared", space="cta") + + # CLC returns first CTA coordinates: m_idx=x, l_idx=z + # l_idx is the L (batch) dimension; decode to (bid, hid) + hid = l_idx % self.params.problem_shape_mbh[2] + bid = l_idx // self.params.problem_shape_mbh[2] + + cta_idx_in_cluster, cta_idy_in_cluster, _ = self.cta_id_in_cluster + cur_tile_coord = ( + m_idx + cta_idx_in_cluster, # M dimension + 0, # N always 0 for FMHA + (bid, hid), # (B, H) packed + ) + + return cutlass.utils.WorkTileInfo(cur_tile_coord, vld) + + def get_current_work(self, *, loc=None, ip=None): + """Get current work tile from CLC response.""" + return self.work_tile_info_from_clc_response(self._clc_response_ptr, loc=loc, ip=ip) + + def initial_work_tile_info(self, *, loc=None, ip=None): + """Get initial work tile based on block index.""" + bidx, bidy, bidz = self._block_idx + # bidz is the L (batch) dimension; decode to (bid, hid) + hid = bidz % self.params.problem_shape_mbh[2] + bid = bidz // self.params.problem_shape_mbh[2] + return cutlass.utils.WorkTileInfo((bidx, 0, (bid, hid)), True) + + def advance_to_next_work(self, *, loc=None, ip=None): + """Consumer-side advance: wait for next tile, read coordinates, release.""" + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work(loc=loc, ip=ip) + self.clc.consumer_release(loc=loc, ip=ip) + self._num_tiles_executed += Int32(1) + return work + + def prefetch_next_work(self, *, loc=None, ip=None): + """Producer-side: issue CLC query for next tile.""" + self.clc.prefetch_next_work(loc=loc, ip=ip) + + def producer_tail(self, *, loc=None, ip=None): + """Producer-side cleanup after last tile.""" + self.clc.producer_tail(loc=loc, ip=ip) + + @property + def num_tiles_executed(self) -> Int32: + return self._num_tiles_executed + + +def compute_sm100_fmha_grid_clc( + o_shape: cute.Shape, + cta_tiler: Tuple[int, int, int], + cluster_shape_mnk: Tuple[int, int, int], +) -> Tuple[Sm100FmhaClcDynamicTileSchedulerParams, Tuple[int, int, int]]: + """Compute grid parameters for FMHA with CLC dynamic scheduling.""" + problem_shape_mbh = ( + cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), + cute.size(o_shape[2][0]), + cute.size(o_shape[2][1]), + ) + tile_sched_params = Sm100FmhaClcDynamicTileSchedulerParams(problem_shape_mbh, cluster_shape_mnk) + grid = Sm100FmhaClcDynamicTileScheduler.get_grid_shape(tile_sched_params) + return tile_sched_params, grid + + +############################################################################## +# Fused Mask +############################################################################## + + +def make_sm100_thread_cooperative_group(size: int): + return cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, size) + + +SM100_TMEM_CAPACITY_COLUMNS = 512 diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index a39230520e4..b9dc4f5c112 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -4,6 +4,7 @@ import hashlib import inspect import os +from functools import partial from typing import Type, Callable, Optional, Tuple, overload import cutlass @@ -213,6 +214,34 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te ) +def convert_from_dlpack_compact_dynamic( + x, + *, + dynamic_modes: tuple[int, ...], + alignment: int = 16, + stride_order=None, + divisibility: int = 1, + enable_tvm_ffi: bool = False, +) -> cute.Tensor: + """Convert via DLPack and mark selected compact dimensions as dynamic.""" + if isinstance(dynamic_modes, int): + dynamic_modes = (dynamic_modes,) + if stride_order is None: + stride_order = x.dim_order() + t = ( + from_dlpack(x, assumed_align=alignment, enable_tvm_ffi=True) + if enable_tvm_ffi + else from_dlpack(x, assumed_align=alignment) + ) + for mode in dynamic_modes: + t = t.mark_compact_shape_dynamic( + mode=mode, + stride_order=stride_order, + divisibility=divisibility, + ) + return t + + def convert_from_dlpack_leading_static( x, leading_dim, alignment=16, static_modes=None, stride_order=None ) -> cute.Tensor: @@ -820,6 +849,92 @@ def domain_offset_aligned( return cute.make_tensor(new_ptr, tensor.layout) +@dsl_user_op +def warp_reduction( + val: cute.Numeric, op: Callable, *, threads_in_group: int = 32, loc=None, ip=None +) -> cute.Numeric: + """Warp-wide reduction helper for a custom binary op.""" + offset = threads_in_group // 2 + while offset > 0: + val = op( + val, + cute.arch.shuffle_sync_bfly( + val, offset=offset, mask=-1, mask_and_clamp=31, loc=loc, ip=ip + ), + ) + offset //= 2 + return val + + +warp_reduction_max = partial( + warp_reduction, op=lambda x, y: fmax(x, y) if isinstance(x, Float32) else max(x, y) +) +warp_reduction_sum = partial(warp_reduction, op=lambda x, y: x + y) # noqa: FURB118 + + +@dsl_user_op +def make_cotiled_copy( + atom: cute.CopyAtom, atom_layout_tv: cute.Layout, data_layout: cute.Layout, *, loc=None, ip=None +) -> cute.TiledCopy: + """Compatibility wrapper for deprecated CuTeDSL `make_cotiled_copy`.""" + assert cute.is_static(atom_layout_tv.type), "atom_layout_tv must be static" + assert cute.is_static(data_layout.type), "data_layout must be static" + + inv_layout_ = cute.left_inverse(data_layout, loc=loc, ip=ip) + inv_data_layout = cute.make_layout( + (inv_layout_.shape, (1)), stride=(inv_layout_.stride, (0)), loc=loc, ip=ip + ) + layout_tv_data = cute.composition(inv_data_layout, atom_layout_tv, loc=loc, ip=ip) + + atom_layout_v_to_check = cute.coalesce( + cute.make_layout(atom_layout_tv.shape[1], stride=atom_layout_tv.stride[1], loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + data_layout_v_to_check = cute.coalesce( + cute.composition( + data_layout, + cute.make_layout( + layout_tv_data.shape[1], stride=layout_tv_data.stride[1], loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + assert data_layout_v_to_check == atom_layout_v_to_check, ( + "the memory pointed to by atom_layout_tv does not exist in the data_layout." + ) + + flat_data_shape = cute.product_each(data_layout.shape, loc=loc, ip=ip) + tiler = tuple( + cute.filter( + cute.composition( + cute.make_layout( + flat_data_shape, + stride=tuple(0 if j != i else 1 for j in range(cute.rank(flat_data_shape))), + loc=loc, + ip=ip, + ), + layout_tv_data, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + for i in range(cute.rank(flat_data_shape)) + ) + tile2data = cute.composition( + cute.make_layout(flat_data_shape, loc=loc, ip=ip), tiler, loc=loc, ip=ip + ) + layout_tv = cute.composition( + cute.left_inverse(tile2data, loc=loc, ip=ip), layout_tv_data, loc=loc, ip=ip + ) + return cute.make_tiled_copy(atom, layout_tv, tiler, loc=loc, ip=ip) + + @cute.jit def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index e2fa31bdf88..b260efd8648 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao, Siyu Wang, Shengbin Di, Yuxi Chi, Johnsonms, Linfeng Zheng, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, Wei Lin. import math import itertools @@ -95,9 +95,11 @@ def wrapper(*args, **kwargs): (3, 3), (64, 32), (64, 128), + (64, 1), # SM100 hd256 2CTA test case (128, 128), (128, 192), (256, 256), + (255, 256), # SM100 hd256 2CTA test case (239, 1), (799, 3), (113, 203), @@ -142,6 +144,19 @@ def test_flash_attn_output( pytest.xfail("has_qv: local not supported yet") if has_qv and has_learnable_sink: pytest.xfail("has_qv: learnable sink not supported yet") + # TODO(wangsiyu): SM100/SM110 head_dim=256 2CTA kernel currently does not support the following features. + # Remove these skips when support is added. + if d == 256 and IS_SM100: + if has_learnable_sink: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support learnable_sink yet") + if local: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support local attention yet") + if softcap > 0.0: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support softcap yet") + if deterministic: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support deterministic mode yet") + if causal and seqlen_q > seqlen_k: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support causal attention with seqlen_q > seqlen_k yet") device = "cuda" # set seed seed = 0 @@ -288,8 +303,16 @@ def test_flash_attn_output( # SplitKV not supported on SM90 - skip this iteration if IS_SM90 and num_splits > 1: continue - if IS_SM100 and (d >= 192 and dv >= 192): # hdim 192 and 256 not support on SM100 + if IS_SM100 and (d >= 192 and dv >= 192) and not (d == 256 and dv == 256): continue + # TODO(wangsiyu): SM100 head_dim=256 2CTA kernel does not support pack_gqa yet. + # pack_gqa=None means auto-enable for GQA/MQA (qhead_per_kvhead > 1) + # Remove this when support is added. + if d == 256 and IS_SM100: + if pack_gqa is True: + continue + if pack_gqa is None and mha_type != "mha": + continue out, lse = flash_attn_func( q, k, @@ -327,7 +350,11 @@ def test_flash_attn_output( and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 - and ((dv == d and d <= 128) or (d == 192 and dv == 128)) + and ( + (dv == d and d <= 128) + or (d == 192 and dv == 128) + or (IS_SM100 and d == 256 and dv == 256) + ) and learnable_sink is None # and False and not ((causal or local) and seqlen_k < seqlen_q) @@ -438,7 +465,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) # @pytest.mark.parametrize("d", [128, 192]) -@pytest.mark.parametrize("d", [64, 128, 192]) +@pytest.mark.parametrize("d", [64, 128, 192, 256]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -462,6 +489,11 @@ def test_flash_attn_output( (1023, 1024), (1024, 1023), (2048, 2048), + # SM100 hd256 2CTA test cases + (64, 1), + (255, 256), + (4096, 4096), + (4224, 4224), ], ) @pytest.mark.parametrize("varlen_mode", ["random", "third", "full"]) @@ -508,6 +540,23 @@ def test_flash_attn_varlen_output( local = local_enum > 0 if local and causal: pytest.skip() + # TODO(wangsiyu): SM100/SM110 head_dim=256 2CTA kernel currently does not support the following features. + # Remove these skips when support is added. + if d == 256 and IS_SM100: + if has_learnable_sink: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support learnable_sink yet") + if local: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support local attention yet") + if softcap > 0.0: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support softcap yet") + if deterministic: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support deterministic mode yet") + if causal and seqlen_q > seqlen_k: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support causal attention with seqlen_q > seqlen_k yet") + if zero_lengths_q or zero_lengths_k: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support zero-length sequences yet") + if not unpad_q or not unpad_kv: + pytest.skip("SM100 head_dim=256 2CTA kernel does not support seqused_q/seqused_k mode yet (requires unpad_q=True and unpad_kv=True)") if ( causal or local ): # Right now reference only supports causal attention with seqlen_k == seqlen_q @@ -524,6 +573,8 @@ def test_flash_attn_varlen_output( dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + if d == 256: + dv_vals = [256] # SM100 hd=256 2CTA kernel only supports dv=256 if dtype == torch.float8_e4m3fn: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] @@ -721,6 +772,14 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # SplitKV not supported on SM90 - skip this iteration if IS_SM90 and num_splits > 1: continue + # TODO(wangsiyu): SM100 head_dim=256 2CTA kernel does not support pack_gqa yet. + # pack_gqa=None means auto-enable for GQA/MQA (qhead_per_kvhead > 1) + # Remove this when support is added. + if d == 256 and IS_SM100: + if pack_gqa is True: + continue + if pack_gqa is None and mha_type != "mha": + continue out_unpad, lse = flash_attn_varlen_func( q_unpad if unpad_q else q, k_unpad if unpad_kv else k, @@ -777,7 +836,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): and not has_qv and not dv > 256 and not attention_chunk != 0 - and ((dv == d and d <= 128) or (d == 192 and dv == 128)) + and ( + (dv == d and d <= 128) + or (d == 192 and dv == 128) + or (IS_SM100 and d == 256 and dv == 256) + ) and not has_learnable_sink # and False ): diff --git a/tests/cute/test_flash_attn_varlen.py b/tests/cute/test_flash_attn_varlen.py index 24e55315671..ad6dcd6da43 100644 --- a/tests/cute/test_flash_attn_varlen.py +++ b/tests/cute/test_flash_attn_varlen.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("B", [1, 7, 20]) @pytest.mark.parametrize("H", [1, 4, 6]) -@pytest.mark.parametrize("D", [64, 128]) +@pytest.mark.parametrize("D", [64, 128, 256]) @pytest.mark.parametrize("min_seq_len", [1, 32, 128]) @pytest.mark.parametrize("max_seq_len", [8, 64, 2048]) @pytest.mark.parametrize("causal", [True, False]) From b21e2049e510ce8a97c82a6c97819a1f14b46d43 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Thu, 23 Apr 2026 14:24:22 -0700 Subject: [PATCH 24/44] [Cute,hd256] Post-merge cleanup: dead code, duplicate imports (#2487) Follow-up polish on the freshly-merged hd256 feature (#2412), sourced from Copilot AI review comments on the original PR. interface.py: drop duplicate `from cutlass import Int32` (already imported at line 17) and unused `from flash_attn.cute.mask import Sm100MaskEnum as MaskEnum`, which is never referenced. mask.py: remove two dead `tidx, tidy, tidx = cute.arch.thread_idx()` lines in Sm100FusedMask.apply_mask and apply_mask_via_causal_local. Neither `tidx` nor `tidy` is ever read in the function bodies; these calls are leftover debug scaffolding (consistent with the commented-out `cute.printf("tidx = ...")` lines nearby at 490/525/665). test_flash_attn.py: drop the stray "/SM110" from two TODO comments. The skip guard is `IS_SM100` only (capability major == 10), and the hd256 2CTA kernel path is only taken when `arch // 10 == 10` (interface.py:573, 1310), never on SM110 (major == 11). --- flash_attn/cute/interface.py | 3 --- flash_attn/cute/mask.py | 3 --- tests/cute/test_flash_attn.py | 4 ++-- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index afdda11f351..30a42bfc475 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -45,12 +45,9 @@ from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine from flash_attn.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 -from cutlass import Int32 - # SM100 head_dim=256 2CTA kernel imports from flash_attn.cute.sm100_hd256_2cta_fmha_forward import BlackwellFusedMultiHeadAttentionForward from flash_attn.cute.sm100_hd256_2cta_fmha_backward import BlackwellFusedMultiHeadAttentionBackward -from flash_attn.cute.mask import Sm100MaskEnum as MaskEnum from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index a1971366183..e9810f51844 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -1376,8 +1376,6 @@ def apply_mask( :param window_size_right: Right-side sliding window size for attention masking. :type window_size_right: Optional[int] """ - - tidx, tidy, tidx = cute.arch.thread_idx() offset = 0 # NOTE: causal masking in this repo aligns the *end* of Q with the *end* of K # when seqlen_k != seqlen_q (same as the test/reference implementation): @@ -1439,7 +1437,6 @@ def apply_mask_via_causal_local( - If apply_semantic_window=True, apply causal/local window constraints. - Always apply residual OOB masking (index_k>=seqlen_k or index_q>=seqlen_q). """ - tidx, tidy, tidx = cute.arch.thread_idx() offset = 0 if cutlass.const_expr(apply_semantic_window): # Match WINDOW_MASK_INFERENCE semantics: end-align Q/K when lengths differ. diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index b260efd8648..b069d8b7580 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -144,7 +144,7 @@ def test_flash_attn_output( pytest.xfail("has_qv: local not supported yet") if has_qv and has_learnable_sink: pytest.xfail("has_qv: learnable sink not supported yet") - # TODO(wangsiyu): SM100/SM110 head_dim=256 2CTA kernel currently does not support the following features. + # TODO(wangsiyu): SM100 head_dim=256 2CTA kernel currently does not support the following features. # Remove these skips when support is added. if d == 256 and IS_SM100: if has_learnable_sink: @@ -540,7 +540,7 @@ def test_flash_attn_varlen_output( local = local_enum > 0 if local and causal: pytest.skip() - # TODO(wangsiyu): SM100/SM110 head_dim=256 2CTA kernel currently does not support the following features. + # TODO(wangsiyu): SM100 head_dim=256 2CTA kernel currently does not support the following features. # Remove these skips when support is added. if d == 256 and IS_SM100: if has_learnable_sink: From ac6f2eb5413748c68192aa384a40a38d60ad6abd Mon Sep 17 00:00:00 2001 From: Reuben Stern <107093092+reubenconducts@users.noreply.github.com> Date: Thu, 23 Apr 2026 19:02:40 -0400 Subject: [PATCH 25/44] [CuTe,Flex] Wire up interface for flex autograd support (#2485) --- flash_attn/cute/interface.py | 45 +++++++++++++-- tests/cute/test_mask_mod.py | 105 +++++++++++++++++++++++------------ tests/cute/test_score_mod.py | 76 ++++++++++++++++--------- 3 files changed, 161 insertions(+), 65 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 30a42bfc475..a7d504c4778 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1870,7 +1870,10 @@ def forward( num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, + score_mod: Optional[Callable] = None, + score_mod_bwd: Optional[Callable] = None, mask_mod: Optional[Callable] = None, + aux_tensors: Optional[list] = None, full_block_cnt: Optional[torch.Tensor] = None, full_block_idx: Optional[torch.Tensor] = None, mask_block_cnt: Optional[torch.Tensor] = None, @@ -1901,24 +1904,30 @@ def forward( softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, + score_mod=score_mod, mask_mod=mask_mod, + aux_tensors=aux_tensors, block_sparse_tensors=block_sparse_tensors, return_lse=return_lse, gather_kv_indices=gather_kv_indices, ) - ctx.save_for_backward(q, k, v, out, lse) + ctx.save_for_backward(q, k, v, out, lse, *(aux_tensors or ())) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.deterministic = deterministic ctx.return_lse = return_lse + ctx.score_mod = score_mod + ctx.score_mod_bwd = score_mod_bwd + ctx.mask_mod = mask_mod ctx.set_materialize_grads(False) return out, lse @staticmethod def backward(ctx, dout, dlse): - q, k, v, out, lse = ctx.saved_tensors + q, k, v, out, lse, *aux = ctx.saved_tensors + aux_tensors = aux if aux else None if not ctx.return_lse: dlse = None if dout is None: @@ -1936,6 +1945,10 @@ def backward(ctx, dout, dlse): window_size_left=ctx.window_size[0], window_size_right=ctx.window_size[1], deterministic=ctx.deterministic, + score_mod=ctx.score_mod, + score_mod_bwd=ctx.score_mod_bwd, + mask_mod=ctx.mask_mod, + aux_tensors=aux_tensors, dlse=dlse, ) return dq, dk, dv, *((None,) * 30) # Extra Nones is fine @@ -1967,6 +1980,7 @@ def forward( pack_gqa: Optional[bool] = None, deterministic: bool = False, score_mod: Optional[Callable] = None, + score_mod_bwd: Optional[Callable] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, ): @@ -1996,7 +2010,18 @@ def forward( return_lse=return_lse, gather_kv_indices=gather_kv_indices, ) - ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + ctx.save_for_backward( + q, + k, + v, + out, + lse, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + *(aux_tensors or ()), + ) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size @@ -2005,12 +2030,15 @@ def forward( ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.return_lse = return_lse + ctx.score_mod = score_mod + ctx.score_mod_bwd = score_mod_bwd ctx.set_materialize_grads(False) return out, lse @staticmethod def backward(ctx, dout, dlse): - q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors + q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, *aux = ctx.saved_tensors + aux_tensors = aux if aux else None if not ctx.return_lse: dlse = None if dout is None: @@ -2034,6 +2062,9 @@ def backward(ctx, dout, dlse): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_k=ctx.max_seqlen_k, deterministic=ctx.deterministic, + score_mod=ctx.score_mod, + score_mod_bwd=ctx.score_mod_bwd, + aux_tensors=aux_tensors, dlse=dlse, ) @@ -2054,7 +2085,10 @@ def flash_attn_func( num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, + score_mod: Optional[Callable] = None, + score_mod_bwd: Optional[Callable] = None, mask_mod: Optional[Callable] = None, + aux_tensors: Optional[list] = None, full_block_cnt: Optional[torch.Tensor] = None, full_block_idx: Optional[torch.Tensor] = None, mask_block_cnt: Optional[torch.Tensor] = None, @@ -2076,7 +2110,10 @@ def flash_attn_func( num_splits, pack_gqa, deterministic, + score_mod, + score_mod_bwd, mask_mod, + aux_tensors, full_block_cnt, full_block_idx, mask_block_cnt, diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 26e0a5e1353..c84d3ad66a1 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -22,7 +22,7 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention import torch.nn.functional as F -from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd +from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd, flash_attn_func from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, fast_sampling, @@ -253,6 +253,7 @@ def _run_mask_test( tile_n, use_block_sparsity, needs_backward=False, + use_autograd=False, ): torch.manual_seed(42) @@ -402,35 +403,62 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): else None ) - out_tuple = _flash_attn_fwd( - q=tensors["q"], - k=tensors["k"], - v=tensors["v"], - out=tensors["out"], - lse=tensors["lse"], - cu_seqlens_q=None, - cu_seqlens_k=None, - seqused_q=None, - seqused_k=None, - page_table=None, - softmax_scale=softmax_scale, - causal=causal, - softcap=None, - window_size_left=window_left, - window_size_right=window_right, - learnable_sink=None, - tile_mn=(tile_m, tile_n), - pack_gqa=pack_gqa, - _arch=None, - score_mod=None, - mask_mod=mask_mod_cute, - block_sparse_tensors=block_sparse_mask_fwd, - return_lse=True, - aux_tensors=aux_tensors_arg, - ) + if use_autograd: + q_ag = tensors["q"].detach().requires_grad_(True) + k_ag = tensors["k"].detach().requires_grad_(True) + v_ag = tensors["v"].detach().requires_grad_(True) + bs_kwargs = {} + if block_sparse_mask_fwd is not None: + bs_kwargs = dict( + mask_block_cnt=block_sparse_mask_fwd.mask_block_cnt, + mask_block_idx=block_sparse_mask_fwd.mask_block_idx, + full_block_cnt=block_sparse_mask_fwd.full_block_cnt, + full_block_idx=block_sparse_mask_fwd.full_block_idx, + block_size=block_sparse_mask_fwd.block_size, + ) + out_cute, lse_cute = flash_attn_func( + q_ag, + k_ag, + v_ag, + softmax_scale=softmax_scale, + causal=causal, + window_size=(window_left, window_right), + pack_gqa=pack_gqa, + mask_mod=mask_mod_cute, + aux_tensors=aux_tensors_arg, + return_lse=True, + **bs_kwargs, + ) + else: + out_tuple = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"], + lse=tensors["lse"], + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=softmax_scale, + causal=causal, + softcap=None, + window_size_left=window_left, + window_size_right=window_right, + learnable_sink=None, + tile_mn=(tile_m, tile_n), + pack_gqa=pack_gqa, + _arch=None, + score_mod=None, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, + aux_tensors=aux_tensors_arg, + ) - out_cute = out_tuple[0] - lse_cute = out_tuple[1] + out_cute = out_tuple[0] + lse_cute = out_tuple[1] tensors_fp32 = { k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v for k, v in tensors.items() @@ -488,11 +516,16 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): device="cuda", BLOCK_SIZE=(tile_m, tile_n), ) - dq_cute, dk_cute, dv_cute = run_cute_mask_bwd( - q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute, - block_sparse_mask_bwd=block_sparse_mask_bwd, tile_m=tile_m, tile_n=tile_n, - aux_tensors=aux_tensors_arg, - ) + if use_autograd: + dq_cute, dk_cute, dv_cute = torch.autograd.grad( + out_cute, (q_ag, k_ag, v_ag), grad_out + ) + else: + dq_cute, dk_cute, dv_cute = run_cute_mask_bwd( + q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute, + block_sparse_mask_bwd=block_sparse_mask_bwd, tile_m=tile_m, tile_n=tile_n, + aux_tensors=aux_tensors_arg, + ) _, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, flex_block_mask, grad_out, dtype=torch.float32 ) @@ -760,8 +793,9 @@ def test_static_masks( ], ) @pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112), (64, 128)]) +@pytest.mark.parametrize("use_autograd", [True, False]) def test_parameterized_masks( - seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, window_size, tile_m, tile_n + seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, window_size, tile_m, tile_n, use_autograd, ): """Test parameterized masks that require recompilation per seqlen pair. @@ -790,6 +824,7 @@ def test_parameterized_masks( tile_n=tile_n, use_block_sparsity=use_block_sparsity, needs_backward=True, + use_autograd=use_autograd, ) diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index 43bf62e7d54..95a05a1d60b 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -5,7 +5,12 @@ from cutlass._mlir.dialects import math as mlir_math import operator from torch.nn.attention.flex_attention import create_block_mask, flex_attention -from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd, _tile_size_bwd_sm90 +from flash_attn.cute.interface import ( + flash_attn_func, + _flash_attn_fwd, + _flash_attn_bwd, + _tile_size_bwd_sm90, +) from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] @@ -753,37 +758,55 @@ def score_squared_eager(score, b, h, q_idx, kv_idx): def run_cute_flash_bwd( - q, k, v, cute_score_mod, cute_score_mod_bwd, aux_tensors=None, pack_gqa=False + q, k, v, cute_score_mod, cute_score_mod_bwd, aux_tensors=None, pack_gqa=False, use_autograd=True, ): """Run flash attention forward + backward with score_mod.""" q_t = q.transpose(1, 2) k_t = k.transpose(1, 2) v_t = v.transpose(1, 2) - out, lse = _flash_attn_fwd( - q_t, - k_t, - v_t, - return_lse=True, - score_mod=cute_score_mod, - aux_tensors=aux_tensors, - pack_gqa=pack_gqa, - ) + if use_autograd: + q_t = q_t.detach().requires_grad_(True) + k_t = k_t.detach().requires_grad_(True) + v_t = v_t.detach().requires_grad_(True) + out, lse = flash_attn_func( + q_t, + k_t, + v_t, + score_mod=cute_score_mod, + score_mod_bwd=cute_score_mod_bwd, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) - grad_out = torch.randn_like(out) + grad_out = torch.randn_like(out) - dq, dk, dv = _flash_attn_bwd( - q_t, - k_t, - v_t, - out, - grad_out, - lse, - score_mod=cute_score_mod, - score_mod_bwd=cute_score_mod_bwd, - aux_tensors=aux_tensors, - pack_gqa=pack_gqa, - ) + dq, dk, dv = torch.autograd.grad(out, (q_t, k_t, v_t), grad_out) + else: + out, lse = _flash_attn_fwd( + q_t, + k_t, + v_t, + return_lse=True, + score_mod=cute_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + + grad_out = torch.randn_like(out) + + dq, dk, dv = _flash_attn_bwd( + q_t, + k_t, + v_t, + out, + grad_out, + lse, + score_mod=cute_score_mod, + score_mod_bwd=cute_score_mod_bwd, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) return ( out.transpose(1, 2), @@ -982,7 +1005,8 @@ def run_flex_block_sparse_score_mod_ref(q_ref, k_ref, v_ref, grad_out_ref, ref_d @pytest.mark.parametrize("dim", [64, 128]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS) -def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_mod_triple): +@pytest.mark.parametrize("use_autograd", [True, False]) +def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_mod_triple, use_autograd): """Test backward pass with score_mod against flex_attention reference.""" if COMPUTE_CAPABILITY == 9 and dim == 64: pytest.skip("head_dim=64 not supported on SM90 for backward") @@ -994,7 +1018,7 @@ def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_ seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=4, dim=dim, dtype=dtype ) - out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd(q, k, v, cute_fwd, cute_bwd) + out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd(q, k, v, cute_fwd, cute_bwd, use_autograd=use_autograd) out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) From c9a560f44002da92d82680036e3482d0ac4939a3 Mon Sep 17 00:00:00 2001 From: Reuben Stern <107093092+reubenconducts@users.noreply.github.com> Date: Sun, 26 Apr 2026 13:01:50 -0400 Subject: [PATCH 26/44] add missing score_mod_bwd param to FlashAttnVarlenFunc (#2496) --- flash_attn/cute/interface.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index a7d504c4778..779987bf24e 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -2146,6 +2146,7 @@ def flash_attn_varlen_func( pack_gqa: Optional[bool] = None, deterministic: bool = False, score_mod: Optional[Callable] = None, + score_mod_bwd: Optional[Callable] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, ): @@ -2159,7 +2160,7 @@ def flash_attn_varlen_func( gather_kv_indices: a tensor of shape (batch, seqlen_q, gather_kv_length) or (total_q, gather_kv_length) if there is cu_seqlens_q. Currently, only used for topk sparsity with MLA absorption kernel. - + min_seqlen_k: for varlen, specifies the minimum kv sequence length for any batch. Used with gather_kv_indices to determine if we need oob masking. """ @@ -2186,6 +2187,7 @@ def flash_attn_varlen_func( pack_gqa, deterministic, score_mod, + score_mod_bwd, aux_tensors, return_lse, ) From eefa86de0e674cfd5843cf4af4a6169edbdbc788 Mon Sep 17 00:00:00 2001 From: baycore Date: Mon, 27 Apr 2026 22:52:05 +0800 Subject: [PATCH 27/44] fix: typos and missing comments in FA4 cute kernel files (#2502) - interface.py: remove extra space in is_deepseek_mla_absorbed_shape condition - softmax.py: fix comment typo "my" -> "may" in apply_score_mod_inner - mask.py: fix docstring typo "conventio" -> "convention" in backward mask - flash_bwd.py: clarify why hdim_multiple_of=32 differs from fwd's 16 Co-authored-by: watt --- flash_attn/cute/flash_bwd.py | 3 ++- flash_attn/cute/interface.py | 2 +- flash_attn/cute/mask.py | 2 +- flash_attn/cute/softmax.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index eeb7615b1d3..81c8ac68bd9 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -65,7 +65,8 @@ def __init__( :param is_causal: is causal """ self.dtype = dtype - # padding head_dim to a multiple of 16 as k_block_size + # padding head_dim to a multiple of 32 (stricter than fwd's 16) due to + # backward kernel register layout requirements for dQ/dK/dV accumulation hdim_multiple_of = 32 self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) head_dim_v = head_dim_v if head_dim_v is not None else head_dim diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 779987bf24e..c8b388c8d92 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -89,7 +89,7 @@ def _get_device_arch(): def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, alignment: int) -> None: """Validate head dimension constraints based on compute capability.""" is_deepseek_shape = head_dim == 192 and head_dim_v == 128 - is_deepseek_mla_absorbed_shape = head_dim == 64 and head_dim_v == 512 + is_deepseek_mla_absorbed_shape = head_dim == 64 and head_dim_v == 512 is_dedicate_kernel_shape = head_dim == 256 and head_dim_v == 256 is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128 diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index e9810f51844..9c171ba9865 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -563,7 +563,7 @@ def apply_mask_sm100_transposed( """ Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q. - Coordinate conventio: + Coordinate convention: - ROW corresponds to Q (m_block) - COL corresponds to KV (n_block) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 9369e0d49ca..cc9b9d401d4 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -434,7 +434,7 @@ def apply_score_mod_inner( q_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) # For Pack-GQA with non-constant q_idx, we need per-element head indices - # since a thread my process multiple query head indices + # since a thread may process multiple query head indices if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): head_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) From 547031aaac1fa613b63ec03a1babc1cf5fd1967d Mon Sep 17 00:00:00 2001 From: geruome <85235464+geruome@users.noreply.github.com> Date: Tue, 28 Apr 2026 00:49:00 +0800 Subject: [PATCH 28/44] [SM100] Guard gO None in empty-tile correction (#2504) --- flash_attn/cute/block_sparse_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index b19dcd37cf8..0f81f863673 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -729,6 +729,8 @@ def handle_block_sparse_empty_tile_correction_sm100( if const_expr(gmem_tiled_copy_O is None): pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase) + + gO_stage = gO[None, None, stage] if const_expr(gO is not None) else None correction_epilogue( thr_mma_pv, tOtO[None, None, None, stage], @@ -739,7 +741,7 @@ def handle_block_sparse_empty_tile_correction_sm100( Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs sO[None, None, stage], mO_cur, - gO[None, None, stage], + gO_stage, gmem_tiled_copy_O, ) if const_expr(gmem_tiled_copy_O is None): From 6b52632d8afb3a88323b2f2e6bde2015e08261cd Mon Sep 17 00:00:00 2001 From: Reuben Stern <107093092+reubenconducts@users.noreply.github.com> Date: Mon, 27 Apr 2026 17:13:16 -0400 Subject: [PATCH 29/44] [CuTe, Flex] simplify blocksparse interface in flash_attn_func (#2506) * simplify blocksparse tensors interface in flash_attn_func * remove blocksparse kwargs --- flash_attn/cute/interface.py | 35 +++++++++-------------------------- tests/cute/test_mask_mod.py | 13 +++---------- 2 files changed, 12 insertions(+), 36 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index c8b388c8d92..7adb11e1da5 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1874,23 +1874,10 @@ def forward( score_mod_bwd: Optional[Callable] = None, mask_mod: Optional[Callable] = None, aux_tensors: Optional[list] = None, - full_block_cnt: Optional[torch.Tensor] = None, - full_block_idx: Optional[torch.Tensor] = None, - mask_block_cnt: Optional[torch.Tensor] = None, - mask_block_idx: Optional[torch.Tensor] = None, - block_size: Optional[Tuple[int, int]] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, + block_sparse_tensors_bwd: Optional[BlockSparseTensorsTorch] = None, return_lse: bool = False, ): - # Only create block sparse tensors if at least one block sparse parameter is provided - block_sparse_tensors = None - if any(t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]): - block_sparse_tensors = BlockSparseTensorsTorch( - full_block_cnt=full_block_cnt, - full_block_idx=full_block_idx, - mask_block_cnt=mask_block_cnt, - mask_block_idx=mask_block_idx, - block_size=block_size, - ) out, lse = _flash_attn_fwd( q, k, @@ -1921,6 +1908,7 @@ def forward( ctx.score_mod = score_mod ctx.score_mod_bwd = score_mod_bwd ctx.mask_mod = mask_mod + ctx.block_sparse_tensors_bwd = block_sparse_tensors_bwd ctx.set_materialize_grads(False) return out, lse @@ -1949,6 +1937,7 @@ def backward(ctx, dout, dlse): score_mod_bwd=ctx.score_mod_bwd, mask_mod=ctx.mask_mod, aux_tensors=aux_tensors, + block_sparse_tensors=ctx.block_sparse_tensors_bwd, dlse=dlse, ) return dq, dk, dv, *((None,) * 30) # Extra Nones is fine @@ -2089,11 +2078,8 @@ def flash_attn_func( score_mod_bwd: Optional[Callable] = None, mask_mod: Optional[Callable] = None, aux_tensors: Optional[list] = None, - full_block_cnt: Optional[torch.Tensor] = None, - full_block_idx: Optional[torch.Tensor] = None, - mask_block_cnt: Optional[torch.Tensor] = None, - mask_block_idx: Optional[torch.Tensor] = None, - block_size: Optional[Tuple[int, int]] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, + block_sparse_tensors_bwd: Optional[BlockSparseTensorsTorch] = None, return_lse: bool = False, ): return FlashAttnFunc.apply( @@ -2114,11 +2100,8 @@ def flash_attn_func( score_mod_bwd, mask_mod, aux_tensors, - full_block_cnt, - full_block_idx, - mask_block_cnt, - mask_block_idx, - block_size, + block_sparse_tensors, + block_sparse_tensors_bwd, return_lse, ) @@ -2430,4 +2413,4 @@ def flash_attn_combine( seqused, varlen_batch_idx=varlen_batch_idx, ) - return out, lse \ No newline at end of file + return out, lse diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index c84d3ad66a1..484f5191725 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -407,15 +407,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): q_ag = tensors["q"].detach().requires_grad_(True) k_ag = tensors["k"].detach().requires_grad_(True) v_ag = tensors["v"].detach().requires_grad_(True) - bs_kwargs = {} - if block_sparse_mask_fwd is not None: - bs_kwargs = dict( - mask_block_cnt=block_sparse_mask_fwd.mask_block_cnt, - mask_block_idx=block_sparse_mask_fwd.mask_block_idx, - full_block_cnt=block_sparse_mask_fwd.full_block_cnt, - full_block_idx=block_sparse_mask_fwd.full_block_idx, - block_size=block_sparse_mask_fwd.block_size, - ) + out_cute, lse_cute = flash_attn_func( q_ag, k_ag, @@ -426,8 +418,9 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): pack_gqa=pack_gqa, mask_mod=mask_mod_cute, aux_tensors=aux_tensors_arg, + block_sparse_tensors=block_sparse_mask_fwd, + block_sparse_tensors_bwd=block_sparse_mask_bwd, return_lse=True, - **bs_kwargs, ) else: out_tuple = _flash_attn_fwd( From 519445aa2952715101d0fdbc23fadea04c91907e Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 27 Apr 2026 17:25:04 -0400 Subject: [PATCH 30/44] Fix (#2505) Signed-off-by: Matthew Bonanni --- flash_attn/cute/flash_fwd_mla_sm100.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash_attn/cute/flash_fwd_mla_sm100.py b/flash_attn/cute/flash_fwd_mla_sm100.py index 07cd99f71e9..2987b4c0460 100644 --- a/flash_attn/cute/flash_fwd_mla_sm100.py +++ b/flash_attn/cute/flash_fwd_mla_sm100.py @@ -718,6 +718,7 @@ def make_tma(make_fn, mX, smem_layout, mma_tiler, tiled_mma): ), cluster=self.cluster_shape_mnk, smem=SharedStorage.size_in_bytes(), + stream=stream, ) @cute.kernel From b86e0cc1dcf0caf9ff104d1cfe8c3aab4619e824 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Mon, 27 Apr 2026 14:28:39 -0700 Subject: [PATCH 31/44] Fix clc scheduling request bug (#2508) stack-info: PR: https://github.com/Dao-AILab/flash-attention/pull/2508, branch: drisspg/stack/35 --- flash_attn/cute/interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 7adb11e1da5..a11c8debe2b 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -715,7 +715,7 @@ def _flash_attn_fwd( q_subtile_factor, mma_pv_is_rs, intra_wg_overlap, - requested_use_clc_scheduler, + use_clc_scheduler, qv is not None, gather_kv_length, sparse_kv, @@ -889,7 +889,7 @@ def _flash_attn_fwd( is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, q_subtile_factor=q_subtile_factor, use_2cta_instrs=use_2cta_instrs, - use_clc_scheduler=requested_use_clc_scheduler, + use_clc_scheduler=use_clc_scheduler, ) elif arch // 10 == 12: # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity From 89ce84bc2054cde1b94e0733252ea92b14908903 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Mon, 27 Apr 2026 15:11:33 -0700 Subject: [PATCH 32/44] [Tests,MLA] Close coverage gaps in test_flash_attn_mla_absorbed{,_varlen} (#2483) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Expand the two MLA-absorbed tests with a few low-risk axes and unblock a stale nheads parametrize. Axes stay inside what the MLA-absorbed kernel supports per the guards in `interface.py:633-648` (the existing skips for softcap/learnable_sink/local/paged/fp8/split_kv/score_mod/ mask_mod under `qv` are untouched). Non-varlen `test_flash_attn_mla_absorbed`: - Comment out the hardcoded `nheads = 128` override at line 1827 so the `nheads [16, 128]` parametrize is actually respected. The override was unintentional and the `nheads=16` path is meant to check the kernel works under head sharding (it's already exercised that way in the varlen test). The existing `if kv_sparsity and nheads != 128: pytest.skip()` at line 1816 is left in place — that's a separate intentional guard. - `gather_kv_length: [2048] -> [1024, 2048]` — exercise a second value on the `kv_sparsity=True` path. - Add seqlen `(4096, 4096)` — the table had `(2048, 2048)` and `(1, 8192)` but nothing with both Q and K in the mid-prefill range. - Assert `lse` is NaN-free. LSE is consumed by backward and by split-KV output combine; NaNs there aren't caught by the existing `out` diff check and silently corrupt any downstream use. Varlen `test_flash_attn_mla_absorbed_varlen`: - `gather_kv_length: [2048] -> [1024, 2048]` (same rationale). - LSE NaN check, gated on `unpad_q`. The padded path can legitimately contain uninit tail beyond `seqused_q`, so the check only runs on the packed-unpad path where every LSE slot is live. Context on a narrower scope than the first draft: tiny seqlens ((1,1), (1,3), (2,1)), `zero_lengths_q/k` with True, and `add_unused_qkv=True` were tried but triggered order-of-magnitude output divergence on MLA varlen (FA4 returning zeros where the reference has real values; `Output max diff ~5.25` against a `Pytorch max diff` of ~0.03). The combinations that fire it are `add_unused_qkv=True` with short K (seqlen_k=1) or with `kv_sparsity=True`, and `zero_lengths_k=True` across most unpad/ varlen_mode combos. Those look like real kernel issues orthogonal to a coverage expansion, so they're left out of this PR and can be filed separately. --- tests/cute/test_flash_attn.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index b069d8b7580..4446c053b5b 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -1819,7 +1819,7 @@ def test_flash_attn_invalid_head_dim(head_dim): @pytest.mark.parametrize("nheads", [16, 128]) @pytest.mark.parametrize("kv_sparsity", [False, True]) # @pytest.mark.parametrize("kv_sparsity", [True]) -@pytest.mark.parametrize("gather_kv_length", [2048]) +@pytest.mark.parametrize("gather_kv_length", [1024, 2048]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -1846,6 +1846,7 @@ def test_flash_attn_invalid_head_dim(head_dim): (1024, 1023), (2048, 2048), (1, 8192), + (4096, 4096), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) @@ -1886,7 +1887,7 @@ def test_flash_attn_mla_absorbed( torch.cuda.synchronize() batch_size = 9 if seqlen_k <= 2048 else 2 # batch_size = 2 - nheads = 128 + # nheads = 128 nheads_kv = nheads if mha_type == "mha" else (8 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [512] @@ -2043,6 +2044,7 @@ def test_flash_attn_mla_absorbed( assert (out - out_ref).abs().max().item() <= rtol * ( out_pt - out_ref ).abs().max().item() + fwd_atol + assert not torch.isnan(lse).any(), "LSE contains NaN" repeats = 1000 for iter in range(repeats): @@ -2078,10 +2080,11 @@ def test_flash_attn_mla_absorbed( @pytest.mark.parametrize("local_enum", [0]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("add_unused_qkv", [False, True]) @pytest.mark.parametrize("add_unused_qkv", [False]) @pytest.mark.parametrize("kv_sparsity", [False, True]) # @pytest.mark.parametrize("kv_sparsity", [False]) -@pytest.mark.parametrize("gather_kv_length", [2048]) +@pytest.mark.parametrize("gather_kv_length", [1024, 2048]) @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("nheads", [16, 128]) # @pytest.mark.parametrize("nheads", [128]) @@ -2435,6 +2438,10 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): assert (out_cmp - out_ref_cmp).abs().max().item() <= rtol * ( out_pt_cmp - out_ref_cmp ).abs().max().item() + fwd_atol + # LSE sanity: only valid positions (packed unpad path; padded path + # can legitimately contain uninit tail beyond seqused_q). + if unpad_q: + assert not torch.isnan(lse).any(), "LSE contains NaN" repeats = 1000 for iter in range(repeats): From 96bd151b00add1dc5744115cfc350c0a809c3fa2 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Mon, 27 Apr 2026 19:48:32 -0700 Subject: [PATCH 33/44] Add cache utils logging test (#2509) stack-info: PR: https://github.com/Dao-AILab/flash-attention/pull/2509, branch: drisspg/stack/36 --- tests/cute/test_cache_utils.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/cute/test_cache_utils.py diff --git a/tests/cute/test_cache_utils.py b/tests/cute/test_cache_utils.py new file mode 100644 index 00000000000..4b1e690920b --- /dev/null +++ b/tests/cute/test_cache_utils.py @@ -0,0 +1,30 @@ +import logging +from types import SimpleNamespace + +import flash_attn.cute.cache_utils as cache_utils +from flash_attn.cute import fa_logging + + +def test_persistent_cache_hit_logs_at_host_level_only(tmp_path, monkeypatch, caplog): + caplog.set_level(logging.INFO, logger="flash_attn") + original_level = fa_logging.get_fa_log_level() + key = ("test-key",) + cache = cache_utils.JITPersistentCache(tmp_path) + obj_path = tmp_path / f"{cache._key_to_hash(key)}.o" + obj_path.write_bytes(b"cache-hit") + monkeypatch.setattr( + cache_utils.cute.runtime, + "load_module", + lambda *_args, **_kwargs: SimpleNamespace(func=object()), + ) + try: + monkeypatch.setattr(fa_logging, "_fa_log_level", 0) + assert cache_utils.JITPersistentCache(tmp_path)._try_load_from_storage(key) + assert "Loading compiled function from disk" not in caplog.text + + caplog.clear() + monkeypatch.setattr(fa_logging, "_fa_log_level", 1) + assert cache_utils.JITPersistentCache(tmp_path)._try_load_from_storage(key) + assert "Loading compiled function from disk" in caplog.text + finally: + monkeypatch.setattr(fa_logging, "_fa_log_level", original_level) From 6c73fb506fa84424c3fa04880f85c789b4a498e2 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Mon, 27 Apr 2026 21:42:56 -0700 Subject: [PATCH 34/44] [hd256] Improve forward kernel with exp2 FMA emulation (+3% to +9% performance gain) (#2488) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [hd256] Improve forward kernel with exp2 FMA emulation Rebased cherry-pick of `e122e67` from `Johnsonms/exp2-emu-hd256` on top of merged main (hd256 PR #2412, `27b4eb9`). The original branch was based on a pre-merge snapshot; the other five commits in that branch were absorbed into the squash-merge, leaving this one novel change. ## Change Replace a fraction of hardware `exp2` (SFU) instructions with a polynomial FMA emulation (`ex2_emulation_2`) in the softmax P-tile computation. The key insight: SM100's SFU throughput is a bottleneck for hdim=256 due to the large tile size. By substituting 3 out of every 4 `exp2` calls (`ex2_emu_freq=4`, `ex2_emu_res=3`) with packed FMA polynomial approximation, we shift pressure onto the underutilized FMA pipeline. Additionally, the P write-slot acquisition is moved earlier to overlap any pipeline stall with the `exp2` compute. Kernel-only change; no API change. Backward is untouched. ## Validation (this PR vs `origin/main` @ `b21e204` — includes #2412 hd256 base and #2487 post-merge cleanup) B200, bf16, hdim=256, MHA (32:32) and GQA (32:2), 3-run means, locked clocks @ 1755 MHz, seqlens 4k..128k. ### FWD delta vs `origin/main` (TFLOPS mean, 3 runs each) | seqlen | causal | MHA 32:32 | GQA 32:2 | |-------:|:------:|:---------:|:--------:| | 4k | F | -0.2% | +0.7% | | 8k | F | +0.4% | +0.3% | | 16k | F | +0.7% | -1.0% | | 32k | F | +2.3% | -0.3% | | 64k | F | **+5.4%** | **+5.2%** | | 128k | F | **+7.4%** | **+7.3%** | | 4k | T | +0.3% | +0.9% | | 8k | T | +0.8% | +0.9% | | 16k | T | +0.8% | +1.1% | | 32k | T | **+5.2%** | -0.4% | | 64k | T | +1.1% | +1.1% | | 128k | T | **+3.4%** | **+2.6%** | - 19 of 24 cells positive; 4 slightly negative, all within the batch-quantization noise band that `origin/main` itself already showed in our 3-run regression sweep. - Peak gain **+7.4%** (MHA 128k non-causal) — exactly where softmax SFU pressure is worst, consistent with the theory above. - Averages: **MHA fwd +2.3%**, **GQA fwd +1.5%**. ### Correctness smoke `pytest tests/cute/test_flash_attn.py::test_flash_attn_output -k "256-False-0-0.0-False-False"` on B200: **78 passed, 78 skipped, 0 failed** — identical pass/skip count to `origin/main`. ## Caveat Exp2 FMA emulation introduces small numerical differences vs hardware `exp2`. The existing test tolerances accept the delta. * [hd256] Wire ex2_emu params through _TUNING_CONFIG with tuned values The exp2 emulation knobs (ex2_emu_freq, ex2_emu_res, ex2_emu_start_frg) and softmax register counts for the hd256 forward kernel were hardcoded in BlackwellFusedMultiHeadAttentionForward.__init__, invisible to the central _TUNING_CONFIG table used by all other kernel configs. - flash_fwd_sm100.py: add hd256 entries to _TUNING_CONFIG (causal and non-causal; always 2cta, no sm103 variant). New ex2_emu_res field is hd256-specific; existing entries are unaffected. hd256 uses a fixed num_regs_other=32 (not derived from the 512-budget formula). - sm100_hd256_2cta_fmha_forward.py: replace hardcoded self.* assignments with a _TUNING_CONFIG lookup. Tuned values (B200, bf16, locked clocks): freq=14, res=6, start_frg=0 for both causal and non-causal. The inner loop steps k by 2, so k%freq only takes even values; freq=14/res=6 gives ~43% emulation (3 out of 7 even k%14 steps), replacing the previous 50:50 split (freq=4/res=3). --- flash_attn/cute/flash_fwd_sm100.py | 12 ++-- .../cute/sm100_hd256_2cta_fmha_forward.py | 63 ++++++++++++++----- 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 75e767cdc44..9027da189ac 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -73,19 +73,23 @@ # Values: # ex2_emu_freq: int — how often to use emulated exp2 (0=all hardware exp2, higher=more emulation). # SM103 has fast native exp2, so set freq=0 there. +# ex2_emu_res: int — (hd256 only) number of fragment-pairs per freq period to emulate. # ex2_emu_start_frg: int — fragment index to start emulation from # num_regs_softmax: int — register count for softmax warps (multiple of 8) # num_regs_correction: int — register count for correction warps (multiple of 8) # num_regs_other is derived: 512 - num_regs_softmax * 2 - num_regs_correction +# (hd256 exception: num_regs_other is fixed at 32, not derived) _TUNING_CONFIG = { - (True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 176, 'num_regs_correction': 88}, - (False, True, 128, False): {'ex2_emu_freq': 16, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 192, 'num_regs_correction': 72}, + (True, False, 128, False): {"ex2_emu_freq": 10, "ex2_emu_start_frg": 1, "num_regs_softmax": 176, "num_regs_correction": 88}, + (False, True, 128, False): {"ex2_emu_freq": 16, "ex2_emu_start_frg": 1, "num_regs_softmax": 192, "num_regs_correction": 72}, (True, False, 192, False): {"ex2_emu_freq": 16, "ex2_emu_start_frg": 0, "num_regs_softmax": 184, "num_regs_correction": 80}, (False, True, 192, False): {"ex2_emu_freq": 32, "ex2_emu_start_frg": 1, "num_regs_softmax": 192, "num_regs_correction": 72}, (True, False, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 80}, (False, True, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, (True, False, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, (False, True, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 72}, + (True, False, 256, False): {"ex2_emu_freq": 14, "ex2_emu_res": 6, "ex2_emu_start_frg": 0, "num_regs_softmax": 256, "num_regs_correction": 160}, + (True, True, 256, False): {"ex2_emu_freq": 14, "ex2_emu_res": 6, "ex2_emu_start_frg": 0, "num_regs_softmax": 256, "num_regs_correction": 160}, } _FP8_TUNING_CONFIG = { (True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 160, 'num_regs_correction': 72}, @@ -1565,8 +1569,8 @@ def mma( # idesc=qk_mma_idesc, smem_desc_base_b=k_smem_base, tCrB_layout=tSrK[None, None, None, 0].layout, - smem_var_name_prefix=f"fa_fwd_q_smem_desc", - idesc_var_name=f"fa_fwd_qk_mma_idesc", + smem_var_name_prefix="fa_fwd_q_smem_desc", + idesc_var_name="fa_fwd_qk_mma_idesc", kind=qk_mma_kind, smem_offset=-sQ_stage_stride if stage == 0 else sQ_stage_stride, zero_init=True, diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py index 03108e25a3d..a9463222718 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py @@ -28,7 +28,8 @@ Sm100FusedMask as FusedMask, ) from flash_attn.cute.tile_scheduler import SM100_TMEM_CAPACITY_COLUMNS -from flash_attn.cute.flash_fwd_sm100 import DescaleTensors +from flash_attn.cute.flash_fwd_sm100 import DescaleTensors, _TUNING_CONFIG +from flash_attn.cute.utils import ex2_emulation_2 class BlackwellFusedMultiHeadAttentionForward: @@ -136,9 +137,14 @@ def __init__( self.tmem_o_offset = 256 self.tmem_p_offset = self.tmem_s_offset - self.num_regs_softmax = 256 - self.num_regs_correction = 160 - self.num_regs_other = 32 + _tune_key = (True, is_causal, 256, False) # hd256: always 2cta, no sm103 variant + _tune = _TUNING_CONFIG.get(_tune_key, {}) + self.num_regs_softmax = _tune.get("num_regs_softmax", 256) + self.num_regs_correction = _tune.get("num_regs_correction", 160) + self.num_regs_other = 32 # fixed for hd256; not derived from 512 budget like other kernels + self.ex2_emu_freq = _tune.get("ex2_emu_freq", 4) + self.ex2_emu_res = _tune.get("ex2_emu_res", 3) + self.ex2_emu_start_frg = _tune.get("ex2_emu_start_frg", 0) self.buffer_align_bytes = 1024 @@ -1477,19 +1483,44 @@ def softmax_step( scale = scale_softmax_log2 minus_row_max_scale = (0.0 - row_max_safe) * scale - tTMEM_STORErP = cute.make_rmem_tensor(tTMEM_LOADrS.shape, self.q_dtype) - for k in range(0, cute.size(tTMEM_LOADrS), 2): - tTMEM_LOADrS[k], tTMEM_LOADrS[k + 1] = cute.arch.fma_packed_f32x2( - (tTMEM_LOADrS[k], tTMEM_LOADrS[k + 1]), - (scale, scale), - (minus_row_max_scale, minus_row_max_scale), - ) - tTMEM_LOADrS[k] = cute.math.exp2(tTMEM_LOADrS[k], fastmath=True) - tTMEM_LOADrS[k + 1] = cute.math.exp2(tTMEM_LOADrS[k + 1], fastmath=True) - s_vec = tTMEM_LOADrS.load() - tTMEM_STORErP.store(s_vec.to(self.q_dtype)) - + # Acquire P write slot early — overlaps any pipeline stall with exp2 compute p_handle = p_mma_producer.acquire_and_advance() + # Fragment-based FMA + exp2 + bf16 conversion + # Trades SFU for FMA via polynomial emulation on a fraction of elements + ex2_frg_tile = 32 + ex2_frg_cnt = cute.size(tTMEM_LOADrS) // ex2_frg_tile + tTMEM_LOADrS_ex2 = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(ex2_frg_tile)) + tTMEM_STORErP = cute.make_rmem_tensor(tTMEM_LOADrS.shape, self.q_dtype) + tTMEM_STORErP_ex2 = cute.logical_divide(tTMEM_STORErP, cute.make_layout(ex2_frg_tile)) + for j in cutlass.range_constexpr(ex2_frg_cnt): + for k in cutlass.range_constexpr(0, ex2_frg_tile, 2): + tTMEM_LOADrS_ex2[k, j], tTMEM_LOADrS_ex2[k + 1, j] = cute.arch.fma_packed_f32x2( + (tTMEM_LOADrS_ex2[k, j], tTMEM_LOADrS_ex2[k + 1, j]), + (scale, scale), + (minus_row_max_scale, minus_row_max_scale), + ) + if cutlass.const_expr(self.ex2_emu_freq == 0): + tTMEM_LOADrS_ex2[k, j] = cute.math.exp2(tTMEM_LOADrS_ex2[k, j], fastmath=True) + tTMEM_LOADrS_ex2[k + 1, j] = cute.math.exp2( + tTMEM_LOADrS_ex2[k + 1, j], fastmath=True + ) + else: + if cutlass.const_expr( + k % self.ex2_emu_freq < self.ex2_emu_freq - self.ex2_emu_res + or j >= ex2_frg_cnt - 1 + or j < self.ex2_emu_start_frg + ): + tTMEM_LOADrS_ex2[k, j] = cute.math.exp2( + tTMEM_LOADrS_ex2[k, j], fastmath=True + ) + tTMEM_LOADrS_ex2[k + 1, j] = cute.math.exp2( + tTMEM_LOADrS_ex2[k + 1, j], fastmath=True + ) + else: + tTMEM_LOADrS_ex2[k, j], tTMEM_LOADrS_ex2[k + 1, j] = ex2_emulation_2( + tTMEM_LOADrS_ex2[k, j], tTMEM_LOADrS_ex2[k + 1, j] + ) + tTMEM_STORErP_ex2[None, j].store(tTMEM_LOADrS_ex2[None, j].load().to(self.q_dtype)) tmem_store_atom = cute.make_copy_atom( tcgen05.St32x32bOp(tcgen05.Repetition(32)), self.qk_acc_dtype ) From ebeff90781648b33d487231c9d006776bd4390d1 Mon Sep 17 00:00:00 2001 From: EduardDurech <39579228+EduardDurech@users.noreply.github.com> Date: Tue, 28 Apr 2026 11:39:53 +0200 Subject: [PATCH 35/44] SM90 FA4 QuACK 0.4 Compatibility (#2513) * SM90 FA4 QuACK 0.4 Compatibility * Require QuACK>=0.4 --- CLAUDE.md | 2 +- flash_attn/cute/flash_bwd_sm90.py | 4 ---- flash_attn/cute/pyproject.toml | 2 +- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 35d96195fd8..3b5f9672b77 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -20,7 +20,7 @@ pip install flash-attn-4 pip install -e "flash_attn/cute[dev]" ``` -Dependencies: `nvidia-cutlass-dsl>=4.4.1`, `torch`, `einops`, `apache-tvm-ffi`, `quack-kernels>=0.2.10`. +Dependencies: `nvidia-cutlass-dsl>=4.4.1`, `torch`, `einops`, `apache-tvm-ffi`, `quack-kernels>=0.4.0`. ## Running Tests diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index c9a690d1e90..9a5f5d6c471 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -1200,7 +1200,6 @@ def mma( tiled_mma_SdP, sP_cpy, tidx, - self.arch, transpose=self.SdP_swapAB, position_independent=True, major_mode_size=mms_PdS, @@ -1210,7 +1209,6 @@ def mma( tiled_mma_SdP, sdS_cpy, tidx, - self.arch, transpose=self.SdP_swapAB, position_independent=True, major_mode_size=mms_PdS, @@ -1644,7 +1642,6 @@ def epilogue_dKV( tiled_mma_dV, sdV, tidx, - self.arch, transpose=self.dKV_swapAB, position_independent=True, ) @@ -1652,7 +1649,6 @@ def epilogue_dKV( tiled_mma_dK, sdK, tidx, - self.arch, transpose=self.dKV_swapAB, position_independent=True, ) diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml index 6ecf64d4e1d..cb1c3bb884f 100644 --- a/flash_attn/cute/pyproject.toml +++ b/flash_attn/cute/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "typing_extensions", "apache-tvm-ffi>=0.1.5,<0.2", "torch-c-dlpack-ext", - "quack-kernels>=0.3.3", + "quack-kernels>=0.4.0", ] [project.optional-dependencies] From ba59def94cd7a0c12e2a8c673b0a4655be67c5c4 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Tue, 28 Apr 2026 09:50:54 -0700 Subject: [PATCH 36/44] ci: use /tmp for apptainer tmpdir to fix xattrerror on VAST (#2511) --- .github/actions/gpu-test/action.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/actions/gpu-test/action.yml b/.github/actions/gpu-test/action.yml index 54adc89b3b2..e03680e41dc 100644 --- a/.github/actions/gpu-test/action.yml +++ b/.github/actions/gpu-test/action.yml @@ -45,10 +45,10 @@ runs: PULL_REF=$(echo "$FA4_IMAGE" | sed 's/:[^@]*@/@/') echo "PULL_REF=$PULL_REF" echo "SIF=$SIF" - mkdir -p "$CI_WORK_DIR/apptainer_tmp" "$CI_WORK_DIR/apptainer_cache" + mkdir -p "$CI_WORK_DIR/apptainer_cache" /tmp/apptainer_tmp if [ ! -f "$SIF" ]; then echo "Pulling $PULL_REF → $SIF" - APPTAINER_TMPDIR="$CI_WORK_DIR/apptainer_tmp" \ + APPTAINER_TMPDIR="/tmp/apptainer_tmp" \ APPTAINER_CACHEDIR="$CI_WORK_DIR/apptainer_cache" \ apptainer pull "$SIF" "docker://$PULL_REF" else From b995b246d9c2493a8629ef8c657c4bf992c67e0b Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Thu, 30 Apr 2026 00:17:56 +0900 Subject: [PATCH 37/44] Fix long MSVC linker commands on Windows (#2517) Use a temporary response file when setuptools emits an oversized link.exe command so Windows builds with many object files can complete. Made-with: Cursor --- setup.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/setup.py b/setup.py index dcb89f85efe..a8f912f7862 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ import ast import glob import shutil +import tempfile from pathlib import Path from typing import Literal, Optional from packaging.version import parse, Version @@ -637,6 +638,35 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + def build_extensions(self) -> None: + original_spawn = None + if sys.platform == "win32" and self.compiler.compiler_type == "msvc": + original_spawn = self.compiler.spawn + + def spawn(cmd): + if not cmd or Path(str(cmd[0])).name.lower() != "link.exe": + return original_spawn(cmd) + cmd = [str(arg) for arg in cmd] + if len(subprocess.list2cmdline(cmd)) <= 32767: + return original_spawn(cmd) + # Temporary workaround adapted from https://github.com/pypa/distutils/pull/406 + # until setuptools/distutils ships response-file handling for long MSVC links. + with tempfile.TemporaryDirectory() as tmpdir: + rsp_path = Path(tmpdir) / "cmdline.txt" + rsp_path.write_text( + "\n".join(subprocess.list2cmdline([arg]) for arg in cmd[1:]) + "\n", + encoding="ascii", + ) + return original_spawn([cmd[0], f"@{rsp_path}"]) + + self.compiler.spawn = spawn + + try: + super().build_extensions() + finally: + if original_spawn is not None: + self.compiler.spawn = original_spawn + # Build install_requires based on platform if ROCM_BACKEND == "triton": From 06259d9f67c95d63b73365fe238ff52d86c2a675 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Thu, 30 Apr 2026 11:25:07 -0700 Subject: [PATCH 38/44] Fix test_flash_attn_fast varlen call after qv positional insert (#2527) flash_attn_varlen_func now takes qv as the 4th positional parameter, which shifted cu_seqlens_q/k and max_seqlen_q/k by one slot in the existing positional call and caused 132 AttributeError failures at interface.py:370 (`'int' object has no attribute 'shape'`). Switch the call to keyword args. --- tests/cute/test_flash_attn_fast.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cute/test_flash_attn_fast.py b/tests/cute/test_flash_attn_fast.py index 433859d94d8..32deb4b5168 100644 --- a/tests/cute/test_flash_attn_fast.py +++ b/tests/cute/test_flash_attn_fast.py @@ -139,8 +139,8 @@ def test_flash_attn_varlen_output(seqlen, d, causal, mha_type, dtype): out_varlen, lse = flash_attn_varlen_func( q_varlen, k_varlen, v_varlen, - cu_seqlens, cu_seqlens, - seqlen, seqlen, + cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=seqlen, max_seqlen_k=seqlen, causal=causal, ) From c5f6ff4905b924d20b5fc2fe74551638cd1421f4 Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Thu, 30 Apr 2026 12:06:15 -0700 Subject: [PATCH 39/44] [Cute,Bwd,Sm90] Fix determinism for GQA, port Sm100 approach in (#2510) * [Cute,Bwd,Sm90] Fix determinism for GQA, port Sm100 approach in * add tests --- flash_attn/cute/flash_bwd_sm90.py | 43 +++++++++++++++++++- tests/cute/test_flash_attn_race_condition.py | 2 - 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 9a5f5d6c471..2e420924e92 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -402,6 +402,15 @@ def _qkv_transpose(t): if const_expr(self.deterministic): assert mdQ_semaphore is not None mdQ_semaphore = layout_utils.select(mdQ_semaphore, mode=[2, 3, 1, 0]) + if const_expr(self.deterministic and self.qhead_per_kvhead > 1): + assert mdK_semaphore is not None + assert mdV_semaphore is not None + mdK_semaphore, mdV_semaphore = [ + layout_utils.select(t, mode=[2, 3, 1, 0]) for t in (mdK_semaphore, mdV_semaphore) + ] + else: + mdK_semaphore = None + mdV_semaphore = None self.num_mma_threads = tiled_mma_SdP.size assert self.num_mma_threads + 128 == self.num_threads @@ -598,6 +607,8 @@ def _qkv_transpose(t): blocksparse_tensors, qhead_per_kvhead_divmod, mdQ_semaphore, + mdK_semaphore, + mdV_semaphore, window_size_left, window_size_right, ).launch( @@ -651,6 +662,8 @@ def kernel( blocksparse_tensors: Optional[BlockSparseTensors] = None, qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, mdQ_semaphore: Optional[cute.Tensor] = None, + mdK_semaphore: Optional[cute.Tensor] = None, + mdV_semaphore: Optional[cute.Tensor] = None, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, ): @@ -787,6 +800,8 @@ def kernel( tiled_mma_dQ, mdK, mdV, + mdK_semaphore, + mdV_semaphore, mdQaccum, sQ, sK, @@ -1092,6 +1107,8 @@ def mma( tiled_mma_dQ: cute.TiledMma, mdK: cute.Tensor, mdV: cute.Tensor, + mdK_semaphore: Optional[cute.Tensor], + mdV_semaphore: Optional[cute.Tensor], mdQaccum: cute.Tensor, sQ: cute.Tensor, sK: cute.Tensor, @@ -1388,6 +1405,8 @@ def mma( head_idx, batch_idx, qhead_per_kvhead_divmod, + mdK_semaphore, + mdV_semaphore, ) else: # KV tile with zero Q blocks produces no dK/dV; write zeros. @@ -1411,6 +1430,8 @@ def mma( head_idx, batch_idx, qhead_per_kvhead_divmod, + mdK_semaphore, + mdV_semaphore, ) tile_scheduler.advance_to_next_work() @@ -1615,6 +1636,8 @@ def epilogue_dKV( head_idx: Int32, batch_idx: Int32, qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, + mdK_semaphore: Optional[cute.Tensor] = None, + mdV_semaphore: Optional[cute.Tensor] = None, ): epi_barrier = cutlass.pipeline.NamedBarrier( barrier_id=int(NamedBarrierBwd.Epilogue), num_threads=self.num_mma_threads @@ -1669,11 +1692,18 @@ def epilogue_dKV( store_dK() cute.arch.cp_async_bulk_commit_group() else: + deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1 sdKaccum_shape0 = self.tile_n * self.tile_hdim // self.num_wg_mma sdVaccum_shape0 = self.tile_n * self.tile_hdimv // self.num_wg_mma sdKaccum_layout = cute.make_layout((sdKaccum_shape0, self.num_wg_mma)) sdVaccum_layout = cute.make_layout((sdVaccum_shape0, self.num_wg_mma)) head_idx_kv = head_idx // qhead_per_kvhead_divmod + if const_expr(deterministic_KV): + assert mdK_semaphore is not None + assert mdV_semaphore is not None + mdK_semaphore_cur = mdK_semaphore[n_block, None, head_idx_kv, batch_idx] + mdV_semaphore_cur = mdV_semaphore[n_block, None, head_idx_kv, batch_idx] + lock_value = head_idx % self.qhead_per_kvhead mdKaccum_cur = seqlen.offset_batch_K( mdK, batch_idx, dim=2, padded=True, multiple=self.tile_hdim )[None, head_idx_kv] @@ -1697,7 +1727,10 @@ def epilogue_dKV( tdKsdKaccum = thr_copy_dKVaccum_r2s.partition_D(sdKaccum) tdVsdVaccum = thr_copy_dKVaccum_r2s.partition_D(sdVaccum) - cute.arch.cp_async_bulk_wait_group(0, read=True) + read_flag = const_expr(not deterministic_KV) + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + if const_expr(deterministic_KV): + barrier.wait_eq(mdK_semaphore_cur.iterator, tidx, 0, lock_value) epi_barrier.arrive_and_wait() tdKrdKaccum_flat = cute.make_tensor(acc_dK.iterator, tdKsdKaccum.shape) cute.autovec_copy(tdKrdKaccum_flat, tdKsdKaccum) @@ -1713,7 +1746,10 @@ def epilogue_dKV( ) cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + if const_expr(deterministic_KV): + barrier.arrive_inc(mdK_semaphore_cur.iterator, tidx, 0, 1) + barrier.wait_eq(mdV_semaphore_cur.iterator, tidx, 0, lock_value) epi_barrier.arrive_and_wait() tdVrdVaccum_flat = cute.make_tensor(acc_dV.iterator, tdVsdVaccum.shape) cute.autovec_copy(tdVrdVaccum_flat, tdVsdVaccum) @@ -1728,6 +1764,9 @@ def epilogue_dKV( self.tma_copy_bytes["dVacc"] // self.num_wg_mma, ) cute.arch.cp_async_bulk_commit_group() + if const_expr(deterministic_KV): + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + barrier.arrive_inc(mdV_semaphore_cur.iterator, tidx, 0, 1) @cute.jit def dQaccum_store( diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py index 18295e01843..12f4c15ae3e 100644 --- a/tests/cute/test_flash_attn_race_condition.py +++ b/tests/cute/test_flash_attn_race_condition.py @@ -251,8 +251,6 @@ def test_flash_attn_output( and learnable_sink is None # and False ): - if IS_SM90 and mha_type != "mha": - pytest.xfail("SM90 backward: GQA/MQA has tensor layout issue (qhead_per_kvhead > 1)") if IS_SM90 and local: pytest.xfail("SM90 backward: local attention not supported yet") g = torch.randn_like(out) From a7f6fbd9cf51101786157bb087438b584e0f6747 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Fri, 1 May 2026 10:25:05 -0700 Subject: [PATCH 40/44] benchmarks/tune_ex2_emu: hd256 sweep support and clock lock/unlock (#2495) - Add ex2_emu_res as a third sweep dimension for hd256 keys; skip Phase 2 (num_regs_other is fixed for hd256). - Upgrade clock handling to lock/unlock: setup_clocks() locks at startup and unlocks via atexit, with --lock-clocks/--no-lock-clocks flag. - Fix nvidia-smi GPU targeting to respect CUDA_VISIBLE_DEVICES. --- benchmarks/tune_ex2_emu.py | 184 +++++++++++++++++++++++++++++++------ 1 file changed, 157 insertions(+), 27 deletions(-) diff --git a/benchmarks/tune_ex2_emu.py b/benchmarks/tune_ex2_emu.py index a2a74422e5f..c21eac93cc5 100644 --- a/benchmarks/tune_ex2_emu.py +++ b/benchmarks/tune_ex2_emu.py @@ -10,6 +10,7 @@ Requires: the _TUNING_CONFIG dict in flash_attn/cute/flash_fwd_sm100.py. """ import argparse +import atexit import json import re import subprocess @@ -86,6 +87,98 @@ def detect_sm103(): print(f"GPU: {torch.cuda.get_device_name()}, SM{sm}, is_sm103={is_sm103}") return is_sm103 +def _get_gpu_selector(): + """Return the nvidia-smi GPU selector (-i argument) for the current device. + + Resolves CUDA_VISIBLE_DEVICES so that nvidia-smi targets the same physical + GPU that PyTorch is using. + """ + visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if visible: + entries = [e.strip() for e in visible.split(",") if e.strip()] + if entries: + idx = torch.cuda.current_device() if torch.cuda.is_available() else 0 + return entries[idx] if idx < len(entries) else entries[0] + if torch.cuda.is_available(): + return str(torch.cuda.current_device()) + return None + +def _nvidia_smi_cmd(*args): + """Build nvidia-smi command for the current GPU, prepending sudo when not root.""" + prefix = [] if os.geteuid() == 0 else ["sudo"] + cmd = prefix + ["nvidia-smi"] + selector = _get_gpu_selector() + if selector is not None: + cmd += ["-i", selector] + return cmd + list(args) + +def _query_clocks(): + """Return (current_mhz_str, max_mhz_str) or (None, None) on failure.""" + try: + result = subprocess.run( + _nvidia_smi_cmd( + "--query-gpu=clocks.current.graphics,clocks.max.graphics", + "--format=csv,noheader,nounits", + ), + capture_output=True, text=True, + ) + except OSError: + return None, None + if result.returncode != 0: + return None, None + lines = [ln.strip() for ln in result.stdout.strip().splitlines() if ln.strip()] + if not lines: + return None, None + fields = [f.strip() for f in lines[0].split(",")] + if len(fields) < 2 or not fields[0] or not fields[1]: + return None, None + return fields[0], fields[1] + +def lock_clocks(max_mhz): + """Lock GPU clocks to max_mhz. Returns True on success.""" + try: + result = subprocess.run( + _nvidia_smi_cmd("--lock-gpu-clocks", str(max_mhz)), + capture_output=True, text=True, + ) + except OSError as e: + print(f"WARNING: Could not lock GPU clocks ({e}).") + return False + if result.returncode == 0: + print(f"Locked GPU clocks to {max_mhz} MHz.") + return True + print(f"WARNING: Could not lock GPU clocks ({result.stderr.strip()}).") + return False + +def unlock_clocks(): + """Unlock GPU clocks (best-effort, called at exit).""" + try: + subprocess.run(_nvidia_smi_cmd("--reset-gpu-clocks"), capture_output=True) + except OSError: + pass + +def setup_clocks(do_lock): + """Query clock state; if do_lock, attempt to lock and register unlock at exit.""" + cur, max_clk = _query_clocks() + if cur is None: + return + if do_lock: + if cur == max_clk: + print(f"GPU clocks already at max ({max_clk} MHz).") + elif lock_clocks(max_clk): + atexit.register(unlock_clocks) + print("GPU clocks will be unlocked on exit.") + else: + lock_cmd = " ".join(_nvidia_smi_cmd("--lock-gpu-clocks", max_clk)) + print(f" To lock manually: {lock_cmd}") + else: + if cur != max_clk: + print(f"WARNING: GPU clocks not locked ({cur} MHz, max {max_clk} MHz).") + print(" Benchmark results may vary between runs.") + lock_cmd = " ".join(_nvidia_smi_cmd("--lock-gpu-clocks", max_clk)) + print(f" To lock: {lock_cmd}") + print() + def run_benchmark(causal_flag, headdim_str, seqlen, rep=20, warmup=10): """Run benchmark, return (ms, tflops, mfu) or (None, None, None). @@ -125,11 +218,15 @@ def parse_args(): p.add_argument("--seqlen", type=str, default="8192") p.add_argument("--rep", type=int, default=20) p.add_argument("--warmup", type=int, default=10) + p.add_argument("--lock-clocks", action=argparse.BooleanOptionalAction, default=True, + help="Lock GPU clocks before tuning (requires sudo); use --no-lock-clocks to warn only") return p.parse_args() def main(): args = parse_args() + setup_clocks(args.lock_clocks) original_src = read_file() + atexit.register(write_file, original_src) config = parse_tuning_config(original_src) hdim, hdim_v = parse_headdim(args.headdim) @@ -148,51 +245,81 @@ def main(): print(f"Tuning hdim={args.headdim}, hdim_padded={hdim_padded}, seqlen={args.seqlen}, is_sm103={is_sm103}") print(f"Keys to tune: {keys_to_tune}\n") - # ── Phase 1: ex2_emu_freq + ex2_emu_start_frg sweep ── + # ── Phase 1: ex2_emu_freq + ex2_emu_start_frg sweep (+ ex2_emu_res for hd256) ── freq_values = [0, 6, 8, 10, 12, 14, 16, 20, 24, 32] start_frg_values = [0, 1] + # hd256 inner loop steps k by 2, so k%freq only takes even values. + # Meaningful res values that produce distinct hw:emu ratios on this sweep grid: + # freq=6, res=3 → 50:50 | freq=8, res=6 → 25:75 | freq=8, res=4 → 50:50 (diff freq) + hd256_res_values = [3, 6, 4] for key in keys_to_tune: - use_2cta, is_causal, _, _ = key + use_2cta, is_causal, key_hdim, _ = key causal_flag = "true" if is_causal else "false" causal_label = "causal" if is_causal else "non-causal" cta_label = "2CTA" if use_2cta else "1CTA" + is_hd256 = key_hdim == 256 + res_values = hd256_res_values if is_hd256 else [None] print("=" * 70) print(f"Phase 1: ex2_emu sweep for {causal_label} ({cta_label})") print("=" * 70) - print(f"{'freq':>5} {'start':>6} {'ms':>8} {'tflops':>10} {'mfu':>8}") - print("-" * 45) + if is_hd256: + print(f"{'freq':>5} {'res':>4} {'start':>6} {'ms':>8} {'tflops':>10} {'mfu':>8}") + print("-" * 50) + else: + print(f"{'freq':>5} {'start':>6} {'ms':>8} {'tflops':>10} {'mfu':>8}") + print("-" * 45) best_freq = config[key]["ex2_emu_freq"] + best_res = config[key].get("ex2_emu_res", None) best_start = config[key]["ex2_emu_start_frg"] best_tflops = 0 for start_frg in start_frg_values: for freq in freq_values: - test_config = dict(config) - test_config[key] = {**config[key], "ex2_emu_freq": freq, "ex2_emu_start_frg": start_frg} - write_file(patch_config(original_src, test_config)) - try: - ms, tflops, mfu = run_benchmark(causal_flag, args.headdim, args.seqlen, args.rep, args.warmup) - if tflops is None: - print(f"{freq:>5} {start_frg:>6} ERROR") - continue - marker = " ***" if tflops > best_tflops else "" - print(f"{freq:>5} {start_frg:>6} {ms:>8.2f} {tflops:>10.0f} {mfu:>8.1f}{marker}") - if tflops > best_tflops: - best_tflops = tflops - best_freq = freq - best_start = start_frg - except Exception as e: - print(f"{freq:>5} {start_frg:>6} ERROR: {e}") - sys.stdout.flush() - - print(f"\n Best: freq={best_freq}, start_frg={best_start}, {best_tflops:.0f} TFLOPS") - config[key] = {**config[key], "ex2_emu_freq": best_freq, "ex2_emu_start_frg": best_start} + for res in res_values: + test_config = dict(config) + patch = {"ex2_emu_freq": freq, "ex2_emu_start_frg": start_frg} + if res is not None: + patch["ex2_emu_res"] = res + test_config[key] = {**config[key], **patch} + write_file(patch_config(original_src, test_config)) + try: + ms, tflops, mfu = run_benchmark(causal_flag, args.headdim, args.seqlen, args.rep, args.warmup) + if tflops is None: + if is_hd256: + print(f"{freq:>5} {res:>4} {start_frg:>6} ERROR") + else: + print(f"{freq:>5} {start_frg:>6} ERROR") + continue + marker = " ***" if tflops > best_tflops else "" + if is_hd256: + print(f"{freq:>5} {res:>4} {start_frg:>6} {ms:>8.2f} {tflops:>10.0f} {mfu:>8.1f}{marker}") + else: + print(f"{freq:>5} {start_frg:>6} {ms:>8.2f} {tflops:>10.0f} {mfu:>8.1f}{marker}") + if tflops > best_tflops: + best_tflops = tflops + best_freq = freq + best_res = res + best_start = start_frg + except Exception as e: + if is_hd256: + print(f"{freq:>5} {res:>4} {start_frg:>6} ERROR: {e}") + else: + print(f"{freq:>5} {start_frg:>6} ERROR: {e}") + sys.stdout.flush() + + if is_hd256: + print(f"\n Best: freq={best_freq}, res={best_res}, start_frg={best_start}, {best_tflops:.0f} TFLOPS") + config[key] = {**config[key], "ex2_emu_freq": best_freq, "ex2_emu_res": best_res, "ex2_emu_start_frg": best_start} + else: + print(f"\n Best: freq={best_freq}, start_frg={best_start}, {best_tflops:.0f} TFLOPS") + config[key] = {**config[key], "ex2_emu_freq": best_freq, "ex2_emu_start_frg": best_start} # ── Phase 2: Register count sweep (softmax, correction; other = 512 - 2*softmax - correction) ── + # hd256 skipped: its num_regs_other=32 is fixed and the 512-budget formula does not apply. reg_combos = [] for softmax in [176, 184, 192, 200]: @@ -202,7 +329,7 @@ def main(): reg_combos.append((softmax, correction, other)) for key in keys_to_tune: - use_2cta, is_causal, _, _ = key + use_2cta, is_causal, key_hdim, _ = key causal_flag = "true" if is_causal else "false" causal_label = "causal" if is_causal else "non-causal" cta_label = "2CTA" if use_2cta else "1CTA" @@ -210,6 +337,11 @@ def main(): print("\n" + "=" * 70) print(f"Phase 2: Register sweep for {causal_label} ({cta_label})") print("=" * 70) + + if key_hdim == 256: + print(" Skipping: hd256 uses fixed num_regs_other=32; 512-budget formula does not apply.") + continue + print(f"{'softmax':>8} {'corr':>6} {'other':>6} {'ms':>8} {'tflops':>10} {'mfu':>8}") print("-" * 55) @@ -255,10 +387,8 @@ def main(): val_parts = ", ".join(f'"{k}": {json.dumps(v)}' for k, v in config[key].items()) print(f" {key!r}: {{{val_parts}}},") - print(f"\nTo apply, update _TUNING_CONFIG in {KERNEL_FILE}") - # Restore original write_file(original_src) print("Restored original file.") From 956366b9791645fc5250b9b4b16be26e48f8e399 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Fri, 1 May 2026 10:25:39 -0700 Subject: [PATCH 41/44] [FA4][hd256] Backward TMA bulk-store epilogue + LSE/dpsum coalesce (#2497) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [FA4][hd256] Coalesce LSE/dpsum per-K-iter loads in dkdv Switch LSE and sum_OdO GMEM→SMEM loads from scatter indexing (thread_idx*N + i) to warp-coalesced indexing (thread_idx + i*32). Applied at all three load sites in the dkdv accumulation loop. * [FA4][hd256] TMA bulk-store epilogue for dK/dV and dQ dkdv: replace per-thread scattered GMEM stores with a cooperative TMA bulk-store path. Both warp-groups write into a CTA-shared (64, 256) SMEM staging tile aliased onto the dead sP+sdST buffers; WG 0 fires 4x(64, 64) cp.async.bulk stores. Per-thread store retained for varlen. dq: single-stage TMA bulk store aliased onto the consumed sdO buffer. Per-thread store retained for varlen. --- ...100_hd256_2cta_fmha_backward_dkdvkernel.py | 315 +++++++++++++++--- ...sm100_hd256_2cta_fmha_backward_dqkernel.py | 116 ++++++- 2 files changed, 364 insertions(+), 67 deletions(-) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py index e413d6f638a..7a8cdeede6a 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py @@ -548,6 +548,51 @@ def __call__( self.tma_copy_V_bytes = cute.size_in_bytes(V.element_type, V_smem_layout) * atom_thr_size self.tma_copy_dO_bytes = cute.size_in_bytes(dO.element_type, dO_smem_layout) * atom_thr_size + # Variant 3a epilogue: TMA store atoms (S2G) for dK / dV. + # Each compute warp group owns half the hd_v output via split_wg, so + # the per-WG epi tile is (cta_tiler[1], cta_tiler[2] / num_compute_wgs). + # SMEM staging will alias onto sP+sdST in subsequent commits; for now + # the atoms and layouts are built and threaded through but unused. + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + num_compute_wgs = self.num_compute_warps // 4 + # Variant 3a Path 2: CTA-shared epilogue SMEM. epi_tile is (M, gcd(128B, ...)) + # = (M, 64) for bf16 hd=256. Total stages = num_compute_wgs * num_epi_stages + # = 4 stages of (64, 64), virtually a per-CTA (64, 256) buffer aliased onto + # sP+sdST. Both warp-groups cooperatively populate this buffer; TMA fires + # one (64, 64) box per stage to the corresponding (64, 64) GMEM slice. + epi_cols_dKV = math.gcd( + 128 // (dK.element_type.width // 8), self.cta_tiler[2] // num_compute_wgs + ) + num_epi_stages_dKV = (self.cta_tiler[2] // num_compute_wgs) // epi_cols_dKV + epi_tile_dKV = (self.cta_tiler[1], epi_cols_dKV) + total_epi_stages = num_compute_wgs * num_epi_stages_dKV + dK_layout_enum = utils.LayoutEnum.from_tensor(dK) + dV_layout_enum = utils.LayoutEnum.from_tensor(dV) + sdK_epi_layout = sm100_utils.make_smem_layout_epi( + dK.element_type, + dK_layout_enum, + epi_tile_dKV, + total_epi_stages, + ) + sdV_epi_layout = sm100_utils.make_smem_layout_epi( + dV.element_type, + dV_layout_enum, + epi_tile_dKV, + total_epi_stages, + ) + tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom( + tma_store_op, + dK, + cute.select(sdK_epi_layout, mode=[0, 1]), + epi_tile_dKV, + ) + tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom( + tma_store_op, + dV, + cute.select(sdV_epi_layout, mode=[0, 1]), + epi_tile_dKV, + ) + @cute.struct class SharedStorage: # Pipeline barriers @@ -674,6 +719,10 @@ class SharedStorage: tma_tensor_dOT, dK, dV, + tma_atom_dK, + tma_tensor_dK, + tma_atom_dV, + tma_tensor_dV, scaled_LSE, scale_softmax, sum_OdO, @@ -692,6 +741,8 @@ class SharedStorage: P_smem_layout_staged, LSE_smem_layout, sum_OdO_smem_layout, + sdK_epi_layout, + sdV_epi_layout, self.tile_sched_params, ).launch( grid=bwd_grid, @@ -725,6 +776,10 @@ def dkdv_bwd( dOT_in: cute.Tensor, dK: cute.Tensor, dV: cute.Tensor, + tma_atom_dK: cute.CopyAtom, + dK_tma: cute.Tensor, + tma_atom_dV: cute.CopyAtom, + dV_tma: cute.Tensor, LSE: cute.Tensor, scale_softmax: cutlass.Float32, sum_OdO: cute.Tensor, @@ -743,6 +798,8 @@ def dkdv_bwd( P_smem_layout_staged: cute.ComposedLayout, LSE_smem_layout: cute.Layout, sum_OdO_smem_layout: cute.Layout, + sdK_epi_layout: cute.ComposedLayout, + sdV_epi_layout: cute.ComposedLayout, tile_sched_params: FmhaStaticTileSchedulerParams | FmhaClcDynamicTileSchedulerParams, ): """Core CuTeDSL backward kernel.""" @@ -1359,6 +1416,7 @@ def dkdv_bwd( sP, sLSE, sdST, + sdOT, sSum_OdO, dK, dV, @@ -1390,6 +1448,12 @@ def dkdv_bwd( varlen, sK, seqlen_k_cur_batch, + tma_atom_dK, + dK_tma, + tma_atom_dV, + dV_tma, + sdK_epi_layout, + sdV_epi_layout, ) cute.arch.barrier( barrier_id=self.epilogue_sync_bar_id, @@ -1552,6 +1616,7 @@ def dkdv_bwd( sP, sLSE, sdST, + sdOT, sSum_OdO, dK, dV, @@ -1583,6 +1648,12 @@ def dkdv_bwd( varlen, sK, seqlen_k_cur_batch, + tma_atom_dK, + dK_tma, + tma_atom_dV, + dV_tma, + sdK_epi_layout, + sdV_epi_layout, ) cute.arch.barrier( @@ -1821,24 +1892,19 @@ def load( ) sLSE_for_copy = cute.flat_divide(sLSE, (1,)) LSE_for_copy = cute.flat_divide(LSE, (1,)) + # Warp-coalesced: at each i, lane T accesses index `T + i*W` (stride-1 + # across the warp) instead of `T*N + i` (stride-N across the warp). for i in cutlass.range_constexpr(async_copy_num_elts): - LSE_idx = self.tile_shape_Q * iter_index + thread_idx * async_copy_num_elts - if cute.elem_less(LSE_idx + i, problem_shape[0]): + LSE_idx = self.tile_shape_Q * iter_index + thread_idx + i * self.threads_per_warp + sLSE_idx = thread_idx + i * self.threads_per_warp + if cute.elem_less(LSE_idx, problem_shape[0]): cute.copy( atom_async_copy, - LSE_for_copy[None, LSE_idx + i, (blk_coord_h, blk_coord_b)], - sLSE_for_copy[ - None, - thread_idx * async_copy_num_elts + i, - lse_handle.index, - ], + LSE_for_copy[None, LSE_idx, (blk_coord_h, blk_coord_b)], + sLSE_for_copy[None, sLSE_idx, lse_handle.index], ) else: - sLSE_for_copy[ - None, - thread_idx * async_copy_num_elts + i, - lse_handle.index, - ].fill(0.0) + sLSE_for_copy[None, sLSE_idx, lse_handle.index].fill(0.0) lse_handle.commit() v_handle = load_mma_V_producer.acquire_and_advance() @@ -1861,23 +1927,16 @@ def load( sSum_OdO_for_copy = cute.flat_divide(sSum_OdO, (1,)) sum_OdO_for_copy = cute.flat_divide(sum_OdO, (1,)) for i in cutlass.range_constexpr(async_copy_num_elts): - sum_OdO_idx = self.tile_shape_Q * iter_index + thread_idx * async_copy_num_elts - if cute.elem_less(sum_OdO_idx + i, problem_shape[0]): + sum_OdO_idx = self.tile_shape_Q * iter_index + thread_idx + i * self.threads_per_warp + sSum_OdO_idx = thread_idx + i * self.threads_per_warp + if cute.elem_less(sum_OdO_idx, problem_shape[0]): cute.copy( atom_async_copy, - sum_OdO_for_copy[None, sum_OdO_idx + i, (blk_coord_h, blk_coord_b)], - sSum_OdO_for_copy[ - None, - thread_idx * async_copy_num_elts + i, - sum_odo_handle.index, - ], + sum_OdO_for_copy[None, sum_OdO_idx, (blk_coord_h, blk_coord_b)], + sSum_OdO_for_copy[None, sSum_OdO_idx, sum_odo_handle.index], ) else: - sSum_OdO_for_copy[ - None, - thread_idx * async_copy_num_elts + i, - sum_odo_handle.index, - ].fill(0.0) + sSum_OdO_for_copy[None, sSum_OdO_idx, sum_odo_handle.index].fill(0.0) sum_odo_handle.commit() dot_handle = load_mma_dOT_producer.acquire_and_advance() @@ -1917,23 +1976,16 @@ def load( sLSE_for_copy = cute.flat_divide(sLSE, (1,)) LSE_for_copy = cute.flat_divide(LSE, (1,)) for i in cutlass.range_constexpr(async_copy_num_elts): - LSE_idx = self.tile_shape_Q * iter_index + thread_idx * async_copy_num_elts - if cute.elem_less(LSE_idx + i, problem_shape[0]): + LSE_idx = self.tile_shape_Q * iter_index + thread_idx + i * self.threads_per_warp + sLSE_idx = thread_idx + i * self.threads_per_warp + if cute.elem_less(LSE_idx, problem_shape[0]): cute.copy( atom_async_copy, - LSE_for_copy[None, LSE_idx + i, (blk_coord_h, blk_coord_b)], - sLSE_for_copy[ - None, - thread_idx * async_copy_num_elts + i, - lse_handle.index, - ], + LSE_for_copy[None, LSE_idx, (blk_coord_h, blk_coord_b)], + sLSE_for_copy[None, sLSE_idx, lse_handle.index], ) else: - sLSE_for_copy[ - None, - thread_idx * async_copy_num_elts + i, - lse_handle.index, - ].fill(0.0) + sLSE_for_copy[None, sLSE_idx, lse_handle.index].fill(0.0) lse_handle.commit() do_handle = load_mma_dO_producer.acquire_and_advance() @@ -1948,23 +2000,18 @@ def load( sSum_OdO_for_copy = cute.flat_divide(sSum_OdO, (1,)) sum_OdO_for_copy = cute.flat_divide(sum_OdO, (1,)) for i in cutlass.range_constexpr(async_copy_num_elts): - sum_OdO_idx = self.tile_shape_Q * iter_index + thread_idx * async_copy_num_elts - if cute.elem_less(sum_OdO_idx + i, problem_shape[0]): + sum_OdO_idx = ( + self.tile_shape_Q * iter_index + thread_idx + i * self.threads_per_warp + ) + sSum_OdO_idx = thread_idx + i * self.threads_per_warp + if cute.elem_less(sum_OdO_idx, problem_shape[0]): cute.copy( atom_async_copy, - sum_OdO_for_copy[None, sum_OdO_idx + i, (blk_coord_h, blk_coord_b)], - sSum_OdO_for_copy[ - None, - thread_idx * async_copy_num_elts + i, - sum_odo_handle.index, - ], + sum_OdO_for_copy[None, sum_OdO_idx, (blk_coord_h, blk_coord_b)], + sSum_OdO_for_copy[None, sSum_OdO_idx, sum_odo_handle.index], ) else: - sSum_OdO_for_copy[ - None, - thread_idx * async_copy_num_elts + i, - sum_odo_handle.index, - ].fill(0.0) + sSum_OdO_for_copy[None, sSum_OdO_idx, sum_odo_handle.index].fill(0.0) sum_odo_handle.commit() dot_handle = load_mma_dOT_producer.acquire_and_advance() @@ -2379,6 +2426,7 @@ def compute( sLSE: cute.Tensor, # sdS: cute.Tensor, sdST: cute.Tensor, + sdOT: cute.Tensor, sSum_OdO: cute.Tensor, dK: cute.Tensor, dV: cute.Tensor, @@ -2410,6 +2458,12 @@ def compute( varlen: bool, sK: cute.Tensor, problem_shape_k_cur_batch: Int32, + tma_atom_dK: cute.CopyAtom, + dK_tma: cute.Tensor, + tma_atom_dV: cute.CopyAtom, + dV_tma: cute.Tensor, + sdK_epi_layout: cute.ComposedLayout, + sdV_epi_layout: cute.ComposedLayout, ): """CuTeDSL kernel for recomputing softmax and producing dk and dv.""" tidx, _, _ = cute.arch.thread_idx() @@ -2649,6 +2703,15 @@ def compute( mma_compute_dKdV_producer, mma_compute_dKdV_consumer, problem_shape_k_cur_batch, + tma_atom_dK, + dK_tma, + tma_atom_dV, + dV_tma, + sdK_epi_layout, + sdV_epi_layout, + varlen, + sdOT, + sP, ) if not cutlass.const_expr(self.use_clc_scheduler): @@ -2760,8 +2823,28 @@ def epilogue( mma_compute_dKdV_producer, mma_compute_dKdV_consumer, problem_shape_k_cur_batch: Int32, + tma_atom_dK: cute.CopyAtom, + dK_tma: cute.Tensor, + tma_atom_dV: cute.CopyAtom, + dV_tma: cute.Tensor, + sdK_epi_layout: cute.ComposedLayout, + sdV_epi_layout: cute.ComposedLayout, + varlen: bool, + sdOT: cute.Tensor, + sP: cute.Tensor, ): - """Epilogue phase to store result from tensor memory to register, then global memory.""" + """Variant 3a (5/5) Path 2: CTA-shared SMEM with cooperative WG writes + TMA bulk store. + + Both warp-groups cooperatively populate a per-CTA (64, 256) virtual SMEM + buffer (4 stages of (64, 64) aliased onto sP+sdST). Per-thread t2r N + coverage is interleaved across the full hd=256, so per-WG TMA is not + viable — instead we treat SMEM as one shared per-CTA buffer and let + each thread's `self.store`-equivalent write into it via a (64, 256) + virtual tensor whose N axis maps (n%64, n//64) → (N_within, stage). + After an inter-WG barrier (256 threads), the leader warp fires 4 TMA + bulk stores, one per stage, to the corresponding (64, 64) GMEM slice. + Varlen falls back to per-thread self.store as in flash_bwd_sm100.py. + """ tidx, _, _ = cute.arch.thread_idx() _, K, D, HB = problem_shape _, blk_coord_k, _, blk_coord_batch = blk_coord @@ -2789,12 +2872,41 @@ def epilogue( num_warp_groups = self.num_compute_warps // 4 dp_idx = tidx % 128 wg_idx = (tidx % (self.num_compute_warps * self.threads_per_warp)) // 128 + leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0 + + # Path 2 SMEM staging. dV stages through sdOT (already-consumed by the + # dV MMA before the dV epilogue begins). dK stages through sP+sdST + # (dead after dK MMA completes, before dK epilogue runs). + s_epi_dK = cute.make_tensor( + cute.recast_ptr(sP.iterator, sdK_epi_layout.inner, dK.element_type), + sdK_epi_layout.outer, + ) + s_epi_dV = cute.make_tensor( + cute.recast_ptr(sdOT.iterator, sdV_epi_layout.inner, dV.element_type), + sdV_epi_layout.outer, + ) + + # Compile-time: stage tile shape and number of stages. + epi_cols_dKV = math.gcd( + 128 // (dV.element_type.width // 8), self.cta_tiler[2] // num_warp_groups + ) + num_epi_stages_dKV = (self.cta_tiler[2] // num_warp_groups) // epi_cols_dKV + total_epi_stages = num_warp_groups * num_epi_stages_dKV + epi_tile_dKV = (self.cta_tiler[1], epi_cols_dKV) + + # Local (M, N) coord tensor for SMEM indexing (no global domain offset + # — cdK/cdV are domain-offset by blk_coord_k * tile_shape_K to match + # the GMEM destination, but the SMEM indexing must be per-CTA-local). + cdV_local = cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])) + cdK_local = cdV_local tiled_t2r_dK = tcgen05.make_tmem_copy(load_op, tdKtdK) thread_t2r_dK = tiled_t2r_dK.get_slice(dp_idx) tTR_cdK = thread_t2r_dK.partition_D(cdK) tTR_cdK = split_wg(tTR_cdK, num_warp_groups, wg_idx) + tTR_cdK_local = thread_t2r_dK.partition_D(cdK_local) + tTR_cdK_local = split_wg(tTR_cdK_local, num_warp_groups, wg_idx) tTR_gdK = thread_t2r_dK.partition_D(gdK) tTR_gdK = split_wg(tTR_gdK, num_warp_groups, wg_idx) tTR_rdK = cute.make_rmem_tensor(tTR_cdK.shape, self.acc_dtype) @@ -2821,17 +2933,80 @@ def epilogue( tTR_cdV = thread_t2r_dV.partition_D(cdV) tTR_cdV = split_wg(tTR_cdV, num_warp_groups, wg_idx) + tTR_cdV_local = thread_t2r_dV.partition_D(cdV_local) + tTR_cdV_local = split_wg(tTR_cdV_local, num_warp_groups, wg_idx) tTR_gdV = thread_t2r_dV.partition_D(gdV) tTR_gdV = split_wg(tTR_gdV, num_warp_groups, wg_idx) tTR_rdV = cute.make_rmem_tensor(tTR_cdV.shape, self.acc_dtype) tTR_tdV = thread_t2r_dV.partition_S(tdVtdV) tTR_tdV = split_wg(tTR_tdV, num_warp_groups, wg_idx) + # GMEM destinations for the multi-stage TMA path (gated on not-varlen). + if cutlass.const_expr(not varlen): + mdV_tma_3d = cute.make_tensor( + dV_tma.iterator, + cute.make_layout((K, self.cta_tiler[2], HB), stride=dV_tma.stride), + ) + mdV_tma_cur = mdV_tma_3d[None, None, blk_coord_batch] + gdV_tma = cute.local_tile( + mdV_tma_cur, (self.cta_tiler[1], self.cta_tiler[2]), (blk_coord_k, 0) + ) + gdV_tma_epi = cute.local_tile(gdV_tma, epi_tile_dKV, (0, None)) + + mdK_tma_3d = cute.make_tensor( + dK_tma.iterator, + cute.make_layout((K, self.cta_tiler[2], HB), stride=dK_tma.stride), + ) + mdK_tma_cur = mdK_tma_3d[None, None, blk_coord_batch] + gdK_tma = cute.local_tile( + mdK_tma_cur, (self.cta_tiler[1], self.cta_tiler[2]), (blk_coord_k, 0) + ) + gdK_tma_epi = cute.local_tile(gdK_tma, epi_tile_dKV, (0, None)) + + cta_threads = self.num_compute_warps * self.threads_per_warp + dkdv_handle = mma_compute_dKdV_consumer.wait_and_advance() if blk_coord_k * self.tile_shape_K < problem_shape_k_cur_batch: cute.copy(tiled_t2r_dV, tTR_tdV, tTR_rdV) - self.store(tTR_gdV, tTR_rdV, tTR_cdV, (K, D)) + tTR_rdV_cast = cute.make_rmem_tensor(tTR_rdV.shape, dV.element_type) + tTR_rdV_cast.store(tTR_rdV.load().to(dV.element_type)) + + if cutlass.const_expr(not varlen): + # reg -> SMEM via per-element indexed stores using tTR_cdV's + # per-thread (M, N) coords. (M, N) is per-CTA cdV space (M=0..63, + # N=0..255). We map N=(n%epi_cols, n//epi_cols) → (N_within, stage) + # of the 3D s_epi_dV tensor. + for _i in cutlass.range_constexpr(cute.size(tTR_cdV_local, mode=[2])): + for _j in cutlass.range_constexpr(cute.size(tTR_cdV_local[None, 0, _i])): + c = tTR_cdV_local[None, 0, _i][_j] + m_pos = c[0] + n_pos = c[1] + stage_pos = n_pos // epi_cols_dKV + n_within_pos = n_pos % epi_cols_dKV + v = tTR_rdV_cast[None, 0, _i][_j] + s_epi_dV[m_pos, n_within_pos, stage_pos] = v + cute.arch.fence_view_async_shared() + # Inter-WG barrier — both warp-groups must finish their writes + # before the leader warp reads SMEM via TMA. + cute.arch.barrier(barrier_id=5, number_of_threads=cta_threads) + # TMA bulk store, one (64, 64) box per stage. + if leader_warp and wg_idx == 0: + for _stage in cutlass.range_constexpr(total_epi_stages): + sdV_stage = s_epi_dV[None, None, _stage] + gdV_stage = gdV_tma_epi[None, None, _stage] + td_sdV, td_gdV = cpasync.tma_partition( + tma_atom_dV, + 0, + cute.make_layout(1), + cute.group_modes(sdV_stage, 0, 2), + cute.group_modes(gdV_stage, 0, 2), + ) + cute.copy(tma_atom_dV, td_sdV, td_gdV) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + else: + self.store(tTR_gdV, tTR_rdV, tTR_cdV, (K, D)) cute.arch.fence_view_async_tmem_load() dkdv_handle.release() @@ -2844,7 +3019,37 @@ def epilogue( for i in cutlass.range(cute.size(tTR_rdK), unroll_full=True): tTR_rdK[i] = scale_softmax * tTR_rdK[i] - self.store(tTR_gdK, tTR_rdK, tTR_cdK, (K, D)) + tTR_rdK_cast = cute.make_rmem_tensor(tTR_rdK.shape, dK.element_type) + tTR_rdK_cast.store(tTR_rdK.load().to(dK.element_type)) + + if cutlass.const_expr(not varlen): + for _i in cutlass.range_constexpr(cute.size(tTR_cdK_local, mode=[2])): + for _j in cutlass.range_constexpr(cute.size(tTR_cdK_local[None, 0, _i])): + c = tTR_cdK_local[None, 0, _i][_j] + m_pos = c[0] + n_pos = c[1] + stage_pos = n_pos // epi_cols_dKV + n_within_pos = n_pos % epi_cols_dKV + v = tTR_rdK_cast[None, 0, _i][_j] + s_epi_dK[m_pos, n_within_pos, stage_pos] = v + cute.arch.fence_view_async_shared() + cute.arch.barrier(barrier_id=6, number_of_threads=cta_threads) + if leader_warp and wg_idx == 0: + for _stage in cutlass.range_constexpr(total_epi_stages): + sdK_stage = s_epi_dK[None, None, _stage] + gdK_stage = gdK_tma_epi[None, None, _stage] + td_sdK, td_gdK = cpasync.tma_partition( + tma_atom_dK, + 0, + cute.make_layout(1), + cute.group_modes(sdK_stage, 0, 2), + cute.group_modes(gdK_stage, 0, 2), + ) + cute.copy(tma_atom_dK, td_sdK, td_gdK) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + else: + self.store(tTR_gdK, tTR_rdK, tTR_cdK, (K, D)) cute.arch.fence_view_async_tmem_load() dkdv_handle.release() diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py index 346f3c1b90d..b25ca48f007 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py @@ -435,6 +435,24 @@ def __call__( self.tma_copy_v_bytes = v_copy_size * cute.size(qk_tiled_mma.thr_id.shape) self.tma_copy_kt_bytes = kt_copy_size * cute.size(qk_tiled_mma.thr_id.shape) + # dQ epilogue TMA store. Use a single (M, 64) SMEM stage and rotate it + # across the hd=256 N dimension to stay within the SMEM budget. + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + epi_cols_dQ = math.gcd(128 // (dq.element_type.width // 8), self.epi_tile[1]) + epi_tile_dQ = (self.epi_tile[0], epi_cols_dQ) + sdQ_epi_layout = sm100_utils.make_smem_layout_epi( + dq.element_type, + self.dq_layout, + epi_tile_dQ, + 1, + ) + tma_atom_dQ, tma_tensor_dQ = cpasync.make_tiled_tma_atom( + tma_store_op, + dq, + cute.select(sdQ_epi_layout, mode=[0, 1]), + epi_tile_dQ, + ) + @cute.struct class SharedStorage: # TMA G2S load barriers: LOAD warp (producer) -> MMA warp (consumer) @@ -487,6 +505,8 @@ class SharedStorage: tma_tensor_do, tma_atom_kt, tma_tensor_kt, + tma_atom_dQ, + tma_tensor_dQ, lse, sum_odo, dq, @@ -502,6 +522,7 @@ class SharedStorage: do_smem_layout_staged, kt_smem_layout_staged, ds_tmem_layout, + sdQ_epi_layout, lse_smem_layout, sum_odo_smem_layout, self.tile_sched_params, @@ -530,6 +551,8 @@ def kernel( mdO_qdl: cute.Tensor, tma_atom_kt: cute.CopyAtom, mK_dkl: cute.Tensor, + tma_atom_dQ: cute.CopyAtom, + mdQ_tma: cute.Tensor, mLSE: cute.Tensor, mSum_OdO: cute.Tensor, mdQ_qdl: cute.Tensor, @@ -545,6 +568,7 @@ def kernel( do_smem_layout_staged: cute.ComposedLayout, kt_smem_layout_staged: cute.ComposedLayout, ds_tmem_layout_staged: cute.ComposedLayout, + sdQ_epi_layout: cute.ComposedLayout, lse_smem_layout: cute.Layout, sum_odo_smem_layout: cute.Layout, tile_sched_params: FmhaStaticTileSchedulerParams | FmhaClcDynamicTileSchedulerParams, @@ -569,6 +593,7 @@ def kernel( cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_v) cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_do) cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_kt) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_dQ) bidx, bidy, bidz = cute.arch.block_idx() tidx, _, _ = cute.arch.thread_idx() @@ -785,6 +810,14 @@ def kernel( layout=sum_odo_smem_layout, byte_alignment=128, ) + # Alias the dQ TMA epilogue staging buffer onto sdO. A standalone + # (128, 64) bf16 allocation exceeds B200's per-block SMEM limit for + # this kernel, while sdO is no longer needed by the epilogue warp once + # the dQ accumulator is ready to store. + s_epi_dQ = cute.make_tensor( + cute.recast_ptr(sdO.iterator, sdQ_epi_layout.inner, self.dq_dtype), + sdQ_epi_layout.outer, + ) qk_thr_mma = qk_tiled_mma.get_slice(mma_tile_coord_v) # default 1sm dov_thr_mma = dov_tiled_mma.get_slice(mma_tile_coord_v) # default 1sm dsk_thr_mma = dsk_tiled_mma.get_slice(mma_tile_coord_v) # default 1sm @@ -1919,12 +1952,21 @@ def kernel( gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + gdQ_tma_staged = gdQ_staged + if cutlass.const_expr(not varlen): + gdQ_tma_qdl = cute.flat_divide( + mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + gdQ_tma_staged = gdQ_tma_qdl[ + None, None, curr_block_coord[0], None, curr_block_coord[2] + ] # dQ TMEM to GMEM mma_dq_consumer = self.dQ_epilogue( (seqlen_q, cuseqlen_q, mQ_qdl.shape[0], batch_coord), (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged), self.epi_tile, + (tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen), ) work_tile = tile_sched.advance_to_next_work() # NOTE: tmem.free() moved to kernel end to enable cluster-wide sync @@ -2103,15 +2145,22 @@ def dQ_epilogue( value_args: Tuple, dq_args: Tuple, epi_tile: cute.Tile, + tma_args: Tuple, ) -> Tuple[pipeline.PipelineConsumer, pipeline.PipelineProducer]: seqlen_q, cuseqlen_q, total_q, batch_coord = value_args (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged) = dq_args + tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen = tma_args dq_handle = mma_dq_consumer.wait_and_advance() tidx, _, _ = cute.arch.thread_idx() bidx, bidy, bidz = cute.arch.block_idx() cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) cute.arch.fence_view_async_shared() + epi_cols_dQ = math.gcd(128 // (self.dq_dtype.width // 8), epi_tile[1]) + num_epi_stages_dQ = epi_tile[1] // epi_cols_dQ + epi_tile_dQ = (epi_tile[0], epi_cols_dQ) + leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0 + for iter in cutlass.range(self.iterations_dsk): gdQ = gdQ_staged[None, None, iter] cdQ = cdQ_staged[None, None, iter] @@ -2119,6 +2168,8 @@ def dQ_epilogue( tdQtdQ_epi = cute.zipped_divide(tdQtdQ, epi_tile) cdQ_epi = cute.zipped_divide(cdQ, epi_tile) gdQ_epi = cute.zipped_divide(gdQ, epi_tile) + cdQ_local = cute.make_identity_tensor(epi_tile) + cdQ_local_epi = cute.zipped_divide(cdQ_local, epi_tile) tidx, _, _ = cute.arch.thread_idx() thread_idx = tidx % (self.threads_per_warp * len(self.epilogue_warp_ids)) tmem_copy_atom = cute.make_copy_atom( @@ -2129,17 +2180,58 @@ def dQ_epilogue( tTMEM_LOADtdQ = thr_tmem_load.partition_S(tdQtdQ_epi) tTMEM_LOADgdQ = thr_tmem_load.partition_D(gdQ_epi) tTMEM_LOADcdQ = thr_tmem_load.partition_D(cdQ_epi) - - for i in cutlass.range(cute.size(tTMEM_LOADtdQ, mode=[1]), unroll_full=True): - tTMEM_LOADtdQ_i = tTMEM_LOADtdQ[None, i, 0] - tTMEM_LOADgdQ_i = tTMEM_LOADgdQ[None, i, 0] - tTMEM_LOADcdQ_i = tTMEM_LOADcdQ[None, i, 0] - tTMrdQ = cute.make_rmem_tensor(tTMEM_LOADcdQ[None, 0, i].shape, self.acc_dtype) - cute.copy(tiled_tmem_load, tTMEM_LOADtdQ_i, tTMrdQ) - tSMrdQ = cute.make_rmem_tensor(tTMrdQ.shape, self.q_dtype) - dq_vec = tTMrdQ.load() - tSMrdQ.store(dq_vec.to(self.q_dtype)) - if cute.elem_less(tTMEM_LOADcdQ_i[0][0], seqlen_q): - cute.autovec_copy(tSMrdQ, tTMEM_LOADgdQ_i) + tTMEM_LOADcdQ_local = thr_tmem_load.partition_D(cdQ_local_epi) + + if cutlass.const_expr(not varlen): + gdQ_tma = gdQ_tma_staged[None, None, iter] + gdQ_tma_epi = cute.local_tile(gdQ_tma, epi_tile_dQ, (0, None)) + sdQ_stage = s_epi_dQ[None, None, 0] + + for stage_k in cutlass.range_constexpr(num_epi_stages_dQ): + for i in cutlass.range(cute.size(tTMEM_LOADtdQ, mode=[1]), unroll_full=True): + tTMEM_LOADtdQ_i = tTMEM_LOADtdQ[None, i, 0] + tTMEM_LOADcdQ_i_local = tTMEM_LOADcdQ_local[None, i, 0] + tTMrdQ = cute.make_rmem_tensor(tTMEM_LOADcdQ_i_local.shape, self.acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtdQ_i, tTMrdQ) + tSMrdQ = cute.make_rmem_tensor(tTMrdQ.shape, self.q_dtype) + dq_vec = tTMrdQ.load() + tSMrdQ.store(dq_vec.to(self.q_dtype)) + for j in cutlass.range_constexpr(cute.size(tTMEM_LOADcdQ_i_local)): + c = tTMEM_LOADcdQ_i_local[j] + m_pos = c[0] + n_pos = c[1] + if n_pos // epi_cols_dQ == stage_k: + s_epi_dQ[m_pos, n_pos % epi_cols_dQ, 0] = tSMrdQ[j] + + cute.arch.fence_view_async_shared() + cute.arch.barrier(barrier_id=3, number_of_threads=128) + + if leader_warp: + gdQ_stage = gdQ_tma_epi[None, None, stage_k] + td_sdQ, td_gdQ = cpasync.tma_partition( + tma_atom_dQ, + 0, + cute.make_layout(1), + cute.group_modes(sdQ_stage, 0, 2), + cute.group_modes(gdQ_stage, 0, 2), + ) + cute.copy(tma_atom_dQ, td_sdQ, td_gdQ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + # Non-issuing threads must not rotate the SMEM buffer + # until the leader warp's TMA read has completed. + cute.arch.barrier(barrier_id=3, number_of_threads=128) + else: + for i in cutlass.range(cute.size(tTMEM_LOADtdQ, mode=[1]), unroll_full=True): + tTMEM_LOADtdQ_i = tTMEM_LOADtdQ[None, i, 0] + tTMEM_LOADgdQ_i = tTMEM_LOADgdQ[None, i, 0] + tTMEM_LOADcdQ_i = tTMEM_LOADcdQ[None, i, 0] + tTMrdQ = cute.make_rmem_tensor(tTMEM_LOADcdQ[None, 0, i].shape, self.acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtdQ_i, tTMrdQ) + tSMrdQ = cute.make_rmem_tensor(tTMrdQ.shape, self.q_dtype) + dq_vec = tTMrdQ.load() + tSMrdQ.store(dq_vec.to(self.q_dtype)) + if cute.elem_less(tTMEM_LOADcdQ_i[0][0], seqlen_q): + cute.autovec_copy(tSMrdQ, tTMEM_LOADgdQ_i) dq_handle.release() return mma_dq_consumer From cb213fce11c3baf9168f7fa607bc7f22e3323554 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Fri, 1 May 2026 10:26:08 -0700 Subject: [PATCH 42/44] [hd256] Add TMA paged KV support to SM100 2CTA forward kernel (#2489) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [hd256] Add TMA paged KV support to SM100 2CTA forward kernel Rebased cherry-pick of `49fe257` from `Johnsonms/paged-kv-hd256` on top of merged main (hd256 PR #2412, `27b4eb9` + post-merge cleanup #2487, `b21e204`). Original branch was based on a pre-merge snapshot; its base commits were absorbed into the squash-merge. ## Change Adds paged KV support to the SM100 hd256 2CTA forward kernel. The paged path reuses the dense TMA load path — logical KV blocks are remapped to physical page indices through the page table at load time, so each page maps to exactly one TMA tile. **Constraint:** `page_size` must equal `tile_n = 128`. ### `flash_attn/cute/sm100_hd256_2cta_fmha_forward.py` - Conditional K/V tensor layout in `__call__`: dense `(s_k, d, ((h_r, h_k), b))` vs paged `(page_size, d, h_k, num_pages)` for K (and transposed for V). - Conditional K/V TMA setup in the load warp: dense uses `domain_offset` + batch indexing; paged uses `head_kv` slicing and keeps `num_pages` as the outer mode for per-load `page_idx` lookup. - Conditional per-load `page_idx`: K uses mode-2 subtile + mode-3 page; V uses mode-1 page. - Plumb `mPageTable` + `max_seqlen_k` through the kernel signature. `seqlen_k` in each of the 4 warp sections now uses `max_seqlen_k` for the paged path. - Store `qhead_per_kvhead` on `self` and derive `head_kv_coord` via integer divide (matches the `flash_fwd_sm100` convention for contiguous GQA grouping). - Relax `mPageTable` / `paged_kv_non_tma` assertions. ### `flash_attn/cute/paged_kv.py` - Extract `_flatten_smem_sm100` / `_copy_row_async` helpers from `load_KV` — pure refactor, no behavior change for existing callers. ### `tests/cute/test_flash_attn.py` - `test_flash_attn_paged_hd256_sm100_tma`: bit-exact vs dense varlen reference + determinism check, parametrized over `seqlen_q`. - `test_flash_attn_paged_hd256_sm100_tma_gqa`: same check for GQA with `nheads_kv in {2, 4, 8}` — exercises `qhead_per_kvhead > 1`, which a modulo-aliasing bug would fail. ## Validation (this PR vs `origin/main` @ `b21e204`) ### Correctness smoke `pytest tests/cute/test_flash_attn.py` on B200, filter combines the 6 new paged tests with the existing d=256 dense subset: ``` -k "paged_hd256_sm100_tma or (test_flash_attn_output and 256-False-0-0.0-False-False)" ``` Result: **84 passed, 78 skipped, 0 failed** in 2 min — 78 from the dense d=256 subset (identical pass/skip count to `origin/main`) and **6 from the new `paged_hd256_sm100_tma[_gqa]` tests**. ### FWD perf delta vs `origin/main` (TFLOPS mean, 3 runs each) B200, bf16, hdim=256, locked clocks @ 1755 MHz. | seqlen | causal | MHA 32:32 | GQA 32:2 | |-------:|:------:|:---------:|:--------:| | 4k | F | +0.2% | +0.3% | | 8k | F | 0.0% | -0.1% | | 16k | F | -0.2% | 0.0% | | 32k | F | -1.0% | **+2.3%** | | 64k | F | **+2.1%** | +0.8% | | 128k | F | -0.8% | -1.7% | | 4k | T | +0.2% | +0.2% | | 8k | T | +0.1% | +0.1% | | 16k | T | +0.3% | +0.1% | | 32k | T | **+3.0%** | -1.2% | | 64k | T | -0.3% | -0.1% | | 128k | T | +0.1% | -0.7% | - **22 of 24 cells within ±2%.** - Two `> 2%` outliers are both **positive** and in the batch- quantization noise zone that `origin/main` itself showed run-spread in during our 3-run baseline sweep — not regressions. - **Aggregated means: MHA +0.31%, GQA +0.00%.** - Paged-KV path isn't exercised by `benchmark_attn.py` (which uses contiguous KV); dense-path perf parity is the regression-critical property and is preserved. ## Caveat - **page_size == tile_n == 128 is a hard constraint.** Callers that want a different page size will need a separate path. - The paged-KV path itself is correctness-tested by the two new `paged_hd256_sm100_tma` tests (bit-exact vs dense reference, with and without GQA). Perf of the paged path was not benchmarked. * [hd256] Address review comments on TMA paged KV - interface.py: assert max_seqlen_k % page_size == 0, page_table sized to exact seqlen, and page_table fully contiguous for hd256 paged path - tests: add shuffled-page-table test; allclose for correctness checks - paged_kv.py: trim _flatten_smem_sm100 docstring to one line - sm100_hd256_2cta_fmha_forward.py: cut multi-line comment blocks * [hd256] Prefetch page indices and eliminate redundant V page reads in TMA paged KV K and V for the same KV block share the same physical page, so the separate mPageTable read issued for V was always fetching the same index already loaded for K. Carry k_page_idx forward as v_page_idx_prev and drop all V-side page-table reads. Additionally, issue the next K page read immediately after K TMA dispatch (while V TMA is being issued) so the ~25-cycle L2 latency is hidden behind in-flight work. Together these changes halve the number of scalar GMEM page-table reads per kernel call. NCU (B=4 S=8192 H=8 D=256): executed instructions −0.4 % L2 elapsed cycles −2.2 % (overhead vs dense: +3.5 % → +1.2 %) Benchmark — paged vs. dense latency overhead GPU 0 locked 1965 MHz, non-causal, bf16, page_size=128: seqlen B before after delta ------ -- ------ ------ ------ 1024 32 +0.2 % +0.4 % −0.2 % 2048 16 +0.4 % +0.4 % 0.0 % 4096 8 −0.1 % +0.2 % −0.3 % 8192 4 +4.9 % +1.8 % −3.1 % 16384 2 +7.7 % +5.2 % −2.5 % 32768 1 +4.9 % +0.4 % −4.5 % 65536 1 +0.9 % −1.8 % −2.7 % No effect at short sequences (TMEM-bound); −2.5 to −4.5 % overhead reduction at medium-to-long sequences where page-table reads were on the producer warp's critical path. --- flash_attn/cute/interface.py | 17 ++ flash_attn/cute/paged_kv.py | 64 +++--- .../cute/sm100_hd256_2cta_fmha_forward.py | 206 +++++++++++++----- tests/cute/test_flash_attn.py | 201 +++++++++++++++++ 4 files changed, 412 insertions(+), 76 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index a11c8debe2b..6441b0cc13c 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -857,6 +857,23 @@ def _flash_attn_fwd( "SM100 forward with head_dim=256 does not support learnable_sink" assert seqused_q is None and seqused_k is None, \ "SM100 forward with head_dim=256 does not support seqused_q/seqused_k" + if page_table is not None: + assert max_seqlen_k % page_size == 0, ( + f"SM100 hd256 2CTA paged KV requires max_seqlen_k divisible by " + f"page_size ({page_size}), got max_seqlen_k={max_seqlen_k}" + ) + assert page_table.shape[1] == max_seqlen_k // page_size, ( + f"SM100 hd256 2CTA paged KV requires page_table.shape[1] == " + f"max_seqlen_k // page_size ({max_seqlen_k} // {page_size} = " + f"{max_seqlen_k // page_size}), got {page_table.shape[1]}; " + f"pass page_table[:, :{max_seqlen_k // page_size}] to slice to " + f"the actual sequence length" + ) + assert page_table.stride(0) == page_table.shape[1], ( + f"SM100 hd256 2CTA paged KV requires a fully contiguous page_table " + f"(stride(0)={page_table.stride(0)} must equal " + f"shape[1]={page_table.shape[1]})" + ) # pack_gqa is an auto-selected optimization; disable it for hd256 kernel pack_gqa = False diff --git a/flash_attn/cute/paged_kv.py b/flash_attn/cute/paged_kv.py index bf11acbc24e..efcf71202f2 100644 --- a/flash_attn/cute/paged_kv.py +++ b/flash_attn/cute/paged_kv.py @@ -169,6 +169,42 @@ def compute_X_ptr(self, K_or_V: str): tPrXPtr[i] = utils.elem_pointer(mX, (page_offset, 0, page)).toint() return tPrXPtr + @cute.jit + def _flatten_smem_sm100(self, sX: cute.Tensor, K_or_V: str): + """Flatten SM100 smem ((a,b), cta_split, k) to (a,(b,k)); transpose V to (d,page_size).""" + sX_pi = cute.make_tensor( + sX.iterator, + cute.make_layout( + (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])), + stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])), + ), + ) + if const_expr(K_or_V == "V"): + sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0])) + return sX_pi + + @cute.jit + def _copy_row_async( + self, + tXsX: cute.Tensor, + tXcX: cute.Tensor, + mX_paged_cur_copy: cute.Tensor, + m: Int32, + should_load: cute.Tensor, + ): + """Issue cp.async copies for one row across all k-tiles.""" + for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])): + ki = tXcX[0, 0, k][1] // self.async_copy_elems + mX_paged_cur_copy_ki = mX_paged_cur_copy[None, ki] + tXsX_k = tXsX[None, m, k] + mX_paged_cur_copy_ki = cute.make_tensor(mX_paged_cur_copy_ki.iterator, tXsX_k.layout) + cute.copy( + self.gmem_tiled_copy_KV, + mX_paged_cur_copy_ki, + tXsX_k, + pred=should_load, + ) + @cute.jit def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): assert K_or_V in ("K", "V") @@ -181,18 +217,7 @@ def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): sX_pi = cute.group_modes(sX, 0, 1) # SM90 does NOT transpose V here (it's transposed via utils.transpose_view before MMA) else: - # SM100: Finesse sX layout to be (M, N). - sX_pi = cute.make_tensor( - sX.iterator, - cute.make_layout( - (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])), - stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])), - ), - ) - - if const_expr(K_or_V == "V"): - # Transpose smem V to match transposed gmem layout - sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0])) + sX_pi = self._flatten_smem_sm100(sX, K_or_V) head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded cX = cute.make_identity_tensor((self.n_block_size, head_dim)) @@ -218,17 +243,4 @@ def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): ) mX_paged_cur = cute.make_tensor(x_gmem_ptr, cute.make_layout((head_dim,))) mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,)) - - for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])): - ki = tXcX[0, 0, k][1] // self.async_copy_elems - mX_paged_cur_copy_ki = mX_paged_cur_copy[None, ki] - tXsX_k = tXsX[None, m, k] - mX_paged_cur_copy_ki = cute.make_tensor( - mX_paged_cur_copy_ki.iterator, tXsX_k.layout - ) - cute.copy( - self.gmem_tiled_copy_KV, - mX_paged_cur_copy_ki, - tXsX_k, - pred=should_load, - ) + self._copy_row_async(tXsX, tXcX, mX_paged_cur_copy, m, should_load) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py index a9463222718..28087125f47 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py @@ -62,7 +62,9 @@ def __init__( assert score_mod is None, "SM100 forward with head_dim=256 does not support score_mod" assert mask_mod is None, "SM100 forward with head_dim=256 does not support mask_mod" assert not has_aux_tensors, "SM100 forward with head_dim=256 does not support aux tensors" - assert not paged_kv_non_tma, "SM100 forward with head_dim=256 does not support paged KV" + assert not paged_kv_non_tma, ( + "SM100 hd256 2CTA supports TMA paged KV only (page_size must equal tile_n=128)" + ) assert not pack_gqa, "SM100 forward with head_dim=256 does not support pack_gqa" assert not is_split_kv, "SM100 forward with head_dim=256 does not support SplitKV" assert q_subtile_factor is None, ( @@ -79,6 +81,7 @@ def __init__( mma_tiler = (128, 128, head_dim) self.qk_acc_dtype = qk_acc_dtype self.pv_acc_dtype = pv_acc_dtype + self.qhead_per_kvhead = qhead_per_kvhead self.mma_tiler = mma_tiler assert mma_tiler[0] == 128 and mma_tiler[1] == 128, "Only 128x128 tile impl is supported" assert mma_tiler[2] == 256, "Only 256 is supported for 128x128 tile impl" @@ -184,7 +187,6 @@ def __call__( assert mSeqUsedQ is None and mSeqUsedK is None, ( "SM100 forward with head_dim=256 does not support seqused_q/seqused_k" ) - assert mPageTable is None, "SM100 forward with head_dim=256 does not support paged KV" assert learnable_sink is None, ( "SM100 forward with head_dim=256 does not support learnable_sink" ) @@ -296,18 +298,42 @@ def __call__( stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), ) q = cute.make_tensor(q_tensor.iterator, q_layout) - # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast - k_layout = cute.make_layout( - (s_k_total, d, ((h_r, h_k), b)), - stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)), - ) - k = cute.make_tensor(k_tensor.iterator, k_layout) - # (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast - v_layout = cute.make_layout( - (d, s_k_total, ((h_r, h_k), b)), - stride=(1, d64 * h_k64, ((0, d64), stride_b_kv)), - ) - v = cute.make_tensor(v_tensor.iterator, v_layout) + if cutlass.const_expr(mPageTable is not None): + # Paged: K layout (num_pages, page_size, h_k, d); page_table maps kv_coord→physical page. + num_pages = k_tensor.shape[0] + page_size = k_tensor.shape[1] + page_size64 = Int64(page_size) + max_seqlen_k_paged = Int32(mPageTable.shape[1] * page_size) + k_paged_layout = cute.make_layout( + (page_size, d, h_k, num_pages), + stride=(d64 * h_k64, 1, d64, page_size64 * d64 * h_k64), + ) + k = cute.make_tensor(k_tensor.iterator, k_paged_layout) + v_paged_layout = cute.make_layout( + (d, page_size, h_k, num_pages), + stride=(1, d64 * h_k64, d64, page_size64 * d64 * h_k64), + ) + v = cute.make_tensor(v_tensor.iterator, v_paged_layout) + page_table_layout = cute.make_layout( + (b, mPageTable.shape[1]), + stride=(Int64(mPageTable.shape[1]), 1), + ) + page_table = cute.make_tensor(mPageTable.iterator, page_table_layout) + else: + # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast + k_layout = cute.make_layout( + (s_k_total, d, ((h_r, h_k), b)), + stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)), + ) + k = cute.make_tensor(k_tensor.iterator, k_layout) + # (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast + v_layout = cute.make_layout( + (d, s_k_total, ((h_r, h_k), b)), + stride=(1, d64 * h_k64, ((0, d64), stride_b_kv)), + ) + v = cute.make_tensor(v_tensor.iterator, v_layout) + page_table = None + max_seqlen_k_paged = None # (s, d, ((h_r, h_k), b)) o_layout = cute.make_layout( (s_q_total, d, ((h_r, h_k), b)), @@ -505,6 +531,8 @@ class SharedStorage: scale_softmax_log2, scale_softmax, scale_output, + page_table, + max_seqlen_k_paged, window_size_left, window_size_right, self.cluster_layout_vmnk, @@ -540,6 +568,8 @@ def kernel( scale_softmax_log2: Float32, scale_softmax: Float32, scale_output: Float32, + mPageTable: Optional[cute.Tensor], + max_seqlen_k: Optional[Int32], window_size_left: Optional[Int32], window_size_right: Optional[Int32], cluster_layout_vmnk: cute.Layout, @@ -781,7 +811,9 @@ def kernel( continue_cond = False batch_coord = curr_block_coord[2][1] seqlen_q = mQ_qdl.shape[0] - seqlen_k = mK_kdl.shape[0] + seqlen_k = ( + mK_kdl.shape[0] if cutlass.const_expr(mPageTable is None) else max_seqlen_k + ) cuseqlen_q = Int32(0) cuseqlen_k = Int32(0) block_offset = ( @@ -809,8 +841,6 @@ def kernel( ) if not continue_cond: mQ_qdl_ = cute.domain_offset(cute.select(block_offset, mode=[0, 2, 3]), mQ_qdl) - mK_kdl_ = cute.domain_offset(cute.select(block_offset, mode=[1, 2, 3]), mK_kdl) - mV_dkl_ = cute.domain_offset(cute.select(block_offset, mode=[2, 1, 3]), mV_dkl) # Local tile partition global tensors q_cta_layout = cute.make_layout( cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape @@ -828,31 +858,70 @@ def kernel( kv_cta_layout = cute.make_layout( cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape ) - gK_kdl = cute.flat_divide(mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2])) - tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) - tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( - tma_atom_k, - block_in_cluster_coord_vmnk[1], - kv_cta_layout, - cute.group_modes(sK, 0, 3), - cute.group_modes(tSgK_kdl, 0, 3), - ) - - gV_dkl = cute.flat_divide(mV_dkl_, cute.select(self.pv_mma_tiler, mode=[1, 2])) - tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) - tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( - tma_atom_v, - block_in_cluster_coord_vmnk[1], - kv_cta_layout, - cute.group_modes(sV, 0, 3), - cute.group_modes(tSgV_dkl, 0, 3), - ) + if cutlass.const_expr(mPageTable is None): + # Dense path: domain_offset K/V by batch block, select batch via mma_block_coord[2]. + mK_kdl_ = cute.domain_offset( + cute.select(block_offset, mode=[1, 2, 3]), mK_kdl + ) + mV_dkl_ = cute.domain_offset( + cute.select(block_offset, mode=[2, 1, 3]), mV_dkl + ) + gK_kdl = cute.flat_divide( + mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2]) + ) + tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_k, + block_in_cluster_coord_vmnk[1], + kv_cta_layout, + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_kdl, 0, 3), + ) + gV_dkl = cute.flat_divide( + mV_dkl_, cute.select(self.pv_mma_tiler, mode=[1, 2]) + ) + tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_v, + block_in_cluster_coord_vmnk[1], + kv_cta_layout, + cute.group_modes(sV, 0, 3), + cute.group_modes(tSgV_dkl, 0, 3), + ) + # ((atom_v, rest_v), RestN, RestK) + tKgK = tKgK_kdl[None, None, None, mma_block_coord[2]] + tVgV = tVgV_dkl[None, None, None, mma_block_coord[2]] + else: + # Paged path: slice K/V by KV head, keep num_pages dim for page_idx-based TMA. + head_kv_coord = curr_block_coord[2][0] // self.qhead_per_kvhead + mK_kdl_ = mK_kdl[None, None, head_kv_coord, None] + mV_dkl_ = mV_dkl[None, None, head_kv_coord, None] + gK_kdl = cute.flat_divide( + mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2]) + ) + tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_k, + block_in_cluster_coord_vmnk[1], + kv_cta_layout, + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_kdl, 0, 3), + ) + gV_dkl = cute.flat_divide( + mV_dkl_, cute.select(self.pv_mma_tiler, mode=[1, 2]) + ) + tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_v, + block_in_cluster_coord_vmnk[1], + kv_cta_layout, + cute.group_modes(sV, 0, 3), + cute.group_modes(tSgV_dkl, 0, 3), + ) + tKgK = tKgK_kdl + tVgV = tVgV_dkl # ((atom_v, rest_v), RestK) tQgQ = tQgQ_qdl[None, mma_block_coord[0], None, mma_block_coord[2]] - # ((atom_v, rest_v), RestN, RestK) - tKgK = tKgK_kdl[None, None, None, mma_block_coord[2]] - # ((atom_v, rest_v), RestN, RestK) - tVgV = tVgV_dkl[None, None, None, mma_block_coord[2]] seqlen_kv_loop_start, seqlen_kv_loop_steps = ( FusedMask.get_trip_start_count_via_block_info( @@ -879,42 +948,73 @@ def kernel( # K0 kv_coord = seqlen_kv_loop_start + k_page_idx = ( + mPageTable[batch_coord, kv_coord] + if cutlass.const_expr(mPageTable is not None) + else None + ) for iter in cutlass.range(self.iterations_qk, unroll=1): k_handle = load_kv_producer.acquire_and_advance() cute.copy( tma_atom_k, - tKgK[None, kv_coord, iter], + tKgK[None, kv_coord, iter] + if cutlass.const_expr(mPageTable is None) + else tKgK[None, 0, iter, k_page_idx], tKsK[None, k_handle.index], tma_bar_ptr=k_handle.barrier, ) kv_coord += 1 + # v_page_idx_prev carries K[i-1]'s page index for use as V[i-1]'s page + # (K and V for the same KV block share the same physical page). + # Also serves as the Vend page index when seqlen_kv_loop_steps == 1. + v_page_idx_prev = ( + k_page_idx if cutlass.const_expr(mPageTable is not None) else None + ) + # Prefetch K1 page after K0 TMA dispatch to hide L2 latency. + if cutlass.const_expr(mPageTable is not None): + if seqlen_kv_loop_steps > 1: + k_page_idx = mPageTable[batch_coord, kv_coord] for i in cutlass.range(1, seqlen_kv_loop_steps, 1, unroll=1): - # Ki + # Ki: k_page_idx was prefetched at end of previous iteration + # (or in the prologue for i==1); L2 latency already hidden. for iter in cutlass.range(self.iterations_qk, unroll=1): k_handle = load_kv_producer.acquire_and_advance() cute.copy( tma_atom_k, - tKgK[None, kv_coord, iter], + tKgK[None, kv_coord, iter] + if cutlass.const_expr(mPageTable is None) + else tKgK[None, 0, iter, k_page_idx], tKsK[None, k_handle.index], tma_bar_ptr=k_handle.barrier, ) - # Vi-1 + # Vi-1: reuse v_page_idx_prev (= K[i-1]'s page), no extra GMEM read. for iter in cutlass.range(self.iterations_pv, unroll=1): v_handle = load_kv_producer.acquire_and_advance() cute.copy( tma_atom_v, - tVgV[None, iter, kv_coord - 1], + tVgV[None, iter, kv_coord - 1] + if cutlass.const_expr(mPageTable is None) + else tVgV[None, iter, 0, v_page_idx_prev], tVsV[None, v_handle.index], tma_bar_ptr=v_handle.barrier, ) + v_page_idx_prev = ( + k_page_idx if cutlass.const_expr(mPageTable is not None) else None + ) kv_coord += 1 - # Vend + # Prefetch next K page while V TMA is in flight. + if cutlass.const_expr(mPageTable is not None): + if kv_coord < seqlen_kv_loop_end: + k_page_idx = mPageTable[batch_coord, kv_coord] + # Vend: reuse v_page_idx_prev (= K[end-1]'s page), no extra GMEM read. for iter in cutlass.range(self.iterations_pv, unroll=1): v_handle = load_kv_producer.acquire_and_advance() cute.copy( tma_atom_v, - tVgV[None, iter, seqlen_kv_loop_end - 1], + tVgV[None, iter, seqlen_kv_loop_end - 1] + if cutlass.const_expr(mPageTable is None) + else tVgV[None, iter, 0, v_page_idx_prev], tVsV[None, v_handle.index], tma_bar_ptr=v_handle.barrier, ) @@ -939,7 +1039,9 @@ def kernel( ) continue_cond = False seqlen_q = mQ_qdl.shape[0] - seqlen_k = mK_kdl.shape[0] + seqlen_k = ( + mK_kdl.shape[0] if cutlass.const_expr(mPageTable is None) else max_seqlen_k + ) batch_coord = curr_block_coord[2][1] if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] @@ -1189,7 +1291,9 @@ def kernel( batch_coord = curr_block_coord[2][1] continue_cond = False seqlen_q = mQ_qdl.shape[0] - seqlen_k = mK_kdl.shape[0] + seqlen_k = ( + mK_kdl.shape[0] if cutlass.const_expr(mPageTable is None) else max_seqlen_k + ) cuseqlen_q = Int32(0) if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] @@ -1300,7 +1404,9 @@ def kernel( ) batch_coord = curr_block_coord[2][1] seqlen_q = mQ_qdl.shape[0] - seqlen_k = mK_kdl.shape[0] + seqlen_k = ( + mK_kdl.shape[0] if cutlass.const_expr(mPageTable is None) else max_seqlen_k + ) continue_cond = False cuseqlen_q = Int32(0) if cutlass.const_expr(cum_seqlen_q is not None): diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 4446c053b5b..f96412dd7c7 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -1790,6 +1790,207 @@ def test_flash_attn_paged_deepseek(seqlen_q, page_size): assert torch.equal(out, out_ref) +@pytest.mark.parametrize("seqlen_q", [128, 512, 2048]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_paged_hd256_sm100_tma(seqlen_q): + """TMA paged KV in the SM100 hd256 2CTA forward kernel. + + Verifies paged KV (page_table + TMA) matches the non-paged varlen reference + and is deterministic across runs. page_size must equal tile_n=128. + """ + if not IS_SM100: + pytest.skip("SM100-specific paged hd256 test") + device = "cuda" + dtype = torch.bfloat16 + d = 256 + batch_size = 2 + nheads = 16 + nheads_kv = 16 + page_size = 128 + assert seqlen_q % page_size == 0 + + torch.random.manual_seed(0) + q = torch.randn(batch_size * seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size * seqlen_q, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size * seqlen_q, nheads_kv, d, device=device, dtype=dtype) + cu_seqlens_q = torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) * seqlen_q + cu_seqlens_k = cu_seqlens_q.clone() + + # Non-paged reference (varlen). + out_ref, _ = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + ) + + # Repack into paged layout: (total_pages, page_size, nheads_kv, d). + num_pages_per_seq = seqlen_q // page_size + total_pages = batch_size * num_pages_per_seq + k_paged = torch.zeros(total_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + v_paged = torch.zeros(total_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + for b in range(batch_size): + for s in range(seqlen_q): + pi = b * num_pages_per_seq + s // page_size + po = s % page_size + k_paged[pi, po] = k[b * seqlen_q + s] + v_paged[pi, po] = v[b * seqlen_q + s] + page_table = torch.arange(total_pages, dtype=torch.int32, device=device).reshape( + batch_size, num_pages_per_seq + ) + + # Paged via hd256 2CTA TMA paged path — run twice for determinism. + out_paged_0, _ = flash_attn_varlen_func( + q, k_paged, v_paged, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=None, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + page_table=page_table, + ) + out_paged_1, _ = flash_attn_varlen_func( + q, k_paged, v_paged, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=None, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + page_table=page_table, + ) + + if is_fake_mode(): + return + + print(f"Paged vs non-paged max diff: {(out_paged_0 - out_ref).abs().max().item()}") + print(f"Paged determinism diff: {(out_paged_1 - out_paged_0).abs().max().item()}") + assert torch.allclose(out_paged_0, out_ref, atol=1e-3, rtol=1e-3), "Paged output does not match non-paged reference" + assert torch.equal(out_paged_1, out_paged_0), "Paged output is not deterministic" + + +@pytest.mark.parametrize("nheads_kv", [2, 4, 8]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_paged_hd256_sm100_tma_gqa(nheads_kv): + """TMA paged KV for SM100 hd256 2CTA with GQA (nheads_q > nheads_kv). + + Exercises the head_kv_coord derivation for qhead_per_kvhead > 1 — the MHA + test passes by coincidence since modulo and integer division agree when + qhead_per_kvhead == 1. + """ + if not IS_SM100: + pytest.skip("SM100-specific paged hd256 test") + device = "cuda" + dtype = torch.bfloat16 + d = 256 + batch_size = 2 + nheads = 16 + page_size = 128 + seqlen_q = 512 + assert nheads % nheads_kv == 0 and seqlen_q % page_size == 0 + + torch.random.manual_seed(0) + q = torch.randn(batch_size * seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size * seqlen_q, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size * seqlen_q, nheads_kv, d, device=device, dtype=dtype) + cu_seqlens_q = torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) * seqlen_q + cu_seqlens_k = cu_seqlens_q.clone() + + out_ref, _ = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + ) + + num_pages_per_seq = seqlen_q // page_size + total_pages = batch_size * num_pages_per_seq + k_paged = torch.zeros(total_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + v_paged = torch.zeros(total_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + for b in range(batch_size): + for s in range(seqlen_q): + pi = b * num_pages_per_seq + s // page_size + po = s % page_size + k_paged[pi, po] = k[b * seqlen_q + s] + v_paged[pi, po] = v[b * seqlen_q + s] + page_table = torch.arange(total_pages, dtype=torch.int32, device=device).reshape( + batch_size, num_pages_per_seq + ) + + out_paged, _ = flash_attn_varlen_func( + q, k_paged, v_paged, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=None, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + page_table=page_table, + ) + + if is_fake_mode(): + return + + print(f"GQA nheads_kv={nheads_kv} paged vs non-paged max diff: {(out_paged - out_ref).abs().max().item()}") + assert torch.allclose(out_paged, out_ref, atol=1e-3, rtol=1e-3), ( + f"Paged GQA output does not match non-paged reference (nheads_kv={nheads_kv})" + ) + + +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_paged_hd256_sm100_tma_shuffled(): + """TMA paged KV for SM100 hd256 2CTA with a non-identity (shuffled) page_table. + + An identity page_table passes even if the kernel ignores it. This test + shuffles physical pages so a kernel that bypasses page_table would silently + read wrong data, proving the remapping path is exercised. + """ + if not IS_SM100: + pytest.skip("SM100-specific paged hd256 test") + device = "cuda" + dtype = torch.bfloat16 + d = 256 + batch_size = 2 + nheads = 16 + nheads_kv = 16 + page_size = 128 + seqlen_q = 512 + num_pages_per_seq = seqlen_q // page_size + total_pages = batch_size * num_pages_per_seq + + torch.random.manual_seed(42) + q = torch.randn(batch_size * seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size * seqlen_q, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size * seqlen_q, nheads_kv, d, device=device, dtype=dtype) + cu_seqlens_q = torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) * seqlen_q + cu_seqlens_k = cu_seqlens_q.clone() + + out_ref, _ = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + ) + + # Shuffle physical pages: reverse order within each batch item. + # Build as Python list of ints to avoid .item() calls on FakeTensors during compilation. + perm = [ + list(range((b + 1) * num_pages_per_seq - 1, b * num_pages_per_seq - 1, -1)) + for b in range(batch_size) + ] + page_table = torch.tensor(perm, dtype=torch.int32, device=device) + + k_paged = torch.zeros(total_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + v_paged = torch.zeros(total_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + for b in range(batch_size): + for s in range(seqlen_q): + phys = perm[b][s // page_size] # Python int, safe in FakeTensorMode + po = s % page_size + k_paged[phys, po] = k[b * seqlen_q + s] + v_paged[phys, po] = v[b * seqlen_q + s] + + out_paged, _ = flash_attn_varlen_func( + q, k_paged, v_paged, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=None, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, + page_table=page_table, + ) + + if is_fake_mode(): + return + + print(f"Shuffled paged vs non-paged max diff: {(out_paged - out_ref).abs().max().item()}") + assert torch.allclose(out_paged, out_ref, atol=1e-3, rtol=1e-3), ( + "Shuffled paged output does not match non-paged reference" + ) + + @pytest.mark.parametrize("head_dim", [4, 148, 288]) def test_flash_attn_invalid_head_dim(head_dim): device = "cuda" From 2e53092aa70fccd3f04013a01a52dc20c619e62b Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Mon, 4 May 2026 15:31:03 -0700 Subject: [PATCH 43/44] Deterministic backward for blocksparse impl (#2253) --- flash_attn/cute/block_info.py | 17 + flash_attn/cute/block_sparse_utils.py | 22 +- flash_attn/cute/block_sparsity.py | 213 ++++++++- flash_attn/cute/cache_utils.py | 1 - flash_attn/cute/compute_block_sparsity.py | 11 +- flash_attn/cute/flash_bwd_sm100.py | 171 ++++--- flash_attn/cute/flash_fwd_sm100.py | 1 - flash_attn/cute/interface.py | 52 ++- tests/cute/test_mask_mod.py | 525 +++++++++++++++++++++- 9 files changed, 913 insertions(+), 100 deletions(-) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index f21013891b4..422da2b66a0 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -137,3 +137,20 @@ def get_n_block_min_before_local_mask( n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q n_idx_left = n_idx - self.window_size_left return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.tile_n)) + + @cute.jit + def get_n_block_max_for_m_block( + self, + seqlen_info: SeqlenInfoQK, + m_block: Int32, + n_block_global_max: Int32, + ) -> Int32: + if const_expr(self.is_causal or self.window_size_right is not None): + m_idx_max = (m_block + 1) * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx_right = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + if const_expr(self.window_size_right is not None): + n_idx_right += self.window_size_right + return min(n_block_global_max, cute.ceil_div(n_idx_right, self.tile_n)) + return n_block_global_max diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index 0f81f863673..d664b16dc64 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -188,7 +188,7 @@ def produce_block_sparse_loads( must be converted to unpacked for sparse tensor indexing. """ - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) @@ -332,7 +332,7 @@ def consume_block_sparse_loads( must be converted to unpacked for sparse tensor indexing. """ - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) @@ -552,7 +552,7 @@ def produce_block_sparse_loads_sm100( """ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] @@ -629,7 +629,7 @@ def get_total_block_count( ): m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors if const_expr(full_block_cnt is not None): return ( mask_block_cnt[batch_idx, head_idx, m_block_sparse] @@ -780,7 +780,7 @@ def softmax_block_sparse_sm100( warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] @@ -795,8 +795,6 @@ def softmax_block_sparse_sm100( total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt if total_block_cnt == 0: - # See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. - # pipeline_sm_stats.producer_commit_w_index(stage_idx) sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx) else: if curr_mask_block_cnt > 0: @@ -907,7 +905,7 @@ def get_total_q_block_count_bwd( m_block_max: int = 0, ): """Count total tile iterations for given n_block (KV tile) in backward.""" - q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors + q_block_cnt, _, full_block_cnt, _, *_ = blocksparse_tensors total = q_block_cnt[batch_idx, head_idx, n_block] if const_expr(full_block_cnt is not None): total = total + full_block_cnt[batch_idx, head_idx, n_block] @@ -1051,7 +1049,7 @@ def get_block_sparse_iteration_info_bwd( Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count). """ - q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] @@ -1175,7 +1173,7 @@ def produce_block_sparse_q_loads_bwd_sm90( Returns updated (producer_state_Q, producer_state_dO). """ - q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] @@ -1270,7 +1268,7 @@ def consume_block_sparse_mma_bwd_sm90( Returns updated (consumer_state_Q, consumer_state_dO). """ - q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] @@ -1396,7 +1394,7 @@ def dQaccum_store_block_sparse_bwd_sm90( Iterates partial blocks first, then full blocks, matching producer/consumer order. """ - q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index 3fad8c9f491..4a5726b7493 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -19,9 +19,13 @@ class BlockSparseTensors(NamedTuple): mask_block_idx: cute.Tensor full_block_cnt: cute.Tensor | None full_block_idx: cute.Tensor | None + dq_write_order: cute.Tensor | None = None + dq_write_order_full: cute.Tensor | None = None def __new_from_mlir_values__(self, values): if len(values) == 2: + values = (*values, None, None, None, None) + elif len(values) == 4: values = (*values, None, None) return BlockSparseTensors(*values) @@ -32,6 +36,138 @@ class BlockSparseTensorsTorch(NamedTuple): full_block_cnt: torch.Tensor | None = None full_block_idx: torch.Tensor | None = None block_size: tuple[int, int] | None = None + dq_write_order: torch.Tensor | None = None + dq_write_order_full: torch.Tensor | None = None + spt: bool | None = None + + +def _ordered_to_dense_simple( + num_blocks: torch.Tensor, + indices: torch.Tensor, + num_cols: int, +) -> torch.Tensor: + """Convert ordered sparse representation to dense binary matrix. + + Args: + num_blocks: [B, H, num_rows] count of valid entries per row + indices: [B, H, num_rows, max_entries] column indices (valid entries packed left) + num_cols: total number of columns + + Returns: + dense: [B, H, num_rows, num_cols] binary int32 matrix + """ + B, H, num_rows, max_entries = indices.shape + device = indices.device + dense = torch.zeros(B, H, num_rows, num_cols + 1, dtype=torch.int32, device=device) + col_range = torch.arange(max_entries, device=device) + valid = col_range[None, None, None, :] < num_blocks[:, :, :, None] + safe_indices = torch.where(valid, indices.long(), num_cols) + row_idx = torch.arange(num_rows, device=device)[None, None, :, None].expand_as(indices) + b_idx = torch.arange(B, device=device)[:, None, None, None].expand_as(indices) + h_idx = torch.arange(H, device=device)[None, :, None, None].expand_as(indices) + dense[b_idx, h_idx, row_idx, safe_indices] = 1 + return dense[:, :, :, :num_cols] + + +def compute_dq_write_order( + fwd_mask_cnt: torch.Tensor, + fwd_mask_idx: torch.Tensor, + fwd_full_cnt: torch.Tensor | None, + fwd_full_idx: torch.Tensor | None, + bwd_mask_cnt: torch.Tensor, + bwd_mask_idx: torch.Tensor, + bwd_full_cnt: torch.Tensor | None, + bwd_full_idx: torch.Tensor | None, + spt: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Compute dQ write-order metadata for deterministic block-sparse backward. + + For each (n_block, i) in the backward iteration, computes the semaphore + lock value: the rank of n_block in the combined (partial + full) sorted + contributor list for the target m_block. + + Lock values are assigned in ascending n_block order (or descending if spt=True) + to guarantee deadlock-freedom with the CTA scheduling order. + + Args: + fwd_mask_cnt: [B, H, num_m_blocks] partial contributor counts per m_block + fwd_mask_idx: [B, H, num_m_blocks, max_kv] partial contributor n_block indices (ascending) + fwd_full_cnt: [B, H, num_m_blocks] full contributor counts per m_block (optional) + fwd_full_idx: [B, H, num_m_blocks, max_kv] full contributor n_block indices (optional) + bwd_mask_cnt: [B, H, num_n_blocks] partial iteration counts per n_block + bwd_mask_idx: [B, H, num_n_blocks, max_q] partial iteration m_block indices + bwd_full_cnt: [B, H, num_n_blocks] full iteration counts per n_block (optional) + bwd_full_idx: [B, H, num_n_blocks, max_q] full iteration m_block indices (optional) + spt: if True, reverse ordering (highest n_block gets lock_value=0) + + Returns: + (dq_write_order, dq_write_order_full): tensors parallel to bwd_mask_idx + and bwd_full_idx respectively, containing lock values. + """ + device = fwd_mask_idx.device + B, H, num_m, max_kv_partial = fwd_mask_idx.shape + _, _, num_n, max_q_partial = bwd_mask_idx.shape + + has_full = fwd_full_cnt is not None and fwd_full_idx is not None + + dense_partial = _ordered_to_dense_simple(fwd_mask_cnt, fwd_mask_idx, num_n) + if has_full: + dense_full = _ordered_to_dense_simple(fwd_full_cnt, fwd_full_idx, num_n) + dense = (dense_partial + dense_full).clamp(max=1) + else: + dense = dense_partial + + cumsum = dense.cumsum(dim=-1) + rank_table = (cumsum - dense).to(torch.int32) + + if spt: + total_per_m = cumsum[:, :, :, -1:] + rank_table = (total_per_m - 1 - rank_table).to(torch.int32) + + def _gather_write_order(bwd_idx, bwd_cnt): + b_i = torch.arange(B, device=device)[:, None, None, None].expand_as(bwd_idx) + h_i = torch.arange(H, device=device)[None, :, None, None].expand_as(bwd_idx) + n_i = torch.arange(bwd_idx.shape[2], device=device)[None, None, :, None].expand_as(bwd_idx) + m_vals = bwd_idx.long().clamp(0, num_m - 1) + return rank_table[b_i, h_i, m_vals, n_i].to(torch.int32) + + dq_write_order = _gather_write_order(bwd_mask_idx, bwd_mask_cnt) + + dq_write_order_full = None + if has_full and bwd_full_cnt is not None and bwd_full_idx is not None: + dq_write_order_full = _gather_write_order(bwd_full_idx, bwd_full_cnt) + + return dq_write_order, dq_write_order_full + + +def compute_dq_write_order_from_block_mask( + block_mask, + spt: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = block_mask.as_tuple() + return compute_dq_write_order( + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + spt=spt, + ) def get_sparse_q_block_size( @@ -110,6 +246,25 @@ def _check_and_expand_block( return expanded_cnt, expanded_idx +def _check_and_expand_metadata_tensor( + name: str, + tensor: torch.Tensor | None, + expected_shape: Tuple[int, ...], + context: str | None, + hint: str | Callable[[], str] | None, + device: torch.device, +) -> torch.Tensor | None: + if tensor is None: + return None + if tensor.dtype != torch.int32: + raise ValueError(f"{name} must have dtype torch.int32") + if tensor.device != device: + raise ValueError(f"{name} must be on the same device as block sparse tensors") + if not tensor.is_cuda: + raise ValueError(f"{name} must live on CUDA") + return _expand_sparsity_tensor(tensor, expected_shape, name, context, hint) + + def get_block_sparse_expected_shapes( batch_size: int, num_head: int, @@ -279,12 +434,37 @@ def normalize_block_sparse_tensors( if full_cnt is not None and mask_cnt.device != full_cnt.device: raise ValueError("All block sparse tensors must be on the same device") + dq_write_order = _check_and_expand_metadata_tensor( + "dq_write_order", + tensors.dq_write_order, + tuple(mask_idx.shape), + context, + hint, + mask_cnt.device, + ) + dq_write_order_full = _check_and_expand_metadata_tensor( + "dq_write_order_full", + tensors.dq_write_order_full, + tuple(full_idx.shape) if full_idx is not None else expected_index_shape, + context, + hint, + mask_cnt.device, + ) + spt = tensors.spt + if spt is not None and not isinstance(spt, bool): + raise ValueError("spt must be a bool when provided") + if spt is not None and dq_write_order is None: + raise ValueError("spt requires dq_write_order to be provided") + return BlockSparseTensorsTorch( mask_block_cnt=mask_cnt, mask_block_idx=mask_idx, full_block_cnt=full_cnt, full_block_idx=full_idx, block_size=tensors.block_size, + dq_write_order=dq_write_order, + dq_write_order_full=dq_write_order_full, + spt=spt, ) @@ -316,6 +496,8 @@ def get_block_sparse_broadcast_pattern( tensors.mask_block_idx, tensors.full_block_cnt, tensors.full_block_idx, + tensors.dq_write_order, + tensors.dq_write_order_full, ): if tensor is not None: patterns.append(get_broadcast_dims(tensor)) @@ -423,30 +605,21 @@ def to_cute_block_sparse_tensors( """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi""" if not is_block_sparsity_enabled(tensors): return None - - ( - mask_block_cnt, - mask_block_idx, - full_block_cnt, - full_block_idx, - *_, - ) = tensors - - ( - mask_block_cnt_tensor, - mask_block_idx_tensor, - ) = [ + mask_block_cnt_tensor, mask_block_idx_tensor = [ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) - for t in (mask_block_cnt, mask_block_idx) + for t in (tensors.mask_block_cnt, tensors.mask_block_idx) ] - ( - full_block_cnt_tensor, - full_block_idx_tensor, - ) = [ + full_block_cnt_tensor, full_block_idx_tensor = [ + to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) + if t is not None + else None + for t in (tensors.full_block_cnt, tensors.full_block_idx) + ] + dq_write_order_tensor, dq_write_order_full_tensor = [ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) if t is not None else None - for t in (full_block_cnt, full_block_idx) + for t in (tensors.dq_write_order, tensors.dq_write_order_full) ] return BlockSparseTensors( @@ -454,6 +627,8 @@ def to_cute_block_sparse_tensors( mask_block_idx_tensor, full_block_cnt_tensor, full_block_idx_tensor, + dq_write_order_tensor, + dq_write_order_full_tensor, ) diff --git a/flash_attn/cute/cache_utils.py b/flash_attn/cute/cache_utils.py index f1b59700448..658a8d5b656 100644 --- a/flash_attn/cute/cache_utils.py +++ b/flash_attn/cute/cache_utils.py @@ -30,7 +30,6 @@ CompileKeyType: TypeAlias = tuple[Hashable, ...] CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function - # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1" diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py index a2dd98e41d2..69e8309a028 100644 --- a/flash_attn/cute/compute_block_sparsity.py +++ b/flash_attn/cute/compute_block_sparsity.py @@ -54,7 +54,7 @@ def __call__( seqlen_k: Int32, aux_tensors: Optional[list] = None, ): - self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx = blocksparse_tensors + self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx, *_ = blocksparse_tensors if const_expr(self.compute_full_blocks): assert self.full_cnt is not None and self.full_idx is not None, ( @@ -366,7 +366,14 @@ def compute_block_sparsity( ) compute_block_sparsity.compile_cache[compile_key]( - blocksparse_tensors_torch[:4], + ( + blocksparse_tensors_torch.mask_block_cnt, + blocksparse_tensors_torch.mask_block_idx, + blocksparse_tensors_torch.full_block_cnt, + blocksparse_tensors_torch.full_block_idx, + blocksparse_tensors_torch.dq_write_order, + blocksparse_tensors_torch.dq_write_order_full, + ), seqlen_q, seqlen_k, aux_tensors, diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 4b4083eda9e..9184ddeb029 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -58,6 +58,7 @@ def __init__( tile_n: int = 128, is_persistent: bool = False, deterministic: bool = False, + spt: Optional[bool] = None, cluster_size: int = 1, use_2cta_instrs: bool = False, score_mod: cutlass.Constexpr | None = None, @@ -116,6 +117,7 @@ def __init__( self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = False self.deterministic = deterministic + self.spt_override = spt # Score mod and mask mod support self.score_mod = score_mod @@ -705,7 +707,11 @@ def __call__( TileScheduler = SingleTileLPTBwdScheduler else: TileScheduler = SingleTileScheduler - self.spt = (self.is_causal or self.is_local) and self.deterministic + if const_expr(self.spt_override is None): + self.spt = (self.is_causal or self.is_local) and self.deterministic + else: + assert self.spt_override is not None + self.spt = self.spt_override and self.deterministic tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks cute.size(mQ.shape[2]), # num_heads = num_query_heads @@ -2989,8 +2995,17 @@ def compute_loop( # prefetch_LSE = not self.is_causal prefetch_LSE = False - # some tiles might be empty due to block sparsity + + curr_q_cnt = Int32(0) + curr_q_idx = None + curr_full_cnt = Int32(0) + curr_full_idx = None + loop_count = m_block_max - m_block_min + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max + ) if const_expr(self.use_block_sparsity): + assert blocksparse_tensors is not None ( curr_q_cnt, curr_q_idx, @@ -3006,17 +3021,14 @@ def compute_loop( m_block_max=m_block_max, ) process_tile = loop_count > Int32(0) - else: - process_tile = ( - const_expr(not self.is_local and not self.is_varlen_q) - or m_block_min < m_block_max - ) - loop_count = m_block_max - m_block_min # Mainloop # Block sparsity: iterate over sparse m_block count and derive actual m_block # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly. for iter_idx in cutlass.range(loop_count, unroll=1): + m_block = m_block_min + iter_idx + m_block_oob = False + is_full_block = False if const_expr(self.use_block_sparsity): m_block, is_full_block = get_m_block_from_iter_bwd( iter_idx, @@ -3028,10 +3040,6 @@ def compute_loop( m_block_max=m_block_max, ) m_block_oob = m_block >= m_block_max - else: - m_block = m_block_min + iter_idx - m_block_oob = False - is_full_block = False # Prefetch 1 stage of LSE pipeline_LSE.consumer_wait(consumer_state_LSE) tSrLSE_s2r = cute.make_fragment(tScS_t2r[None, 0, 0, 0].shape, Float32) @@ -3412,6 +3420,38 @@ def compute_loop( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + @cute.jit + def _dq_semaphore_lock_value( + self, + iter_idx: Int32, + curr_q_cnt: Int32, + curr_dq_write_order: Optional[cute.Tensor], + curr_dq_write_order_full: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], + block_info: BlockInfo, + seqlen, + m_block: Int32, + n_block: Int32, + n_block_global_max: Int32, + ) -> Int32: + lock_value = n_block + if const_expr(self.spt): + n_block_max_for_m_block = block_info.get_n_block_max_for_m_block( + seqlen, m_block, n_block_global_max + ) + lock_value = n_block_max_for_m_block - 1 - n_block + if const_expr(self.use_block_sparsity): + assert blocksparse_tensors is not None + if const_expr(blocksparse_tensors.dq_write_order is not None): + sparse_iter = iter_idx // self.subtile_factor + if sparse_iter < curr_q_cnt: + assert curr_dq_write_order is not None + lock_value = curr_dq_write_order[sparse_iter] + else: + assert curr_dq_write_order_full is not None + lock_value = curr_dq_write_order_full[sparse_iter - curr_q_cnt] + return lock_value + @cute.jit def dQacc_reduce( self, @@ -3484,13 +3524,24 @@ def dQacc_reduce( ) if const_expr(self.deterministic): + assert mdQ_semaphore is not None mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] - # delay_semaphore_release = self.is_causal and not self.tile_hdim == 192 - delay_semaphore_release = not self.tile_hdim == 192 + delay_semaphore_release = not self.tile_hdim == 192 and not self.use_block_sparsity + n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n) - # some tiles might be empty due to block sparsity + curr_q_cnt = Int32(0) + curr_q_idx = None + curr_full_cnt = Int32(0) + curr_full_idx = None + curr_dq_write_order = None + curr_dq_write_order_full = None + loop_count = m_block_max - m_block_min + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max + ) if const_expr(self.use_block_sparsity): + assert blocksparse_tensors is not None ( curr_q_cnt, curr_q_idx, @@ -3506,17 +3557,25 @@ def dQacc_reduce( m_block_max=m_block_max, ) process_tile = loop_count > Int32(0) - else: - process_tile = ( - const_expr(not self.is_local and not self.is_varlen_q) - or m_block_min < m_block_max - ) - loop_count = m_block_max - m_block_min + if const_expr(self.deterministic and self.use_block_sparsity): + assert blocksparse_tensors is not None + if const_expr(blocksparse_tensors.dq_write_order is not None): + assert blocksparse_tensors.dq_write_order is not None + curr_dq_write_order = blocksparse_tensors.dq_write_order[ + batch_idx, head_idx, n_block, None + ] + if const_expr(blocksparse_tensors.dq_write_order_full is not None): + assert blocksparse_tensors.dq_write_order_full is not None + curr_dq_write_order_full = blocksparse_tensors.dq_write_order_full[ + batch_idx, head_idx, n_block, None + ] # dQacc_reduce mainloop # Block sparsity: iterate over sparse m_block count and derive actual m_block # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly. for iter_idx in cutlass.range(loop_count, unroll=1): + m_block = m_block_min + iter_idx + m_block_oob_upper = False if const_expr(self.use_block_sparsity): m_block, _ = get_m_block_from_iter_bwd( iter_idx, @@ -3527,10 +3586,7 @@ def dQacc_reduce( subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) - if m_block_max > 0: - m_block = cutlass.min(m_block, m_block_max - 1) - else: - m_block = m_block_min + iter_idx + m_block_oob_upper = m_block >= m_block_max pipeline_dQ.consumer_wait(dQ_consumer_state) # TMEM -> RMEM tdQrdQ_t2r = cute.make_fragment(tdQrdQ_t2r_shape, Float32) @@ -3541,6 +3597,8 @@ def dQacc_reduce( pipeline_dQ.consumer_release(dQ_consumer_state) dQ_consumer_state.advance() + if m_block_max > 0: + m_block = cutlass.min(m_block, m_block_max - 1) gdQaccum_cur = gdQaccum[None, None, m_block] tdQrdQ_shape = ( @@ -3558,22 +3616,28 @@ def dQacc_reduce( cute.arch.fence_view_async_shared() # semaphore acquire if const_expr(self.deterministic and stage == 0): - if const_expr(self.spt): - _, n_block_max_for_m_block = block_info.get_n_block_min_max( - seqlen, m_block + if not m_block_oob_upper: + lock_value = self._dq_semaphore_lock_value( + iter_idx, + curr_q_cnt, + curr_dq_write_order, + curr_dq_write_order_full, + blocksparse_tensors, + block_info, + seqlen, + m_block, + n_block_cta_group, + n_block_global_max, + ) + barrier.wait_eq( + mdQ_semaphore_cur[(m_block, None)].iterator, + tidx, + cta_rank_in_cluster, + lock_value, ) - lock_value = n_block_max_for_m_block - 1 - n_block_cta_group - else: - lock_value = n_block_cta_group - barrier.wait_eq( - mdQ_semaphore_cur[(m_block, None)].iterator, - tidx, - cta_rank_in_cluster, - lock_value, - ) self.reduce_sync_barrier.arrive_and_wait() # Copy from shared memory to global memory - if is_tma_warp: + if is_tma_warp and not m_block_oob_upper: with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, smem_idx].iterator, @@ -3582,20 +3646,12 @@ def dQacc_reduce( ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(self.sdQaccum_stage - 1, read=read_flag) + elif is_tma_warp: + # Drain pending TMA stores so SMEM buffers are safe to reuse + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() dQ_tma_store_producer_state.advance() - # Directly add to gmem, much slower - # tdQgdQ = thr_copy_dQaccum_r2s.partition_D(gdQaccum[None, stage, m_block]) - # assert cute.size(tdQrdQ_r2s) == cute.size(tdQgdQ) - # for i in cutlass.range(cute.size(tdQrdQ_r2s) // 4, unroll_full=True): - # copy_utils.atomic_add_fp32x4( - # tdQrdQ_r2s[4 * i], - # tdQrdQ_r2s[4 * i + 1], - # tdQrdQ_r2s[4 * i + 2], - # tdQrdQ_r2s[4 * i + 3], - # utils.elem_pointer(tdQgdQ, 4 * i), - # ) - # semaphore release for prior m_block + if const_expr(self.deterministic and stage == 0 and delay_semaphore_release): if m_block > m_block_min: barrier.arrive_inc( @@ -3617,12 +3673,13 @@ def dQacc_reduce( # NOTE: arrive_inc calls red_release which issues membar if const_expr(self.deterministic and not delay_semaphore_release): if const_expr(self.sdQaccum_stage > 1 and not self.tile_hdim == 192): - if is_tma_warp: + if is_tma_warp and not m_block_oob_upper: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() - barrier.arrive_inc( - mdQ_semaphore_cur[m_block, None].iterator, tidx, cta_rank_in_cluster, 1 - ) + if not m_block_oob_upper: + barrier.arrive_inc( + mdQ_semaphore_cur[m_block, None].iterator, tidx, cta_rank_in_cluster, 1 + ) if process_tile: if is_tma_warp: @@ -3638,7 +3695,10 @@ def dQacc_reduce( ) if const_expr( - self.deterministic and not self.spt and block_info.window_size_left is not None + self.deterministic + and not self.spt + and not self.use_block_sparsity + and block_info.window_size_left is not None ): m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m) for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1): @@ -3863,6 +3923,7 @@ def epilogue_dK_or_dV_tma( deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1 if const_expr(deterministic_KV): + assert mdKV_semaphore is not None mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] if const_expr(not self.dKV_postprocess): diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 9027da189ac..42acbeaec86 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -2036,7 +2036,6 @@ def softmax_loop( ) if const_expr(self.use_block_sparsity) or has_work: - # See block_sparse_utils.py NOTE [SM100 block-sparse empty tiles: mbarrier contract]. pipeline_sm_stats.producer_acquire_w_index_phase(stage, sm_stats_producer_phase) sm_stats_producer_phase ^= 1 diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 6441b0cc13c..d12ea7b80b4 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -981,10 +981,6 @@ def _flash_attn_fwd( compile_args.insert(-3, descale_tensors_tensor) _flash_attn_fwd.compile_cache[compile_key] = cute.compile(*compile_args, options="--enable-tvm-ffi") - # In "fake mode", we will take torch fake tensors as input and the expected behaviors are: - # - Use those fake metadata to populate compilation cache - # - Return "fake" output tensors, which could be needed in follow-up fake operations - # Thus, we skip the actual kernel invocation here. if not is_fake_mode(): q_call, k_call, v_call = q.detach(), k.detach(), v.detach() qv_call = qv.detach() if qv is not None else None @@ -1038,7 +1034,16 @@ def _flash_attn_fwd( if arch // 10 in [10, 11]: call_args.append(descale_tensors) call_args.extend([ - normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, + ( + normalized_block_sparse_tensors.mask_block_cnt, + normalized_block_sparse_tensors.mask_block_idx, + normalized_block_sparse_tensors.full_block_cnt, + normalized_block_sparse_tensors.full_block_idx, + normalized_block_sparse_tensors.dq_write_order, + normalized_block_sparse_tensors.dq_write_order_full, + ) + if normalized_block_sparse_tensors is not None + else None, aux_tensors, ]) _flash_attn_fwd.compile_cache[compile_key](*call_args) @@ -1566,6 +1571,30 @@ def _flash_attn_bwd( block_size=(m_block_size, n_block_size), subtile_factor=subtile_factor, ) + if deterministic: + if normalized_block_sparse_tensors.dq_write_order is None: + raise ValueError( + "deterministic block-sparse backward requires dq_write_order in block_sparse_tensors" + ) + if ( + normalized_block_sparse_tensors.full_block_cnt is not None + and normalized_block_sparse_tensors.dq_write_order_full is None + ): + raise ValueError( + "deterministic block-sparse backward requires dq_write_order_full when full blocks are present" + ) + if normalized_block_sparse_tensors.spt is None: + raise ValueError( + "deterministic block-sparse backward requires block_sparse_tensors.spt " + "to match dq_write_order direction" + ) + if ( + normalized_block_sparse_tensors is not None + and normalized_block_sparse_tensors.spt is not None + ): + spt = normalized_block_sparse_tensors.spt and deterministic + else: + spt = (causal or local) and deterministic if arch // 10 in [8, 9, 12]: compile_key = ( @@ -1627,6 +1656,7 @@ def _flash_attn_bwd( cluster_size, use_2cta_instrs, deterministic, + spt, score_mod_hash, score_mod_bwd_hash, mask_mod_hash, @@ -1762,6 +1792,7 @@ def _flash_attn_bwd( cluster_size=cluster_size, use_2cta_instrs=use_2cta_instrs, deterministic=deterministic, + spt=spt, score_mod=score_mod, score_mod_bwd=score_mod_bwd, mask_mod=mask_mod, @@ -1825,7 +1856,16 @@ def _flash_attn_bwd( dK_semaphore, dV_semaphore, aux_tensors, - normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, + ( + normalized_block_sparse_tensors.mask_block_cnt, + normalized_block_sparse_tensors.mask_block_idx, + normalized_block_sparse_tensors.full_block_cnt, + normalized_block_sparse_tensors.full_block_idx, + normalized_block_sparse_tensors.dq_write_order, + normalized_block_sparse_tensors.dq_write_order_full, + ) + if normalized_block_sparse_tensors is not None + else None, ) # Postprocess: convert dq_accum from float32 to dq in bf16/fp16 # hd=256 2CTA backward has its own internal postprocess, skip here. diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 484f5191725..ceef6500b97 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -27,6 +27,8 @@ BlockSparseTensorsTorch, fast_sampling, normalize_block_sparse_config, + compute_dq_write_order, + compute_dq_write_order_from_block_mask, ) from flash_attn.cute.cache_utils import get_jit_cache from flash_attn.cute import utils @@ -677,8 +679,12 @@ def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): full_block_idx=full_q_idx, block_size=(sparse_tile_m, tile_n), ) - - + dq_write_order = compute_dq_write_order_from_block_mask(bm, spt=False) + block_sparse_mask_bwd = block_sparse_mask_bwd._replace( + dq_write_order=dq_write_order[0], + dq_write_order_full=dq_write_order[1], + spt=False, + ) out_tuple = _flash_attn_fwd( q=q, k=k, v=v, out=out, lse=lse, cu_seqlens_q=None, cu_seqlens_k=None, @@ -1119,7 +1125,7 @@ def wrapped_normalize(*args, **kwargs): def run_cute_mask_bwd( q, k, v, out, lse, grad_out, mask_mod_cute, block_sparse_mask_bwd=None, tile_m=128, tile_n=128, - aux_tensors=None, + aux_tensors=None, deterministic=False, causal=False, window_size_left=None, window_size_right=None, ): """Run flash attention backward with mask_mod. @@ -1132,6 +1138,7 @@ def run_cute_mask_bwd( block_sparse_mask_bwd: Block sparse tensors for backward pass tile_m, tile_n: Tile sizes aux_tensors: Auxiliary tensors for mask_mod (e.g., doc_ids for document masking) + deterministic: Whether to enable deterministic backward Returns (dq, dk, dv) all in BSHD format. """ @@ -1142,12 +1149,15 @@ def run_cute_mask_bwd( out=out, dout=grad_out, lse=lse, - causal=False, + causal=causal, m_block_size=tile_m, n_block_size=tile_n, + window_size_left=window_size_left, + window_size_right=window_size_right, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_bwd, aux_tensors=aux_tensors, + deterministic=deterministic, ) return dq, dk, dv @@ -1739,6 +1749,513 @@ def test_persistent_blocksparse_empty_tiles(): assert not out.isnan().any() +def _build_dense_from_ordered(num_blocks, indices, num_cols): + """Build dense binary matrix from ordered sparse representation (test helper).""" + B, H, num_rows, max_entries = indices.shape + batch_is_broadcast = B == 1 or (indices.stride(0) == 0 and num_blocks.stride(0) == 0) + head_is_broadcast = H == 1 or (indices.stride(1) == 0 and num_blocks.stride(1) == 0) + batch_size = 1 if batch_is_broadcast else B + head_size = 1 if head_is_broadcast else H + indices_view = indices[:batch_size, :head_size] + num_blocks_view = num_blocks[:batch_size, :head_size] + dense = torch.zeros( + batch_size, + head_size, + num_rows, + num_cols + 1, + dtype=torch.int32, + device=indices.device, + ) + valid = ( + torch.arange(max_entries, device=indices.device)[None, None, None, :] + < num_blocks_view[:, :, :, None] + ) + safe_indices = torch.where(valid, indices_view.long(), num_cols) + dense.scatter_(-1, safe_indices, valid.to(torch.int32)) + dense = dense[:, :, :, :num_cols] + if batch_size != B or head_size != H: + return dense.expand(B, H, num_rows, num_cols) + return dense + + +def _verify_deadlock_freedom( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + dq_wo, dq_wo_full, spt=False, +): + """Verify the critical deadlock-freedom invariant for all m_blocks. + + For non-spt: the lowest n_block contributor to each m_block must have lock_value=0. + For spt: the highest n_block contributor must have lock_value=0. + """ + B, H, num_m = kv_mask_cnt.shape + num_n = kv_mask_idx.shape[-1] + + dense = _build_dense_from_ordered(kv_mask_cnt, kv_mask_idx, num_n) + if full_kv_cnt is not None: + dense = dense | _build_dense_from_ordered(full_kv_cnt, full_kv_idx, num_n) + + for b in range(B): + for h in range(H): + for m in range(num_m): + contributors = dense[b, h, m].nonzero(as_tuple=True)[0] + if len(contributors) == 0: + continue + target_n = contributors[-1].item() if spt else contributors[0].item() + + found = False + cnt_partial = q_mask_cnt[b, h, target_n].item() + for i in range(cnt_partial): + if q_mask_idx[b, h, target_n, i].item() == m: + assert dq_wo[b, h, target_n, i].item() == 0, ( + f"n_block={target_n} should get lock_value=0 for m_block={m} (spt={spt})" + ) + found = True + break + if not found and full_q_cnt is not None: + cnt_full = full_q_cnt[b, h, target_n].item() + for i in range(cnt_full): + if full_q_idx[b, h, target_n, i].item() == m: + assert dq_wo_full[b, h, target_n, i].item() == 0, ( + f"n_block={target_n} (full) should get lock_value=0 for m_block={m} (spt={spt})" + ) + found = True + break + assert found, f"target n_block={target_n} not found in backward lists for m_block={m}" + + +def _verify_unique_ranks_per_m_block( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + dq_wo, dq_wo_full, +): + """Verify that for each m_block, the lock values form a contiguous 0..N-1 range.""" + B, H, num_m = kv_mask_cnt.shape + num_n = kv_mask_idx.shape[-1] + + dense = _build_dense_from_ordered(kv_mask_cnt, kv_mask_idx, num_n) + if full_kv_cnt is not None: + dense = dense | _build_dense_from_ordered(full_kv_cnt, full_kv_idx, num_n) + + for b in range(B): + for h in range(H): + for m in range(num_m): + contributors = dense[b, h, m].nonzero(as_tuple=True)[0] + total = len(contributors) + if total == 0: + continue + lock_vals = set() + for n in contributors.tolist(): + cnt_p = q_mask_cnt[b, h, n].item() + for i in range(cnt_p): + if q_mask_idx[b, h, n, i].item() == m: + lock_vals.add(dq_wo[b, h, n, i].item()) + if full_q_cnt is not None: + cnt_f = full_q_cnt[b, h, n].item() + for i in range(cnt_f): + if full_q_idx[b, h, n, i].item() == m: + lock_vals.add(dq_wo_full[b, h, n, i].item()) + assert lock_vals == set(range(total)), ( + f"m_block={m}: expected ranks {{0..{total-1}}}, got {lock_vals}" + ) + + +def _run_write_order_test(mask_mod_flex, seqlen_q, seqlen_k, block_size, B=1, H=4, spt=False): + bm = create_block_mask( + mask_mod_flex, B, H, seqlen_q, seqlen_k, + device="cuda", BLOCK_SIZE=(block_size, block_size), + ) + ( + _, _, + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_, + ) = bm.as_tuple() + + dq_wo, dq_wo_full = compute_dq_write_order( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + spt=spt, + ) + + _verify_deadlock_freedom( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + dq_wo, dq_wo_full, spt=spt, + ) + if not spt: + _verify_unique_ranks_per_m_block( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + dq_wo, dq_wo_full, + ) + + +def _build_block_sparse_masks_for_bwd( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + spt, +): + sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m + bm = create_block_mask( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = bm.as_tuple() + + block_sparse_mask_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(sparse_tile_m, tile_n), + ) + block_sparse_mask_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, + mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, + full_block_idx=full_q_idx, + block_size=(sparse_tile_m, tile_n), + ) + dq_write_order = compute_dq_write_order_from_block_mask(bm, spt=spt) + return block_sparse_mask_fwd, block_sparse_mask_bwd._replace( + dq_write_order=dq_write_order[0], + dq_write_order_full=dq_write_order[1], + spt=spt, + ) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 256), (512, 512), (383, 769)]) +@pytest.mark.parametrize( + "mask_name,window_size", + [ + ("block_diagonal", None), + ("causal", None), + ("sliding_window", 256), + ("document", None), + ], +) +@pytest.mark.parametrize("spt", [False, True]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa"]) +def test_block_sparse_bwd_deterministic(seqlen_q, seqlen_k, mask_name, window_size, spt, kv_mode): + torch.manual_seed(42) + if mask_name == "sliding_window" and seqlen_q > seqlen_k: + pytest.skip("sliding_window requires seqlen_q <= seqlen_k") + if spt and mask_name not in ("sliding_window", "causal"): + pytest.skip("spt path is only exercised for sliding_window and causal in this test") + + batch_size = 1 + nheads = 4 + nheads_kv = 1 if kv_mode == "gqa" else nheads + pack_gqa = nheads != nheads_kv + headdim = 128 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + + mask_mod_cute, mask_mod_flex = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + + aux_tensors_arg = None + if mask_name == "document": + doc_ids = random_doc_id_tensor(nheads, batch_size, max(seqlen_q, seqlen_k), device="cuda").to( + torch.int32 + ) + original_flex_mask = mask_mod_flex + + def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): + return original_flex_mask(b, h, q_idx, kv_idx, doc_ids) + + aux_tensors_arg = [doc_ids] + + tensors = create_tensors(batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim, dtype) + q = tensors["q"] + k = tensors["k"] + v = tensors["v"] + block_mask_nheads = 1 if pack_gqa else nheads + block_sparse_mask_fwd, block_sparse_mask_bwd = _build_block_sparse_masks_for_bwd( + mask_mod_flex=mask_mod_flex, + batch_size=batch_size, + nheads=block_mask_nheads, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + tile_m=tile_m, + tile_n=tile_n, + spt=spt, + ) + causal_arg = spt and mask_name == "causal" + window_size_left_arg = window_size if spt and mask_name == "sliding_window" else None + window_size_right_arg = 0 if spt and mask_name == "sliding_window" else None + mask_mod_arg = mask_mod_cute if not spt else None + + out_cute, lse_cute = _flash_attn_fwd( + q=q, + k=k, + v=v, + out=torch.empty(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype), + lse=torch.empty(batch_size, nheads, seqlen_q, device="cuda", dtype=torch.float32), + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=1.0 / math.sqrt(headdim), + causal=causal_arg, + softcap=None, + window_size_left=window_size_left_arg, + window_size_right=window_size_right_arg, + learnable_sink=None, + tile_mn=(tile_m, tile_n), + pack_gqa=pack_gqa, + _arch=None, + score_mod=None, + mask_mod=mask_mod_arg, + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, + aux_tensors=aux_tensors_arg, + ) + + grad_out = torch.randn_like(out_cute) + dq0, dk0, dv0 = run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + mask_mod_arg, + block_sparse_mask_bwd=block_sparse_mask_bwd, + tile_m=tile_m, + tile_n=tile_n, + aux_tensors=aux_tensors_arg, + deterministic=True, + causal=causal_arg, + window_size_left=window_size_left_arg, + window_size_right=window_size_right_arg, + ) + + num_repeats = 3 if spt else 50 + for _ in range(num_repeats): + dq, dk, dv = run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + mask_mod_arg, + block_sparse_mask_bwd=block_sparse_mask_bwd, + tile_m=tile_m, + tile_n=tile_n, + aux_tensors=aux_tensors_arg, + deterministic=True, + causal=causal_arg, + window_size_left=window_size_left_arg, + window_size_right=window_size_right_arg, + ) + assert torch.equal(dq, dq0) + assert torch.equal(dk, dk0) + assert torch.equal(dv, dv0) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +def _setup_block_sparse_deterministic_validation_case(): + torch.manual_seed(42) + batch_size = 1 + nheads = 4 + seqlen_q = 256 + seqlen_k = 256 + headdim = 128 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + + _, mask_mod_flex = get_mask_pair( + "block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k + ) + + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + block_sparse_mask_fwd, block_sparse_mask_bwd = _build_block_sparse_masks_for_bwd( + mask_mod_flex=mask_mod_flex, + batch_size=batch_size, + nheads=nheads, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + tile_m=tile_m, + tile_n=tile_n, + spt=False, + ) + out_cute, lse_cute = _flash_attn_fwd( + q=q, + k=k, + v=v, + out=torch.empty(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype), + lse=torch.empty(batch_size, nheads, seqlen_q, device="cuda", dtype=torch.float32), + softmax_scale=1.0 / math.sqrt(headdim), + tile_mn=(tile_m, tile_n), + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, + ) + + return q, k, v, out_cute, lse_cute, torch.randn_like(out_cute), block_sparse_mask_bwd, tile_m, tile_n + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +def test_block_sparse_bwd_deterministic_missing_dq_write_order_raises(): + q, k, v, out_cute, lse_cute, grad_out, block_sparse_mask_bwd, tile_m, tile_n = ( + _setup_block_sparse_deterministic_validation_case() + ) + block_sparse_mask_bwd_no_dq_write_order = block_sparse_mask_bwd._replace( + dq_write_order=None, + dq_write_order_full=None, + spt=None, + ) + + with pytest.raises(ValueError, match="requires dq_write_order in block_sparse_tensors"): + run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + None, + block_sparse_mask_bwd=block_sparse_mask_bwd_no_dq_write_order, + tile_m=tile_m, + tile_n=tile_n, + deterministic=True, + ) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +def test_block_sparse_bwd_deterministic_missing_dq_write_order_full_raises(): + q, k, v, out_cute, lse_cute, grad_out, block_sparse_mask_bwd, tile_m, tile_n = ( + _setup_block_sparse_deterministic_validation_case() + ) + block_sparse_mask_bwd_no_dq_write_order_full = block_sparse_mask_bwd._replace( + full_block_cnt=torch.zeros_like(block_sparse_mask_bwd.mask_block_cnt), + full_block_idx=torch.zeros_like(block_sparse_mask_bwd.mask_block_idx), + dq_write_order_full=None, + spt=False, + ) + + with pytest.raises(ValueError, match="requires dq_write_order_full when full blocks are present"): + run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + None, + block_sparse_mask_bwd=block_sparse_mask_bwd_no_dq_write_order_full, + tile_m=tile_m, + tile_n=tile_n, + deterministic=True, + ) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +def test_block_sparse_bwd_deterministic_missing_spt_raises(): + q, k, v, out_cute, lse_cute, grad_out, block_sparse_mask_bwd, tile_m, tile_n = ( + _setup_block_sparse_deterministic_validation_case() + ) + block_sparse_mask_bwd_no_spt = block_sparse_mask_bwd._replace(spt=None) + + with pytest.raises(ValueError, match="requires block_sparse_tensors.spt"): + run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + None, + block_sparse_mask_bwd=block_sparse_mask_bwd_no_spt, + tile_m=tile_m, + tile_n=tile_n, + deterministic=True, + ) + + +WRITE_ORDER_SEQLENS = [ + (256, 256), + (512, 512), + (1024, 1024), + (2048, 2048), + (4096, 4096), + (512, 1024), + (1024, 512), + (384, 768), +] + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", WRITE_ORDER_SEQLENS) +@pytest.mark.parametrize("mask_name", ["block_diagonal", "mini_causal", "prefix_lm", "dilated_sliding_window"]) +@pytest.mark.parametrize("spt", [False, True]) +def test_dq_write_order_static_masks(seqlen_q, seqlen_k, mask_name, spt): + torch.manual_seed(42) + _, mask_mod_flex = get_mask_pair(mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k) + _run_write_order_test(mask_mod_flex, seqlen_q, seqlen_k, block_size=128, spt=spt) + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", WRITE_ORDER_SEQLENS) +@pytest.mark.parametrize( + "mask_name,window_size", + [ + ("causal", None), + ("block_causal", None), + ("sliding_window", 128), + ("sliding_window", 256), + ("sliding_window", 512), + ], +) +@pytest.mark.parametrize("spt", [False, True]) +def test_dq_write_order_parameterized_masks(seqlen_q, seqlen_k, mask_name, window_size, spt): + torch.manual_seed(42) + if mask_name == "sliding_window" and seqlen_q > seqlen_k: + pytest.skip("sliding_window requires seqlen_q <= seqlen_k") + _, mask_mod_flex = get_mask_pair(mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size) + _run_write_order_test(mask_mod_flex, seqlen_q, seqlen_k, block_size=128, spt=spt) + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(512, 512), (1024, 1024), (2048, 2048)]) +@pytest.mark.parametrize("spt", [False, True]) +def test_dq_write_order_document_mask(seqlen_q, seqlen_k, spt): + torch.manual_seed(42) + B, H = 1, 4 + doc_ids = random_doc_id_tensor(H, B, max(seqlen_q, seqlen_k), device="cuda").to(torch.int32) + + def doc_mask(b, h, q_idx, kv_idx): + return doc_ids[b, h, q_idx] == doc_ids[b, h, kv_idx] + + _run_write_order_test(doc_mask, seqlen_q, seqlen_k, block_size=128, B=B, H=H, spt=spt) + def test_compact_block_sparse_indices(): """Test that compact block sparse index tensors (idx.shape[3] < n_blocks) work correctly. From fe1ddad90f961a1fbc85304dff06220de2f1920f Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 11 May 2026 21:56:39 +0000 Subject: [PATCH 44/44] Revert "expose num_splits for FA2 and add option for kernel blocksize alignment (#2448)" This reverts commit b322ae2675065ad96ad3c248fd6ef0f32252808f. --- csrc/flash_attn/flash_api.cpp | 7 ++++--- csrc/flash_attn/src/flash_fwd_launch_template.h | 12 ++---------- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 99294b466f5..2811e7c4551 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -533,8 +533,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s int window_size_right, const float softcap, const bool return_softmax, - std::optional gen_, - int num_splits = 0) { + int num_splits, + std::optional gen_) { // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; @@ -702,7 +702,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s params.page_block_size = page_block_size; // Keep references to these tensors to extend their lifetime at::Tensor softmax_lse_accum, out_accum; - if (paged_KV || seqlenq_ngroups_swapped) { + if (seqlenq_ngroups_swapped) { + // Only apply split-k for decoding std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads, head_size, max_seqlen_k, max_seqlen_q, head_size_rounded, diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index b7831c5e832..934e7b9114b 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -162,19 +162,11 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int kBlockM = 64; + constexpr static int kBlockM = 64; // Fixed for all head dimensions // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, // and for headdim 192 with block size 64 x 128. constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - // if user specifies num_splits=1, we assume they want bitwise identical - // numerics across the split KV and standard kernels so we align kBLockN to - // match - if (params.num_splits == 1) { - constexpr static int kBlockN_standard = Headdim <= 64 ? 128 : 64; - run_flash_splitkv_fwd, Is_causal>(params, stream); - } else { - run_flash_splitkv_fwd, Is_causal>(params, stream); - } + run_flash_splitkv_fwd, Is_causal>(params, stream); } template