For RoPEMaskedAttentionHead,
if return_attn_weights:
attn_mask = torch.tril(torch.ones((m,m)), diagonal=0)
attn_weights = torch.bmm(q_rotated, k_rotated.transpose(1,2)) / np.sqrt(d) + attn_mask
attn_weights = F.softmax(attn_weights, dim=-1)
return activations, attn_weights
return activations
I think it should be
if return_attn_weights:
attn_mask = torch.tril(torch.ones((m,m)),diagonal=0)
attn_mask = torch.where(attn_mask==1,torch.tensor(0),torch.tensor(float('-inf')))
attn_weights = torch.bmm(q_rotated,k_rotated.transpose(1,2)) / np.sqrt(d) + attn_mask
attn_weights = F.softmax(attn_weights,dim=-1)
return activations,attn_weights
return activations
For RoPEMaskedAttentionHead,
I think it should be