Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_DISABLE_PYBIND
FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
)
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")
Expand Down
21 changes: 21 additions & 0 deletions csrc/flash_attn/flash_api_torch_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@ namespace FLASH_NAMESPACE {

////////////////////////////// From flash_api.cpp //////////////////////////////

std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_);

std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
Expand Down Expand Up @@ -105,6 +120,12 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
* Torch Library Registration
*/
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor? alibi_slopes, "
"float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, "
"float softcap, bool return_softmax, Generator? gen)"
"-> Tensor[]");
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));

ops.def("varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor cu_seqlens_q, "
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? block_table, Tensor? alibi_slopes, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
Expand Down
8 changes: 0 additions & 8 deletions hopper/static_switch.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,6 @@
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#elif defined(FLASHATTENTION_VARLEN_ONLY)
#define VARLEN_SWITCH(COND, CONST_NAME, ...) \
[&] { \
TORCH_CHECK(COND, "This flash attention build only supports varlen " \
"(for build size reasons)."); \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
}()
#else
#define VARLEN_SWITCH BOOL_SWITCH
#endif
Expand Down
102 changes: 96 additions & 6 deletions tests/test_vllm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from einops import rearrange, repeat

from vllm_flash_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported
is_fa_version_supported,
fa_version_unsupported_reason
)

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
Expand All @@ -23,15 +25,49 @@
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
VERSIONS = \
([2] if is_fa_version_supported(2) else []) + \
([3] if is_fa_version_supported(3) else [])
VERSIONS = [2, 3]


def construct_local_mask(
seqlen_q,
seqlen_k,
window_size=(-1, -1), # -1 means infinite window size
query_padding_mask=None,
key_padding_mask=None,
device=None,
key_leftpad=None,
):
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
if key_leftpad is not None:
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
col_idx < row_idx + sk - sq - window_size[0],
)


def ref_attn(
q,
k,
v,
scale,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
Expand Down Expand Up @@ -74,10 +110,11 @@ def ref_attn(
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
q *= scale
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
scores = torch.einsum("bthd,bshd->bhts", q, k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
scores = torch.einsum("bthd,bshd->bhts", q, k)

lse_ref = scores.logsumexp(dim=-1)

Expand Down Expand Up @@ -178,6 +215,59 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0)


@pytest.mark.parametrize("seq_len", [1, 10, 256, 533])
@pytest.mark.parametrize("batch_size", [1, 7, 32])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("fa_version", VERSIONS)
@torch.inference_mode()
def test_flash_attn(
seq_len: int,
batch_size: int,
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
soft_cap: Optional[float],
fa_version: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
scale = head_size**-0.5

query = torch.randn(
batch_size, seq_len, num_query_heads, head_size, dtype=dtype)
key = torch.randn(
batch_size, seq_len, num_kv_heads, head_size, dtype=dtype)
value = torch.randn(
batch_size, seq_len, num_kv_heads, head_size, dtype=dtype)

output = flash_attn_func(
query,
key,
value,
softmax_scale=scale,
causal=True,
softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version,
)

ref_output, _ = ref_attn(
q=query,
k=key,
v=value,
scale=scale,
causal=True,
softcap=soft_cap if soft_cap is not None else 0,
)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand Down
1 change: 1 addition & 0 deletions vllm_flash_attn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Use relative import to support build-from-source installation in vLLM
from .flash_attn_interface import (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
sparse_attn_func,
Expand Down
109 changes: 109 additions & 0 deletions vllm_flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,115 @@ def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x


def flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
fa_version: int = DEFAULT_FA_VERSION,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.

If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

if fa_version == 2:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops._vllm_fa2_C.fwd(
q,
k,
v,
out,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size[0], window_size[1],
softcap,
return_softmax_lse and dropout_p > 0,
None,
)
elif fa_version == 3:
out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd(
q, k, v,
None, None, # k_new, v_new
out,
None, None, # cu_seqlens_q, cu_seqlens_k
None, # cu_seqlens_k_new
None, None, # seqused_q, seqused_k
None, None, # max_seqlen_q, max_seqlen_k
None,
alibi_slopes,
None, # kv_batch_idx
None, None, # rotary_cos, rotary_sin
None, None, None, # q_descale, k_descale, v_descale
softmax_scale,
causal,
window_size[0], window_size[1],
0, # sink_token_length
softcap,
True, # rotary_interleaved
0, # num_splits
None, # pack_gqa
0, # sm_margin
)

return (out, softmax_lse) if return_softmax_lse else out


def flash_attn_varlen_func(
q,
k,
Expand Down