Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ __pycache__/
# Distribution / packaging
bin/
build/
cmake-build-*/
develop-eggs/
dist/
eggs/
Expand Down
323 changes: 226 additions & 97 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ project(vllm_flash_attn LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_EXTENSIONS OFF)

set(FA2_ENABLED ON)
set(FA3_ENABLED ON)
option(FA2_ENABLED "Enable building Flash Attention 2" ON)
option(FA3_ENABLED "Enable building Flash Attention 3 (CUDA only)" ON)

option(VLLM_FA_API_ONLY "Only build backend kernels used by vLLM (varlen_fwd and fwd_kvcache)" ON)

# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
Expand Down Expand Up @@ -37,6 +39,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# Likely should also be in sync with the vLLM version.
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.1")

find_python_constrained_versions(${PYTHON_SUPPORTED_VERSIONS})

Expand Down Expand Up @@ -91,7 +94,19 @@ if (NOT HIP_FOUND AND CUDA_FOUND)
"${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}")
elseif (HIP_FOUND)
message(FATAL_ERROR "ROCm build is not currently supported for vllm-flash-attn.")
set(VLLM_GPU_LANG "HIP")

# Importing torch recognizes and sets up some HIP/ROCm configuration but does
# not let cmake recognize .hip files. In order to get cmake to understand the
# .hip extension automatically, HIP must be enabled explicitly.
enable_language(HIP)

# ROCm 5.X and 6.X
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
"expected for ROCm build, saw ${Torch_VERSION} instead.")
endif ()
else ()
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
endif ()
Expand All @@ -110,129 +125,243 @@ if (NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_FA_GPU_FLAGS "--threads=${NVCC_THREADS}")
endif ()

# Replace instead of appending, nvcc doesn't like duplicate -O flags.
string(REPLACE "-O2" "-O3" CMAKE_${VLLM_GPU_LANG}_FLAGS_RELWITHDEBINFO "${CMAKE_${VLLM_GPU_LANG}_FLAGS_RELWITHDEBINFO}")

# Other flags
list(APPEND VLLM_FA_GPU_FLAGS --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math)

# If CUTLASS is compiled on NVCC >= 12.5, it by default uses
# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
# driver API. This causes problems when linking with earlier versions of CUDA.
# Setting this variable sidesteps the issue by calling the driver directly.
list(APPEND VLLM_FA_GPU_FLAGS -DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
if (VLLM_GPU_LANG STREQUAL "CUDA")
# Other flags
list(APPEND VLLM_FA_GPU_FLAGS --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math)

# Replace instead of appending, nvcc doesn't like duplicate -O flags.
string(REPLACE "-O2" "-O3" CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO}")
# If CUTLASS is compiled on NVCC >= 12.5, it by default uses
# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
# driver API. This causes problems when linking with earlier versions of CUDA.
# Setting this variable sidesteps the issue by calling the driver directly.
list(APPEND VLLM_FA_GPU_FLAGS -DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)

#
# _C extension
#
#
# _C extension
#

if (FA2_ENABLED)
file(GLOB FA2_GEN_SRCS "csrc/flash_attn/src/flash_fwd_*.cu")
if (FA2_ENABLED)
file(GLOB FA2_GEN_SRCS "csrc/flash_attn/src/flash_fwd_*.cu")

# For CUDA we set the architectures on a per file basis
if (VLLM_GPU_LANG STREQUAL "CUDA")
# For CUDA we set the architectures on a per file basis
cuda_archs_loose_intersection(FA2_ARCHS "8.0;9.0" "${CUDA_ARCHS}")
message(STATUS "FA2_ARCHS: ${FA2_ARCHS}")

set_gencode_flags_for_srcs(
SRCS "${FA2_GEN_SRCS}"
CUDA_ARCHS "${FA2_ARCHS}")
endif()

define_gpu_extension_target(
_vllm_fa2_C
DESTINATION vllm_flash_attn
LANGUAGE ${VLLM_GPU_LANG}
SOURCES
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api_sparse.cpp
csrc/flash_attn/flash_api_torch_lib.cpp
${FA2_GEN_SRCS}
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
USE_SABI 3
WITH_SOABI)

target_include_directories(_vllm_fa2_C PRIVATE
csrc/flash_attn
csrc/flash_attn/src
csrc/common
csrc/cutlass/include)

# custom definitions
target_compile_definitions(_vllm_fa2_C PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_DISABLE_PYBIND
)
endif ()
SRCS "${FA2_GEN_SRCS}"
CUDA_ARCHS "${FA2_ARCHS}")

define_gpu_extension_target(
_vllm_fa2_C
DESTINATION vllm_flash_attn
LANGUAGE ${VLLM_GPU_LANG}
SOURCES
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api_sparse.cpp
csrc/flash_attn/flash_api_torch_lib.cpp
${FA2_GEN_SRCS}
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
USE_SABI 3
WITH_SOABI)

target_include_directories(_vllm_fa2_C PRIVATE
csrc/flash_attn
csrc/flash_attn/src
csrc/common
csrc/cutlass/include)

# custom definitions
target_compile_definitions(_vllm_fa2_C PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_DISABLE_PYBIND
)
endif ()

# FA3 requires CUDA 12.0 or later
if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
# BF16 source files
file(GLOB FA3_BF16_GEN_SRCS
file(GLOB FA3_BF16_GEN_SRCS
"hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
file(GLOB FA3_BF16_GEN_SRCS_
file(GLOB FA3_BF16_GEN_SRCS_
"hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
file(GLOB FA3_FP16_GEN_SRCS
"hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
file(GLOB FA3_FP16_GEN_SRCS_
"hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})

# TODO add fp8 source files when FP8 is enabled
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS})
# TODO add fp8 source files when FP8 is enabled
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS})

# For CUDA we set the architectures on a per file basis
if (VLLM_GPU_LANG STREQUAL "CUDA")
# For CUDA we set the architectures on a per file basis
cuda_archs_loose_intersection(FA3_ARCHS "8.0;9.0a" "${CUDA_ARCHS}")
message(STATUS "FA3_ARCHS: ${FA3_ARCHS}")

set_gencode_flags_for_srcs(
SRCS "${FA3_GEN_SRCS}"
CUDA_ARCHS "${FA3_ARCHS}")
SRCS "${FA3_GEN_SRCS}"
CUDA_ARCHS "${FA3_ARCHS}")
set_gencode_flags_for_srcs(
SRCS "hopper/flash_fwd_combine.cu"
CUDA_ARCHS "${FA3_ARCHS}")
SRCS "hopper/flash_fwd_combine.cu"
CUDA_ARCHS "${FA3_ARCHS}")


define_gpu_extension_target(
_vllm_fa3_C
DESTINATION vllm_flash_attn
LANGUAGE ${VLLM_GPU_LANG}
SOURCES
hopper/flash_fwd_combine.cu
hopper/flash_api.cpp
hopper/flash_api_torch_lib.cpp
${FA3_GEN_SRCS}
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
ARCHITECTURES ${VLLM_FA_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)

target_include_directories(_vllm_fa3_C PRIVATE
hopper
csrc/common
csrc/cutlass/include)


# custom definitions
target_compile_definitions(_vllm_fa3_C PRIVATE
$<IF:${VLLM_FA_API_ONLY},FLASHATTENTION_DISABLE_BACKWARD,> # vLLM API does not require bwd
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_DISABLE_PYBIND
FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
)
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")
endif()
elseif (VLLM_GPU_LANG STREQUAL "HIP")
# CLang on ROCm
# --offload-compress required to keep size under 2GB (fails with errs)
list(APPEND VLLM_FA_GPU_FLAGS -ffast-math -fgpu-flush-denormals-to-zero --offload-compress)

# CK fails to compile below O2 as inlining is needed for certain inline assembly
string(REGEX REPLACE "-O(g|0)?" "-O2" CMAKE_HIP_FLAGS_DEBUG "${CMAKE_HIP_FLAGS_DEBUG}")

# Generate FA from CK example kernels
# Generate at configure time so we can glob
set(FA_GENERATED_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/gen)
set(CK_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/csrc/composable_kernel/example/ck_tile/01_fmha/generate.py)
file(MAKE_DIRECTORY ${FA_GENERATED_OUTDIR})

# bwd kernels not required for vLLM API
set(ROCm_CK_KERNELS "fwd" "fwd_appendkv" "fwd_splitkv")
if(NOT VLLM_FA_API_ONLY)
list(APPEND ROCm_CK_KERNELS "bwd")
endif()

# TODO(luka) only run if required
foreach (KERNEL IN LISTS ROCm_CK_KERNELS)
execute_process(
COMMAND
"${Python_EXECUTABLE}" "${CK_GEN_SCRIPT}" "-d" "${KERNEL}" "--output_dir" "${FA_GENERATED_OUTDIR}" "--receipt" "2"
RESULT_VARIABLE PYTHON_ERROR_CODE
ERROR_VARIABLE PYTHON_STDERR
OUTPUT_VARIABLE PYTHON_OUT
)
if (NOT PYTHON_ERROR_CODE EQUAL 0)
message(FATAL_ERROR "Cannot generate Python sources with error: ${PYTHON_ERROR_CODE}\n
stdout:${PYTHON_OUT}\n
stderr:${PYTHON_STDERR}")
endif ()
endforeach ()

file(GLOB FA3_GEN_SRCS "${FA_GENERATED_OUTDIR}/fmha_*wd*.cpp")
# Copy cpp files to hip because running hipify on them is a no-op as they only contain instantiations
foreach(FILE ${FA3_GEN_SRCS})
string(REGEX REPLACE "\.cpp$" ".hip" FILE_HIP ${FILE})
file(COPY_FILE ${FILE} ${FILE_HIP})
list(APPEND FA3_GEN_SRCS_CU ${FILE_HIP})
endforeach ()

# These files are "converted" to .cu before being passed to torch.build_extension on upstream.
# We need to do the same so that hipify treats them correctly. We copy the files in the source tree like upstream.
set(VLLM_FA2_CPP_CU_SRCS
# csrc/flash_attn_ck/flash_api.cpp # only contains declarations & PyBind
csrc/flash_attn_ck/flash_common.cpp
csrc/flash_attn_ck/mha_fwd_kvcache.cpp
csrc/flash_attn_ck/mha_varlen_fwd.cpp
)

# The rest are not required for vLLM API
if(NOT VLLM_FA_API_ONLY)
list(APPEND VLLM_FA2_CPP_CU_SRCS
csrc/flash_attn_ck/mha_bwd.cpp
csrc/flash_attn_ck/mha_fwd.cpp
csrc/flash_attn_ck/mha_varlen_bwd.cpp>
)
endif ()

foreach(CPP_FILE ${VLLM_FA2_CPP_CU_SRCS})
string(REGEX REPLACE "\.cpp$" ".cu" CU_FILE ${CPP_FILE})
set(CU_FILE_ABS ${CMAKE_CURRENT_SOURCE_DIR}/${CU_FILE})
set(CPP_FILE_ABS ${CMAKE_CURRENT_SOURCE_DIR}/${CPP_FILE})
add_custom_command(
OUTPUT ${CU_FILE_ABS}
COMMAND ${CMAKE_COMMAND} -E copy ${CPP_FILE_ABS} ${CU_FILE_ABS}
DEPENDS ${CPP_FILE_ABS}
COMMENT "Copying ${CPP_FILE} to ${CU_FILE_ABS}"
)
list(APPEND VLLM_FA2_CU_SRCS ${CU_FILE}) # relative to source dir
endforeach ()

# This target automatically depends on the copy by depending on copied files
define_gpu_extension_target(
_vllm_fa3_C
DESTINATION vllm_flash_attn
LANGUAGE ${VLLM_GPU_LANG}
SOURCES
hopper/flash_fwd_combine.cu
hopper/flash_api.cpp
hopper/flash_api_torch_lib.cpp
${FA3_GEN_SRCS}
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
ARCHITECTURES ${VLLM_FA_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)

target_include_directories(_vllm_fa3_C PRIVATE
hopper
csrc/common
csrc/cutlass/include)

# custom definitions
target_compile_definitions(_vllm_fa3_C PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_DISABLE_PYBIND
FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
_vllm_fa2_C
DESTINATION vllm_flash_attn
LANGUAGE ${VLLM_GPU_LANG}
SOURCES
csrc/flash_attn_ck/flash_api_torch_lib.cpp
${VLLM_FA2_CU_SRCS}
${FA3_GEN_SRCS_CU}
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
USE_SABI 3
WITH_SOABI
# CPP_AS_HIP
)
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")
endif ()

target_include_directories(_vllm_fa2_C PRIVATE
csrc/common
csrc/composable_kernel/include
csrc/composable_kernel/library/include
csrc/composable_kernel/example/ck_tile/01_fmha
)

target_compile_definitions(_vllm_fa2_C PRIVATE
CK_TILE_FMHA_FWD_FAST_EXP2=1
CK_ENABLE_BF16
CK_ENABLE_BF8
CK_ENABLE_FP16
CK_ENABLE_FP32
CK_ENABLE_FP64
CK_ENABLE_FP8
CK_ENABLE_INT8
CK_USE_XDL
USE_PROF_API=1
# FLASHATTENTION_DISABLE_BACKWARD
__HIP_PLATFORM_HCC__=1
FLASHATTENTION_DISABLE_PYBIND
)

# Data section exceeds 2GB, compress HIP binaries
target_link_options(_vllm_fa2_C PRIVATE "--offload-compress")
endif ()
Loading