Skip to content

Bug in FlashMultiheadAttention? #239

@manuelknott

Description

@manuelknott

The FlashMultiheadAttention module in submodule.py appears to have been originally written for flash_attn_func (from
the flash_attn library), which was later replaced with torch.nn.functional.scaled_dot_product_attention. These two
functions have different dimension conventions:

  • flash_attn_func: expects (batch, seqlen, nheads, headdim) — computes attention over dim 1 (seqlen)
  • F.scaled_dot_product_attention: expects (batch, ..., seqlen, headdim) — computes attention over the second-to-last
    dim

The tensor layout was not updated when switching libraries in c1ee07e. As a result, SDPA interprets num_heads (4) as the sequence length, computing 4x4 cross-head attention instead of the intended self-attention along the disparity dimension as described in Section 3.2 of the paper.

I tested the released checkpoint with adding a .transpose(1, 2) before and after SDPA but it barely made a difference, which is curious. Q: Was the released checkpoint trained with flash_attn_func or with F.scaled_dot_product_attention?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions