Skip to content

Commit

Permalink
Initial ROCm build working (missing .cpp->.cu copies)
Browse files Browse the repository at this point in the history
  • Loading branch information
Luka Govedič committed Jan 29, 2025
1 parent 96f4ed7 commit 5af0053
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 113 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ __pycache__/
# Distribution / packaging
bin/
build/
cmake-build-*/
develop-eggs/
dist/
eggs/
Expand Down
288 changes: 192 additions & 96 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# Likely should also be in sync with the vLLM version.
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.1")

find_python_constrained_versions(${PYTHON_SUPPORTED_VERSIONS})

Expand Down Expand Up @@ -91,7 +92,19 @@ if (NOT HIP_FOUND AND CUDA_FOUND)
"${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}")
elseif (HIP_FOUND)
message(FATAL_ERROR "ROCm build is not currently supported for vllm-flash-attn.")
set(VLLM_GPU_LANG "HIP")

# Importing torch recognizes and sets up some HIP/ROCm configuration but does
# not let cmake recognize .hip files. In order to get cmake to understand the
# .hip extension automatically, HIP must be enabled explicitly.
enable_language(HIP)

# ROCm 5.X and 6.X
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
"expected for ROCm build, saw ${Torch_VERSION} instead.")
endif ()
else ()
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
endif ()
Expand All @@ -110,129 +123,212 @@ if (NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_FA_GPU_FLAGS "--threads=${NVCC_THREADS}")
endif ()

# Replace instead of appending, nvcc doesn't like duplicate -O flags.
string(REPLACE "-O2" "-O3" CMAKE_${VLLM_GPU_LANG}_FLAGS_RELWITHDEBINFO "${CMAKE_${VLLM_GPU_LANG}_FLAGS_RELWITHDEBINFO}")

# Other flags
list(APPEND VLLM_FA_GPU_FLAGS --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math)

# If CUTLASS is compiled on NVCC >= 12.5, it by default uses
# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
# driver API. This causes problems when linking with earlier versions of CUDA.
# Setting this variable sidesteps the issue by calling the driver directly.
list(APPEND VLLM_FA_GPU_FLAGS -DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
if (VLLM_GPU_LANG STREQUAL "CUDA")
# Other flags
list(APPEND VLLM_FA_GPU_FLAGS --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math)

# Replace instead of appending, nvcc doesn't like duplicate -O flags.
string(REPLACE "-O2" "-O3" CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO}")
# If CUTLASS is compiled on NVCC >= 12.5, it by default uses
# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
# driver API. This causes problems when linking with earlier versions of CUDA.
# Setting this variable sidesteps the issue by calling the driver directly.
list(APPEND VLLM_FA_GPU_FLAGS -DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)

#
# _C extension
#
#
# _C extension
#

if (FA2_ENABLED)
file(GLOB FA2_GEN_SRCS "csrc/flash_attn/src/flash_fwd_*.cu")
if (FA2_ENABLED)
file(GLOB FA2_GEN_SRCS "csrc/flash_attn/src/flash_fwd_*.cu")

# For CUDA we set the architectures on a per file basis
if (VLLM_GPU_LANG STREQUAL "CUDA")
# For CUDA we set the architectures on a per file basis
cuda_archs_loose_intersection(FA2_ARCHS "8.0;9.0" "${CUDA_ARCHS}")
message(STATUS "FA2_ARCHS: ${FA2_ARCHS}")

set_gencode_flags_for_srcs(
SRCS "${FA2_GEN_SRCS}"
CUDA_ARCHS "${FA2_ARCHS}")
endif()

define_gpu_extension_target(
_vllm_fa2_C
DESTINATION vllm_flash_attn
LANGUAGE ${VLLM_GPU_LANG}
SOURCES
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api_sparse.cpp
csrc/flash_attn/flash_api_torch_lib.cpp
${FA2_GEN_SRCS}
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
USE_SABI 3
WITH_SOABI)

target_include_directories(_vllm_fa2_C PRIVATE
csrc/flash_attn
csrc/flash_attn/src
csrc/common
csrc/cutlass/include)

# custom definitions
target_compile_definitions(_vllm_fa2_C PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_DISABLE_PYBIND
)
endif ()
SRCS "${FA2_GEN_SRCS}"
CUDA_ARCHS "${FA2_ARCHS}")

define_gpu_extension_target(
_vllm_fa2_C
DESTINATION vllm_flash_attn
LANGUAGE ${VLLM_GPU_LANG}
SOURCES
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api_sparse.cpp
csrc/flash_attn/flash_api_torch_lib.cpp
${FA2_GEN_SRCS}
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
USE_SABI 3
WITH_SOABI)

target_include_directories(_vllm_fa2_C PRIVATE
csrc/flash_attn
csrc/flash_attn/src
csrc/common
csrc/cutlass/include)

# custom definitions
target_compile_definitions(_vllm_fa2_C PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_DISABLE_PYBIND
)
endif ()

# FA3 requires CUDA 12.0 or later
if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
# BF16 source files
file(GLOB FA3_BF16_GEN_SRCS
file(GLOB FA3_BF16_GEN_SRCS
"hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
file(GLOB FA3_BF16_GEN_SRCS_
file(GLOB FA3_BF16_GEN_SRCS_
"hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
file(GLOB FA3_FP16_GEN_SRCS
"hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
file(GLOB FA3_FP16_GEN_SRCS_
"hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})

# TODO add fp8 source files when FP8 is enabled
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS})
# TODO add fp8 source files when FP8 is enabled
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS})

# For CUDA we set the architectures on a per file basis
if (VLLM_GPU_LANG STREQUAL "CUDA")
# For CUDA we set the architectures on a per file basis
cuda_archs_loose_intersection(FA3_ARCHS "8.0;9.0a" "${CUDA_ARCHS}")
message(STATUS "FA3_ARCHS: ${FA3_ARCHS}")

set_gencode_flags_for_srcs(
SRCS "${FA3_GEN_SRCS}"
CUDA_ARCHS "${FA3_ARCHS}")
SRCS "${FA3_GEN_SRCS}"
CUDA_ARCHS "${FA3_ARCHS}")
set_gencode_flags_for_srcs(
SRCS "hopper/flash_fwd_combine.cu"
CUDA_ARCHS "${FA3_ARCHS}")
SRCS "hopper/flash_fwd_combine.cu"
CUDA_ARCHS "${FA3_ARCHS}")


define_gpu_extension_target(
_vllm_fa3_C
DESTINATION vllm_flash_attn
LANGUAGE ${VLLM_GPU_LANG}
SOURCES
hopper/flash_fwd_combine.cu
hopper/flash_api.cpp
hopper/flash_api_torch_lib.cpp
${FA3_GEN_SRCS}
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
ARCHITECTURES ${VLLM_FA_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)

target_include_directories(_vllm_fa3_C PRIVATE
hopper
csrc/common
csrc/cutlass/include)


# custom definitions
target_compile_definitions(_vllm_fa3_C PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_DISABLE_PYBIND
FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
)
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")
endif()

elseif (VLLM_GPU_LANG STREQUAL "HIP")
# CLang on ROCm
# --offload-compress required to keep size under 2GB (fails with errs)
list(APPEND VLLM_FA_GPU_FLAGS -ffast-math -fgpu-flush-denormals-to-zero --offload-compress)

# CK fails to compile below O2 as inlining is needed for certain inline assembly
string(REGEX REPLACE "-O(g|0)?" "-O2" CMAKE_HIP_FLAGS_DEBUG "${CMAKE_HIP_FLAGS_DEBUG}")

# Generate FA from CK example kernels
# Generate at configure time so we can glob
set(FA_GENERATED_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/gen)
set(CK_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/csrc/composable_kernel/example/ck_tile/01_fmha/generate.py)
file(MAKE_DIRECTORY ${FA_GENERATED_OUTDIR})
# TODO(luka) only run if required
foreach (KERNEL IN ITEMS "fwd" "fwd_appendkv" "fwd_splitkv" "bwd")
execute_process(
COMMAND
"${Python_EXECUTABLE}" "${CK_GEN_SCRIPT}" "-d" "${KERNEL}" "--output_dir" "${FA_GENERATED_OUTDIR}" "--receipt" "2"
RESULT_VARIABLE PYTHON_ERROR_CODE
ERROR_VARIABLE PYTHON_STDERR
OUTPUT_VARIABLE PYTHON_OUT
)
if (NOT PYTHON_ERROR_CODE EQUAL 0)
message(FATAL_ERROR "Cannot generate Python sources with error: ${PYTHON_ERROR_CODE}\n
stdout:${PYTHON_OUT}\n
stderr:${PYTHON_STDERR}")
endif ()
endforeach ()

file(GLOB FA3_GEN_SRCS "${FA_GENERATED_OUTDIR}/fmha_*wd*.cpp")
# Copy cpp files to hip because running hipify on them is a no-op as they only contain instantiations
foreach(FILE ${FA3_GEN_SRCS})
string(REGEX REPLACE "\.cpp$" ".hip" FILE_HIP ${FILE})
file(COPY_FILE ${FILE} ${FILE_HIP})
list(APPEND FA3_GEN_SRCS_CU ${FILE_HIP})
endforeach ()

# TODO: copy cpp->cu for correct hipification
# - try copying into gen/ or maybe even directly into build-tree (make sure that it's where hipify would copy it)
define_gpu_extension_target(
_vllm_fa3_C
DESTINATION vllm_flash_attn
LANGUAGE ${VLLM_GPU_LANG}
SOURCES
hopper/flash_fwd_combine.cu
hopper/flash_api.cpp
hopper/flash_api_torch_lib.cpp
${FA3_GEN_SRCS}
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
ARCHITECTURES ${VLLM_FA_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)

target_include_directories(_vllm_fa3_C PRIVATE
hopper
csrc/common
csrc/cutlass/include)

# custom definitions
target_compile_definitions(_vllm_fa3_C PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_DISABLE_PYBIND
FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
_vllm_fa2_C
DESTINATION vllm_flash_attn
LANGUAGE ${VLLM_GPU_LANG}
SOURCES
# csrc/flash_attn_ck/flash_api.cu # only contains declarations & PyBind
csrc/flash_attn_ck/flash_api_torch_lib.cpp
csrc/flash_attn_ck/flash_common.cu
csrc/flash_attn_ck/mha_bwd.cu
csrc/flash_attn_ck/mha_fwd_kvcache.cu
csrc/flash_attn_ck/mha_fwd.cu
csrc/flash_attn_ck/mha_varlen_bwd.cu
csrc/flash_attn_ck/mha_varlen_fwd.cu
${FA3_GEN_SRCS_CU}
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
USE_SABI 3
WITH_SOABI
# CPP_AS_HIP
)
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")
endif ()

target_include_directories(_vllm_fa2_C PRIVATE
csrc/common
csrc/composable_kernel/include
csrc/composable_kernel/library/include
csrc/composable_kernel/example/ck_tile/01_fmha
)

target_compile_definitions(_vllm_fa2_C PRIVATE
CK_TILE_FMHA_FWD_FAST_EXP2=1
CK_ENABLE_BF16
CK_ENABLE_BF8
CK_ENABLE_FP16
CK_ENABLE_FP32
CK_ENABLE_FP64
CK_ENABLE_FP8
CK_ENABLE_INT8
CK_USE_XDL
USE_PROF_API=1
# FLASHATTENTION_DISABLE_BACKWARD
__HIP_PLATFORM_HCC__=1
FLASHATTENTION_DISABLE_PYBIND
)

# Data section exceeds 2GB, compress HIP binaries
target_link_options(_vllm_fa2_C PRIVATE "--offload-compress")
endif ()
13 changes: 8 additions & 5 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
# Split into C++ and non-C++ (i.e. CUDA) sources.
#
set(SRCS ${ORIG_SRCS})
set(CXX_SRCS ${ORIG_SRCS})
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$")
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$")
set(EXCLUDED_SRCS ${ORIG_SRCS})
set(EXCLUDE_REGEX "\.(cc|cpp|hip)$")
list(FILTER SRCS EXCLUDE REGEX ${EXCLUDE_REGEX})
list(FILTER EXCLUDED_SRCS INCLUDE REGEX ${EXCLUDE_REGEX})
message(DEBUG "Excluded source files: ${EXCLUDED_SRCS}")

#
# Generate ROCm/HIP source file names from CUDA file names.
Expand All @@ -78,15 +80,16 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
endforeach()

set(CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/csrc)
set(CSRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/csrc)
add_custom_target(
hipify${NAME}
COMMAND ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS}
COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p "${CSRC_DIR}" -o "${CSRC_BUILD_DIR}" ${SRCS}
DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
BYPRODUCTS ${HIP_SRCS}
COMMENT "Running hipify on ${NAME} extension source files.")

# Swap out original extension sources with hipified sources.
list(APPEND HIP_SRCS ${CXX_SRCS})
list(APPEND HIP_SRCS ${EXCLUDED_SRCS})
set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
endfunction()

Expand Down
5 changes: 5 additions & 0 deletions csrc/flash_attn_ck/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits);

#ifndef FLASHATTENTION_DISABLE_PYBIND

#include <torch/python.h>

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.doc() = "FlashAttention";
Expand All @@ -120,3 +124,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
}
#endif
Loading

0 comments on commit 5af0053

Please sign in to comment.