Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
fa5b632
fix noisy logger (#2414)
drisspg Apr 1, 2026
b46c587
[AMD ROCm] Fix NaN in FMHA BWD when seq_q=0 (#2421)
rocking5566 Apr 1, 2026
4bc0ab1
Add FA4 CI: GitHub Actions workflow with Apptainer on B200 runner (#2…
Johnsonms Apr 2, 2026
83f9e45
Fix some bugs of CI (#2423)
Johnsonms Apr 2, 2026
ab5cb6e
[ROCM] Fix windows issues (#2385)
micmelesse Apr 2, 2026
1233b73
fix: add cu13 extra to dev install instructions for CUDA 13 / B200 sy…
Johnsonms Apr 3, 2026
65bfd9a
Fix: disable 2-CTA backward mode when block_sparse_tensors is used (#…
jduprat Apr 3, 2026
15270e6
CI: extend FA4 test matrix with causal/non-causal correctness and fwd…
Johnsonms Apr 4, 2026
14f3627
feat(cute): implement softcap backward pass, correct math formula, an…
CaesarG Apr 11, 2026
79f317c
[Doc] Remove old comment about supported features
tridao Apr 13, 2026
09c93ea
[DSL] Remove cute_compile_patched
tridao Apr 13, 2026
f219c89
[Cute,Sm100,Fwd] add MLA 64/512 with topk sparsity for MQA 128 heads …
jayhshah Apr 15, 2026
628452c
Handle linter for flash mla file (#2459)
jayhshah Apr 15, 2026
83b8b8f
Disable 2CTA fwd non-causal on CUDA 12 to work around codegen regress…
Johnsonms Apr 15, 2026
d7f60e6
Add CLC scheduler heuristic (#2455)
drisspg Apr 16, 2026
b322ae2
expose num_splits for FA2 and add option for kernel blocksize alignme…
liangel-02 Apr 16, 2026
b65ae6b
[Cute,Fwd,Sm100] fp8 e4m3 and e5m2 support (#2109)
dcw02 Apr 17, 2026
881139b
Expose --pack-gqa and --num-splits in benchmark_attn.py (#2473)
Johnsonms Apr 19, 2026
bbda031
Fix: pass num_splits through varlen_fwd Python wrapper (fixes #2448 r…
hsyysy Apr 20, 2026
03cd065
[Cute,Fwd] Fix crash when K/V has seqlen == 0 (#2470)
Johnsonms Apr 20, 2026
41b2ef6
fix causal calcs (#2463)
drisspg Apr 20, 2026
3a7694c
fix (#2481)
geruome Apr 21, 2026
27b4eb9
Feat([FA4][CUTE DSL]) Add head_dim=256 support (forward + backward) (…
wangsiyu Apr 23, 2026
b21e204
[Cute,hd256] Post-merge cleanup: dead code, duplicate imports (#2487)
Johnsonms Apr 23, 2026
ac6f2eb
[CuTe,Flex] Wire up interface for flex autograd support (#2485)
reubenconducts Apr 23, 2026
c9a560f
add missing score_mod_bwd param to FlashAttnVarlenFunc (#2496)
reubenconducts Apr 26, 2026
eefa86d
fix: typos and missing comments in FA4 cute kernel files (#2502)
dxasu Apr 27, 2026
547031a
[SM100] Guard gO None in empty-tile correction (#2504)
geruome Apr 27, 2026
6b52632
[CuTe, Flex] simplify blocksparse interface in flash_attn_func (#2506)
reubenconducts Apr 27, 2026
519445a
Fix (#2505)
MatthewBonanni Apr 27, 2026
b86e0cc
Fix clc scheduling request bug (#2508)
drisspg Apr 27, 2026
89ce84b
[Tests,MLA] Close coverage gaps in test_flash_attn_mla_absorbed{,_var…
Johnsonms Apr 27, 2026
96bd151
Add cache utils logging test (#2509)
drisspg Apr 28, 2026
6c73fb5
[hd256] Improve forward kernel with exp2 FMA emulation (+3% to +9% pe…
Johnsonms Apr 28, 2026
ebeff90
SM90 FA4 QuACK 0.4 Compatibility (#2513)
EduardDurech Apr 28, 2026
ba59def
ci: use /tmp for apptainer tmpdir to fix xattrerror on VAST (#2511)
Johnsonms Apr 28, 2026
b995b24
Fix long MSVC linker commands on Windows (#2517)
jammm Apr 29, 2026
06259d9
Fix test_flash_attn_fast varlen call after qv positional insert (#2527)
henrylhtsang Apr 30, 2026
c5f6ff4
[Cute,Bwd,Sm90] Fix determinism for GQA, port Sm100 approach in (#2510)
v0i0 Apr 30, 2026
a7f6fbd
benchmarks/tune_ex2_emu: hd256 sweep support and clock lock/unlock (#…
Johnsonms May 1, 2026
956366b
[FA4][hd256] Backward TMA bulk-store epilogue + LSE/dpsum coalesce (#…
Johnsonms May 1, 2026
cb213fc
[hd256] Add TMA paged KV support to SM100 2CTA forward kernel (#2489)
Johnsonms May 1, 2026
2e53092
Deterministic backward for blocksparse impl (#2253)
drisspg May 4, 2026
d42fc11
Merge remote-tracking branch 'upstream/main' into merge_upstream
MatthewBonanni Apr 27, 2026
653caa5
Merge remote-tracking branch 'upstream/main' into merge_upstream
MatthewBonanni Apr 27, 2026
3d8af08
Merge remote-tracking branch 'upstream/main' into merge_upstream
MatthewBonanni May 5, 2026
fe1ddad
Revert "expose num_splits for FA2 and add option for kernel blocksize…
MatthewBonanni May 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions .github/actions/gpu-test/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
name: GPU Test
description: Compile and run FA4 tests (pull SIF from Docker Hub, cache by tag)

inputs:
test-filter:
description: pytest -k filter expression
required: false
default: ""
compile-workers:
description: parallel workers for Pass 1 kernel compilation
required: false
default: "64"
fa4_image_cu129:
description: Docker image for CUDA 12.9 (used when driver does not support CUDA 13.0)
required: true
fa4_image_cu130:
description: Docker image for CUDA 13.0 (used when driver supports CUDA 13.0)
required: true

runs:
using: composite
steps:
- name: Select FA4 image based on CUDA version
shell: bash
run: |
# Read max supported CUDA version from nvidia-smi header, e.g. "CUDA Version: 12.9"
CUDA_VER=$(nvidia-smi | grep -oP "CUDA Version: \K[0-9]+\.[0-9]+")
CUDA_MAJOR=$(echo "$CUDA_VER" | cut -d. -f1)
echo "Detected max CUDA version: $CUDA_VER"
if [ "$CUDA_MAJOR" -ge 13 ]; then
echo "Using cu130 image"
echo "FA4_IMAGE=${{ inputs.fa4_image_cu130 }}" >> "$GITHUB_ENV"
else
echo "Using cu129 image"
echo "FA4_IMAGE=${{ inputs.fa4_image_cu129 }}" >> "$GITHUB_ENV"
fi

- name: Pull FA4 SIF
shell: bash
run: |
CI_WORK_DIR="${CI_WORK_DIR:-/scratch/user/$USER}"
TAG=$(echo "$FA4_IMAGE" | tr '/: ' '---')
SIF="$CI_WORK_DIR/${TAG}.sif"
# Apptainer doesn't support tag@digest refs — strip the tag, keep digest only.
PULL_REF=$(echo "$FA4_IMAGE" | sed 's/:[^@]*@/@/')
echo "PULL_REF=$PULL_REF"
echo "SIF=$SIF"
mkdir -p "$CI_WORK_DIR/apptainer_cache" /tmp/apptainer_tmp
if [ ! -f "$SIF" ]; then
echo "Pulling $PULL_REF → $SIF"
APPTAINER_TMPDIR="/tmp/apptainer_tmp" \
APPTAINER_CACHEDIR="$CI_WORK_DIR/apptainer_cache" \
apptainer pull "$SIF" "docker://$PULL_REF"
else
echo "Using cached SIF: $SIF"
fi
# Remove stale SIFs from previous image versions to prevent unbounded disk growth.
find "$CI_WORK_DIR" -maxdepth 1 -name "*.sif" ! -name "$(basename "$SIF")" -delete
echo "FA4_SIF=$SIF" >> "$GITHUB_ENV"

- name: Compile and run tests
shell: bash
run: |
python3 "$GITHUB_WORKSPACE/tools/ci/run_fa4_ci.py" \
--repo-root "$GITHUB_WORKSPACE" \
--test-filter "${{ inputs.test-filter }}" \
--compile-workers "${{ inputs.compile-workers }}"
8 changes: 8 additions & 0 deletions .github/scripts/test_ci_local.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/usr/bin/env bash
set -euo pipefail

SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../../" && pwd)

python3 "$SCRIPT_DIR/tools/ci/run_fa4_ci.py" \
--repo-root "$SCRIPT_DIR" \
"$@"
45 changes: 45 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: CI

on:
push:
branches: [main, ci-fix]

permissions:
contents: read

env:
CI_WORK_DIR: ${{ vars.CI_WORK_DIR || format('/scratch/user/{0}', github.actor) }}
FA4_TEST_FILTER: "1024-1024-128-True-0-0.0-False-False-False-mha-dtype0 or 1024-1024-128-False-0-0.0-False-False-False-mha-dtype0"

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install ruff
run: pip install ruff
- name: Ruff check
run: ruff check flash_attn/cute/ --extend-exclude "flash_attn/cute/flash_bwd.py,flash_attn/cute/flash_fwd.py,flash_attn/cute/flash_fwd_sm100.py,flash_attn/cute/interface.py"
- name: Ruff format
run: ruff format --check flash_attn/cute/ --exclude "flash_attn/cute/flash_bwd.py,flash_attn/cute/flash_fwd.py,flash_attn/cute/flash_fwd_sm100.py,flash_attn/cute/interface.py"

fa4-correctness-and-benchmark:
strategy:
fail-fast: false
matrix:
gpu: [b200]
runs-on: [self-hosted, '${{ matrix.gpu }}']
name: fa4-correctness-and-benchmark (${{ matrix.gpu }})
timeout-minutes: 60
steps:
- uses: actions/checkout@v4
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- uses: ./.github/actions/gpu-test
with:
test-filter: ${{ env.FA4_TEST_FILTER }}
fa4_image_cu129: "togethercomputer/training-performance:flash-attn-cu12.9-26.03.25@sha256:304a5c3d2b3a75b151cd2a964cd26d444e0d8b5686d63943df13378c9705f943"
fa4_image_cu130: "togethercomputer/training-performance:flash-attn-cu13.0-26.04.01@sha256:56e50b056eb4d671410846c3483e843ee7bd0f5b13cb45b6f0d7eb8bd27694a5"
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ var/

# Dev
venv
agent_space/
benchmarks/results/

# compile-time generated file
flash_attn_config.py
1 change: 1 addition & 0 deletions AGENTS.md
82 changes: 82 additions & 0 deletions AI/CLC_TRACE_DEBUG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# CLC Trace Debugging

Use this when you suspect the CLC work scheduler is making surprising tile assignment decisions and you want a raw scheduler trace from the current kernel.

## Current trace format

SM100 forward kernels emit one trace line per scheduler-warp query at `FA_LOG_LEVEL=3`:

```text
[CLC] query sm=<smid> cta=<blockIdx.x> (m_blk=<m>,h=<h>,b=<b>,s=<s>) valid=<0|1>
```

Current emit sites:
- `flash_attn/cute/flash_fwd_sm100.py`
- `flash_attn/cute/flash_fwd_mla_sm100.py`

## How to capture a trace

Important:
- `FA_LOG_LEVEL=3` is needed for the `[CLC] query ...` device-side prints.
- `FA_CLC=1` only requests CLC; the kernel may still fall back if the shape/features disable it.

Minimal repro pattern:

```bash
FA_LOG_LEVEL=3 FA_CLC=1 CUDA_VISIBLE_DEVICES=0 python - <<'PY' \
> agent_space/clc_trace.log 2>&1
import torch
from flash_attn.cute.interface import flash_attn_func

torch.manual_seed(0)
q = torch.randn(1, 512, 16, 128, device='cuda', dtype=torch.bfloat16)
k = torch.randn(1, 512, 1, 128, device='cuda', dtype=torch.bfloat16)
v = torch.randn(1, 512, 1, 128, device='cuda', dtype=torch.bfloat16)
flash_attn_func(q, k, v, causal=True)
torch.cuda.synchronize()
PY
```

If you want the run to say explicitly whether CLC was selected, keep the host log prefix too:

```text
[FA] TileScheduler=SingleTileLPTScheduler, scheduling_mode=CLC, USE_2CTA=False
```

## What to look for

- `scheduling_mode=CLC` in host logs confirms the shape actually used the CLC path.
- `valid=1` means the returned work tile is valid.
- `valid=0` means the scheduler is exhausted for that CTA/scheduler warp query.
- `m_blk`, `h`, `b`, `s` are the logical work coordinates after the scheduler mapping.
- `cta` is the physical `blockIdx.x`; for clustered launches multiple CTAs may participate in the same logical tile.

## Parse the trace

A lightweight parser lives in `AI/parse_clc_log.py`.

Text summary:

```bash
python AI/parse_clc_log.py agent_space/clc_trace.log
```

HTML view:

```bash
python AI/parse_clc_log.py agent_space/clc_trace.log --html -o agent_space/clc_trace.html
```

## Suggested workflow

1. Reproduce the surprising case with `FA_LOG_LEVEL=3 FA_CLC=1`.
2. Save stdout/stderr to `agent_space/clc_trace.log`.
3. Run `AI/parse_clc_log.py` on that log to get a compact per-SM / per-CTA summary.
4. If the trace still looks suspicious, attach or paste that log in the investigation thread / agent notes.
5. Compare against the relevant mapping logic in `flash_attn/cute/tile_scheduler.py`.

## Caveats

- The trace is noisy and expensive; use a single small shape first.
- Because the print happens on scheduler queries, many lines may be terminal `valid=0` queries after work is exhausted.
- Dense noncausal and varlen MHA may intentionally fall back away from CLC depending on the current heuristic in `flash_attn/cute/interface.py`.
Loading
Loading