The FlashMultiheadAttention module in submodule.py appears to have been originally written for flash_attn_func (from
the flash_attn library), which was later replaced with torch.nn.functional.scaled_dot_product_attention. These two
functions have different dimension conventions:
flash_attn_func: expects (batch, seqlen, nheads, headdim) — computes attention over dim 1 (seqlen)
F.scaled_dot_product_attention: expects (batch, ..., seqlen, headdim) — computes attention over the second-to-last
dim
The tensor layout was not updated when switching libraries in c1ee07e. As a result, SDPA interprets num_heads (4) as the sequence length, computing 4x4 cross-head attention instead of the intended self-attention along the disparity dimension as described in Section 3.2 of the paper.
I tested the released checkpoint with adding a .transpose(1, 2) before and after SDPA but it barely made a difference, which is curious. Q: Was the released checkpoint trained with flash_attn_func or with F.scaled_dot_product_attention?
The FlashMultiheadAttention module in submodule.py appears to have been originally written for
flash_attn_func(fromthe flash_attn library), which was later replaced with
torch.nn.functional.scaled_dot_product_attention. These twofunctions have different dimension conventions:
flash_attn_func: expects(batch, seqlen, nheads, headdim)— computes attention over dim 1 (seqlen)F.scaled_dot_product_attention: expects(batch, ..., seqlen, headdim)— computes attention over the second-to-lastdim
The tensor layout was not updated when switching libraries in c1ee07e. As a result, SDPA interprets num_heads (4) as the sequence length, computing 4x4 cross-head attention instead of the intended self-attention along the disparity dimension as described in Section 3.2 of the paper.
I tested the released checkpoint with adding a .transpose(1, 2) before and after SDPA but it barely made a difference, which is curious. Q: Was the released checkpoint trained with
flash_attn_funcor withF.scaled_dot_product_attention?