In multitoken_paged_attention_sm100_task_impl (include/mirage/persistent_kernel/tasks/blackwell/attention_sm100.cuh), the pipelined K/V prefetch loop iterates with curr_iter_len (the size of the tile currently being consumed by MMA) as its bound, when it should be next_iter_len (the size of the tile actually being loaded). This is because the kernel does double-buffered K/V tiling: while iteration i runs MMA on a tile of length curr_iter_len, it asynchronously prefetches iteration i+1's tile of length next_iter_len into shared memory. The prefetch loop's per-thread chunk index should range over the elements of the tile being loaded, not the tile being consumed:
// Unfixed: WRONG — uses the size of the tile we are CONSUMING.
for (int chunk_idx = threadIdx.x;
chunk_idx < curr_iter_len * HEAD_DIM / CP_CHUNK_SIZE; // ← bug, should be next_iter_len
chunk_idx += NUM_THREADS) {
int dst_row = chunk_idx / (HEAD_DIM / CP_CHUNK_SIZE);
int col = (chunk_idx % (HEAD_DIM / CP_CHUNK_SIZE)) * CP_CHUNK_SIZE;
// copies row `dst_row` of the NEXT tile from
// page `page_indices[cp_finished_seq_len / PAGE_SIZE]`
...
}
In
multitoken_paged_attention_sm100_task_impl(include/mirage/persistent_kernel/tasks/blackwell/attention_sm100.cuh), the pipelined K/V prefetch loop iterates withcurr_iter_len(the size of the tile currently being consumed by MMA) as its bound, when it should benext_iter_len(the size of the tile actually being loaded). This is because the kernel does double-buffered K/V tiling: while iterationiruns MMA on a tile of lengthcurr_iter_len, it asynchronously prefetches iterationi+1's tile of lengthnext_iter_leninto shared memory. The prefetch loop's per-thread chunk index should range over the elements of the tile being loaded, not the tile being consumed: