Skip to content

[Bug] Incorrect range of K/V prefetch in multitoken_paged_attention kernel #698

@Risc-lt

Description

@Risc-lt

In multitoken_paged_attention_sm100_task_impl (include/mirage/persistent_kernel/tasks/blackwell/attention_sm100.cuh), the pipelined K/V prefetch loop iterates with curr_iter_len (the size of the tile currently being consumed by MMA) as its bound, when it should be next_iter_len (the size of the tile actually being loaded). This is because the kernel does double-buffered K/V tiling: while iteration i runs MMA on a tile of length curr_iter_len, it asynchronously prefetches iteration i+1's tile of length next_iter_len into shared memory. The prefetch loop's per-thread chunk index should range over the elements of the tile being loaded, not the tile being consumed:

// Unfixed: WRONG — uses the size of the tile we are CONSUMING.
for (int chunk_idx = threadIdx.x;
     chunk_idx < curr_iter_len * HEAD_DIM / CP_CHUNK_SIZE;   // ← bug, should be next_iter_len
     chunk_idx += NUM_THREADS) {
    int dst_row = chunk_idx / (HEAD_DIM / CP_CHUNK_SIZE);
    int col     = (chunk_idx % (HEAD_DIM / CP_CHUNK_SIZE)) * CP_CHUNK_SIZE;
    // copies row `dst_row` of the NEXT tile from
    // page `page_indices[cp_finished_seq_len / PAGE_SIZE]`
    ...
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions