Skip to content

[Bug] SM100 attention/mla_prefill use Ampere-era mma instead of Blackwell-native tcgen05.mma (blocks GQA ≥ 8:1) #702

@aganhui

Description

@aganhui

Blocking Bug

Models with GQA ratio ≥ 8:1 (e.g., Qwen3-32B with 64 Q heads / 8 KV heads) fail to compile on Blackwell (B200) due to shared memory overflow in attention_sm100.cuh. The static_assert at line 179 triggers because total smem (~232 KB) exceeds Blackwell's dynamic smem budget (201 KB).

Related: #700 (Qwen3-32B feature request — point 5 states "the attention kernel should handle this without modification," but this is not the case.)


Problem

The SM100 attention and MLA_prefill kernels use Ampere-era mma.m16n8k16 instructions instead of Blackwell-native tcgen05.mma. This is the root cause of the smem overflow that blocks GQA ≥ 8:1 models.

Why mma.m16n8k16 causes smem overflow

With mma.m16n8k16, MMA accumulators reside in thread-private registers. After the MMA, a cross-warp reduction is needed to produce the final output, which requires spilling accumulators to shared memory. This is the S_O_BUFFER in attention_sm100.cuh:

// attention_sm100.cuh:79
constexpr int MMA_ITERS_M = (MAX_TOKENS * NUM_QO_PER_KV + 15) / 16;

// attention_sm100.cuh:176-177
constexpr size_t S_O_BUFFER_SIZE =
    sizeof(float) * MMA_ITERS_M * NUM_THREADS * 64;

For GQA 8:1 with MAX_TOKENS=8:

  • NUM_QO_PER_KV = 8MMA_ITERS_M = (8 × 8 + 15) / 16 = 4
  • S_O_BUFFER_SIZE = 4 × 256 × 64 × 4 = 256 KB (dominates total smem)
  • Total smem ≈ 232 KB → exceeds 201 KB budget → static_assert failure

Why tcgen05.mma would fix this

Blackwell's tcgen05.mma operates on larger tiles (e.g., m128n128k16 for MMA, vs m16n8k16 for Ampere) and uses TMEM (Tensor Memory) as CTA-shared accumulator storage. Since TMEM is architecturally separate from shared memory, the cross-warp smem spill (S_O_BUFFER) is eliminated entirely.

Aspect mma.m16n8k16 (Ampere) tcgen05.mma (Blackwell)
MMA tile 16×8×16 per warp 128×128×16 (CTA-level)
Accumulator Thread registers TMEM (CTA-shared)
Cross-warp reduction Required → S_O_BUFFER in smem Not needed (TMEM is CTA-shared)
Async execution No Yes
Smem impact S_O_BUFFER = 128–256 KB Eliminated

Affected kernels

Kernel File Issue
attention_sm100 include/mirage/persistent_kernel/tasks/blackwell/attention_sm100.cuh Uses mma.m16n8k16 + includes from tasks/ampere/
mla_prefill_sm100 include/mirage/persistent_kernel/tasks/blackwell/mla_prefill_sm100.cuh Uses mma.m16n8k16 (PF_MMA_M=16, PF_MMA_N=8, PF_MMA_K=16)

Secondary Bug: MAX_TOKENS hardcodes to 8

In task_register.cc:1946, the SM100 attention instantiation only passes 9 template parameters, so MAX_TOKENS defaults to 8:

// task_register.cc:1946
code.e("kernel::multitoken_paged_attention_sm100_task_impl<bfloat16, $, $, "
       "$, $, "
       "$, $, $, $>(",
       num_q_heads / num_kv_heads,
       1,             // KV_CACHE_STRIDE
       kv_stride,
       qkv_stride,
       output_size,
       head_dim,
       max_seq_len,
       page_size);
       // MAX_TOKENS missing → defaults to 8

By contrast, the Hopper version (line 1112) dynamically passes max_tokens from the tensor shape. The SM100 path should do the same.


Short-Term Workaround

For immediate unblocking of GQA ≥ 8:1 models, the smem budget can be increased for CC ≥ 100 in runtime_header.h. Blackwell B200 has 228 KB total shared memory per SM. The current code only has a CC ≥ 90 branch:

// runtime_header.h — current
#if MPK_TARGET_CC >= 90
constexpr int MAX_DYNAMIC_SHARED_MEMORY_SIZE =
    207 * 1024 - WORKER_RESERVED_STATIC_SHARED_MEMORY_SIZE;

Adding a CC ≥ 100 branch with a higher budget (e.g., 225 KB) would allow the current mma.m16n8k16 implementation to compile for GQA 8:1 at the cost of fewer concurrent CTAs per SM. This is a stopgap — the proper fix is migrating to tcgen05.mma.


Proposed Fix

  1. Migrate SM100 attention to tcgen05.mma: Replace mma.m16n8k16 + register accumulators + S_O_BUFFER with tcgen05.mma + TMEM accumulators. This eliminates the S_O_BUFFER entirely and resolves the smem overflow.
  2. Migrate SM100 MLA_prefill to tcgen05.mma: Same approach for the MLA_prefill kernel.
  3. Fix MAX_TOKENS pass-through: Pass max_tokens from tensor shape in task_register.cc, matching the Hopper path.
  4. Add CC ≥ 100 smem budget: As a short-term workaround, add a separate branch in runtime_header.h for CC ≥ 100.

Environment

  • GPU: NVIDIA B200 (Blackwell SM100a)
  • Branch: mpk (default)
  • CUDA: 13.0
  • Model: Qwen3-32B (GQA 8:1, 64 Q heads / 8 KV heads, head_dim=128)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions