diff --git a/CMakeLists.txt b/CMakeLists.txt index 3961c1696..7ed01c988 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -284,21 +284,39 @@ elseif (VLLM_GPU_LANG STREQUAL "HIP") list(APPEND FA3_GEN_SRCS_CU ${FILE_HIP}) endforeach () - # TODO: copy cpp->cu for correct hipification - # - try copying into gen/ or maybe even directly into build-tree (make sure that it's where hipify would copy it) + # These files are "converted" to .cu before being passed to torch.build_extension on upstream. + # We need to do the same so that hipify treats them correctly. We copy the files in the source tree like upstream. + set(VLLM_FA2_CPP_CU_SRCS + # csrc/flash_attn_ck/flash_api.cpp # only contains declarations & PyBind + csrc/flash_attn_ck/flash_common.cpp + csrc/flash_attn_ck/mha_bwd.cpp + csrc/flash_attn_ck/mha_fwd_kvcache.cpp + csrc/flash_attn_ck/mha_fwd.cpp + csrc/flash_attn_ck/mha_varlen_bwd.cpp + csrc/flash_attn_ck/mha_varlen_fwd.cpp + ) + + foreach(CPP_FILE ${VLLM_FA2_CPP_CU_SRCS}) + string(REGEX REPLACE "\.cpp$" ".cu" CU_FILE ${CPP_FILE}) + set(CU_FILE_ABS ${CMAKE_CURRENT_SOURCE_DIR}/${CU_FILE}) + set(CPP_FILE_ABS ${CMAKE_CURRENT_SOURCE_DIR}/${CPP_FILE}) + add_custom_command( + OUTPUT ${CU_FILE_ABS} + COMMAND ${CMAKE_COMMAND} -E copy ${CPP_FILE_ABS} ${CU_FILE_ABS} + DEPENDS ${CPP_FILE_ABS} + COMMENT "Copying ${CPP_FILE} to ${CU_FILE_ABS}" + ) + list(APPEND VLLM_FA2_CU_SRCS ${CU_FILE}) # relative to source dir + endforeach () + + # This target automatically depends on the copy by depending on copied files define_gpu_extension_target( _vllm_fa2_C DESTINATION vllm_flash_attn LANGUAGE ${VLLM_GPU_LANG} SOURCES - # csrc/flash_attn_ck/flash_api.cu # only contains declarations & PyBind csrc/flash_attn_ck/flash_api_torch_lib.cpp - csrc/flash_attn_ck/flash_common.cu - csrc/flash_attn_ck/mha_bwd.cu - csrc/flash_attn_ck/mha_fwd_kvcache.cu - csrc/flash_attn_ck/mha_fwd.cu - csrc/flash_attn_ck/mha_varlen_bwd.cu - csrc/flash_attn_ck/mha_varlen_fwd.cu + ${VLLM_FA2_CU_SRCS} ${FA3_GEN_SRCS_CU} COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS} USE_SABI 3