diff --git a/.gitignore b/.gitignore index c5a6a3501..14efeef18 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,8 @@ src/dispatch.inc src/generated/ python/csrc/generated/ python/flashinfer/_build_meta.py +python/flashinfer/jit/aot_config.py +flashinfer-aot/csrc_aot/generated/ # Generated documentation files docs/generated diff --git a/3rdparty/cutlass b/3rdparty/cutlass index 033d9efd2..f7b19de32 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit 033d9efd2db0bbbcf3b3b0650acde6c472f3948e +Subproject commit f7b19de32c5d1f3cedfc735c2849f12b537522ee diff --git a/CMakeLists.txt b/CMakeLists.txt index 68c2b6cb7..54c116449 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,7 +40,6 @@ flashinfer_option(FLASHINFER_TVM_SOURCE_DIR "The path to tvm for building tvm bi # The following configurations can impact the binary # size of the generated library flashinfer_option(FLASHINFER_GEN_HEAD_DIMS "Head dims to enable" 64 128 256) -flashinfer_option(FLASHINFER_GEN_LOGITS_POST_HOOKS "Logits post hooks" 0 1) flashinfer_option(FLASHINFER_GEN_POS_ENCODING_MODES "Pos encodings to enable" 0 1 2) flashinfer_option(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "QK reductions to enable" "false" "true") flashinfer_option(FLASHINFER_GEN_MASK_MODES "Mask modes to enable" 0 1 2) @@ -81,7 +80,6 @@ endif(FLASHINFER_ENABLE_BF16) # generate kernel inst set (HEAD_DIMS ${FLASHINFER_GEN_HEAD_DIMS}) -set (LOGITS_POST_HOOKS ${FLASHINFER_GEN_LOGITS_POST_HOOKS}) set (POS_ENCODING_MODES ${FLASHINFER_GEN_POS_ENCODING_MODES}) set (ALLOW_FP16_QK_REDUCTIONS ${FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS}) set (MASK_MODES ${FLASHINFER_GEN_MASK_MODES}) @@ -112,8 +110,8 @@ file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated) set(dispatch_inc_file ${PROJECT_SOURCE_DIR}/src/dispatch.inc) add_custom_command( OUTPUT ${dispatch_inc_file} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --logits_post_hooks ${LOGITS_POST_HOOKS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} + DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_dispatch_inc.py COMMENT "Generating additional source file ${generated_dispatch_inc}" VERBATIM ) @@ -121,182 +119,172 @@ add_custom_target(dispatch_inc DEPENDS ${dispatch_inc_file}) # single decode kernel inst generation foreach(head_dim IN LISTS HEAD_DIMS) - foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) - foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + foreach(dtype IN LISTS DECODE_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_decode_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND single_decode_kernels_src ${generated_kernel_src}) + endforeach(dtype) + + # fp8 kv-cache + foreach(dtype_kv IN LISTS DECODE_FP8_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_decode_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND single_decode_kernels_src ${generated_kernel_src}) + endforeach(dtype_kv) + endforeach(pos_encoding_mode) +endforeach(head_dim) + +# batch decode kernel inst generation +foreach(head_dim IN LISTS HEAD_DIMS) + foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + # paged kv-cache + foreach(idtype IN LISTS IDTYPES) foreach(dtype IN LISTS DECODE_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_decode_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) - list(APPEND single_decode_kernels_src ${generated_kernel_src}) + list(APPEND batch_decode_kernels_src ${generated_kernel_src}) endforeach(dtype) # fp8 kv-cache foreach(dtype_kv IN LISTS DECODE_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_decode_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) - list(APPEND single_decode_kernels_src ${generated_kernel_src}) + list(APPEND batch_decode_kernels_src ${generated_kernel_src}) endforeach(dtype_kv) - endforeach(pos_encoding_mode) - endforeach(logits_post_hook) + endforeach(idtype) + endforeach(pos_encoding_mode) endforeach(head_dim) -# batch decode kernel inst generation +add_library(decode_kernels STATIC ${single_decode_kernels_src} ${batch_decode_kernels_src}) +target_include_directories(decode_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR}) +target_compile_options(decode_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options -compress-all) + +# single prefill kernel inst generation foreach(head_dim IN LISTS HEAD_DIMS) - foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) - foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) - # paged kv-cache - foreach(idtype IN LISTS IDTYPES) - foreach(dtype IN LISTS DECODE_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) + foreach(mask_mode IN LISTS MASK_MODES) + foreach(dtype IN LISTS PREFILL_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_prefill_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) - list(APPEND batch_decode_kernels_src ${generated_kernel_src}) + list(APPEND single_prefill_kernels_src ${generated_kernel_src}) endforeach(dtype) - # fp8 kv-cache - foreach(dtype_kv IN LISTS DECODE_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu) + foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_single_prefill_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) - list(APPEND batch_decode_kernels_src ${generated_kernel_src}) + list(APPEND single_prefill_kernels_src ${generated_kernel_src}) endforeach(dtype_kv) - endforeach(idtype) - endforeach(pos_encoding_mode) - endforeach(logits_post_hook) + endforeach(mask_mode) + endforeach(allow_fp16_qk_reduction) + endforeach(pos_encoding_mode) endforeach(head_dim) -add_library(decode_kernels STATIC ${single_decode_kernels_src} ${batch_decode_kernels_src}) -target_include_directories(decode_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR}) -target_compile_options(decode_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options -compress-all) - -# single prefill kernel inst generation +# batch paged prefill kernel inst generation foreach(head_dim IN LISTS HEAD_DIMS) - foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) - foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) - foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) - foreach(mask_mode IN LISTS MASK_MODES) + foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) + foreach(mask_mode IN LISTS MASK_MODES) + foreach(idtype IN LISTS IDTYPES) foreach(dtype IN LISTS PREFILL_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_prefill_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) - list(APPEND single_prefill_kernels_src ${generated_kernel_src}) + list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src}) endforeach(dtype) foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_paged_prefill_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) - list(APPEND single_prefill_kernels_src ${generated_kernel_src}) + list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src}) endforeach(dtype_kv) - endforeach(mask_mode) - endforeach(allow_fp16_qk_reduction) - endforeach(pos_encoding_mode) - endforeach(logits_post_hook) -endforeach(head_dim) - -# batch paged prefill kernel inst generation -foreach(head_dim IN LISTS HEAD_DIMS) - foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) - foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) - foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) - foreach(mask_mode IN LISTS MASK_MODES) - foreach(idtype IN LISTS IDTYPES) - foreach(dtype IN LISTS PREFILL_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src}) - endforeach(dtype) - - foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src}) - endforeach(dtype_kv) - endforeach(idtype) - endforeach(mask_mode) - endforeach(allow_fp16_qk_reduction) - endforeach(pos_encoding_mode) - endforeach(logits_post_hook) + endforeach(idtype) + endforeach(mask_mode) + endforeach(allow_fp16_qk_reduction) + endforeach(pos_encoding_mode) endforeach(head_dim) # batch ragged prefill kernel inst generation foreach(head_dim IN LISTS HEAD_DIMS) - foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) - foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) - foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) - foreach(mask_mode IN LISTS MASK_MODES) - foreach(idtype IN LISTS IDTYPES) - foreach(dtype IN LISTS PREFILL_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src}) - endforeach(dtype) - - foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src}) - endforeach(dtype_kv) - endforeach(idtype) - endforeach(mask_mode) - endforeach(allow_fp16_qk_reduction) - endforeach(pos_encoding_mode) - endforeach(logits_post_hook) + foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) + foreach(mask_mode IN LISTS MASK_MODES) + foreach(idtype IN LISTS IDTYPES) + foreach(dtype IN LISTS PREFILL_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_ragged_prefill_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src}) + endforeach(dtype) + + foreach(dtype_kv IN LISTS PREFILL_FP8_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_head_${head_dim}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/flashinfer-aot/generate_batch_ragged_prefill_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src}) + endforeach(dtype_kv) + endforeach(idtype) + endforeach(mask_mode) + endforeach(allow_fp16_qk_reduction) + endforeach(pos_encoding_mode) endforeach(head_dim) add_library(prefill_kernels STATIC ${single_prefill_kernels_src} ${batch_paged_prefill_kernels_src} ${batch_ragged_prefill_kernels_src}) @@ -488,8 +476,6 @@ if(FLASHINFER_FASTDEQUANT_TEST) target_link_libraries(test_fast_dequant PRIVATE gtest gtest_main) endif(FLASHINFER_FASTDEQUANT_TEST) - - if (FLASHINFER_DISTRIBUTED) find_package(MPI REQUIRED) diff --git a/cmake/config.cmake b/cmake/config.cmake index 0d51e4916..475721add 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -24,7 +24,6 @@ set(FLASHINFER_FASTDEQUANT_TEST ON) set(FLASHINFER_DISTRIBUTED ON) # The following configurations can impact the binary # size of the generated library -set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0) set(FLASHINFER_GEN_HEAD_DIMS 64 128 256) set(FLASHINFER_GEN_KV_LAYOUTS 0 1) set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2) diff --git a/flashinfer-aot/3rdparty b/flashinfer-aot/3rdparty new file mode 120000 index 000000000..303a6484e --- /dev/null +++ b/flashinfer-aot/3rdparty @@ -0,0 +1 @@ +../3rdparty \ No newline at end of file diff --git a/flashinfer-aot/MANIFEST.in b/flashinfer-aot/MANIFEST.in new file mode 100644 index 000000000..b20747fef --- /dev/null +++ b/flashinfer-aot/MANIFEST.in @@ -0,0 +1,12 @@ +# sdist & wheel +include version.txt +recursive-include include * +recursive-include csrc * +recursive-include 3rdparty/cutlass * + +# wheel-only +exclude flashinfer/_build_meta.py + +# Unneeded files +prune */__pycache__ +global-exclude *.so diff --git a/flashinfer-aot/csrc b/flashinfer-aot/csrc new file mode 120000 index 000000000..bf5627220 --- /dev/null +++ b/flashinfer-aot/csrc @@ -0,0 +1 @@ +../python/csrc \ No newline at end of file diff --git a/python/csrc/activation.cu b/flashinfer-aot/csrc_aot/activation.cu similarity index 81% rename from python/csrc/activation.cu rename to flashinfer-aot/csrc_aot/activation.cu index ef3a781a2..d2866ccc4 100644 --- a/python/csrc/activation.cu +++ b/flashinfer-aot/csrc_aot/activation.cu @@ -18,11 +18,25 @@ #include -#include "flashinfer_ops.h" #include "pytorch_extension_utils.h" using namespace flashinfer; +__device__ __forceinline__ float silu(const float& val) { + return val / (1.0f + __expf(-val)); +} + +__device__ __forceinline__ float gelu(const float& val) { + constexpr float kAlpha = M_SQRT1_2; + return val * 0.5f * (1.0f + ::erf(val * kAlpha)); +} + +__device__ __forceinline__ float gelu_tanh(const float& val) { + const float cdf = + 0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val)))); + return val * cdf; +} + void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1); @@ -33,7 +47,7 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); - flashinfer::activation::act_and_mul_kernel + flashinfer::activation::act_and_mul_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); @@ -51,7 +65,7 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); - flashinfer::activation::act_and_mul_kernel + flashinfer::activation::act_and_mul_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); @@ -69,7 +83,7 @@ void gelu_and_mul(torch::Tensor& out, torch::Tensor& input) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); - flashinfer::activation::act_and_mul_kernel + flashinfer::activation::act_and_mul_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); diff --git a/flashinfer-aot/csrc_aot/batch_decode.cu b/flashinfer-aot/csrc_aot/batch_decode.cu new file mode 100644 index 000000000..f9e4796f8 --- /dev/null +++ b/flashinfer-aot/csrc_aot/batch_decode.cu @@ -0,0 +1,205 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include + +#include "pytorch_extension_utils.h" + +namespace flashinfer { + +template +cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +} // namespace flashinfer + +std::vector BatchDecodeWithPagedKVCachePlan( + bool use_logits_soft_cap, unsigned int head_dim, torch::Tensor empty_q_data, + torch::Tensor empty_kv_data, torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, + torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + auto device = float_workspace_buffer.device(); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + indptr = indptr.to(torch::kCPU); + + DecodePlanInfo plan_info; + + using IdType = int32_t; + // check indptr has idtype int32 + TORCH_CHECK(indptr.scalar_type() == torch::kInt32, "indptr must be int32"); + constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; + + auto q_scalar_type = empty_q_data.scalar_type(); + auto kv_scalar_type = empty_kv_data.scalar_type(); + + DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { + using DTypeQ = q_type; + using DTypeKV = kv_type; + using DTypeO = DTypeQ; + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { + using ParamsT = BatchDecodeParams; + using AttentionVariant = + ComposedAttention; + + cudaError_t status = DecodePlan( + static_cast(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, + static_cast(int_workspace_buffer.data_ptr()), + static_cast(page_locked_int_workspace_buffer.data_ptr()), + int_workspace_size_in_bytes, plan_info, static_cast(indptr.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, page_size, enable_cuda_graph, + /*stream=*/torch_current_stream); + + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + + return plan_info.ToVector(); +} + +std::vector BatchDecodeWithPagedKVCacheRun( + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + std::vector plan_info_vec, torch::Tensor q, + std::optional paged_kv_cache, std::optional paged_k_cache, + std::optional paged_v_cache, torch::Tensor paged_kv_indptr, + torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, + std::optional alibi_slopes, unsigned int kv_layout_code, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + DecodePlanInfo plan_info; + plan_info.FromVector(plan_info_vec); + QKVLayout kv_layout = static_cast(kv_layout_code); + bool paged_kv_defined = paged_kv_cache.has_value(); + auto device = q.device(); + int64_t batch_size = q.size(0); + int64_t num_qo_heads = q.size(1); + int64_t num_kv_heads, page_size; + if (paged_kv_defined) { + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_kv_cache->size(2); + page_size = paged_kv_cache->size(3); + } else { + page_size = paged_kv_cache->size(2); + num_kv_heads = paged_kv_cache->size(3); + } + } else { + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache->size(1); + page_size = paged_k_cache->size(2); + } else { + page_size = paged_k_cache->size(1); + num_kv_heads = paged_k_cache->size(2); + } + } + uint32_t head_dim = q.size(2); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + torch::Tensor o = torch::empty_like(q); + torch::Tensor lse; + if (return_lse) { + lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32))); + } + + TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); + + void* float_buffer = static_cast(float_workspace_buffer.data_ptr()); + void* int_buffer = static_cast(int_workspace_buffer.data_ptr()); + + using IdType = int32_t; + constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; + + // get q_scalar_type and kv_scalar_type + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = + paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); + + DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { + using DTypeQ = q_type; + using DTypeKV = kv_type; + using DTypeO = DTypeQ; + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_LOGITS_SOFT_CAP(logits_soft_cap > 0, USE_LOGITS_SOFT_CAP, [&] { + using ParamsT = BatchDecodeParams; + using AttentionVariant = + ComposedAttention; + + paged_kv_t paged_kv( + num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout, + static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() + : nullptr), + static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), + static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + ParamsT params(static_cast(q.data_ptr()), + /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + /*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap, + sm_scale, rope_scale, rope_theta); + + DTypeO* tmp_v = nullptr; + float* tmp_s = nullptr; + params.request_indices = + GetPtrFromBaseOffset(int_buffer, plan_info.request_indices_offset); + params.kv_tile_indices = + GetPtrFromBaseOffset(int_buffer, plan_info.kv_tile_indices_offset); + params.o_indptr = GetPtrFromBaseOffset(int_buffer, plan_info.o_indptr_offset); + params.kv_chunk_size_ptr = + GetPtrFromBaseOffset(int_buffer, plan_info.kv_chunk_size_ptr_offset); + if (plan_info.split_kv) { + tmp_v = GetPtrFromBaseOffset(float_buffer, plan_info.v_offset); + tmp_s = GetPtrFromBaseOffset(float_buffer, plan_info.s_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer, plan_info.block_valid_mask_offset); + } + } + params.padded_batch_size = plan_info.padded_batch_size; + + cudaError_t status = + flashinfer::BatchDecodeWithPagedKVCacheDispatched( + params, tmp_v, tmp_s, /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} diff --git a/flashinfer-aot/csrc_aot/batch_prefill.cu b/flashinfer-aot/csrc_aot/batch_prefill.cu new file mode 100644 index 000000000..ce9433789 --- /dev/null +++ b/flashinfer-aot/csrc_aot/batch_prefill.cu @@ -0,0 +1,340 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include + +#include "pytorch_extension_utils.h" + +namespace flashinfer { + +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +} // namespace flashinfer + +std::vector BatchPrefillWithKVCachePlan( + unsigned int head_dim, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor page_locked_int_workspace_buffer, torch::Tensor qo_indptr, + torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + + auto device = float_workspace_buffer.device(); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + qo_indptr = qo_indptr.to(torch::kCPU); + kv_indptr = kv_indptr.to(torch::kCPU); + + PrefillPlanInfo plan_info; + + using IdType = int32_t; + + cudaError_t status = PrefillPlan( + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), + kv_indptr.data_ptr(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, + enable_cuda_graph, /*sizeof_dtype_o=*/2, torch_current_stream); + + TORCH_CHECK(status == cudaSuccess, + "Failed to plan prefill with error: ", cudaGetErrorString(status)); + + return plan_info.ToVector(); +} + +std::vector BatchPrefillWithRaggedKVCacheRun( + unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, + torch::Tensor k, torch::Tensor v, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, + torch::Tensor kv_indptr, std::optional maybe_qk_indptr, unsigned int layout, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + bool return_lse) { + PrefillPlanInfo plan_info; + plan_info.FromVector(plan_info_vec); + QKVLayout kv_layout = static_cast(layout); + + int64_t num_qo_heads = q.size(1); + int64_t head_dim = q.size(2); + int64_t num_kv_heads = (kv_layout == QKVLayout::kNHD) ? k.size(1) : k.size(0); + uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; + if (kv_layout == QKVLayout::kNHD) { + kv_stride_n = k.stride(0); + kv_stride_h = k.stride(1); + } else { + kv_stride_h = k.stride(0); + kv_stride_n = k.stride(1); + } + + auto device = float_workspace_buffer.device(); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + auto o = torch::empty_like(q, q.options()); + int64_t nnz_qo = q.size(0); + torch::Tensor lse = torch::empty({0}); + if (return_lse) { + lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); + } + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; + const MaskMode mask_mode = static_cast(mask_mode_code); + const bool use_logits_soft_cap = logits_soft_cap > 0.f; + using IdType = int32_t; + + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = k.scalar_type(); + + DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { + using DTypeQ = q_type; + using DTypeKV = kv_type; + using DTypeO = DTypeQ; + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { + using RaggedParamsT = BatchPrefillRaggedParams; + using RaggedAttentionVariant = + ComposedAttention; + + RaggedParamsT params( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + maybe_custom_mask.has_value() ? static_cast(maybe_custom_mask->data_ptr()) + : nullptr, + static_cast(qo_indptr.data_ptr()), + static_cast(kv_indptr.data_ptr()), + maybe_qk_indptr.has_value() ? static_cast(maybe_qk_indptr->data_ptr()) + : nullptr, + /*q_offset=*/nullptr, + /*k_rope_pos_offset=*/nullptr, static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, + kv_stride_n, kv_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, + rope_theta); + + DTypeO* tmp_v = nullptr; + float* tmp_s = nullptr; + + params.request_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.request_indices_offset); + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.kv_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_tile_indices_offset); + params.o_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.o_indptr_offset); + params.kv_chunk_size_ptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset); + if (plan_info.split_kv) { + params.merge_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); + tmp_v = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.v_offset); + tmp_s = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.s_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); + } + } + params.total_num_rows = plan_info.total_num_rows; + params.padded_batch_size = plan_info.padded_batch_size; + + WarpLayout warp_layout = WarpLayout(plan_info.warp_layout_code); + cudaError_t status = cudaSuccess; + + DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { + status = flashinfer::BatchPrefillWithRaggedKVCacheDispatched< + WARP_LAYOUT, HEAD_DIM, POS_ENCODING_MODE, + /*use_fp16_qk_reduction=*/false, MASK_MODE, RaggedAttentionVariant>( + params, tmp_v, tmp_s, torch_current_stream); + }); + + TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + }); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} + +std::vector BatchPrefillWithPagedKVCacheRun( + unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, + std::optional paged_kv_cache, std::optional paged_k_cache, + std::optional paged_v_cache, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, + torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, std::optional maybe_qk_indptr, + unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse) { + PrefillPlanInfo plan_info; + plan_info.FromVector(plan_info_vec); + QKVLayout kv_layout = static_cast(layout); + bool paged_kv_defined = paged_kv_cache.has_value(); + auto device = q.device(); + int64_t batch_size = paged_kv_indptr.size(0) - 1; + int64_t num_qo_heads = q.size(1); + int64_t num_kv_heads, page_size; + uint32_t head_dim = q.size(2); + if (paged_kv_defined) { + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_kv_cache->size(2); + page_size = paged_kv_cache->size(3); + } else { + page_size = paged_kv_cache->size(2); + num_kv_heads = paged_kv_cache->size(3); + } + } else { + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache->size(1); + page_size = paged_k_cache->size(2); + } else { + page_size = paged_k_cache->size(1); + num_kv_heads = paged_k_cache->size(2); + } + } + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + auto o = torch::empty_like(q, q.options()); + int64_t nnz_qo = q.size(0); + torch::Tensor lse = torch::empty({0}); + if (return_lse) { + lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); + } + + void* float_buffer_ptr = static_cast(float_workspace_buffer.data_ptr()); + void* int_buffer_ptr = static_cast(int_workspace_buffer.data_ptr()); + + constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; + const MaskMode mask_mode = static_cast(mask_mode_code); + using IdType = int32_t; + bool use_logits_soft_cap = logits_soft_cap > 0.f; + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = + paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); + + DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { + using DTypeQ = q_type; + using DTypeKV = kv_type; + using DTypeO = DTypeQ; + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout, + static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() + : nullptr), + static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() + : nullptr), + static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() + : nullptr), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + + using PagedParamsT = BatchPrefillPagedParams; + using PagedAttentionVariant = + ComposedAttention; + + PagedParamsT params( + static_cast(q.data_ptr()), paged_kv, + maybe_custom_mask.has_value() ? static_cast(maybe_custom_mask->data_ptr()) + : nullptr, + static_cast(qo_indptr.data_ptr()), + maybe_qk_indptr.has_value() ? static_cast(maybe_qk_indptr->data_ptr()) + : nullptr, + /*q_offset=*/nullptr, static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + /*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap, sm_scale, + rope_scale, rope_theta); + + DTypeO* tmp_v = nullptr; + float* tmp_s = nullptr; + + params.request_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.request_indices_offset); + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.kv_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_tile_indices_offset); + params.o_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.o_indptr_offset); + params.kv_chunk_size_ptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset); + if (plan_info.split_kv) { + params.merge_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); + tmp_v = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.v_offset); + tmp_s = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.s_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); + } + } + params.total_num_rows = plan_info.total_num_rows; + params.padded_batch_size = plan_info.padded_batch_size; + + WarpLayout warp_layout = WarpLayout(plan_info.warp_layout_code); + cudaError_t status = cudaSuccess; + + DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { + status = flashinfer::BatchPrefillWithPagedKVCacheDispatched< + WARP_LAYOUT, HEAD_DIM, POS_ENCODING_MODE, + /*use_fp16_qk_reduction=*/false, MASK_MODE, PagedAttentionVariant>( + params, tmp_v, tmp_s, torch_current_stream); + }); + + TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); + }); + }); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} diff --git a/python/csrc/flashinfer_ops.h b/flashinfer-aot/csrc_aot/flashinfer_ops.cu similarity index 68% rename from python/csrc/flashinfer_ops.h rename to flashinfer-aot/csrc_aot/flashinfer_ops.cu index de2ba379e..ee2091bfa 100644 --- a/python/csrc/flashinfer_ops.h +++ b/flashinfer-aot/csrc_aot/flashinfer_ops.cu @@ -16,10 +16,6 @@ #pragma once #include -#include -#include -#include - void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, torch::Tensor append_indptr, std::optional paged_kv_cache, std::optional paged_k_cache, @@ -114,18 +110,46 @@ torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, torch::Tensor& A_scale, torch::Tensor& B_scale); -class CutlassSegmentGEMMPyTorchWrapper { - public: - void RegisterWorkspaceBuffer(torch::Tensor workspace_buffer); - - torch::Tensor Run(torch::Tensor seg_indptr, torch::Tensor weight_indices, torch::Tensor x, - torch::Tensor weight, unsigned int batch_size, bool weight_column_major); - - CutlassSegmentGEMMPyTorchWrapper(torch::Tensor workspace_buffer) - : handler_(std::make_shared()) { - RegisterWorkspaceBuffer(workspace_buffer); - } - - private: - std::shared_ptr handler_; -}; +torch::Tensor CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); + m.def("merge_state", &merge_state, "Merge two self-attention states"); + m.def("merge_state_in_place", &merge_state_in_place, + "Merge another self-attention state in-place."); + m.def("merge_states", &merge_states, "Merge multiple self-attention states"); + m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); + m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, + "Top-k sampling from probabilities"); + m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, + "Min-p sampling from probabilities"); + m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, + "Top-p sampling from probabilities"); + m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, + "Top-k and top-p sampling from probabilities"); + m.def("top_k_renorm_probs", &top_k_renorm_probs, "Renormalize probabilities by top-k mask"); + m.def("top_p_renorm_probs", &top_p_renorm_probs, "Renormalize probabilities by top-p mask"); + m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); + m.def("chain_speculative_sampling", &chain_speculative_sampling, + "Speculative sampling from sequence of probabilities"); + m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); + m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); + m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); + m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, + "Gemma Fused add root mean square normalization"); + m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); + m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); + m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); + m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); + m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, + "Apply Llama 3.1 style RoPE in-place"); + m.def("apply_rope", &apply_rope, "Apply RoPE"); + m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); + m.def("packbits", &packbits, "GPU packbits operator"); + m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); + m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator"); + m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); +} diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu b/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu new file mode 100644 index 000000000..a9a666a31 --- /dev/null +++ b/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, + torch::Tensor tmp, + std::optional alibi_slopes, + unsigned int layout, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta); + +std::vector BatchDecodeWithPagedKVCachePlan( + bool use_logits_soft_cap, unsigned int head_dim, torch::Tensor empty_q_data, + torch::Tensor empty_kv_data, torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer, + torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); + +std::vector BatchDecodeWithPagedKVCacheRun( + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + std::vector plan_info_vec, torch::Tensor q, + std::optional paged_kv_cache, std::optional paged_k_cache, + std::optional paged_v_cache, torch::Tensor paged_kv_indptr, + torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, + std::optional alibi_slopes, unsigned int kv_layout_code, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, + "Single-request decode with KV-Cache operator"); + m.def("batch_decode_with_paged_kv_cache_plan", &BatchDecodeWithPagedKVCachePlan); + m.def("batch_decode_with_paged_kv_cache_run", &BatchDecodeWithPagedKVCacheRun); +} diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu b/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu new file mode 100644 index 000000000..955a6cb1f --- /dev/null +++ b/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +std::vector single_prefill_with_kv_cache( + unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, + std::optional maybe_packed_custom_mask, torch::Tensor tmp, + std::optional maybe_alibi_slopes, unsigned int layout, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse); + +std::vector BatchPrefillWithKVCachePlan( + unsigned int head_dim, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor page_locked_int_workspace_buffer, torch::Tensor qo_indptr, + torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph); + +std::vector BatchPrefillWithRaggedKVCacheRun( + unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, + torch::Tensor k, torch::Tensor v, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, + torch::Tensor kv_indptr, std::optional maybe_qk_indptr, unsigned int layout, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + bool return_lse); + +std::vector BatchPrefillWithPagedKVCacheRun( + unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, + std::optional paged_kv_cache, std::optional paged_k_cache, + std::optional paged_v_cache, std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, + torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, std::optional maybe_qk_indptr, + unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, + "Single-request prefill attention with KV-Cache operator"); + m.def("batch_prefill_with_kv_cache_plan", &BatchPrefillWithKVCachePlan); + m.def("batch_prefill_with_ragged_kv_cache_run", &BatchPrefillWithRaggedKVCacheRun); + m.def("batch_prefill_with_paged_kv_cache_run", &BatchPrefillWithPagedKVCacheRun); +} diff --git a/flashinfer-aot/csrc_aot/pytorch_extension_utils.h b/flashinfer-aot/csrc_aot/pytorch_extension_utils.h new file mode 100644 index 000000000..d7545ce5b --- /dev/null +++ b/flashinfer-aot/csrc_aot/pytorch_extension_utils.h @@ -0,0 +1,275 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include + +#include +#include + +#include "generated/dispatch.inc" + +using namespace flashinfer; + +#ifdef FLASHINFER_ENABLE_BF16 +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::BFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#else +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#endif + +#ifdef FLASHINFER_ENABLE_FP8 +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#else +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + }() +#endif + +#if defined(FLASHINFER_ENABLE_BF16) && defined(FLASHINFER_ENABLE_FP8) +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::BFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#elif defined(FLASHINFER_ENABLE_BF16) +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::BFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#elif defined(FLASHINFER_ENABLE_FP8) +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#else +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#endif + +#define DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_dtype, kv_dtype, c_type_q, c_type_kv, ...) \ + [&]() -> bool { \ + if (kv_dtype == q_dtype) { \ + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_dtype, c_type_q, [&] { \ + using c_type_kv = c_type_q; \ + return __VA_ARGS__(); \ + }); \ + } else { \ + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_dtype, c_type_q, [&] { \ + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(kv_dtype, c_type_kv, \ + [&] { return __VA_ARGS__(); }); \ + }); \ + } \ + }() + +#define _DISPATCH_SWITCH(var_name, cond, ...) \ + [&]() -> bool { \ + switch (cond) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_CASE(case_expr, case_var, ...) \ + case case_expr: { \ + constexpr auto case_var = case_expr; \ + return __VA_ARGS__(); \ + } + +#define DISPATCH_head_dim(expr, const_expr, ...) \ + _DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__)) + +#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("positional encoding mode", expr, \ + _DISPATCH_CASES_pos_encoding_mode(const_expr, __VA_ARGS__)) + +#define DISPATCH_allow_fp16_qk_reduction(expr, const_expr, ...) \ + _DISPATCH_SWITCH("allow_fp16_qk_reduction", expr, \ + _DISPATCH_CASES_allow_fp16_qk_reduction(const_expr, __VA_ARGS__)) + +#define DISPATCH_mask_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) + +#define DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, ...) \ + [&]() -> bool { \ + if (use_logits_soft_cap) { \ + constexpr bool USE_LOGITS_SOFT_CAP = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool USE_LOGITS_SOFT_CAP = false; \ + return __VA_ARGS__(); \ + } \ + }() + +inline void check_shape(const torch::Tensor& a, const torch::Tensor& b, const char* a_name, + const char* b_name) { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", + b.dim()); + for (int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")"); + } +} + +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); +} + +#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \ + TORCH_CHECK(num_qo_heads % num_kv_heads == 0, "num_qo_heads(", num_qo_heads, \ + ") must be divisible by num_kv_heads(", num_kv_heads, ")") + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) + +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) + +inline bool is_float8_tensor(const torch::Tensor& tensor) { + return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || + tensor.scalar_type() == at::ScalarType::Float8_e5m2; +} diff --git a/flashinfer-aot/csrc_aot/single_decode.cu b/flashinfer-aot/csrc_aot/single_decode.cu new file mode 100644 index 000000000..ee62dc921 --- /dev/null +++ b/flashinfer-aot/csrc_aot/single_decode.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "flashinfer/pos_enc.cuh" +#include "pytorch_extension_utils.h" + +namespace flashinfer { + +template +cudaError_t SingleDecodeWithKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp, + cudaStream_t stream); +} // namespace flashinfer + +torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, + torch::Tensor tmp, + std::optional alibi_slopes, + unsigned int layout, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta) { + CHECK_INPUT(q); + CHECK_INPUT(k); + CHECK_INPUT(v); + CHECK_INPUT(tmp); + auto device = q.device(); + CHECK_EQ(k.device(), device); + CHECK_EQ(v.device(), device); + CHECK_EQ(tmp.device(), device); + CHECK_DIM(2, q); + CHECK_DIM(3, k); + CHECK_DIM(3, v); + CHECK_SHAPE(k, v); + CHECK_EQ(q.size(1), k.size(2)); + CHECK_EQ(v.scalar_type(), k.scalar_type()); + unsigned int num_qo_heads = q.size(0); + unsigned int head_dim = q.size(1); + unsigned int kv_len, num_kv_heads; + QKVLayout kv_layout = static_cast(layout); + if (kv_layout == QKVLayout::kNHD) { + kv_len = k.size(0); + num_kv_heads = k.size(1); + } else { + num_kv_heads = k.size(0); + kv_len = k.size(1); + } + CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + auto o = torch::empty_like(q); + + TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); + + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = k.scalar_type(); + + constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; + + DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { + using DTypeQ = q_type; + using DTypeKV = kv_type; + using DTypeO = DTypeQ; + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_LOGITS_SOFT_CAP(logits_soft_cap > 0, USE_LOGITS_SOFT_CAP, [&] { + using ParamsT = SingleDecodeParams; + using AttentionVariant = + ComposedAttention; + ParamsT params(static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), static_cast(o.data_ptr()), + /*alibi_slopes=*/nullptr, kv_len, num_qo_heads, num_kv_heads, kv_layout, + head_dim, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); + cudaError_t status = + flashinfer::SingleDecodeWithKVCacheDispatched( + params, static_cast(tmp.data_ptr()), torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); + }); + + return o; +} diff --git a/flashinfer-aot/csrc_aot/single_prefill.cu b/flashinfer-aot/csrc_aot/single_prefill.cu new file mode 100644 index 000000000..c406ce959 --- /dev/null +++ b/flashinfer-aot/csrc_aot/single_prefill.cu @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include + +#include "pytorch_extension_utils.h" + +namespace flashinfer { + +template +cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp, + cudaStream_t stream); + +} // namespace flashinfer + +std::vector single_prefill_with_kv_cache( + unsigned int mask_mode_code, torch::Tensor q, torch::Tensor k, torch::Tensor v, + std::optional maybe_packed_custom_mask, torch::Tensor tmp, + std::optional maybe_alibi_slopes, unsigned int layout, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + bool return_lse) { + auto device = q.device(); + unsigned int head_dim = q.size(2); + unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; + QKVLayout kv_layout = static_cast(layout); + qo_len = q.size(0); + num_qo_heads = q.size(1); + uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; + if (kv_layout == QKVLayout::kNHD) { + kv_len = k.size(0); + num_kv_heads = k.size(1); + kv_stride_n = k.stride(0); + kv_stride_h = k.stride(1); + } else { + kv_len = k.size(1); + num_kv_heads = k.size(0); + kv_stride_h = k.stride(0); + kv_stride_n = k.stride(1); + } + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + auto o = torch::empty_like(q, q.options()); + torch::Tensor lse = torch::empty({0}); + if (return_lse) { + lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32)); + } + + constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; + const MaskMode mask_mode = static_cast(mask_mode_code); + bool use_logits_soft_cap = logits_soft_cap > 0.f; + + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = k.scalar_type(); + + DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { + using DTypeQ = q_type; + using DTypeKV = kv_type; + using DTypeO = DTypeQ; + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { + using ParamsT = SinglePrefillParams; + using AttentionVariant = + ComposedAttention; + + ParamsT params(static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + maybe_packed_custom_mask.has_value() + ? static_cast(maybe_packed_custom_mask->data_ptr()) + : nullptr, + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, + q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, head_dim, window_left, + logits_soft_cap, sm_scale, rope_scale, rope_theta); + + cudaError_t status = + flashinfer::SinglePrefillWithKVCacheDispatched( + params, static_cast(tmp.data_ptr()), torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SinglePrefillWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); + }); + }); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} diff --git a/flashinfer-aot/flashinfer b/flashinfer-aot/flashinfer new file mode 120000 index 000000000..c5f9b1c7a --- /dev/null +++ b/flashinfer-aot/flashinfer @@ -0,0 +1 @@ +../python/flashinfer \ No newline at end of file diff --git a/python/generate_batch_paged_decode_inst.py b/flashinfer-aot/generate_batch_paged_decode_inst.py similarity index 68% rename from python/generate_batch_paged_decode_inst.py rename to flashinfer-aot/generate_batch_paged_decode_inst.py index 6b98adcdf..efd1945b4 100644 --- a/python/generate_batch_paged_decode_inst.py +++ b/flashinfer-aot/generate_batch_paged_decode_inst.py @@ -20,14 +20,12 @@ pos_encoding_mode_literal, dtype_literal, idtype_literal, - logits_hook_literal, ) from pathlib import Path def get_cu_file_str( head_dim, - logits_hook, pos_encoding_mode, dtype_q, dtype_kv, @@ -38,20 +36,21 @@ def get_cu_file_str( namespace flashinfer {{ -constexpr PageStorage page_storage = PageStorage::kIndices; +using ParamsT = BatchDecodeParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>; -template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, page_storage, {logits_hook}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>( - {dtype_q}* q, {idtype}* q_offset, - paged_kv_t paged_kv, - kv_partition_info_t<{idtype}> kv_partition_info, - {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, - bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads, - int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream); +template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, ComposedAttention>( + ParamsT params, + {dtype_out}* tmp_v, float* tmp_s, + cudaStream_t stream); +template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, ComposedAttention>( + ParamsT params, + {dtype_out}* tmp_v, float* tmp_s, + cudaStream_t stream); }} """.format( - logits_hook=logits_hook_literal[int(logits_hook)], head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], dtype_q=dtype_literal[dtype_q], @@ -64,7 +63,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"batch_paged_decode_head_([0-9]+)_logitshook_([0-9]+)_posenc_([0-9]+)_" + r"batch_paged_decode_head_([0-9]+)_posenc_([0-9]+)_" r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) diff --git a/flashinfer-aot/generate_batch_paged_prefill_inst.py b/flashinfer-aot/generate_batch_paged_prefill_inst.py new file mode 100644 index 000000000..daf37e613 --- /dev/null +++ b/flashinfer-aot/generate_batch_paged_prefill_inst.py @@ -0,0 +1,99 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import sys +import re +import itertools +from literal_map import ( + mask_mode_literal, + pos_encoding_mode_literal, + warp_layout_literal, + dtype_literal, + idtype_literal, +) +from pathlib import Path + + +def get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, + dtype_kv, + dtype_out, + idtype, +): + warp_layout_choice = [0, 1, 2] + + def get_insts(attention_variant, dtype_out): + return "\n".join( + [ + """template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{warp_layout}, {head_dim}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {attention_variant}>( + ParamsT params, + {dtype_out}* tmp_v, + float* tmp_s, cudaStream_t stream); + """.format( + warp_layout=warp_layout_literal[warp_layout], + head_dim=head_dim, + pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], + allow_fp16_qk_reduction=allow_fp16_qk_reduction, + mask_mode=mask_mode_literal[int(mask_mode)], + attention_variant=attention_variant, + dtype_out=dtype_out, + ) + for warp_layout in warp_layout_choice + ] + ) + + use_custom_mask = "true" if int(mask_mode) == 2 else "false" + dtype_q = dtype_literal[dtype_q] + dtype_kv = dtype_literal[dtype_kv] + dtype_out = dtype_literal[dtype_out] + idtype = idtype_literal[idtype] + + content = f"""#include + +namespace flashinfer {{ + +using ParamsT = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>; + +using AttentionVariant1 = ComposedAttention; + +{get_insts("AttentionVariant1", dtype_out)} + +using AttentionVariant2 = ComposedAttention; + +{get_insts("AttentionVariant2", dtype_out)} + +}}""" + return content + + +if __name__ == "__main__": + pattern = ( + r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" + ) + compiled_pattern = re.compile(pattern) + path = Path(sys.argv[1]) + fname = path.name + match = compiled_pattern.match(fname) + + with open(path, "w") as f: + f.write(get_cu_file_str(*match.groups())) diff --git a/python/generate_batch_ragged_prefill_inst.py b/flashinfer-aot/generate_batch_ragged_prefill_inst.py similarity index 59% rename from python/generate_batch_ragged_prefill_inst.py rename to flashinfer-aot/generate_batch_ragged_prefill_inst.py index 2a8c05f5a..6310269a7 100644 --- a/python/generate_batch_ragged_prefill_inst.py +++ b/flashinfer-aot/generate_batch_ragged_prefill_inst.py @@ -22,14 +22,12 @@ warp_layout_literal, dtype_literal, idtype_literal, - logits_hook_literal, ) from pathlib import Path def get_cu_file_str( head_dim, - logits_hook, pos_encoding_mode, allow_fp16_qk_reduction, mask_mode, @@ -39,39 +37,48 @@ def get_cu_file_str( idtype, ): warp_layout_choice = [0, 1, 2] - insts = "\n".join( + def get_insts(attention_variant, dtype_out): + return "\n".join( [ - """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{warp_layout}, {head_dim}, {logits_hook}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>( - {dtype_q}* q, {idtype}* request_indices, {idtype}* q_tile_indices, {idtype}* kv_tile_indices, - {idtype}* q_indptr, {dtype_kv}* k, {dtype_kv}* v, {idtype}* kv_indptr, - uint8_t* custom_mask, {idtype}* qk_indptr, {idtype}* q_offset, {idtype}* k_rope_pos_offset, - {idtype}* o_indptr, {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, {idtype}* merge_indptr, - bool* block_valid_mask, {idtype}* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, - uint32_t padded_batch_size, uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h, - uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream); + """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{warp_layout}, {head_dim}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {attention_variant}>( + ParamsT params, + {dtype_out}* tmp_v, + float* tmp_s, cudaStream_t stream); """.format( warp_layout=warp_layout_literal[warp_layout], - logits_hook=logits_hook_literal[int(logits_hook)], head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, mask_mode=mask_mode_literal[int(mask_mode)], - dtype_q=dtype_literal[dtype_q], - dtype_kv=dtype_literal[dtype_kv], - dtype_out=dtype_literal[dtype_out], - idtype=idtype_literal[idtype], + attention_variant=attention_variant, + dtype_out=dtype_out, ) for warp_layout in warp_layout_choice ] ) + use_custom_mask = "true" if int(mask_mode) == 2 else "false" + + dtype_q = dtype_literal[dtype_q] + dtype_kv = dtype_literal[dtype_kv] + dtype_out = dtype_literal[dtype_out] + idtype = idtype_literal[idtype] + content = f"""#include namespace flashinfer {{ -{insts} +using ParamsT = BatchPrefillRaggedParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>; + +using AttentionVariant1 = ComposedAttention; + +{get_insts("AttentionVariant1", dtype_out)} + +using AttentionVariant2 = ComposedAttention; + +{get_insts("AttentionVariant2", dtype_out)} }} """ @@ -80,7 +87,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"batch_ragged_prefill_head_([0-9]+)_logitshook_([0-9]+)_posenc_([0-9]+)_" + r"batch_ragged_prefill_head_([0-9]+)_posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/generate_dispatch_inc.py b/flashinfer-aot/generate_dispatch_inc.py similarity index 84% rename from python/generate_dispatch_inc.py rename to flashinfer-aot/generate_dispatch_inc.py index a55be1d58..f3ad9db88 100644 --- a/python/generate_dispatch_inc.py +++ b/flashinfer-aot/generate_dispatch_inc.py @@ -20,7 +20,6 @@ pos_encoding_mode_literal, bool_literal, mask_mode_literal, - logits_hook_literal, ) @@ -35,19 +34,6 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: dispatch_head_dims_str = f"""#define _DISPATCH_CASES_head_dim(case_var, ...) \\ {dispatch_head_dims_entries} // EOL -""" - # logits post hooks - dispatch_logits_post_hooks_entries = "\n".join( - [ - " _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format( - logits_hook_literal[_] - ) - for _ in args.logits_post_hooks - ] - ) - dispatch_logits_post_hooks_str = f"""#define _DISPATCH_CASES_logits_post_hook(case_var, ...) \\ -{dispatch_logits_post_hooks_entries} -// EOL """ # positional encoding modes dispatch_pos_encoding_modes_entries = "\n".join( @@ -90,7 +76,6 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: return "\n".join( [ dispatch_head_dims_str, - dispatch_logits_post_hooks_str, dispatch_pos_encoding_modes_str, dispatch_allow_fp16_qk_reductions_str, dispatch_mask_mode_str, @@ -106,13 +91,6 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: parser.add_argument( "--head_dims", type=int, required=True, nargs="+", help="Head dimensions" ) - parser.add_argument( - "--logits_post_hooks", - type=int, - required=True, - nargs="+", - help="Logit post hooks", - ) parser.add_argument( "--pos_encoding_modes", type=int, diff --git a/python/generate_single_decode_inst.py b/flashinfer-aot/generate_single_decode_inst.py similarity index 68% rename from python/generate_single_decode_inst.py rename to flashinfer-aot/generate_single_decode_inst.py index fc57f36c9..754e185f4 100644 --- a/python/generate_single_decode_inst.py +++ b/flashinfer-aot/generate_single_decode_inst.py @@ -19,14 +19,12 @@ from literal_map import ( pos_encoding_mode_literal, dtype_literal, - logits_hook_literal, ) from pathlib import Path def get_cu_file_str( head_dim, - logits_hook, pos_encoding_mode, dtype_q, dtype_kv, @@ -36,15 +34,21 @@ def get_cu_file_str( namespace flashinfer {{ -template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {logits_hook}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( - {dtype_q}* q, {dtype_kv}* k, {dtype_kv}* v, {dtype_out}* o, - {dtype_out}* tmp, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t seq_len, - QKVLayout kv_layout, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream); +using ParamsT = SingleDecodeParams<{dtype_q}, {dtype_kv}, {dtype_out}>; +template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, ComposedAttention>( + ParamsT params, + {dtype_out}* tmp, + cudaStream_t stream); + +template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, ComposedAttention>( + ParamsT params, + {dtype_out}* tmp, + cudaStream_t stream); }} """.format( - logits_hook=logits_hook_literal[int(logits_hook)], head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], dtype_q=dtype_literal[dtype_q], @@ -56,7 +60,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"single_decode_head_([0-9]+)_logitshook_([0-9]+)_posenc_([0-9]+)_" + r"single_decode_head_([0-9]+)_posenc_([0-9]+)_" r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) diff --git a/python/generate_single_prefill_inst.py b/flashinfer-aot/generate_single_prefill_inst.py similarity index 67% rename from python/generate_single_prefill_inst.py rename to flashinfer-aot/generate_single_prefill_inst.py index 3fc6284a5..eb54ed4e5 100644 --- a/python/generate_single_prefill_inst.py +++ b/flashinfer-aot/generate_single_prefill_inst.py @@ -20,14 +20,12 @@ pos_encoding_mode_literal, dtype_literal, mask_mode_literal, - logits_hook_literal, ) from pathlib import Path def get_cu_file_str( head_dim, - logits_hook, pos_encoding_mode, allow_fp16_qk_reduction, mask_mode, @@ -40,15 +38,22 @@ def get_cu_file_str( namespace flashinfer {{ -template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {logits_hook}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( - {dtype_q}* q, {dtype_kv}* k, {dtype_kv}* v, uint8_t* custom_mask, {dtype_out}* o, - {dtype_out}* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, - uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); +using ParamsT = SinglePrefillParams<{dtype_q}, {dtype_kv}, {dtype_out}>; + +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, ComposedAttention>( + ParamsT params, + {dtype_out}* tmp, + cudaStream_t stream); + +template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, ComposedAttention>( + ParamsT params, + {dtype_out}* tmp, + cudaStream_t stream); }} """.format( - logits_hook=logits_hook_literal[int(logits_hook)], head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], allow_fp16_qk_reduction=allow_fp16_qk_reduction, @@ -56,13 +61,14 @@ def get_cu_file_str( dtype_q=dtype_literal[dtype_q], dtype_kv=dtype_literal[dtype_kv], dtype_out=dtype_literal[dtype_out], + use_custom_mask="true" if int(mask_mode) == 2 else "false", ) return content if __name__ == "__main__": pattern = ( - r"single_prefill_head_([0-9]+)_logitshook_([0-9]+)_posenc_([0-9]+)_" + r"single_prefill_head_([0-9]+)_posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) diff --git a/flashinfer-aot/include b/flashinfer-aot/include new file mode 120000 index 000000000..f5030fe88 --- /dev/null +++ b/flashinfer-aot/include @@ -0,0 +1 @@ +../include \ No newline at end of file diff --git a/python/literal_map.py b/flashinfer-aot/literal_map.py similarity index 92% rename from python/literal_map.py rename to flashinfer-aot/literal_map.py index f122ca619..7b1289a9a 100644 --- a/python/literal_map.py +++ b/flashinfer-aot/literal_map.py @@ -20,11 +20,6 @@ 2: "MaskMode::kCustom", } -logits_hook_literal = { - 0: "LogitsPostHook::kNone", - 1: "LogitsPostHook::kSoftCap", -} - warp_layout_literal = { 0: "WarpLayout::k4x1x2", 1: "WarpLayout::k4x1x1", @@ -40,6 +35,7 @@ dtype_literal = { "f16": "half", "bf16": "nv_bfloat16", + "f32": "float", "e4m3": "__nv_fp8_e4m3", "e5m2": "__nv_fp8_e5m2", } diff --git a/flashinfer-aot/setup.py b/flashinfer-aot/setup.py new file mode 100644 index 000000000..396cf334d --- /dev/null +++ b/flashinfer-aot/setup.py @@ -0,0 +1,431 @@ +""" +Copyright (c) 2023 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import List, Tuple + +import pathlib +import os +import re +import itertools +import subprocess +import platform + +import setuptools +import argparse +import torch +import torch.utils.cpp_extension as torch_cpp_ext +from collections import namedtuple + +import generate_single_decode_inst, generate_single_prefill_inst, generate_batch_paged_decode_inst, generate_batch_paged_prefill_inst, generate_batch_ragged_prefill_inst, generate_dispatch_inc + + +root = pathlib.Path(__name__).parent + + +# cuda arch check for fp8 at the moment. +for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): + arch = int(re.search("compute_\d+", cuda_arch_flags).group()[-2:]) + if arch < 75: + raise RuntimeError("FlashInfer requires sm75+") + +enable_bf16 = os.environ.get("FLASHINFER_ENABLE_BF16", "1") == "1" +enable_fp8 = os.environ.get("FLASHINFER_ENABLE_FP8", "1") == "1" + +if enable_bf16: + torch_cpp_ext.COMMON_NVCC_FLAGS.append("-DFLASHINFER_ENABLE_BF16") +if enable_fp8: + torch_cpp_ext.COMMON_NVCC_FLAGS.append("-DFLASHINFER_ENABLE_FP8") + + +def write_if_different(path: pathlib.Path, content: str) -> None: + if path.exists(): + with open(path, "r") as f: + if f.read() == content: + return + with open(path, "w") as f: + f.write(content) + + +def get_instantiation_cu() -> Tuple[List[str], List[str], List[str]]: + path = root / "csrc_aot" / "generated" + path.mkdir(parents=True, exist_ok=True) + + head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",") + pos_encoding_modes = os.environ.get("FLASHINFER_POS_ENCODING_MODES", "0").split(",") + allow_fp16_qk_reduction_options = os.environ.get( + "FLASHINFER_ALLOW_FP16_QK_REDUCTION_OPTIONS", "0" + ).split(",") + mask_modes = os.environ.get("FLASHINFER_MASK_MODES", "0,1,2").split(",") + + head_dims = list(map(int, head_dims)) + pos_encoding_modes = list(map(int, pos_encoding_modes)) + allow_fp16_qk_reduction_options = list(map(int, allow_fp16_qk_reduction_options)) + mask_modes = list(map(int, mask_modes)) + # dispatch.inc + write_if_different( + path / "dispatch.inc", + generate_dispatch_inc.get_dispatch_inc_str( + argparse.Namespace( + head_dims=head_dims, + pos_encoding_modes=pos_encoding_modes, + allow_fp16_qk_reductions=allow_fp16_qk_reduction_options, + mask_modes=mask_modes, + ) + ), + ) + + idtypes = ["i32"] + prefill_dtypes = ["f16"] + decode_dtypes = ["f16"] + fp16_dtypes = ["f16"] + fp8_dtypes = ["e4m3", "e5m2"] + if enable_bf16: + prefill_dtypes.append("bf16") + decode_dtypes.append("bf16") + fp16_dtypes.append("bf16") + if enable_fp8: + decode_dtypes.extend(fp8_dtypes) + + files_decode = [] + files_prefill = [] + single_decode_uris = [] + # single decode files + for ( + head_dim, + pos_encoding_mode, + ) in itertools.product( + head_dims, + pos_encoding_modes, + ): + for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list( + itertools.product(fp16_dtypes, fp8_dtypes) + ): + dtype_out = dtype_q + fname = f"single_decode_head_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" + files_decode.append(str(path / fname)) + content = generate_single_decode_inst.get_cu_file_str( + head_dim, + pos_encoding_mode, + dtype_q, + dtype_kv, + dtype_out, + ) + for use_sliding_window in [True, False]: + for use_logits_soft_cap in [True, False]: + single_decode_uris.append( + f"single_decode_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_out}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}" + ) + write_if_different(path / fname, content) + + # batch decode files + batch_decode_uris = [] + for ( + head_dim, + pos_encoding_mode, + ) in itertools.product( + head_dims, + pos_encoding_modes, + ): + for idtype in idtypes: + for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list( + itertools.product(fp16_dtypes, fp8_dtypes) + ): + dtype_out = dtype_q + fname = f"batch_paged_decode_head_{head_dim}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu" + files_decode.append(str(path / fname)) + content = generate_batch_paged_decode_inst.get_cu_file_str( + head_dim, + pos_encoding_mode, + dtype_q, + dtype_kv, + dtype_out, + idtype, + ) + for use_sliding_window in [True, False]: + for use_logits_soft_cap in [True, False]: + batch_decode_uris.append( + f"batch_decode_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_out}_" + f"dtype_idx_{idtype}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}" + ) + write_if_different(path / fname, content) + + # single prefill files + single_prefill_uris = [] + for ( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + ) in itertools.product( + head_dims, + pos_encoding_modes, + allow_fp16_qk_reduction_options, + mask_modes, + ): + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list( + itertools.product(prefill_dtypes, fp8_dtypes) + ): + fname = f"single_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu" + files_prefill.append(str(path / fname)) + content = generate_single_prefill_inst.get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + ) + for use_sliding_window in [True, False]: + for use_logits_soft_cap in [True, False]: + single_prefill_uris.append( + f"single_prefill_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_q}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"mask_{mask_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}_" + f"f16qk_{bool(allow_fp16_qk_reduction)}" + ) + write_if_different(path / fname, content) + + # batch prefill files + batch_prefill_uris = [] + for ( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + idtype, + ) in itertools.product( + head_dims, + pos_encoding_modes, + allow_fp16_qk_reduction_options, + mask_modes, + idtypes, + ): + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list( + itertools.product(prefill_dtypes, fp8_dtypes) + ): + fname = f"batch_paged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" + files_prefill.append(str(path / fname)) + content = generate_batch_paged_prefill_inst.get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + idtype, + ) + write_if_different(path / fname, content) + + fname = f"batch_ragged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" + files_prefill.append(str(path / fname)) + content = generate_batch_ragged_prefill_inst.get_cu_file_str( + head_dim, + pos_encoding_mode, + allow_fp16_qk_reduction, + mask_mode, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + idtype, + ) + write_if_different(path / fname, content) + + for sliding_window in [True, False]: + for logits_soft_cap in [True, False]: + batch_prefill_uris.append( + f"batch_prefill_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_q}_" + f"dtype_idx_{idtype}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"mask_{mask_mode}_" + f"use_swa_{sliding_window}_" + f"use_logits_cap_{logits_soft_cap}_" + f"f16qk_{bool(allow_fp16_qk_reduction)}" + ) + + return ( + files_prefill, + files_decode, + single_decode_uris + + batch_decode_uris + + single_prefill_uris + + batch_prefill_uris, + ) + + +def get_version(): + version = os.getenv("FLASHINFER_BUILD_VERSION") + if version is None: + with open(root / "version.txt") as f: + version = f.read().strip() + return version + + +def get_cuda_version() -> Tuple[int, int]: + if torch_cpp_ext.CUDA_HOME is None: + nvcc = "nvcc" + else: + nvcc = os.path.join(torch_cpp_ext.CUDA_HOME, "bin/nvcc") + txt = subprocess.check_output([nvcc, "--version"], text=True) + major, minor = map(int, re.findall(r"release (\d+)\.(\d+),", txt)[0]) + return major, minor + + +def generate_build_meta() -> None: + d = {} + version = get_version() + d["cuda_major"], d["cuda_minor"] = get_cuda_version() + d["torch"] = torch.__version__ + d["python"] = platform.python_version() + d["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", None) + with open(root / "flashinfer" / "_build_meta.py", "w") as f: + f.write(f"__version__ = {version!r}\n") + f.write(f"build_meta = {d!r}") + + +def generate_aot_config(aot_kernel_uris: List[str]) -> None: + aot_config_str = f"""prebuilt_ops_uri = set({aot_kernel_uris})""" + with open(root / "flashinfer" / "jit" / "aot_config.py", "w") as f: + f.write(aot_config_str) + + +def remove_unwanted_pytorch_nvcc_flags(): + REMOVE_NVCC_FLAGS = [ + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + for flag in REMOVE_NVCC_FLAGS: + try: + torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) + except ValueError: + pass + + +class NinjaBuildExtension(torch_cpp_ext.BuildExtension): + def __init__(self, *args, **kwargs) -> None: + # do not override env MAX_JOBS if already exists + if not os.environ.get("MAX_JOBS"): + max_num_jobs_cores = max(1, os.cpu_count()) + os.environ["MAX_JOBS"] = str(max_num_jobs_cores) + + super().__init__(*args, **kwargs) + + +if __name__ == "__main__": + remove_unwanted_pytorch_nvcc_flags() + generate_build_meta() + files_prefill, files_decode, aot_kernel_uris = get_instantiation_cu() + generate_aot_config(aot_kernel_uris) + include_dirs = [ + str(root.resolve() / "include"), + str(root.resolve() / "3rdparty" / "cutlass" / "include"), # for group gemm + ] + extra_compile_args = { + "cxx": [ + "-O3", + "-Wno-switch-bool", + ], + "nvcc": [ + "-O3", + "-std=c++17", + "--threads", + "1", + "-Xfatbin", + "-compress-all", + "-use_fast_math", + ], + } + ext_modules = [] + ext_modules.append( + torch_cpp_ext.CUDAExtension( + name="flashinfer._kernels", + sources=[ + "csrc/cascade.cu", + "csrc/page.cu", + "csrc/sampling.cu", + "csrc/norm.cu", + "csrc_aot/activation.cu", + "csrc/rope.cu", + "csrc/quantization.cu", + "csrc/group_gemm.cu", + "csrc/bmm_fp8.cu", + "csrc_aot/flashinfer_ops.cu", + ], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args, + ) + ) + ext_modules.append( + torch_cpp_ext.CUDAExtension( + name="flashinfer._decode_kernels", + sources=[ + "csrc_aot/single_decode.cu", + "csrc_aot/flashinfer_ops_decode.cu", + "csrc_aot/batch_decode.cu", + ] + + files_decode, + include_dirs=include_dirs, + extra_compile_args=extra_compile_args, + ) + ) + ext_modules.append( + torch_cpp_ext.CUDAExtension( + name="flashinfer._prefill_kernels", + sources=[ + "csrc_aot/single_prefill.cu", + "csrc_aot/flashinfer_ops_prefill.cu", + "csrc_aot/batch_prefill.cu", + ] + + files_prefill, + include_dirs=include_dirs, + extra_compile_args=extra_compile_args, + ) + ) + setuptools.setup( + name="flashinfer", + version=get_version(), + packages=setuptools.find_packages(), + author="FlashInfer team", + license="Apache License 2.0", + description="FlashInfer: Kernel Library for LLM Serving", + url="https://github.com/flashinfer-ai/flashinfer", + python_requires=">=3.8", + ext_modules=ext_modules, + cmdclass={"build_ext": NinjaBuildExtension}, + ) diff --git a/flashinfer-aot/version.txt b/flashinfer-aot/version.txt new file mode 120000 index 000000000..aa4e5bece --- /dev/null +++ b/flashinfer-aot/version.txt @@ -0,0 +1 @@ +../version.txt \ No newline at end of file diff --git a/include/flashinfer/activation.cuh b/include/flashinfer/activation.cuh index 67bee024d..0fc6715cb 100644 --- a/include/flashinfer/activation.cuh +++ b/include/flashinfer/activation.cuh @@ -25,25 +25,6 @@ namespace flashinfer { namespace activation { -// https://github.com/NVIDIA/FasterTransformer/blob/d21dc02bc5f70bc7dc0d18ba5801ae263565e68e/src/fastertransformer/kernels/activation_kernels.cu#L126-L133 -__device__ __forceinline__ float silu_kernel(const float& val) { - // NOTE(Zihao): use __fdividef might be faster, at the cost of precision - return val / (1.0f + __expf(-val)); -} - -// https://github.com/pytorch/pytorch/blob/f48038527792814b06dafa6d471acb04c837b972/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 -__device__ __forceinline__ float gelu_kernel(const float& val) { - constexpr float kAlpha = M_SQRT1_2; - return val * 0.5f * (1.0f + ::erf(val * kAlpha)); -} - -template -__device__ __forceinline__ T gelu_tanh_kernel(const T& val) { - const float cdf = - 0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val)))); - return val * cdf; -} - template __global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { constexpr uint32_t vec_size = 16 / sizeof(T); diff --git a/include/flashinfer/allocator.h b/include/flashinfer/allocator.h index e3bbb9f6a..7a1e40375 100644 --- a/include/flashinfer/allocator.h +++ b/include/flashinfer/allocator.h @@ -22,16 +22,23 @@ namespace flashinfer { +// create a function that returns T* from base pointer and offset +template +T* GetPtrFromBaseOffset(void* base_ptr, int64_t offset) { + return reinterpret_cast(reinterpret_cast(base_ptr) + offset); +} + struct AlignedAllocator { - void* ptr; - size_t space; - AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {} + void* base_ptr; + void* cur_ptr; + size_t remaining_space; + AlignedAllocator(void* buf, size_t space) : base_ptr(buf), cur_ptr(buf), remaining_space(space) {} template T* aligned_alloc(size_t size, size_t alignment, std::string name) { - if (std::align(alignment, size, ptr, space)) { - T* result = reinterpret_cast(ptr); - ptr = (char*)ptr + size; - space -= size; + if (std::align(alignment, size, cur_ptr, remaining_space)) { + T* result = reinterpret_cast(cur_ptr); + cur_ptr = (char*)cur_ptr + size; + remaining_space -= size; return result; } else { std::ostringstream oss; @@ -41,6 +48,12 @@ struct AlignedAllocator { } return nullptr; } + + size_t aligned_alloc_offset(size_t size, size_t alignment, std::string name) { + return (char*)aligned_alloc(size, alignment, name) - (char*)base_ptr; + } + + size_t num_allocated_bytes() { return (char*)cur_ptr - (char*)base_ptr; } }; } // namespace flashinfer diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index 9d71e7bf1..3e191fc1c 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -33,7 +33,7 @@ using cp_async::SharedMemFillMode; * \brief The CUDA kernel that merges the self-attention state of two index sets A and B. * \tparam vec_size The vector size used in the kernel. * \tparam DTypeIn The data type of v_a and v_b. - * \tparam DTypeOut The data type of v_merged. + * \tparam DTypeO The data type of v_merged. * \param v_a The partial v of index set A. (n, h, d) * \param s_a The logsumexp value of index set A. (n, h) * \param v_b The partial v of index set B. (n, h, d) @@ -44,10 +44,10 @@ using cp_async::SharedMemFillMode; * \param head_dim The dimension of each head. * \note Both s_a and s_b are logsumexp values with base 2. */ -template +template __global__ void MergeStateKernel(DTypeIn* __restrict__ v_a, float* __restrict__ s_a, DTypeIn* __restrict__ v_b, float* __restrict__ s_b, - DTypeOut* __restrict__ v_merged, float* __restrict__ s_merged, + DTypeO* __restrict__ v_merged, float* __restrict__ s_merged, uint32_t num_heads, uint32_t head_dim) { uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t pos = blockIdx.x; @@ -141,7 +141,7 @@ __device__ __forceinline__ void threadblock_sync_state(state_t& st, DT * \brief The CUDA kernel that merges self-attention states of a list of index sets. * \param vec_size The vector size used in the kernel. * \tparam DTypeIn The data type of v. - * \tparam DTypeOut The data type of v_merged. + * \tparam DTypeO The data type of v_merged. * \param v The partial v of index sets. (n, num_index_sets, h, d) * \param s The logsumexp value of index sets. (n, num_index_sets, h) * \param v_merged The merged v of index sets union. (n, h, d) @@ -150,26 +150,26 @@ __device__ __forceinline__ void threadblock_sync_state(state_t& st, DT * \param head_dim The dimension of each head. * \note s are logsumexp values with base 2. */ -template +template __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S, - DTypeOut* __restrict__ v_merged, float* __restrict__ s_merged, + DTypeO* __restrict__ v_merged, float* __restrict__ s_merged, uint32_t num_index_sets, uint32_t num_heads, uint32_t head_dim) { uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t pos = blockIdx.x; uint32_t head_idx = ty; if (num_index_sets == 0) { - vec_t v; - v.fill(DTypeOut(0.f)); + vec_t v; + v.fill(DTypeO(0.f)); v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = -5e4; + s_merged[pos * num_heads + head_idx] = -math::inf; } return; } if (num_index_sets == 1) { - vec_t v; + vec_t v; v.cast_load(V + (pos * num_heads + head_idx) * head_dim + tx * vec_size); v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); if (s_merged != nullptr) { @@ -205,7 +205,7 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S * \tparam bdy The blockDim.y used in the kernel. * \tparam num_smem_stages The number of stages of shared memory used in the kernel. * \tparam DTypeIn The data type of v. - * \tparam DTypeOut The data type of v_merged. + * \tparam DTypeO The data type of v_merged. * \param V The partial v of index sets. (n, num_index_sets, h, d) * \param S The logsumexp value of index sets. (n, num_index_sets, h) * \param v_merged The merged v of index sets union. (n, h, d) @@ -215,9 +215,9 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S * \note s are logsumexp values with base 2. */ template + typename DTypeO> __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, float* __restrict__ S, - DTypeOut* __restrict__ v_merged, + DTypeO* __restrict__ v_merged, float* __restrict__ s_merged, uint32_t num_index_sets, uint32_t num_heads) { uint32_t tx = threadIdx.x, ty = threadIdx.y; @@ -289,7 +289,7 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa * \tparam bdy The blockDim.y used in the kernel. * \tparam num_smem_stages The number of stages of shared memory used in the kernel. * \tparam DTypeIn The data type of v. - * \tparam DTypeOut The data type of v_merged. + * \tparam DTypeO The data type of v_merged. * \param V The partial v of index sets. (nnz, h, d) * \param S The logsumexp value of index sets. (nnz, h) * \param indptr The start offsets of each position in the variable length array. @@ -300,10 +300,10 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa * \note s are logsumexp values with base 2. */ template + typename DTypeO, typename IdType> __global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S, IdType* indptr, - DTypeOut* __restrict__ v_merged, + DTypeO* __restrict__ v_merged, float* __restrict__ s_merged, uint32_t seq_len, uint32_t num_heads) { uint32_t tx = threadIdx.x, ty = threadIdx.y; @@ -324,17 +324,17 @@ __global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__ const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; if (num_index_sets == 0) { - vec_t v; - v.fill(DTypeOut(0.f)); + vec_t v; + v.fill(DTypeO(0.f)); v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = -5e4; + s_merged[pos * num_heads + head_idx] = -math::inf; } continue; } if (num_index_sets == 1) { - vec_t v; + vec_t v; v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); if (s_merged != nullptr) { @@ -395,7 +395,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__ /*! * \brief Merge the self-attention state of two index sets A and B. * \tparam DTypeIn The data type of v_a and v_b. - * \tparam DTypeOut The data type of v_merged. + * \tparam DTypeO The data type of v_merged. * \param v_a The partial v of index set A (n, h, d) * \param s_a The logsumexp value of index set A. (n, h) * \param v_b The partial v of index set B. (n, h, d) @@ -409,8 +409,8 @@ __global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__ * \return status Indicates whether CUDA calls are successful * \note Both s_a and s_b are logsumexp values with base 2. */ -template -cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DTypeOut* v_merged, +template +cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DTypeO* v_merged, float* s_merged, uint32_t seq_len, uint32_t num_heads, uint32_t head_dim, cudaStream_t stream = nullptr) { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { @@ -419,7 +419,7 @@ cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DType uint32_t bdy = num_heads; dim3 nblks(seq_len); dim3 nthrs(bdx, bdy); - auto kernel = MergeStateKernel; + auto kernel = MergeStateKernel; void* args[] = {&v_a, &s_a, &v_b, &s_b, &v_merged, &s_merged, &num_heads, &head_dim}; FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); }); @@ -461,7 +461,7 @@ cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other /*! * \brief Merge self-attention states of a list of index sets. * \tparam DTypeIn The data type of v. - * \tparam DTypeOut The data type of v_merged. + * \tparam DTypeO The data type of v_merged. * \param v The partial v of index sets. (n, num_index_sets, h, d) * \param s The logsumexp value of index sets. (n, num_index_sets, h) * \param v_merged The merged v of index sets union. (n, h, d) @@ -474,8 +474,8 @@ cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other * \return status Indicates whether CUDA calls are successful * \note s are logsumexp values with base 2. */ -template -cudaError_t MergeStates(DTypeIn* v, float* s, DTypeOut* v_merged, float* s_merged, +template +cudaError_t MergeStates(DTypeIn* v, float* s, DTypeO* v_merged, float* s_merged, uint32_t num_index_sets, uint32_t seq_len, uint32_t num_heads, uint32_t head_dim, cudaStream_t stream = nullptr) { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { @@ -487,8 +487,8 @@ cudaError_t MergeStates(DTypeIn* v, float* s, DTypeOut* v_merged, float* s_merge dim3 nblks(seq_len, num_heads); dim3 nthrs(bdx, bdy); constexpr uint32_t num_smem_stages = 4; - auto kernel = MergeStatesLargeNumIndexSetsKernel; + auto kernel = + MergeStatesLargeNumIndexSetsKernel; void* args[] = {&v, &s, &v_merged, &s_merged, &num_index_sets, &num_heads}; uint32_t smem_size = num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + num_threads * sizeof(float); @@ -499,7 +499,7 @@ cudaError_t MergeStates(DTypeIn* v, float* s, DTypeOut* v_merged, float* s_merge uint32_t bdy = num_heads; dim3 nblks(seq_len); dim3 nthrs(bdx, bdy); - auto kernel = MergeStatesKernel; + auto kernel = MergeStatesKernel; void* args[] = {&v, &s, &v_merged, &s_merged, &num_index_sets, &num_heads, &head_dim}; FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); } @@ -507,8 +507,8 @@ cudaError_t MergeStates(DTypeIn* v, float* s, DTypeOut* v_merged, float* s_merge return cudaSuccess; } -template -cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeOut* v_merged, +template +cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeO* v_merged, float* s_merged, uint32_t seq_len, uint32_t num_heads, uint32_t head_dim, cudaStream_t stream = nullptr) { int dev_id = 0; @@ -526,7 +526,7 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp uint32_t smem_size = num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + num_threads * sizeof(float); auto kernel = PersistentVariableLengthMergeStatesKernel; + DTypeIn, DTypeO, IdType>; FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, num_threads, smem_size)); num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms)); diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 73434b989..bd4497bb4 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -35,7 +35,6 @@ #include "../utils.cuh" #include "../vec_dtypes.cuh" #include "cascade.cuh" -#include "logits_post_hook.cuh" #include "state.cuh" namespace flashinfer { @@ -48,7 +47,6 @@ namespace { /*! * \brief Load k tile from smem and compute qk - * \tparam logits_post_hook The logits post hook used in the kernel * \tparam pos_encoding_mode The positional encoding mode used in the kernel * \tparam head_dim A template integer indicates the head dimension * \tparam vec_size A template integer indicates the vector size @@ -62,19 +60,18 @@ namespace { * in shared memory of different pipeline stages * \param kv_idx A integer indicates the thread-local kv position in kv-cache * \param compute_stage_idx A integer indicates the compute stage index in the pipeline - * \param sm_scale A float indicates the scale applied to pre-softmax logits * \param s A float indicates the thread-local result of qk * \param st The self-attention state to be updated */ -template -__device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage_idx, - const vec_t& q_vec, +template +__device__ __forceinline__ void compute_qk(const typename AttentionVariant::ParamsT& params, + AttentionVariant variant, const uint32_t batch_idx, + const T* smem, const vec_t& q_vec, const vec_t& freq, uint32_t kv_idx_base, - uint32_t iter_base, uint32_t left_close_bound, - uint32_t iter_bound, const int32_t q_offset, - float alibi_slope, float* s, state_t& st, - const float logits_soft_cap) { + uint32_t iter_base, uint32_t iter_bound, + uint32_t qo_head_idx, uint32_t kv_head_idx, float* s, + state_t& st) { uint32_t tx = threadIdx.x, tz = threadIdx.z; float m_prev = st.m; #pragma unroll @@ -97,12 +94,12 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) { s[j] += math::shfl_xor_sync(s[j], offset); } - s[j] = apply_logits_post_hook(s[j], logits_soft_cap); const uint32_t pos = kv_idx_base + tz * tile_size + j; - s[j] = (iter_base + tz * tile_size + j < iter_bound && pos >= left_close_bound) ? s[j] : -5e4; - if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { - s[j] += alibi_slope * float(int(pos) - q_offset); - } + s[j] = variant.LogitsTransform(params, s[j], batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, + qo_head_idx, kv_head_idx); + bool mask = variant.LogitsMask(params, batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx, + kv_head_idx); + s[j] = (iter_base + tz * tile_size + j < iter_bound && mask) ? s[j] : -math::inf; st.m = max(st.m, s[j]); } @@ -182,20 +179,17 @@ __device__ __forceinline__ void sync_state(state_t& st, float* smem, f /*! * \brief FlashAttention decoding cuda kernel with kv-cache for a single request - * \tparam logits_post_hook The logits post hook used in the kernel * \tparam pos_encoding_mode The positional encoding mode * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension * \tparam bdy A template integer indicates the block size in y dimension * \tparam DTypeQ A template type indicates the query data type * \tparam DTypeKV A template type indicates the key-value data type - * \tparam DTypeOut A template type indicates the output data type + * \tparam DTypeO A template type indicates the output data type * \param q [num_qo_heads, head_dim] The query matrix * \param k [seq_len, num_kv_heads, head_dim] The key matrix in kv-cache * \param v [seq_len, num_kv_heads, head_dim] The value matrix in kv-cache * \param o [num_qo_heads, head_dim] The output matrix - * \param info The tensor info of k/v matrices - * \param sm_scale A float indicates the scale applied to pre-softmax logits * \param head_dim A integer indicates the head dimension * \param rope_rcp_scale A floating number indicate the reciprocal * of scaling ratio used in PI(Position Interpolation) for RoPE (Rotary @@ -204,31 +198,34 @@ __device__ __forceinline__ void sync_state(state_t& st, float* smem, f * of "theta" used in RoPE (Rotary Positional Embeddings) * \param kv_chunk_size A integer indicates the kv-chunk size */ -template -__global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, - DTypeKV* __restrict__ v, DTypeOut* __restrict__ o, - float* __restrict__ lse, tensor_info_t info, - int32_t window_left, float logits_soft_cap, - float sm_scale, float rope_rcp_scale, - float rope_rcp_theta, uint32_t kv_chunk_size) { +template +__global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ + typename AttentionVariant::ParamsT params) { + using DTypeQ = typename AttentionVariant::DTypeQ; + using DTypeKV = typename AttentionVariant::DTypeKV; + using DTypeO = typename AttentionVariant::DTypeO; + const DTypeQ* q = params.q; + const DTypeKV* k = params.k; + const DTypeKV* v = params.v; + DTypeO* o = params.o; + float* lse = params.lse; + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; + uint32_t kv_chunk_size = params.kv_chunk_size; + auto block = cg::this_thread_block(); auto grid = cg::this_grid(); - sm_scale *= - (logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap)); constexpr uint32_t head_dim = bdx * vec_size; uint32_t kv_head_idx = blockIdx.y; uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; uint32_t kv_chunk_idx = blockIdx.x; - uint32_t num_qo_heads = info.num_qo_heads; - const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; - uint32_t seq_len = info.kv_len; - uint32_t left_close_bound = - (window_left >= 0) ? sub_if_greater_or_zero(seq_len, window_left + 1) : 0; + uint32_t num_qo_heads = params.num_qo_heads; extern __shared__ uint8_t smem[]; + AttentionVariant variant(params, /*batch_idx=*/0, smem); + const uint32_t seq_len = variant.kv_len; DTypeKV* k_smem = (DTypeKV*)smem; DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * sizeof(DTypeKV)); @@ -246,16 +243,16 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); } // apply rotary embedding to q matrix - q_vec = vec_apply_llama_rope(q + info.get_q_elem_offset(0, qo_head_idx, 0), freq, - seq_len - 1); + q_vec = vec_apply_llama_rope(q + params.get_q_elem_offset(0, qo_head_idx, 0), + freq, seq_len - 1); } else { // do not apply rotary embedding to q matrix - q_vec.cast_load(q + info.get_q_elem_offset(0, qo_head_idx, tx * vec_size)); + q_vec.cast_load(q + params.get_q_elem_offset(0, qo_head_idx, tx * vec_size)); } // multiple q_vec by sm_scale #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { - q_vec[i] *= sm_scale; + q_vec[i] = variant.QueryTransform(params, q_vec[i]); } block.sync(); @@ -272,7 +269,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ cp_async::pred_load( k_smem + (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - k + info.get_kv_elem_offset( + k + params.get_kv_elem_offset( producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx, tx * vec_size), producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); @@ -282,7 +279,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ cp_async::pred_load( v_smem + (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - v + info.get_kv_elem_offset( + v + params.get_kv_elem_offset( producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx, tx * vec_size), producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); @@ -301,17 +298,18 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ // compute qk cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); - compute_qk( - k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec, - freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, left_close_bound, - kv_chunk_size, seq_len - 1, alibi_slope, s, st_local, logits_soft_cap); + compute_qk( + params, variant, /*batch_idx=*/0, + k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, q_vec, freq, + consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, qo_head_idx, + kv_head_idx, s, st_local); block.sync(); // load k for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { cp_async::pred_load( k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - k + info.get_kv_elem_offset( + k + params.get_kv_elem_offset( producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx, tx * vec_size), producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); @@ -331,7 +329,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ cp_async::pred_load( v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - v + info.get_kv_elem_offset( + v + params.get_kv_elem_offset( producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx, tx * vec_size), producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); @@ -357,16 +355,14 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ /*! * \brief FlashAttention decoding cuda kernel with paged kv-cache for multiple requests - * \tparam logits_post_hook The logits post hook used in the kernel * \tparam pos_encoding_mode The positional encoding mode * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension * \tparam bdy A template integer indicates the block size in y dimension * \tparam bdz A template integer indicates the block size in z dimension - * \tparam page_storage Whether to store indices or pointers of each active page * \tparam DTypeQ A template type indicates the query data type * \tparam DTypeKV A template type indicates the key-value data type - * \tparam DTypeOut A template type indicates the output data type + * \tparam DTypeO A template type indicates the output data type * \tparam IdType A template type indicates the index data type * \param q [batch_size, num_qo_heads, head_dim] The query matrix * \param paged_kv The paged kv-cache data structure @@ -380,47 +376,46 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ * \param rope_rcp_theta A floating number indicate the reciprocal * of "theta" used in RoPE (Rotary Positional Embeddings) */ -template -__global__ void BatchDecodeWithPagedKVCacheKernel( - DTypeQ* __restrict__ q, IdType* __restrict__ q_offset, - paged_kv_t paged_kv, - kv_partition_info_t kv_partition_info, DTypeOut* __restrict__ o, - float* __restrict__ lse, bool* __restrict__ block_valid_mask, bool partition_kv, - int32_t window_left, float logits_soft_cap, float sm_scale, float rope_rcp_scale, - float rope_rcp_theta) { +template +__global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ + typename AttentionVariant::ParamsT params) { auto block = cg::this_thread_block(); - sm_scale *= - (logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap)); + using DTypeQ = typename AttentionVariant::DTypeQ; + using DTypeKV = typename AttentionVariant::DTypeKV; + using DTypeO = typename AttentionVariant::DTypeO; + using IdType = typename AttentionVariant::IdType; + const DTypeQ* q = params.q; + DTypeO* o = params.o; + float* lse = params.lse; + const auto paged_kv = params.paged_kv; + const IdType* q_offset = params.q_offset; + const bool* block_valid_mask = params.block_valid_mask; + const uint32_t padded_batch_size = params.padded_batch_size; + const uint32_t num_qo_heads = params.num_qo_heads; + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; + const bool partition_kv = params.partition_kv; constexpr uint32_t head_dim = bdx * vec_size; - const uint32_t batch_idx = blockIdx.x; - const uint32_t kv_head_idx = blockIdx.y; + const uint32_t bx = blockIdx.x, by = blockIdx.y; + const uint32_t batch_idx = params.request_indices[bx]; + const uint32_t kv_tile_idx = params.kv_tile_indices[bx]; + const uint32_t kv_head_idx = by; const uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; - const uint32_t num_qo_heads = gridDim.y * bdy; - const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; - const uint32_t cur_chunk_start = partition_kv ? kv_partition_info.chunk_start_pos[batch_idx] : 0U; - const uint32_t cur_page_indptr_begin = paged_kv.indptr[batch_idx], - cur_page_indptr_end = paged_kv.indptr[batch_idx + 1]; // NOTE(Zihao): when CUDAGraph is enabled, we will launch more blocks than // the actual batch size, so we need to check if the current batch is valid - if (block_valid_mask && !block_valid_mask[batch_idx]) return; - const uint32_t cur_last_page_len = paged_kv.last_page_len[batch_idx]; - const uint32_t kv_chunk_len = - cur_page_indptr_begin != cur_page_indptr_end - ? (cur_page_indptr_end - cur_page_indptr_begin - 1) * paged_kv.page_size + - cur_last_page_len - : 0; - const uint32_t seq_len = - partition_kv ? kv_partition_info.seq_lens_before_partition[batch_idx] : kv_chunk_len; - const uint32_t left_close_bound = - (window_left >= 0) ? sub_if_greater_or_zero(seq_len, window_left + 1) : 0; - const uint32_t mapped_batch_idx = - partition_kv ? kv_partition_info.batch_idx_map[batch_idx] : batch_idx; + if (block_valid_mask && !block_valid_mask[bx]) return; + const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); + const uint32_t kv_len = paged_kv.get_length(batch_idx); + const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; + const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) : kv_len; + const uint32_t chunk_size = chunk_end - chunk_start; extern __shared__ uint8_t smem[]; + AttentionVariant variant(params, batch_idx, smem); DTypeKV* k_smem = (DTypeKV*)smem; DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * sizeof(DTypeKV)); @@ -432,8 +427,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; vec_t q_vec; vec_t freq; - int32_t q_offset_val = q_offset == nullptr ? (seq_len - 1) : q_offset[mapped_batch_idx]; - if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + int32_t q_offset_val = q_offset == nullptr ? (kv_len - 1) : q_offset[batch_idx]; + if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { freq[i] = rope_rcp_scale * @@ -442,32 +437,30 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( } // apply rotary embedding to q matrix q_vec = vec_apply_llama_rope( - q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, q_offset_val); + q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, q_offset_val); } else { // do not apply rotary embedding to q matrix - q_vec.cast_load(q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); + q_vec.cast_load(q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); } #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { - q_vec[i] *= sm_scale; + q_vec[i] = variant.QueryTransform(params, q_vec[i]); } block.sync(); // preload k/v tiles uint32_t stage_idx = 0; constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; - // NOTE(Zihao): when CUDAGraph is disabled, gridDim.x = batch_size, otherwise, - // we guarantee that indptr array length is greater than or equal to batch_size + 1, - // so we can safely access paged_kv.indptr[batch_idx + 1] - const IdType last_indptr = paged_kv.indptr[gridDim.x]; + const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; static_assert(num_stages_smem <= bdx); + uint32_t packed_page_iter_base = paged_kv.indptr[batch_idx] * paged_kv.page_size + chunk_start; #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { uint32_t q, r; - paged_kv.page_size.divmod(((j * bdz + tz) * bdy + ty) * bdx + tx, q, r); + paged_kv.page_size.divmod(packed_page_iter_base + ((j * bdz + tz) * bdy + ty) * bdx + tx, q, r); k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = - paged_kv.protective_get_k_ptr(cur_page_indptr_begin + q, kv_head_idx, r, 0, last_indptr); + paged_kv.protective_get_k_ptr(q, kv_head_idx, r, 0, last_indptr); } block.sync(); @@ -484,7 +477,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( cp_async::pred_load( k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - k_ptrs[j], ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < kv_chunk_len); + k_ptrs[j], ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } cp_async::commit_group(); #pragma unroll @@ -493,7 +486,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( cp_async::pred_load( v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - v_ptr, ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < kv_chunk_len); + v_ptr, ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } cp_async::commit_group(); stage_idx = (stage_idx + 1) % num_stages_smem; @@ -503,28 +496,28 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( float s[bdy * tile_size_per_bdx]; #pragma unroll 2 - for (uint32_t iter = 0; iter < ceil_div(kv_chunk_len, tile_size_per_bdx * bdy * bdz); ++iter) { + for (uint32_t iter = 0; iter < ceil_div(chunk_size, tile_size_per_bdx * bdy * bdz); ++iter) { if ((iter + num_stages_smem) % bdx == 0) { #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { uint32_t q, r; - paged_kv.page_size.divmod(((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz + - ((j * bdz + tz) * bdy + ty) * bdx + tx), - q, r); - k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = paged_kv.protective_get_k_ptr( - cur_page_indptr_begin + q, kv_head_idx, r, 0, last_indptr); + paged_kv.page_size.divmod( + packed_page_iter_base + ((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz + + ((j * bdz + tz) * bdy + ty) * bdx + tx), + q, r); + k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = + paged_kv.protective_get_k_ptr(q, kv_head_idx, r, 0, last_indptr); } } // compute qk cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); - compute_qk( - k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec, - freq, - (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) + - cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz, - iter * tile_size_per_bdx * bdy * bdz, left_close_bound, kv_chunk_len, q_offset_val, - alibi_slope, s, st, logits_soft_cap); + compute_qk( + params, variant, batch_idx, + k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, q_vec, freq, + (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[batch_idx]) + + chunk_start + iter * tile_size_per_bdx * bdy * bdz, + iter * tile_size_per_bdx * bdy * bdz, chunk_size, qo_head_idx, kv_head_idx, s, st); block.sync(); #pragma unroll @@ -542,8 +535,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, k_ptrs[j], - (((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < - kv_chunk_len); + (((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } cp_async::commit_group(); @@ -562,8 +554,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, v_ptr, - (((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < - kv_chunk_len); + (((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } cp_async::commit_group(); stage_idx = (stage_idx + 1) % num_stages_smem; @@ -576,10 +567,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( st.normalize(); if (tz == 0) { - st.o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); + st.o.cast_store(o + (bx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); // write lse if (lse != nullptr) { - lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse(); + lse[bx * num_qo_heads + qo_head_idx] = st.get_lse(); } } } @@ -605,7 +596,7 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo * \brief FlashAttention decoding with kv-cache for a single request * \tparam DTypeQ A template type indicates the query data type * \tparam DTypeKV A template type indicates the key-value data type - * \tparam DTypeOut A template type indicates the output data type + * \tparam DTypeO A template type indicates the output data type * \param q The query matrix, shape: [num_qo_heads, head_dim] * \param k The key matrix in kv-cache, shape: [seq_len, num_kv_heads, head_dim] * for NHD layout, [num_kv_heads, seq_len, head_dim] for HND layout @@ -623,17 +614,17 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo * \param stream The cuda stream to launch the kernel * \return status Indicates whether CUDA calls are successful */ -template -cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, - DTypeOut* tmp, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t seq_len, - QKVLayout kv_layout, int32_t window_left, - float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, +template +cudaError_t SingleDecodeWithKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp, cudaStream_t stream) { - const float rope_rcp_scale = 1.f / rope_scale; - const float rope_rcp_theta = 1.f / rope_theta; + using DTypeQ = typename AttentionVariant::DTypeQ; + using DTypeKV = typename AttentionVariant::DTypeKV; + using DTypeO = typename AttentionVariant::DTypeO; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint32_t seq_len = params.kv_len; + constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t bdx = HEAD_DIM / vec_size; auto compute_capacity = GetCudaComputeCapability(); @@ -643,34 +634,22 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, constexpr uint32_t num_threads = std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeKV)), bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); - tensor_info_t info(1, seq_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 8U) : 1U; DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { const uint32_t smem_size = 2U * NUM_STAGES_SMEM * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) + 2U * bdy * bdz * sizeof(float); - auto kernel = SingleDecodeWithKVCacheKernel; + auto kernel = + SingleDecodeWithKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); if (seq_len <= 256 || tmp == nullptr) { // no need to use partition-kv kernel dim3 nblks = dim3(1, num_kv_heads); dim3 nthrs = dim3(bdx, bdy, bdz); - float* lse = nullptr; - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&o, - (void*)&lse, - (void*)&info, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta, - (void*)&seq_len}; + params.kv_chunk_size = seq_len; + void* args[] = {(void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { @@ -695,18 +674,11 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, } dim3 nthrs = dim3(bdx, bdy, bdz); float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM); - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&tmp, - (void*)&tmp_lse, - (void*)&info, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta, - (void*)&kv_chunk_size}; + auto o = params.o; + params.o = tmp; + params.lse = tmp_lse; + params.kv_chunk_size = kv_chunk_size; + void* args[] = {(void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); FLASHINFER_CUDA_CALL( @@ -717,18 +689,17 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, return cudaSuccess; } -template -cudaError_t BatchDecodeWithPagedKVCacheDispatched( - DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, - kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, - float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads, - int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream) { - const float rope_rcp_scale = 1.f / rope_scale; - const float rope_rcp_theta = 1.f / rope_theta; - const uint32_t num_kv_heads = paged_kv.num_heads; +template +cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream) { + using DTypeQ = typename AttentionVariant::DTypeQ; + using DTypeKV = typename AttentionVariant::DTypeKV; + using DTypeO = typename AttentionVariant::DTypeO; + using IdType = typename AttentionVariant::IdType; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.paged_kv.num_heads; + const uint32_t padded_batch_size = params.padded_batch_size; constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); auto compute_capacity = GetCudaComputeCapability(); @@ -745,55 +716,34 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); auto kernel = - BatchDecodeWithPagedKVCacheKernel; + BatchDecodeWithPagedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + if (tmp_v == nullptr) { // do not use partition-kv kernel - bool partition_kv = false; dim3 nblks(padded_batch_size, num_kv_heads); dim3 nthrs(bdx, bdy, bdz); - - void* args[] = {(void*)&q, - (void*)&q_offset, - (void*)&paged_kv, - (void*)&kv_partition_info, - (void*)&o, - (void*)&lse, - (void*)&block_valid_mask, - (void*)&partition_kv, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta}; + params.partition_kv = false; + void* args[] = {(void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { // use partition-kv kernel - bool partition_kv = true; - void* args[] = {(void*)&q, - (void*)&q_offset, - (void*)&paged_kv, - (void*)&kv_partition_info, - (void*)&tmp_v, - (void*)&tmp_s, - (void*)&block_valid_mask, - (void*)&partition_kv, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta}; + params.partition_kv = true; + auto o = params.o; + auto lse = params.lse; + params.o = tmp_v; + params.lse = tmp_s; + void* args[] = {(void*)¶ms}; dim3 nblks(padded_batch_size, num_kv_heads); dim3 nthrs(bdx, bdy, bdz); FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - FLASHINFER_CUDA_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, kv_partition_info.chunk_indptr, o, lse, - kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream)); + FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse, + params.paged_kv.batch_size, num_qo_heads, + HEAD_DIM, stream)); } }); }); diff --git a/include/flashinfer/attention/decode_params.cuh b/include/flashinfer/attention/decode_params.cuh new file mode 100644 index 000000000..bb988b533 --- /dev/null +++ b/include/flashinfer/attention/decode_params.cuh @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_DECODE_PARAMS_CUH_ +#define FLASHINFER_DECODE_PARAMS_CUH_ + +#include + +#include + +#include "../layout.cuh" +#include "../page.cuh" + +namespace flashinfer { + +template +struct DecodeParamsBase { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + DTypeQ* q; + DTypeO* o; + float* lse; + float sm_scale; +}; + +template +struct SingleDecodeParams : public DecodeParamsBase { + using IdType = int32_t; + DTypeKV* k; + DTypeKV* v; + float* alibi_slopes; + uint32_t kv_len; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t kv_stride_n; + uint32_t kv_stride_h; + uint32_t head_dim; + int32_t window_left; + float logits_soft_cap; + float rope_rcp_scale; + float rope_rcp_theta; + uint32_t kv_chunk_size; + + __device__ __host__ SingleDecodeParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeO* o, + float* alibi_slopes, uint32_t seq_len, + uint32_t num_qo_heads, uint32_t num_kv_heads, + QKVLayout kv_layout, uint32_t head_dim, + int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta) + : DecodeParamsBase{q, o, /*lse=*/nullptr, sm_scale}, + k(k), + v(v), + alibi_slopes(alibi_slopes), + kv_len(seq_len), + num_qo_heads(num_qo_heads), + num_kv_heads(num_kv_heads), + q_stride_n(num_qo_heads * head_dim), + q_stride_h(head_dim), + kv_stride_n((kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim : head_dim), + kv_stride_h((kv_layout == QKVLayout::kNHD) ? head_dim : seq_len * head_dim), + head_dim(head_dim), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + rope_rcp_scale(1.f / rope_scale), + rope_rcp_theta(1.f / rope_theta), + kv_chunk_size(0) {} + + __host__ __device__ __forceinline__ size_t get_q_elem_offset(uint32_t qo_idx, + uint32_t qo_head_idx, + uint32_t feat_idx) const { + return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, q_stride_n, q_stride_h); + } + + __host__ __device__ __forceinline__ size_t get_o_elem_offset(uint32_t qo_idx, + uint32_t qo_head_idx, + uint32_t feat_idx) const { + return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, num_qo_heads * head_dim, head_dim); + } + + __host__ __device__ __forceinline__ size_t get_kv_elem_offset(uint32_t kv_idx, + uint32_t kv_head_idx, + uint32_t feat_idx) const { + return get_elem_offset_impl(kv_idx, kv_head_idx, feat_idx, kv_stride_n, kv_stride_h); + } + + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { return 1; } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_len; + } +}; + +template +struct BatchDecodeParams : public DecodeParamsBase { + using IdType = IdType_; + + IdType* q_offset; + paged_kv_t paged_kv; + float* alibi_slopes; + uint32_t padded_batch_size; + uint32_t num_qo_heads; + int32_t window_left; + float logits_soft_cap; + float rope_rcp_scale; + float rope_rcp_theta; + + IdType* request_indices; + IdType* kv_tile_indices; + IdType* o_indptr; + IdType* kv_chunk_size_ptr; + bool* block_valid_mask; + bool partition_kv; + + __device__ __host__ BatchDecodeParams(DTypeQ* q, IdType* q_offset, + paged_kv_t paged_kv, DTypeO* o, float* lse, + float* alibi_slopes, uint32_t num_qo_heads, + int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta) + : DecodeParamsBase{q, o, lse, sm_scale}, + q_offset(q_offset), + paged_kv(paged_kv), + alibi_slopes(alibi_slopes), + padded_batch_size(0), + num_qo_heads(num_qo_heads), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + rope_rcp_scale(1.f / rope_scale), + rope_rcp_theta(1.f / rope_theta), + request_indices(nullptr), + kv_tile_indices(nullptr), + o_indptr(nullptr), + kv_chunk_size_ptr(nullptr), + block_valid_mask(nullptr), + partition_kv(false) {} + + __host__ __device__ __forceinline__ int32_t get_qo_len(int32_t batch_idx) const { return 1; } + + __host__ __device__ __forceinline__ int32_t get_kv_len(int32_t batch_idx) const { + return paged_kv.get_length(batch_idx); + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_DECODE_PARAMS_CUH_ diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh deleted file mode 100644 index e29b99c49..000000000 --- a/include/flashinfer/attention/handler.cuh +++ /dev/null @@ -1,871 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef FLASHINFER_ATTENTION_HANDLER_CUH_ -#define FLASHINFER_ATTENTION_HANDLER_CUH_ - -#include -#include - -#include -#include -#include -#include -#include - -#include "../allocator.h" -#include "../page.cuh" -#include "../pos_enc.cuh" -#include "../utils.cuh" -#include "logits_post_hook.cuh" -#include "warp_layout.cuh" - -namespace flashinfer { - -template -__global__ void BatchDecodeWithPagedKVCacheKernel( - DTypeQ* __restrict__ q, IdType* __restrict__ q_offset, - paged_kv_t paged_kv, - kv_partition_info_t kv_partition_info, DTypeOut* __restrict__ o, - float* __restrict__ lse, bool* __restrict__ block_valid_mask, bool partition_kv, - int maybe_window_left, float logits_soft_cap, float sm_scale, float rope_rcp_scale, - float rope_rcp_theta); - -/*! - * \brief Compute the maximum number of pages per batch and the new batch size - * after we partition Paged KV-Cache into multiple chunks on KV sequence length - * dimension. - * \tparam IdType A template type indicates the index data type - * \param max_grid_size The maximum grid size of the kernel - * \param num_kv_heads The number of KV heads - * \param num_pages The number of pages per request in the batch - * \param max_num_pages_per_batch_lb The pre-set lower bound of maximum number of - * pages per batch, default to 1 - * \return (max_num_pages_per_batch, new_batch_size) The number of pages per batch and - * the new batch size after the partition. - */ -template -std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( - const uint32_t max_grid_size, const uint32_t num_kv_heads, const std::vector& num_pages, - const uint32_t min_num_pages_per_batch = 1) { - uint32_t low = min_num_pages_per_batch, high = 0; - for (const IdType& elem : num_pages) { - high = max(high, elem); - } - uint32_t new_batch_size; - while (low < high) { - uint32_t mid = (low + high) / 2; - new_batch_size = 0; - for (const IdType& elem : num_pages) { - new_batch_size += ceil_div(elem, mid); - } - if (new_batch_size * num_kv_heads > max_grid_size) { - low = mid + 1; - } else { - high = mid; - } - } - new_batch_size = 0; - for (const IdType& elem : num_pages) { - new_batch_size += ceil_div(std::max(elem, 1), low); - } - return {low, new_batch_size}; -} - -inline std::tuple PrefillBinarySearchKVChunkSize( - const uint32_t max_grid_size, const uint32_t num_kv_heads, - const std::vector& packed_qo_len_arr, const std::vector& kv_len_arr, - const uint32_t qo_chunk_size, const uint32_t min_kv_chunk_size = 1) { - int64_t low = min_kv_chunk_size, high = 0; - int64_t batch_size = packed_qo_len_arr.size(); - int64_t max_kv_len = 0; - for (const int64_t& kv_len : kv_len_arr) { - max_kv_len = std::max(max_kv_len, kv_len); - } - high = max_kv_len; - int64_t new_batch_size; - while (low < high) { - int64_t mid = (low + high) / 2; - new_batch_size = 0; - for (uint32_t i = 0; i < batch_size; ++i) { - new_batch_size += - ceil_div(packed_qo_len_arr[i], qo_chunk_size) * ceil_div(kv_len_arr[i], mid); - } - if (new_batch_size * num_kv_heads > max_grid_size) { - low = mid + 1; - } else { - high = mid; - } - } - new_batch_size = 0; - for (uint32_t i = 0; i < batch_size; ++i) { - new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * - ceil_div(std::max(int(kv_len_arr[i]), 1), low); - } - return {low < max_kv_len, low, new_batch_size}; -} - -/*! - * \brief Estimate the temporary buffer size and the maximum grid size for the - * partition-kv BatchDecodeWithPagedKVCache kernel - * \tparam page_storage Whether to store indices or pointers of each active page - * \tparam DTypeKV A template type indicates the key-value data type - * \tparam DTypeOut A template type indicates the output data type - * \tparam IdType A template type indicates the index data type - * \param split_kv Whether to split the KV cache into multiple chunks - * \param max_grid_size The maximum grid size that can be used in a partiton-kv kernel - * \param max_num_pages_per_batch The maximum number of pages per batch - * \param new_batch_size The new batch size after the partition - * \param paged_kv The paged kv cache data structure - * \param num_qo_heads A integer indicates the number of heads of query and output - * \param pos_encoding_mode The positional encoding mode - * \param stream The cuda stream to launch the kernel - * \return status Indicates whether CUDA calls are successful - */ -template -cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( - bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, - uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr_h, const uint32_t num_qo_heads, - const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { - constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); - auto compute_capacity = GetCudaComputeCapability(); - DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { - constexpr uint32_t bdx = HEAD_DIM / vec_size; - static_assert(bdx <= 32); - constexpr uint32_t bdy = GROUP_SIZE; - constexpr uint32_t num_threads = std::max(128U, bdx * bdy); - constexpr uint32_t bdz = num_threads / (bdx * bdy); - constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; - const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; - const uint32_t smem_size = - 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + - std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); - - auto kernel = - BatchDecodeWithPagedKVCacheKernel; - int num_blocks_per_sm = 0; - int num_sm = 0; - int dev_id = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, - num_threads, smem_size)); - max_grid_size = num_blocks_per_sm * num_sm; - if (batch_size * num_kv_heads >= max_grid_size) { - split_kv = false; - new_batch_size = batch_size; - } else { - // compute max_num_pages_per_batch and new_batch_size - std::vector num_pages(batch_size); - for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; - } - std::tie(max_num_pages_per_batch, new_batch_size) = - PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( - max_grid_size, num_kv_heads, num_pages, std::max(128 / page_size, 1U)); - if (new_batch_size == batch_size && !enable_cuda_graph) { - // do not use partition-kv kernel for short sequence, when not using CUDAGraph - split_kv = false; - } else { - // when using CUDAGraph, we always use partition-kv kernel - split_kv = true; - } - } - return cudaSuccess; - }) -} - -/*! - * \brief Partition Paged KV-Cache into multiple chunks on KV sequence length - * \tparam IdType A template type indicates the index data type - * \param old_batch_size The batch size of the old Paged KV-Cache - * \param old_page_indptr_h The host-side page indptr of the old Paged KV-Cache - * \param old_last_page_len_h The host-side last page offset of the old Paged KV-Cache - * \param max_num_pages_per_batch The maximum number of pages per batch - * \param new_paged_kv_d The device-side new Paged KV-Cache - * \param stream The cuda stream to launch the kernel - * \return status Indicates whether CUDA calls are successful - */ -template -cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo( - const uint32_t max_num_pages_per_batch, const uint32_t old_batch_size, - const uint32_t padded_batch_size, const uint32_t page_size, IdType* old_indptr_h, - IdType* old_last_page_len_h, IdType* new_page_indptr_h, IdType* new_last_page_len_h, - IdType* chunk_indptr_h, IdType* batch_idx_map_h, IdType* chunk_start_pos_h, - IdType* seq_lens_before_partition_h, bool* block_valid_mask_h, void* device_buffer, - void* host_buffer, size_t num_bytes_to_copy, cudaStream_t stream = nullptr) { - std::vector new_page_indptr_vec, new_last_page_len_vec, chunk_indptr_vec, - batch_idx_map_vec, chunk_start_pos_vec, seq_lens_before_partition_vec; - std::vector block_valid_mask_vec; - - new_page_indptr_vec.push_back(0); - chunk_indptr_vec.push_back(0); - - for (uint32_t batch_idx = 0; batch_idx < old_batch_size; batch_idx++) { - uint32_t num_chunks = - ceil_div(old_indptr_h[batch_idx + 1] - old_indptr_h[batch_idx], max_num_pages_per_batch); - chunk_indptr_vec.push_back(chunk_indptr_vec.back() + std::max(num_chunks, 1U)); - if (num_chunks == 0) { - new_page_indptr_vec.push_back(old_indptr_h[batch_idx]); - new_last_page_len_vec.push_back(0); - if (block_valid_mask_h != nullptr) { - block_valid_mask_vec.push_back(true); - } - batch_idx_map_vec.push_back(batch_idx); - chunk_start_pos_vec.push_back(0); - seq_lens_before_partition_vec.push_back(0); - } else { - uint32_t seq_len_before_partition = - (old_indptr_h[batch_idx + 1] - old_indptr_h[batch_idx] - 1) * page_size + - old_last_page_len_h[batch_idx]; - for (uint32_t j = 0; j < num_chunks; ++j) { - bool is_last = (j + 1) == num_chunks; - new_page_indptr_vec.push_back( - min(old_indptr_h[batch_idx] + (j + 1) * max_num_pages_per_batch, - old_indptr_h[batch_idx + 1])); - new_last_page_len_vec.push_back(is_last ? old_last_page_len_h[batch_idx] : page_size); - if (block_valid_mask_h != nullptr) { - block_valid_mask_vec.push_back(true); - } - batch_idx_map_vec.push_back(batch_idx); - chunk_start_pos_vec.push_back(j * max_num_pages_per_batch * page_size); - seq_lens_before_partition_vec.push_back(seq_len_before_partition); - } - } - } - IdType last_page_indptr = new_page_indptr_vec.back(); - while (new_page_indptr_vec.size() < padded_batch_size + 1) { - new_page_indptr_vec.push_back(last_page_indptr); - } - std::copy(new_page_indptr_vec.begin(), new_page_indptr_vec.end(), new_page_indptr_h); - std::copy(new_last_page_len_vec.begin(), new_last_page_len_vec.end(), new_last_page_len_h); - std::copy(chunk_indptr_vec.begin(), chunk_indptr_vec.end(), chunk_indptr_h); - std::copy(batch_idx_map_vec.begin(), batch_idx_map_vec.end(), batch_idx_map_h); - std::copy(chunk_start_pos_vec.begin(), chunk_start_pos_vec.end(), chunk_start_pos_h); - std::copy(seq_lens_before_partition_vec.begin(), seq_lens_before_partition_vec.end(), - seq_lens_before_partition_h); - if (block_valid_mask_h != nullptr) { - std::copy(block_valid_mask_vec.begin(), block_valid_mask_vec.end(), block_valid_mask_h); - } - - FLASHINFER_CUDA_CALL(cudaMemcpyAsync(device_buffer, host_buffer, num_bytes_to_copy, - cudaMemcpyHostToDevice, stream)); - return cudaSuccess; -} - -class BatchDecodeHandler { - public: - template - DType* GetTempV() const { - return (DType*)tmp_v_; - } - float* GetTempS() const { return tmp_s_; } - template - IdType* GetNewIndPtr() const { - return (IdType*)new_indptr_; - } - template - IdType* GetNewLastPageLen() const { - return (IdType*)new_last_page_len_; - } - template - IdType* GetChunkIndPtr() const { - return (IdType*)chunk_indptr_; - } - template - IdType* GetBatchIdxMap() const { - return (IdType*)batch_idx_map_; - } - template - IdType* GetChunkStartPos() const { - return (IdType*)chunk_start_pos_; - } - template - IdType* GetSeqLengthsBeforePartition() const { - return (IdType*)seq_lengths_before_partition_; - } - - uint32_t GetPaddedBatchSize() const { return padded_batch_size_; } - - bool* GetBlockValidMask() const { return block_valid_mask_; } - - template - cudaError_t PlanDispatched(void* float_buffer, size_t float_workspace_size_in_bytes, - void* int_buffer, size_t int_workspace_size_in_bytes, IdType* indptr_h, - IdType* last_page_len_h, uint32_t batch_size, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t page_size) { - Clear(); - batch_size_before_partition_ = batch_size; - bool split_kv; - uint32_t max_grid_size, max_num_pages_per_batch, new_batch_size; - DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { - auto work_estimation_func = - BatchDecodeWithPagedKVCacheWorkEstimationDispatched; - FLASHINFER_CUDA_CALL( - work_estimation_func(split_kv, max_grid_size, max_num_pages_per_batch, new_batch_size, - batch_size, indptr_h, num_qo_heads, page_size, - /*enable_cuda_graph=*/IsCUDAGraphEnabled(), stream_)); - batch_size_after_partition_ = new_batch_size; - if (IsCUDAGraphEnabled()) { - if (batch_size != fixed_batch_size_) { - std::ostringstream err_msg; - err_msg << "The running batch size " << batch_size - << " is not compatible with the fixed batch size " << fixed_batch_size_ - << " initialized for CUDAGraph"; - throw std::runtime_error(err_msg.str()); - } - size_t padded_batch_size = max_grid_size / num_kv_heads; - if (split_kv) { - padded_batch_size_ = padded_batch_size; - AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); - tmp_v_ = float_allocator.aligned_alloc( - num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16, - "batch_decode_tmp_v"); - tmp_s_ = float_allocator.aligned_alloc( - num_qo_heads * padded_batch_size * sizeof(float), 16, "batch_decode_tmp_s"); - AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); - new_indptr_ = int_allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), - 16, "batch_decode_new_indptr"); - - void* new_indptr_h_ = page_locked_buffer_; - new_last_page_len_ = int_allocator.aligned_alloc( - padded_batch_size * sizeof(IdType), 16, "batch_decode_new_last_page_len"); - void* new_last_page_len_h_ = - (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); - chunk_indptr_ = int_allocator.aligned_alloc( - (padded_batch_size + 1) * sizeof(IdType), 16, "batch_decode_chunk_indptr"); - void* chunk_indptr_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); - batch_idx_map_ = int_allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16, - "batch_decode_batch_idx_map"); - void* batch_idx_map_h_ = - (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); - chunk_start_pos_ = int_allocator.aligned_alloc(padded_batch_size * sizeof(IdType), - 16, "batch_decode_chunk_start_pos"); - void* chunk_start_pos_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); - seq_lengths_before_partition_ = int_allocator.aligned_alloc( - padded_batch_size * sizeof(IdType), 16, "batch_decode_seq_lengths_before_partition"); - void* seq_lengths_before_partition_h_ = - (char*)page_locked_buffer_ + - ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); - block_valid_mask_ = int_allocator.aligned_alloc( - padded_batch_size * sizeof(bool), 16, "batch_decode_block_valid_mask"); - bool* block_valid_mask_h_ = - (bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_); - std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0); - - size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)new_indptr_; - FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( - max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr_h, - last_page_len_h, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, - (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, - (IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_, - /*device_buffer=*/new_indptr_, - /*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_)); - } else { - block_valid_mask_ = nullptr; - padded_batch_size_ = batch_size; - } - } else { - // NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled. - block_valid_mask_ = nullptr; - // do not pad the batch size when not using CUDAGraph - padded_batch_size_ = batch_size_after_partition_; - if (split_kv) { - AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); - tmp_v_ = float_allocator.aligned_alloc( - num_qo_heads * new_batch_size * HEAD_DIM * sizeof(DTypeOut), 16, - "batch_decode_tmp_v"); - tmp_s_ = float_allocator.aligned_alloc( - num_qo_heads * new_batch_size * sizeof(float), 16, "batch_decode_tmp_s"); - AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); - new_indptr_ = int_allocator.aligned_alloc( - (batch_size_after_partition_ + 1) * sizeof(IdType), 16, "batch_decode_new_indptr"); - void* new_indptr_h_ = page_locked_buffer_; - new_last_page_len_ = int_allocator.aligned_alloc( - batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_new_last_page_len"); - void* new_last_page_len_h_ = - (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); - chunk_indptr_ = int_allocator.aligned_alloc( - (batch_size_before_partition_ + 1) * sizeof(IdType), 16, "batch_decode_chunk_indptr"); - void* chunk_indptr_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); - batch_idx_map_ = int_allocator.aligned_alloc( - batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_batch_idx_map"); - void* batch_idx_map_h_ = - (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); - chunk_start_pos_ = int_allocator.aligned_alloc( - batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_chunk_start_pos"); - void* chunk_start_pos_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); - seq_lengths_before_partition_ = - int_allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16, - "batch_decode_seq_lengths_before_partition"); - void* seq_lengths_before_partition_h_ = - (char*)page_locked_buffer_ + - ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); - size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)new_indptr_; - FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( - max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr_h, - last_page_len_h, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, - (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, - (IdType*)seq_lengths_before_partition_h_, - /*block_valid_mask_h=*/nullptr, - /*device_buffer=*/new_indptr_, - /*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_)); - } - } - }); - return cudaSuccess; - } - - void Clear() { - padded_batch_size_ = 0; - batch_size_before_partition_ = 0; - batch_size_after_partition_ = 0; - block_valid_mask_ = nullptr; - tmp_v_ = nullptr; - tmp_s_ = nullptr; - new_indptr_ = nullptr; - new_last_page_len_ = nullptr; - chunk_indptr_ = nullptr; - batch_idx_map_ = nullptr; - chunk_start_pos_ = nullptr; - seq_lengths_before_partition_ = nullptr; - } - - void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) { - cudaFreeHost(page_locked_buffer_); - cudaMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes); - } - - uint32_t GetBatchSizeBeforePartition() const { return batch_size_before_partition_; } - - uint32_t GetBatchSizeAfterPartition() const { return batch_size_after_partition_; } - - cudaStream_t GetCUDAStream() const { return stream_; } - - void SetCUDAStream(cudaStream_t stream) { stream_ = stream; } - - /*! - * \brief Constructor of BatchDecodeHandler - * \param enable_cuda_graph A boolean indicates whether to enable CUDA graph - * \param batch_size If enable_cuda_graph is true, we must specify a fixed batch_size - */ - BatchDecodeHandler(bool enable_cuda_graph = false, uint32_t batch_size = 0) - : batch_size_after_partition_(0U), - tmp_v_(nullptr), - tmp_s_(nullptr), - block_valid_mask_(nullptr), - new_indptr_(nullptr), - new_last_page_len_(nullptr), - chunk_indptr_(nullptr), - batch_idx_map_(nullptr), - chunk_start_pos_(nullptr), - seq_lengths_before_partition_(nullptr), - cuda_graph_enabled_(enable_cuda_graph), - fixed_batch_size_(batch_size), - stream_(nullptr) { - cudaMallocHost(&page_locked_buffer_, 8 * 1024 * 1024); - } - ~BatchDecodeHandler() { cudaFreeHost(page_locked_buffer_); } - - bool IsCUDAGraphEnabled() const { return cuda_graph_enabled_; } - - protected: - uint32_t batch_size_before_partition_; - uint32_t batch_size_after_partition_; - void* page_locked_buffer_; - void* tmp_v_; - float* tmp_s_; - bool* block_valid_mask_; - void* new_indptr_; - void* new_last_page_len_; - void* chunk_indptr_; - void* batch_idx_map_; - void* chunk_start_pos_; - void* seq_lengths_before_partition_; - bool cuda_graph_enabled_; - uint32_t padded_batch_size_; - uint32_t fixed_batch_size_; - cudaStream_t stream_; -}; - -template -cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_size, - uint32_t& total_num_tiles_q, uint32_t& new_batch_size, - WarpLayout& warp_layout, uint32_t& kv_chunk_size, - uint32_t& total_num_rows, std::vector& request_indices, - std::vector& qo_tile_indices, - std::vector& kv_tile_indices, - std::vector& merge_indptr, std::vector& o_indptr, - IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, - uint32_t page_size) { - request_indices.clear(); - qo_tile_indices.clear(); - kv_tile_indices.clear(); - merge_indptr.clear(); - o_indptr.clear(); - merge_indptr.push_back(0); - o_indptr.push_back(0); - - const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; - total_num_rows = qo_indptr_h[batch_size]; - - // step 0: get the number of SMs - int num_sm = 0; - int dev_id = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - int num_blocks_per_sm = 2; - int max_grid_size = num_blocks_per_sm * num_sm; - split_max_batch_size = max_grid_size / num_kv_heads; - - // step 1: compute qo_chunk_size - std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); - int64_t sum_packed_qo_len = 0; - for (uint32_t i = 0; i < batch_size; ++i) { - packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size); - kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]); - sum_packed_qo_len += packed_qo_len_arr[i]; - } - int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; - if (avg_packed_qo_len > 64 && head_dim < 256) { - warp_layout = WarpLayout::k4x1x2; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 2) - } else { - auto compute_capacity = GetCudaComputeCapability(); - if (compute_capacity.first >= 8) { - // Ampere or newer - if (avg_packed_qo_len > 16) { - warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1) - } else { - // avg_packed_qo_len <= 16 - warp_layout = WarpLayout::k1x4x1; // (num_warps_x = 1, num_warps_z = 4, num_frags_x = 1) - } - } else { - // NOTE(Zihao): not enough shared memory on Turing for 1x4x1 layout - warp_layout = WarpLayout::k4x1x1; - } - } - const uint32_t qo_chunk_size = get_num_rows_per_cta(warp_layout); - - // step 2: determine kv_chunk_size - std::tie(split_kv, kv_chunk_size, new_batch_size) = PrefillBinarySearchKVChunkSize( - max_grid_size, num_kv_heads, packed_qo_len_arr, kv_len_arr, qo_chunk_size, - /*min_kv_chunk_size=*/std::max((128 / page_size), 1U)); - - // step 3: split qo_indptr and kv_indptr - total_num_tiles_q = 0; - for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { - int64_t packed_qo_len = packed_qo_len_arr[request_idx], - kv_len = std::max(int(kv_len_arr[request_idx]), 1); - int64_t num_tiles_q = ceil_div(packed_qo_len, qo_chunk_size), - num_tiles_kv = ceil_div(kv_len, kv_chunk_size); - total_num_tiles_q += num_tiles_q; - for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) { - for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) { - request_indices.push_back(request_idx); - qo_tile_indices.push_back(q_tile_idx); - kv_tile_indices.push_back(kv_tile_idx); - } - } - - int64_t qo_len = packed_qo_len / gqa_group_size; - for (uint32_t row = 0; row < qo_len; ++row) { - merge_indptr.push_back(merge_indptr.back() + num_tiles_kv); - } - o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv); - } - - // step 4: multiply kv_chunk_size by page_size - kv_chunk_size *= page_size; - - return cudaSuccess; -} - -class BatchPrefillHandler { - public: - template - IdType* GetRequestIndices() const { - return (IdType*)request_indices_; - } - - template - IdType* GetQOTileIndices() const { - return (IdType*)qo_tile_indices_; - } - - template - IdType* GetKVTileIndices() const { - return (IdType*)kv_tile_indices_; - } - - template - IdType* GetMergeIndptr() const { - return (IdType*)merge_indptr_; - } - - template - IdType* GetOIndptr() const { - return (IdType*)o_indptr_; - } - - template - IdType* GetKVChunkSizePtr() const { - return (IdType*)kv_chunk_size_ptr_; - } - - template - DType* GetTempV() const { - return (DType*)tmp_v_; - } - - bool* GetBlockValidMask() const { return block_valid_mask_; } - - float* GetTempS() const { return tmp_s_; } - - uint32_t GetPaddedBatchSize() const { return padded_batch_size_; } - - WarpLayout GetWarpLayout() const { return warp_layout_; } - - uint32_t GetTotalNumRows() const { return total_num_rows_; } - - void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) { - cudaFreeHost(page_locked_buffer_); - cudaMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes); - } - - template - cudaError_t Plan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, - size_t int_workspace_size_in_bytes, IdType* qo_indptr_h, IdType* kv_indptr_h, - uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim, uint32_t page_size) { - Clear(); - if (num_qo_heads % num_kv_heads != 0) { - std::ostringstream err_msg; - err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " - << num_kv_heads; - throw std::invalid_argument(err_msg.str()); - } - bool split_kv; - uint32_t split_max_batch_size, new_batch_size, total_num_tiles_q, kv_chunk_size; - std::vector request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, - merge_indptr_vec, o_indptr_vec; - FLASHINFER_CUDA_CALL(PrefillSplitQOKVIndptr( - split_kv, split_max_batch_size, total_num_tiles_q, new_batch_size, warp_layout_, - kv_chunk_size, total_num_rows_, request_indices_vec, qo_tile_indices_vec, - kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec, qo_indptr_h, kv_indptr_h, batch_size, - num_qo_heads, num_kv_heads, head_dim, page_size)); - const uint32_t qo_tile_size = get_num_rows_per_cta(warp_layout_); - - if (IsCUDAGraphEnabled()) { - padded_batch_size_ = std::max(split_max_batch_size, total_num_tiles_q); - AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); - request_indices_ = int_allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, - "batch_prefill_request_indices"); - void* request_indices_h_ = page_locked_buffer_; - qo_tile_indices_ = int_allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, - "batch_prefill_qo_tile_indices"); - void* qo_tile_indices_h_ = - (char*)page_locked_buffer_ + ((char*)qo_tile_indices_ - (char*)request_indices_); - kv_tile_indices_ = int_allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, - "batch_prefill_kv_tile_indices"); - void* kv_tile_indices_h_ = - (char*)page_locked_buffer_ + ((char*)kv_tile_indices_ - (char*)request_indices_); - o_indptr_ = int_allocator.aligned_alloc(sizeof(IdType) * (batch_size + 1), 16, - "batch_prefill_o_indptr"); - void* o_indptr_h_ = (char*)page_locked_buffer_ + ((char*)o_indptr_ - (char*)request_indices_); - kv_chunk_size_ptr_ = - int_allocator.aligned_alloc(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); - void* kv_chunk_size_ptr_h_ = - (char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_); - *(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size; - if (total_num_tiles_q < split_max_batch_size) { - // need merge_indptr - merge_indptr_ = int_allocator.aligned_alloc(sizeof(IdType) * (total_num_rows_ + 1), - 16, "batch_prefill_merge_indptr"); - void* merge_indptr_h_ = - (char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_); - std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), (IdType*)merge_indptr_h_); - block_valid_mask_ = int_allocator.aligned_alloc(sizeof(bool) * padded_batch_size_, 16, - "batch_prefill_block_valid_mask"); - bool* block_valid_mask_h_ = - (bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)request_indices_); - for (uint32_t i = 0; i < padded_batch_size_; ++i) { - block_valid_mask_h_[i] = i < new_batch_size; - } - } else { - // total_num_tiles_q >= split_max_batch_size, we don't need to perform the second round at - // all. - merge_indptr_ = nullptr; - block_valid_mask_ = nullptr; - } - std::copy(request_indices_vec.begin(), request_indices_vec.end(), - (IdType*)request_indices_h_); - std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), - (IdType*)qo_tile_indices_h_); - std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), - (IdType*)kv_tile_indices_h_); - std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), (IdType*)o_indptr_h_); - - size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)request_indices_; - FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, page_locked_buffer_, num_bytes_to_copy, - cudaMemcpyHostToDevice, stream_)) - - if (total_num_tiles_q < split_max_batch_size) { - AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); - tmp_v_ = float_allocator.aligned_alloc( - num_qo_heads * split_max_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16, - "batch_prefill_tmp_v"); - tmp_s_ = float_allocator.aligned_alloc( - num_qo_heads * split_max_batch_size * qo_tile_size * sizeof(float), 16, - "batch_prefill_tmp_s"); - } else { - tmp_v_ = nullptr; - tmp_s_ = nullptr; - } - } else { - padded_batch_size_ = new_batch_size; - AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); - request_indices_ = int_allocator.aligned_alloc( - sizeof(IdType) * request_indices_vec.size(), 16, "batch_prefill_request_indices"); - void* request_indices_h_ = page_locked_buffer_; - qo_tile_indices_ = int_allocator.aligned_alloc( - sizeof(IdType) * qo_tile_indices_vec.size(), 16, "batch_prefill_qo_tile_indices"); - void* qo_tile_indices_h_ = - (char*)page_locked_buffer_ + ((char*)qo_tile_indices_ - (char*)request_indices_); - kv_tile_indices_ = int_allocator.aligned_alloc( - sizeof(IdType) * kv_tile_indices_vec.size(), 16, "batch_prefill_kv_tile_indices"); - void* kv_tile_indices_h_ = - (char*)page_locked_buffer_ + ((char*)kv_tile_indices_ - (char*)request_indices_); - if (split_kv) { - // need merge_indptr when split_kv is true - merge_indptr_ = int_allocator.aligned_alloc(sizeof(IdType) * merge_indptr_vec.size(), - 16, "batch_prefill_merge_indptr"); - void* merge_indptr_h_ = - (char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_); - std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), (IdType*)merge_indptr_h_); - } - o_indptr_ = int_allocator.aligned_alloc(sizeof(IdType) * o_indptr_vec.size(), 16, - "batch_prefill_o_indptr"); - void* o_indptr_h_ = (char*)page_locked_buffer_ + ((char*)o_indptr_ - (char*)request_indices_); - kv_chunk_size_ptr_ = - int_allocator.aligned_alloc(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); - void* kv_chunk_size_ptr_h_ = - (char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_); - *(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size; - std::copy(request_indices_vec.begin(), request_indices_vec.end(), - (IdType*)request_indices_h_); - std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), - (IdType*)qo_tile_indices_h_); - std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), - (IdType*)kv_tile_indices_h_); - std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), (IdType*)o_indptr_h_); - size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)request_indices_; - - FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, page_locked_buffer_, num_bytes_to_copy, - cudaMemcpyHostToDevice, stream_)) - - if (split_kv) { - AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); - tmp_v_ = float_allocator.aligned_alloc( - num_qo_heads * new_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16, - "batch_prefill_tmp_v"); - tmp_s_ = float_allocator.aligned_alloc( - num_qo_heads * new_batch_size * qo_tile_size * sizeof(float), 16, - "batch_prefill_tmp_s"); - } else { - tmp_v_ = nullptr; - tmp_s_ = nullptr; - } - - block_valid_mask_ = nullptr; - } - return cudaSuccess; - } - - void Clear() { - request_indices_ = nullptr; - qo_tile_indices_ = nullptr; - kv_tile_indices_ = nullptr; - merge_indptr_ = nullptr; - o_indptr_ = nullptr; - kv_chunk_size_ptr_ = nullptr; - tmp_v_ = nullptr; - tmp_s_ = nullptr; - block_valid_mask_ = nullptr; - total_num_rows_ = 0U; - padded_batch_size_ = 0U; - warp_layout_ = WarpLayout::k4x1x2; - } - - cudaStream_t GetCUDAStream() const { return stream_; } - - void SetCUDAStream(cudaStream_t stream) { stream_ = stream; } - - bool IsCUDAGraphEnabled() const { return enable_cuda_graph_; } - - BatchPrefillHandler(bool enable_cuda_graph = false) - : request_indices_(nullptr), - qo_tile_indices_(nullptr), - kv_tile_indices_(nullptr), - merge_indptr_(nullptr), - o_indptr_(nullptr), - kv_chunk_size_ptr_(nullptr), - tmp_v_(nullptr), - tmp_s_(nullptr), - block_valid_mask_(nullptr), - total_num_rows_(0U), - padded_batch_size_(0U), - warp_layout_(WarpLayout::k4x1x2), - enable_cuda_graph_(enable_cuda_graph), - stream_(nullptr) { - cudaMallocHost(&page_locked_buffer_, 8 * 1024 * 1024); - } - ~BatchPrefillHandler() { cudaFreeHost(page_locked_buffer_); } - - protected: - void* page_locked_buffer_; - void* request_indices_; - void* qo_tile_indices_; - void* kv_tile_indices_; - void* merge_indptr_; - void* o_indptr_; - void* kv_chunk_size_ptr_; - void* tmp_v_; - float* tmp_s_; - bool* block_valid_mask_; - uint32_t total_num_rows_; - uint32_t padded_batch_size_; - WarpLayout warp_layout_; - bool enable_cuda_graph_; - cudaStream_t stream_; -}; - -} // namespace flashinfer -#endif // FLASHINFER_ATTENTION_HANDLER_CUH_ diff --git a/include/flashinfer/attention/logits_post_hook.cuh b/include/flashinfer/attention/logits_post_hook.cuh deleted file mode 100644 index 816b94d7d..000000000 --- a/include/flashinfer/attention/logits_post_hook.cuh +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2024 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef FLASHINFER_ATTENTION_LOGITS_POST_HOOK_CUH_ -#define FLASHINFER_ATTENTION_LOGITS_POST_HOOK_CUH_ - -#include "../math.cuh" - -namespace flashinfer { - -enum class LogitsPostHook { - kNone = 0U, - kSoftCap = 1U, -}; - -/*! - * \brief Grok's logits cap function - * \ref - * https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L864-L865 - */ -__forceinline__ __device__ float logits_soft_cap_impl(float x, const float soft_cap) { - return (soft_cap * math::log2e) * math::tanh(x); -} - -__forceinline__ __device__ half2 logits_soft_cap_impl(half2 x, const float soft_cap) { - return __hmul2(__float2half2_rn(soft_cap * math::log2e), math::tanh(x)); -} - -template -__forceinline__ __device__ T apply_logits_post_hook(T x, const float soft_cap); - -template <> -__forceinline__ __device__ float apply_logits_post_hook( - float x, const float soft_cap) { - return x; -} - -template <> -__forceinline__ __device__ float apply_logits_post_hook( - float x, const float soft_cap) { - return logits_soft_cap_impl(x, soft_cap); -} - -template <> -__forceinline__ __device__ half2 -apply_logits_post_hook(half2 x, const float soft_cap) { - return x; -} - -template <> -__forceinline__ __device__ half2 -apply_logits_post_hook(half2 x, const float soft_cap) { - return logits_soft_cap_impl(x, soft_cap); -} - -} // namespace flashinfer - -#endif // FLASHINFER_ATTENTION_LOGITS_POST_HOOK_CUH_ diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 5ad6988c7..54642f5a2 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -32,8 +32,8 @@ #include "../pos_enc.cuh" #include "../utils.cuh" #include "cascade.cuh" -#include "logits_post_hook.cuh" #include "mask.cuh" +#include "variants.cuh" #include "warp_layout.cuh" namespace flashinfer { @@ -46,7 +46,7 @@ constexpr uint32_t warp_size = 32; namespace { -template +template constexpr bool is_invalid_configuration(uint32_t num_frags_x, uint32_t num_frags_y, uint32_t num_frags_z, uint32_t num_warps_x, uint32_t num_warps_z) { @@ -54,7 +54,7 @@ constexpr bool is_invalid_configuration(uint32_t num_frags_x, uint32_t num_frags (num_frags_y > 4 && num_frags_y % (2 * num_warps_x) != 0) || (num_frags_x * (8 * num_frags_y + 2 * sizeof(DTypeQKAccum) * num_frags_z) >= 256) || (sizeof(DTypeKV) == 1 && num_frags_z * 2 % num_warps_x != 0) || - (sizeof(DTypeKV) == 1 && pos_encoding_mode == PosEncodingMode::kRoPELlama)); + (sizeof(DTypeKV) == 1 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama)); } template @@ -88,15 +88,13 @@ __device__ __forceinline__ uint32_t get_warp_idx() { * \param x_second_half Second fragment x[offset:offset*16, j*16+d/2:(j+1)*16+d/2] * \param rope_freq Rope frequency * \param offset The offset of the first row in both fragments. - * \param scale A scale factor applied to the result (used to multiply sm_scale). * \note The sin/cos computation is slow, especially for A100 GPUs which has low * non tensor-ops flops, will optimize in the future. */ template __device__ __forceinline__ void k_frag_apply_llama_rope(T* x_first_half, T* x_second_half, const float* rope_freq, - const uint32_t kv_offset, - float scale = 1.f) { + const uint32_t kv_offset) { static_assert(sizeof(T) == 2); #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { @@ -107,8 +105,8 @@ __device__ __forceinline__ void k_frag_apply_llama_rope(T* x_first_half, T* x_se uint32_t i = reg_id / 4, j = (reg_id % 4) / 2; __sincosf(float(kv_offset + 8 * i) * rope_freq[2 * j + reg_id % 2], &sin, &cos); tmp = x_first_half[reg_id]; - x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin) * scale; - x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin) * scale; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin); } } @@ -116,8 +114,7 @@ template __device__ __forceinline__ void q_frag_apply_llama_rope(T* x_first_half, T* x_second_half, const float* rope_freq, const uint32_t qo_packed_offset, - const uint_fastdiv group_size, - float scale = 1.f) { + const uint_fastdiv group_size) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { float cos, sin, tmp; @@ -128,15 +125,17 @@ __device__ __forceinline__ void q_frag_apply_llama_rope(T* x_first_half, T* x_se __sincosf(float((qo_packed_offset + 8 * i) / group_size) * rope_freq[2 * j + reg_id % 2], &sin, &cos); tmp = x_first_half[reg_id]; - x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin) * scale; - x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin) * scale; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin); } } template -__device__ __forceinline__ void q_frag_apply_llama_rope_with_pos( - T* x_first_half, T* x_second_half, const float* rope_freq, const uint32_t qo_packed_offset, - const uint_fastdiv group_size, const IdType* q_offset, float scale = 1.f) { +__device__ __forceinline__ void q_frag_apply_llama_rope_with_pos(T* x_first_half, T* x_second_half, + const float* rope_freq, + const uint32_t qo_packed_offset, + const uint_fastdiv group_size, + const IdType* q_offset) { float pos[2] = {static_cast(q_offset[qo_packed_offset / group_size]), static_cast(q_offset[(qo_packed_offset + 8) / group_size])}; #pragma unroll @@ -148,8 +147,8 @@ __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos( uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); __sincosf(pos[i] * rope_freq[2 * j + reg_id % 2], &sin, &cos); tmp = x_first_half[reg_id]; - x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin) * scale; - x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin) * scale; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin); } } @@ -162,7 +161,6 @@ __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos( * \tparam T The data type of the input tensor. * \param smem The shared memory to store kv fragments. * \param gptr The global memory pointer. - * \param qkv_info The tensor info of the input tensor. * \param kv_idx_base The base kv index. * \param kv_len The length of kv tensor. */ @@ -213,10 +211,9 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* } template + uint32_t num_frags_z, SwizzleMode swizzle_mode, typename DType, typename IdType> __device__ __forceinline__ void page_produce_kv(smem_t smem, uint32_t* smem_offset, - paged_kv_t& paged_kv, + const paged_kv_t& paged_kv, const uint32_t kv_idx_base, const size_t* kv_offset, const uint32_t kv_len) { // NOTE(Zihao): for fp8, this function doesn't work for head_dim = 64 at the moment @@ -297,7 +294,7 @@ __device__ __forceinline__ void init_states(float (*o_frag)[num_frags_y][8], DTy for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - m[fx][j] = DTypeQKAccum(-5e4); + m[fx][j] = DTypeQKAccum(-math::inf); d[fx][j] = 1.f; } } @@ -345,10 +342,10 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, template -__device__ __forceinline__ void q_smem_inplace_apply_rotary_multiply_sm_scale( +__device__ __forceinline__ void q_smem_inplace_apply_rotary( const uint32_t q_packed_idx, const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, smem_t* q_smem, uint32_t* q_smem_offset_r, - float (*rope_freq)[4], const float sm_scale) { + float (*rope_freq)[4]) { if (get_warp_idx_z() == 0) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); @@ -367,7 +364,7 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_multiply_sm_scale( q_frag_apply_llama_rope( (DTypeQ*)q_frag_local[0], (DTypeQ*)q_frag_local[1], rope_freq[fyi], q_packed_idx + kv_len * group_size - qo_len * group_size + fx * 16 + lane_idx / 4, - group_size, sm_scale); + group_size); q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); q_smem_offset_r_first_half = @@ -381,10 +378,9 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_multiply_sm_scale( template -__device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( +__device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( const uint32_t q_packed_idx_base, const IdType* q_offset, smem_t* q_smem, - const uint_fastdiv group_size, uint32_t* q_smem_offset_r, float (*rope_freq)[4], - const float sm_scale) { + const uint_fastdiv group_size, uint32_t* q_smem_offset_r, float (*rope_freq)[4]) { if (get_warp_idx_z() == 0) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); @@ -402,7 +398,7 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos_multiply_sm q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_frag_apply_llama_rope_with_pos( (DTypeQ*)q_frag_local[0], (DTypeQ*)q_frag_local[1], rope_freq[fyi], - q_packed_idx_base + fx * 16 + lane_idx / 4, group_size, q_offset, sm_scale); + q_packed_idx_base + fx * 16 + lane_idx / 4, group_size, q_offset); q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); q_smem_offset_r_first_half = @@ -415,9 +411,11 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos_multiply_sm } template -__device__ __forceinline__ void q_smem_inplace_multiply_sm_scale(smem_t* q_smem, - const float sm_scale) { + SwizzleMode swizzle_mode, typename AttentionVariant> +__device__ __forceinline__ void q_smem_inplace_transform( + const typename AttentionVariant::ParamsT& params, AttentionVariant variant, + smem_t* q_smem) { + using DTypeQ = typename AttentionVariant::DTypeQ; const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); @@ -428,7 +426,7 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale(smem_tbase) + (i * num_warps + warp_idx) * 256 + lane_idx * 8); #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { - tmp[reg_id] *= sm_scale; + tmp[reg_id] = variant.QueryTransform(params, tmp[reg_id]); } tmp.store((DTypeQ*)(q_smem->base) + (i * num_warps + warp_idx) * 256 + lane_idx * 8); } @@ -514,12 +512,14 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id } } -template -__device__ __forceinline__ void compute_qk( - smem_t* q_smem, uint32_t* q_smem_offset_r, smem_t* k_smem, - uint32_t* k_smem_offset_r, DTypeQKAccum (*s_frag)[num_frags_z][8], const float soft_cap) { +template +__device__ __forceinline__ void compute_qk(smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + DTypeQKAccum (*s_frag)[num_frags_z][8]) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); @@ -587,92 +587,65 @@ __device__ __forceinline__ void compute_qk( } *q_smem_offset_r -= num_frags_y * 2; *k_smem_offset_r -= num_frags_y * sizeof(DTypeKV); +} - if constexpr (std::is_same::value) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { -#pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { - s_frag[fx][fz][reg_id] = - apply_logits_post_hook(s_frag[fx][fz][reg_id], soft_cap); - } - } - } - } else { - static_assert(std::is_same::value); +template +__device__ __forceinline__ void logits_transform(const typename AttentionVariant::ParamsT& params, + AttentionVariant variant, const uint32_t batch_idx, + const uint32_t qo_packed_idx_base, + const uint32_t kv_idx_base, const uint32_t qo_len, + const uint32_t kv_len, + const uint_fastdiv group_size, + DTypeQKAccum (*s_frag)[num_frags_z][8]) { + const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z; #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { - *(half2*)(&s_frag[fx][fz][reg_id * 2]) = apply_logits_post_hook( - *(half2*)(&s_frag[fx][fz][reg_id * 2]), soft_cap); - } - } - } - } -} - -template -__device__ __forceinline__ void apply_alibi_bias(const uint32_t qo_packed_idx_base, - const uint32_t kv_idx_base, const int32_t q_offset, - const uint_fastdiv group_size, - float (*alibi_slope)[2], - T (*s_frag)[num_frags_z][8]) { - const int32_t lane_idx = threadIdx.x; -#pragma unroll - for (int32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (int32_t fz = 0; fz < num_frags_z; ++fz) { -#pragma unroll - for (int32_t reg_id = 0; reg_id < 8; ++reg_id) { - const int32_t q_idx = - (qo_packed_idx_base + fx * 16 + lane_idx / 4 + 8 * ((reg_id % 4) / 2)) / - group_size, - kv_idx = kv_idx_base + fz * 16 + 2 * (lane_idx % 4) + 8 * (reg_id / 4) + - reg_id % 2; - s_frag[fx][fz][reg_id] += - T(alibi_slope[fx][(reg_id % 4) / 2]) * T(kv_idx - q_idx - q_offset); + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + fx * 16 + lane_idx / 4 + 8 * ((reg_id % 4) / 2), q, + r); + const uint32_t q_idx = q, kv_idx = kv_idx_base + fz * 16 + 2 * (lane_idx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + const uint32_t qo_head_idx = kv_head_idx * group_size + r; + s_frag[fx][fz][reg_id] = variant.LogitsTransform(params, s_frag[fx][fz][reg_id], batch_idx, + q_idx, kv_idx, qo_head_idx, kv_head_idx); } } } } -template -__device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base, - const uint32_t kv_idx_base, const uint32_t qo_len, - const uint32_t kv_len, const uint32_t window_left, - const uint32_t chunk_end, const uint_fastdiv group_size, - uint8_t* custom_mask, - DTypeQKAccum (*s_frag)[num_frags_z][8]) { - const uint32_t lane_idx = threadIdx.x; +template +__device__ __forceinline__ void logits_mask(const typename AttentionVariant::ParamsT& params, + AttentionVariant variant, const uint32_t batch_idx, + const uint32_t qo_packed_idx_base, + const uint32_t kv_idx_base, const uint32_t qo_len, + const uint32_t kv_len, const uint32_t chunk_end, + const uint_fastdiv group_size, + DTypeQKAccum (*s_frag)[num_frags_z][8]) { + const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z; #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t fz = 0; fz < num_frags_z; ++fz) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { - const uint32_t q_idx = - (qo_packed_idx_base + fx * 16 + lane_idx / 4 + 8 * ((reg_id % 4) / 2)) / - group_size, - kv_idx = kv_idx_base + fz * 16 + 2 * (lane_idx % 4) + 8 * (reg_id / 4) + - reg_id % 2; - const bool out_of_boundary = - (mask_mode == MaskMode::kCausal - ? (kv_idx + qo_len > kv_len + q_idx || (kv_idx >= chunk_end) || - kv_idx + qo_len + window_left < kv_len + q_idx) - : kv_idx >= chunk_end || kv_idx + qo_len + window_left < kv_len + q_idx); - s_frag[fx][fz][reg_id] = - (out_of_boundary || - (mask_mode == MaskMode::kCustom && q_idx < qo_len && - !((custom_mask[(q_idx * kv_len + kv_idx) / 8] >> ((q_idx * kv_len + kv_idx) % 8)) & - 1))) - ? DTypeQKAccum(-5e4) - : s_frag[fx][fz][reg_id]; + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + fx * 16 + lane_idx / 4 + 8 * ((reg_id % 4) / 2), q, + r); + const uint32_t q_idx = q, kv_idx = kv_idx_base + fz * 16 + 2 * (lane_idx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + const uint32_t qo_head_idx = kv_head_idx * group_size + r; + const bool mask = + (!(MASK_MODE == MaskMode::kCausal + ? (kv_idx + qo_len > kv_len + q_idx || (kv_idx >= chunk_end)) + : kv_idx >= chunk_end)) && + variant.LogitsMask(params, batch_idx, q_idx, kv_idx, qo_head_idx, kv_head_idx); + s_frag[fx][fz][reg_id] = (mask) ? s_frag[fx][fz][reg_id] : DTypeQKAccum(-math::inf); } } } @@ -839,7 +812,7 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], DTy for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - d_rcp[fx][j] = (m[fx][j] != DTypeQKAccum(-5e4)) ? math::ptx_rcp(d[fx][j]) : 0.f; + d_rcp[fx][j] = (m[fx][j] != DTypeQKAccum(-math::inf)) ? math::ptx_rcp(d[fx][j]) : 0.f; } } @@ -898,7 +871,7 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_ float o_scale[2][num_warps_z]; #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - float m_new = -5e4, d_new = 1.f; + float m_new = -math::inf, d_new = 1.f; #pragma unroll for (uint32_t i = 0; i < num_warps_z; ++i) { float2 md = smem_md[(((i * num_warps_x + get_warp_idx_x()) * @@ -956,13 +929,13 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_ } template + SwizzleMode swizzle_mode, typename DTypeO> __device__ __forceinline__ void write_o_reg_gmem( - float (*o_frag)[num_frags_y][8], smem_t* o_smem, DTypeOut* o_ptr_base, + float (*o_frag)[num_frags_y][8], smem_t* o_smem, DTypeO* o_ptr_base, const uint32_t o_packed_idx_base, const uint32_t qo_upper_bound, const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv group_size) { constexpr uint32_t head_dim = num_frags_y * 16; - constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); const uint32_t warp_idx_x = get_warp_idx_x(); const uint32_t lane_idx = threadIdx.x; @@ -972,7 +945,7 @@ __device__ __forceinline__ void write_o_reg_gmem( #pragma unroll for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; - vec_cast::cast<8>((DTypeOut*)o_frag_f16, o_frag[fx][fy]); + vec_cast::cast<8>((DTypeO*)o_frag_f16, o_frag[fx][fy]); #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED uint32_t o_smem_offset_w = o_smem->get_permuted_offset( (warp_idx_x * num_frags_x + fx) * 16 + lane_idx % 16, fy * 2 + lane_idx / 16); @@ -1000,13 +973,13 @@ __device__ __forceinline__ void write_o_reg_gmem( uint32_t q, r; group_size.divmod(o_packed_idx_base + lane_idx / 8 + fx * 16 + j * 4, q, r); const uint32_t o_idx = q; - DTypeOut* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h; + DTypeO* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h; #pragma unroll for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (o_idx < qo_upper_bound) { o_smem->store_128b(o_smem_offset_w, o_ptr); } - o_ptr += 8 * num_elems_per_128b(); + o_ptr += 8 * num_elems_per_128b(); o_smem_offset_w = o_smem->template advance_offset_by_column<8>(o_smem_offset_w, fyo); } o_smem_offset_w = @@ -1023,56 +996,63 @@ __device__ __forceinline__ void write_o_reg_gmem( * \brief FlashAttention prefill CUDA kernel for a single request. * \tparam partition_kv Whether to split kv_len into chunks. * \tparam mask_mode The mask mode used in the attention operation. - * \tparam pos_encoding_mode The positional encoding mode. + * \tparam POS_ENCODING_MODE The positional encoding mode. * \tparam num_frags_x The number of fragments in x dimension. * \tparam num_frags_y The number of fragments in y dimension. * \tparam num_frags_z The number of fragments in z dimension. * \tparam num_warps The number of warps in the threadblock. * \tparam DTypeQ The data type of the query tensor. * \tparam DTypeKV The data type of the key/value tensor. - * \tparam DTypeOut The data type of the output tensor. + * \tparam DTypeO The data type of the output tensor. * \param q The query tensor. * \param k The key tensor. * \param v The value tensor. * \param o The output tensor. * \param tmp The temporary buffer (used when partition_kv is true). * \param lse The logsumexp value. - * \param qkv_info The tensor info of the input tensor. - * \param sm_scale The scale factor applied to the softmax score. * \param log2_rope_rcp_scale log2(1/(rope_scale)), where rope_scale is the scaling * factor used in RoPE interpolation. * \param log2_rope_rcp_theta log2(1/(rope_theta)), where rope_theta is the theta * used in RoPE. */ -template +template __global__ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVCacheKernel( - DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, DTypeKV* __restrict__ v, - uint8_t* __restrict__ custom_mask, DTypeOut* __restrict__ o, float* __restrict__ lse, - const uint32_t qo_len, const uint32_t kv_len, const bool partition_kv, - const uint_fastdiv group_size, const uint32_t q_stride_n, const uint32_t q_stride_h, - const uint32_t kv_stride_n, const uint32_t kv_stride_h, const int32_t maybe_window_left, - const float logits_soft_cap, float sm_scale, const float log2_rope_rcp_scale, - const float log2_rope_rcp_theta) { + const uint_fastdiv group_size, + const __grid_constant__ typename AttentionVariant::ParamsT params) { + using DTypeQ = typename AttentionVariant::DTypeQ; #if (__CUDA_ARCH__ < 800) if constexpr (std::is_same::value) { FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); } else { #endif + using DTypeKV = typename AttentionVariant::DTypeKV; + using DTypeO = typename AttentionVariant::DTypeO; + DTypeQ* q = params.q; + DTypeKV* k = params.k; + DTypeKV* v = params.v; + DTypeO* o = params.o; + float* lse = params.lse; + const uint32_t qo_len = params.qo_len; + const uint32_t kv_len = params.kv_len; + const bool partition_kv = params.partition_kv; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + const uint32_t kv_stride_n = params.kv_stride_n; + const uint32_t kv_stride_h = params.kv_stride_h; + const int32_t maybe_window_left = params.window_left; + const float log2_rope_rcp_scale = params.log2_rope_rcp_scale; + const float log2_rope_rcp_theta = params.log2_rope_rcp_theta; static_assert(sizeof(DTypeQ) == 2); - static_assert(sizeof(DTypeOut) == 2); - sm_scale *= - (logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap)); + static_assert(sizeof(DTypeO) == 2); const uint32_t lane_idx = threadIdx.x, warp_idx = get_warp_idx(); const uint32_t bx = blockIdx.x, chunk_idx = blockIdx.y, kv_head_idx = blockIdx.z; const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, /*head_dim=*/num_frags_y * 16); - float alibi_slopes[num_frags_x][2]; const uint32_t num_chunks = gridDim.y; const uint32_t max_chunk_size = partition_kv ? ceil_div(kv_len, num_chunks) : kv_len; @@ -1082,21 +1062,21 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC const uint32_t chunk_size = chunk_end - chunk_start; auto block = cg::this_thread_block(); - const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len; + extern __shared__ uint8_t smem[]; + AttentionVariant variant(params, /*batch_idx=*/0, smem); + const uint32_t window_left = variant.window_left; constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); - - extern __shared__ uint8_t smem[]; + constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); DTypeQKAccum s_frag[num_frags_x][num_frags_z][8]; float o_frag[num_frags_x][num_frags_y][8]; DTypeQKAccum m[num_frags_x][2]; float d[num_frags_x][2]; float rope_freq[num_frags_y / 2][4]; - if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { init_rope_freq(rope_freq, log2_rope_rcp_scale, log2_rope_rcp_theta); } init_states(o_frag, m, d); @@ -1109,13 +1089,14 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC DTypeQ* q_ptr_base = q + qkv_info.get_q_elem_offset(0, kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b()); - DTypeOut* o_ptr_base = + DTypeO* o_ptr_base = partition_kv ? o + chunk_idx * num_qo_heads * head_dim + qkv_info.get_o_elem_offset(0, kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()) + (lane_idx % 8) * num_elems_per_128b()) : o + qkv_info.get_o_elem_offset(0, kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()); + (lane_idx % 8) * num_elems_per_128b()); + uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( get_warp_idx_x() * num_frags_x * 16 + lane_idx % 16, lane_idx / 16); @@ -1127,28 +1108,14 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC cp_async::wait_group<0>(); block.sync(); - if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { - q_smem_inplace_apply_rotary_multiply_sm_scale( - qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, - sm_scale); - } else { - q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); - } - - if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_head_idx = - kv_head_idx * group_size + - (qo_packed_idx_base + lane_idx / 4 + j * 8 + fx * 16) % group_size; - alibi_slopes[fx][j] = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; - } - } + if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + q_smem_inplace_apply_rotary( + qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq); + block.sync(); } + q_smem_inplace_transform( + params, variant, &qo_smem); constexpr SwizzleMode swizzle_mode_kv = (sizeof(DTypeKV) == 1 && head_dim == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; @@ -1161,7 +1128,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC 16 * head_dim); const uint32_t num_iterations = ceil_div( - mask_mode == MaskMode::kCausal + MASK_MODE == MaskMode::kCausal ? min(chunk_size, sub_if_greater_or_zero( kv_len - qo_len + ((bx + 1) * num_rows_per_cta) / group_size, chunk_start)) @@ -1174,7 +1141,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC (16 * num_warps_z * num_frags_z)); const uint32_t mask_iteration = - (mask_mode == MaskMode::kCausal + (MASK_MODE == MaskMode::kCausal ? min(chunk_size, sub_if_greater_or_zero(kv_len + (bx * num_rows_per_cta) / group_size - qo_len, chunk_start)) @@ -1189,6 +1156,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC v + qkv_info.get_kv_elem_offset( chunk_start + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, kv_head_idx, (lane_idx % kv_frag_cols) * num_elems_per_128b()); + uint32_t k_smem_offset_r = k_smem.get_permuted_offset( get_warp_idx_z() * num_frags_z * 16 + 8 * (lane_idx / 16) + lane_idx % 8, @@ -1210,7 +1178,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC cp_async::wait_group<1>(); block.sync(); - if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { k_smem_inplace_apply_rotary( chunk_start + iter * 16 * num_warps_z * num_frags_z, &k_smem, &k_smem_offset_r, @@ -1219,32 +1187,22 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, - &k_smem_offset_r, s_frag, logits_soft_cap); + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + + logits_transform( + params, variant, /*batch_idx=*/0, qo_packed_idx_base, + chunk_start + + (iter * num_warps_z + get_warp_idx_z()) * num_frags_z * 16, + qo_len, kv_len, group_size, s_frag); - if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { - apply_alibi_bias( - qo_packed_idx_base, - chunk_start + (iter * num_warps_z + get_warp_idx_z()) * - num_frags_z * 16, - int(kv_len) - int(qo_len), group_size, alibi_slopes, s_frag); - } // apply mask - if constexpr (mask_mode == MaskMode::kCustom) { - mask_s( - qo_packed_idx_base, + if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { + logits_mask( + params, variant, /*batch_idx=*/0, qo_packed_idx_base, chunk_start + (iter * num_warps_z + get_warp_idx_z()) * num_frags_z * 16, - qo_len, kv_len, window_left, chunk_end, group_size, custom_mask, s_frag); - } else { - if (iter >= mask_iteration || iter < window_iteration) { - mask_s( - qo_packed_idx_base, - chunk_start + (iter * num_warps_z + get_warp_idx_z()) * - num_frags_z * 16, - qo_len, kv_len, window_left, chunk_end, group_size, nullptr, s_frag); - } + qo_len, kv_len, chunk_end, group_size, s_frag); } // compute m,d states in online softmax @@ -1314,34 +1272,190 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC #endif } -template +template +cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp, + cudaStream_t stream) { + using DTypeQ = typename AttentionVariant::DTypeQ; + using DTypeKV = typename AttentionVariant::DTypeKV; + using DTypeO = typename AttentionVariant::DTypeO; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint32_t qo_len = params.qo_len; + const uint32_t kv_len = params.kv_len; + if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { + std::ostringstream err_msg; + err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal " + "to qo_len, got kv_len" + << kv_len << " and qo_len " << qo_len; + throw std::invalid_argument(err_msg.str()); + } + + const uint32_t group_size = num_qo_heads / num_kv_heads; + const uint_fastdiv group_size_fastdiv(group_size); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + WarpLayout warp_layout; + int64_t unpacked_qo_len = qo_len * group_size; + if (unpacked_qo_len > 64 && HEAD_DIM < 256) { + warp_layout = WarpLayout::k4x1x2; + } else { + auto compute_capacity = GetCudaComputeCapability(); + if (compute_capacity.first >= 8) { + // Ampere or newer + if (unpacked_qo_len > 16) { + warp_layout = WarpLayout::k4x1x1; + } else { + warp_layout = WarpLayout::k1x4x1; + } + } else { + // NOTE(Zihao): not enough shared memory on Turing for 1x4x1 layout + warp_layout = WarpLayout::k4x1x1; + } + } + + DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { + constexpr uint32_t num_frags_x = get_num_frags_x(); + using DTypeQKAccum = + typename std::conditional::value, + half, float>::type; + + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + int max_smem_per_sm = 0; + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); + // we expect each sm execute two threadblocks + // TODO(Zihao): fix the following computation + const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeQ) * 16) ? 2 : 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + + constexpr uint32_t num_warps_x = get_num_warps_x(); + constexpr uint32_t num_warps_z = get_num_warps_z(); + const uint32_t max_num_frags_z_reg = + (HEAD_DIM >= 128 && num_frags_x == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !ALLOW_FP16_QK_REDUCTION) + ? 2 + : (8 / num_frags_x); + // TODO(Zihao): fix the following computation + const uint32_t max_num_frags_z_smem = + (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - num_frags_x * num_warps_x) / + (2 * num_warps_z); + + // control num_frags_z for maximum warp occupancy + DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { + if constexpr (is_invalid_configuration( + num_frags_x, num_frags_y, num_frags_z, num_warps_x, num_warps_z)) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : num_frags_x=" << num_frags_x + << " num_frags_y=" << num_frags_y << " num_frags_z=" << num_frags_z + << " num_warps_x=" << num_warps_x << " num_warps_z=" << num_warps_z + << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + throw std::invalid_argument(err_msg.str()); + } else { + constexpr uint32_t num_threads = (num_warps_x * num_warps_z) * warp_size; + constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; + auto kernel = SinglePrefillWithKVCacheKernel; + // TODO(Zihao): fix the following computation + uint32_t smem_size = (num_frags_x * num_warps_x * sizeof(DTypeQ) + + num_frags_z * num_warps_z * 2 * sizeof(DTypeQ)) * + 16 * HEAD_DIM; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + int num_blocks_per_sm = 0; + int num_sm = 0; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, kernel, num_threads, smem_size)); + uint32_t max_num_kv_chunks = + (num_blocks_per_sm * num_sm) / + (num_kv_heads * ceil_div(qo_len * group_size, num_rows_per_cta)); + uint32_t num_chunks; + if (max_num_kv_chunks > 0) { + uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); + num_chunks = ceil_div(kv_len, chunk_size); + } else { + num_chunks = 0; + } + + if (num_chunks <= 1 || tmp == nullptr) { + // Enough parallelism, do not split-kv + params.partition_kv = false; + void* args[] = {(void*)&group_size_fastdiv, (void*)¶ms}; + dim3 nblks(ceil_div(qo_len * group_size, num_rows_per_cta), 1, num_kv_heads); + dim3 nthrs(32, num_warps_x, num_warps_z); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // Use cooperative groups to increase occupancy + params.partition_kv = true; + float* tmp_lse = (float*)(tmp + num_chunks * qo_len * num_qo_heads * HEAD_DIM); + auto o = params.o; + auto lse = params.lse; + params.o = tmp; + params.lse = tmp_lse; + void* args[] = {(void*)&group_size_fastdiv, (void*)¶ms}; + dim3 nblks(ceil_div(qo_len * group_size, num_rows_per_cta), num_chunks, num_kv_heads); + dim3 nthrs(32, num_warps_x, num_warps_z); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + FLASHINFER_CUDA_CALL(MergeStates(tmp, tmp_lse, o, lse, num_chunks, qo_len, num_qo_heads, + HEAD_DIM, stream)); + } + } + }) + }); + return cudaSuccess; +} + +template __global__ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRaggedKVCacheKernel( - DTypeQ* __restrict__ q, IdType* __restrict__ request_indices, - IdType* __restrict__ q_tile_indices, IdType* __restrict__ kv_tile_indices, - IdType* __restrict__ q_indptr, DTypeKV* __restrict__ k, DTypeKV* __restrict__ v, - IdType* __restrict__ kv_indptr, uint8_t* __restrict__ custom_mask, - IdType* __restrict__ qk_indptr, IdType* __restrict__ q_offset, - IdType* __restrict__ k_rope_pos_offset, IdType* __restrict__ o_indptr, DTypeOut* __restrict__ o, - float* __restrict__ lse, bool* __restrict__ block_valid_mask, - IdType* __restrict__ kv_chunk_size_ptr, const bool partition_kv, const uint_fastdiv group_size, - const uint32_t q_stride_n, const uint32_t q_stride_h, const uint32_t kv_stride_n, - const uint32_t kv_stride_h, const int32_t maybe_window_left, const float logits_soft_cap, - float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { + const uint_fastdiv group_size, + const __grid_constant__ typename AttentionVariant::ParamsT params) { + using DTypeQ = typename AttentionVariant::DTypeQ; #if (__CUDA_ARCH__ < 800) if constexpr (std::is_same::value) { FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); } else { #endif + using DTypeKV = typename AttentionVariant::DTypeKV; + using DTypeO = typename AttentionVariant::DTypeO; + using IdType = typename AttentionVariant::IdType; + DTypeQ* q = params.q; + IdType* request_indices = params.request_indices; + IdType* qo_tile_indices = params.qo_tile_indices; + IdType* kv_tile_indices = params.kv_tile_indices; + IdType* q_indptr = params.q_indptr; + IdType* kv_indptr = params.kv_indptr; + DTypeKV* k = params.k; + DTypeKV* v = params.v; + IdType* q_offset = params.q_offset; + IdType* k_rope_pos_offset = params.k_rope_pos_offset; + IdType* o_indptr = params.o_indptr; + DTypeO* o = params.o; + float* lse = params.lse; + bool* block_valid_mask = params.block_valid_mask; + const bool partition_kv = params.partition_kv; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + const uint32_t kv_stride_n = params.kv_stride_n; + const uint32_t kv_stride_h = params.kv_stride_h; + const int32_t maybe_window_left = params.window_left; + const float log2_rope_rcp_scale = params.log2_rope_rcp_scale; + const float log2_rope_rcp_theta = params.log2_rope_rcp_theta; + static_assert(sizeof(DTypeQ) == 2); - static_assert(sizeof(DTypeOut) == 2); - sm_scale *= - (logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap)); + static_assert(sizeof(DTypeO) == 2); constexpr uint32_t head_dim = num_frags_y * 16; - const uint32_t kv_chunk_size = *kv_chunk_size_ptr; + const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); auto block = cg::this_thread_block(); const uint32_t bx = blockIdx.x, lane_idx = threadIdx.x, @@ -1350,13 +1464,14 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg return; } const uint32_t num_kv_heads = gridDim.z, num_qo_heads = group_size * num_kv_heads; - const uint32_t request_idx = request_indices[bx], qo_tile_idx = q_tile_indices[bx], + const uint32_t request_idx = request_indices[bx], qo_tile_idx = qo_tile_indices[bx], kv_tile_idx = kv_tile_indices[bx]; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; - const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx], - kv_len = kv_indptr[request_idx + 1] - kv_indptr[request_idx]; + extern __shared__ uint8_t smem[]; + AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); + const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, + window_left = variant.window_left; const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; - const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len; const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; const uint32_t chunk_end = @@ -1364,15 +1479,12 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg const uint32_t chunk_size = chunk_end - chunk_start; const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, /*head_dim=*/num_frags_y * 16); - float alibi_slopes[num_frags_x][2]; const uint32_t qo_upper_bound = min(qo_len, ceil_div((qo_tile_idx + 1) * num_rows_per_cta, group_size)); constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); - - extern __shared__ uint8_t smem[]; + constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); DTypeQKAccum s_frag[num_frags_x][num_frags_z][8]; float o_frag[num_frags_x][num_frags_y][8]; @@ -1380,7 +1492,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg float d[num_frags_x][2]; float rope_freq[num_frags_y / 2][4]; - if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { init_rope_freq(rope_freq, log2_rope_rcp_scale, log2_rope_rcp_theta); } init_states(o_frag, m, d); @@ -1394,13 +1506,13 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg q + qkv_info.get_q_elem_offset(q_indptr[request_idx], kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b()); - DTypeOut* o_ptr_base = + DTypeO* o_ptr_base = partition_kv ? o + kv_tile_idx * num_qo_heads * head_dim + qkv_info.get_o_elem_offset(o_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()) + (lane_idx % 8) * num_elems_per_128b()) : o + qkv_info.get_o_elem_offset(o_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()); + (lane_idx % 8) * num_elems_per_128b()); uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( get_warp_idx_x() * num_frags_x * 16 + lane_idx % 16, @@ -1414,38 +1526,24 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg cp_async::wait_group<0>(); block.sync(); - if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { if (!q_offset) { - q_smem_inplace_apply_rotary_multiply_sm_scale( - qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, - sm_scale); + q_smem_inplace_apply_rotary( + qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq); } else { - q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale< - num_warps_x, num_warps_z, num_frags_x, num_frags_y, swizzle_mode_q, DTypeQ>( + q_smem_inplace_apply_rotary_with_pos( qo_packed_idx_base, q_offset + q_indptr[request_idx], &qo_smem, group_size, - &q_smem_offset_r, rope_freq, sm_scale); - } - } else { - q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); - } - - if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_head_idx = - kv_head_idx * group_size + - (qo_packed_idx_base + lane_idx / 4 + j * 8 + fx * 16) % group_size; - alibi_slopes[fx][j] = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; - } + &q_smem_offset_r, rope_freq); } + block.sync(); } + q_smem_inplace_transform( + params, variant, &qo_smem); const uint32_t num_iterations = ceil_div( - (mask_mode == MaskMode::kCausal + (MASK_MODE == MaskMode::kCausal ? min(chunk_size, sub_if_greater_or_zero( kv_len - qo_len + ((qo_tile_idx + 1) * num_rows_per_cta) / group_size, @@ -1459,7 +1557,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg (16 * num_warps_z * num_frags_z)); const uint32_t mask_iteration = - (mask_mode == MaskMode::kCausal + (MASK_MODE == MaskMode::kCausal ? min(chunk_size, sub_if_greater_or_zero( kv_len + (qo_tile_idx * num_rows_per_cta) / group_size - qo_len, chunk_start)) @@ -1509,7 +1607,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg cp_async::wait_group<1>(); block.sync(); - if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { k_smem_inplace_apply_rotary( (k_rope_pos_offset == nullptr ? 0 : k_rope_pos_offset[request_idx]) + chunk_start + @@ -1519,34 +1617,22 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, - &k_smem_offset_r, s_frag, logits_soft_cap); - - if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { - // TODO(Zihao): handle the case that q_offset is specified - apply_alibi_bias( - qo_packed_idx_base, - chunk_start + (iter * num_warps_z + get_warp_idx_z()) * - num_frags_z * 16, - int(kv_len) - int(qo_len), group_size, alibi_slopes, s_frag); - } + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + + logits_transform( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + + (iter * num_warps_z + get_warp_idx_z()) * num_frags_z * 16, + qo_len, kv_len, group_size, s_frag); + // apply mask - if constexpr (mask_mode == MaskMode::kCustom) { - mask_s( - qo_packed_idx_base, + if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { + logits_mask( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, chunk_start + (iter * num_warps_z + get_warp_idx_z()) * num_frags_z * 16, - qo_len, kv_len, window_left, chunk_end, group_size, - custom_mask + qk_indptr[request_idx], s_frag); - } else { - if (iter >= mask_iteration || iter < window_iteration) { - mask_s( - qo_packed_idx_base, - chunk_start + (iter * num_warps_z + get_warp_idx_z()) * - num_frags_z * 16, - qo_len, kv_len, window_left, chunk_end, group_size, nullptr, s_frag); - } + qo_len, kv_len, chunk_end, group_size, s_frag); } // compute m,d states in online softmax @@ -1618,32 +1704,42 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg #endif } -template +template __global__ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPagedKVCacheKernel( - IdType* __restrict__ request_indices, IdType* __restrict__ q_tile_indices, - IdType* __restrict__ kv_tile_indices, DTypeQ* __restrict__ q, - paged_kv_t paged_kv, IdType* __restrict__ q_indptr, - uint8_t* __restrict__ custom_mask, IdType* __restrict__ qk_indptr, - IdType* __restrict__ q_offset, IdType* __restrict__ o_indptr, DTypeOut* __restrict__ o, - float* __restrict__ lse, bool* __restrict__ block_valid_mask, - IdType* __restrict__ kv_chunk_size_ptr, const bool partition_kv, const uint_fastdiv group_size, - int32_t maybe_window_left, const float logits_soft_cap, float sm_scale, - const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { + const uint_fastdiv group_size, + const __grid_constant__ typename AttentionVariant::ParamsT params) { + using DTypeQ = typename AttentionVariant::DTypeQ; #if (__CUDA_ARCH__ < 800) if constexpr (std::is_same::value) { FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); } else { #endif + using DTypeKV = typename AttentionVariant::DTypeKV; + using DTypeO = typename AttentionVariant::DTypeO; + using IdType = typename AttentionVariant::IdType; + IdType* request_indices = params.request_indices; + IdType* qo_tile_indices = params.qo_tile_indices; + IdType* kv_tile_indices = params.kv_tile_indices; + DTypeQ* q = params.q; + IdType* q_indptr = params.q_indptr; + IdType* q_offset = params.q_offset; + IdType* o_indptr = params.o_indptr; + DTypeO* o = params.o; + float* lse = params.lse; + bool* block_valid_mask = params.block_valid_mask; + const paged_kv_t& paged_kv = params.paged_kv; + const bool partition_kv = params.partition_kv; + const int32_t maybe_window_left = params.window_left; + const float log2_rope_rcp_scale = params.log2_rope_rcp_scale; + const float log2_rope_rcp_theta = params.log2_rope_rcp_theta; + static_assert(sizeof(DTypeQ) == 2); - static_assert(sizeof(DTypeOut) == 2); - sm_scale *= - (logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap)); + static_assert(sizeof(DTypeO) == 2); auto block = cg::this_thread_block(); - const uint32_t kv_chunk_size = *kv_chunk_size_ptr; + const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); const uint32_t bx = blockIdx.x, lane_idx = threadIdx.x, warp_idx = get_warp_idx(), kv_head_idx = blockIdx.z; @@ -1651,19 +1747,15 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage return; } const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size; - float alibi_slopes[num_frags_x][2]; - const uint32_t request_idx = request_indices[bx], qo_tile_idx = q_tile_indices[bx], + const uint32_t request_idx = request_indices[bx], qo_tile_idx = qo_tile_indices[bx], kv_tile_idx = kv_tile_indices[bx]; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; - const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx], - kv_len = (paged_kv.indptr[request_idx + 1] != paged_kv.indptr[request_idx]) - ? (paged_kv.indptr[request_idx + 1] - paged_kv.indptr[request_idx] - - 1) * paged_kv.page_size + - paged_kv.last_page_len[request_idx] - : 0; + extern __shared__ uint8_t smem[]; + AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); + const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, + window_left = variant.window_left; const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; - const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len; const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; const uint32_t chunk_end = @@ -1675,9 +1767,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); - - extern __shared__ uint8_t smem[]; + constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); DTypeQKAccum s_frag[num_frags_x][num_frags_z][8]; float o_frag[num_frags_x][num_frags_y][8]; @@ -1685,7 +1775,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage float d[num_frags_x][2]; float rope_freq[num_frags_y / 2][4]; - if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { init_rope_freq(rope_freq, log2_rope_rcp_scale, log2_rope_rcp_theta); } init_states(o_frag, m, d); @@ -1698,13 +1788,13 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage DTypeQ* q_ptr_base = q + get_elem_offset_impl(q_indptr[request_idx], kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b(), q_stride_n, q_stride_h); - DTypeOut* o_ptr_base = + DTypeO* o_ptr_base = partition_kv ? o + kv_tile_idx * num_qo_heads * head_dim + get_elem_offset_impl(o_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b(), + (lane_idx % 8) * num_elems_per_128b(), num_qo_heads * head_dim, head_dim) : o + get_elem_offset_impl(o_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b(), + (lane_idx % 8) * num_elems_per_128b(), num_qo_heads * head_dim, head_dim); uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( get_warp_idx_x() * num_frags_x * 16 + lane_idx % 16, @@ -1718,35 +1808,21 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage cp_async::wait_group<0>(); block.sync(); - if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { if (q_offset == nullptr) { - q_smem_inplace_apply_rotary_multiply_sm_scale( - qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, - sm_scale); + q_smem_inplace_apply_rotary( + qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq); } else { - q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale< - num_warps_x, num_warps_z, num_frags_x, num_frags_y, swizzle_mode_q, DTypeQ>( + q_smem_inplace_apply_rotary_with_pos( qo_packed_idx_base, q_offset + q_indptr[request_idx], &qo_smem, group_size, - &q_smem_offset_r, rope_freq, sm_scale); - } - } else { - q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); - } - - if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_head_idx = - kv_head_idx * group_size + - (qo_packed_idx_base + lane_idx / 4 + j * 8 + fx * 16) % group_size; - alibi_slopes[fx][j] = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; - } + &q_smem_offset_r, rope_freq); } + block.sync(); } + q_smem_inplace_transform( + params, variant, &qo_smem); constexpr SwizzleMode swizzle_mode_kv = (sizeof(DTypeKV) == 1 && head_dim == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; @@ -1794,7 +1870,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage cp_async::commit_group(); const uint32_t num_iterations = ceil_div( - (mask_mode == MaskMode::kCausal + (MASK_MODE == MaskMode::kCausal ? min(chunk_size, sub_if_greater_or_zero( kv_len - qo_len + ((qo_tile_idx + 1) * num_rows_per_cta) / group_size, @@ -1808,7 +1884,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage (16 * num_warps_z * num_frags_z)); const uint32_t mask_iteration = - (mask_mode == MaskMode::kCausal + (MASK_MODE == MaskMode::kCausal ? min(chunk_size, sub_if_greater_or_zero( kv_len + (qo_tile_idx * num_rows_per_cta) / group_size - qo_len, chunk_start)) @@ -1835,7 +1911,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage cp_async::wait_group<1>(); block.sync(); - if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { k_smem_inplace_apply_rotary( (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[request_idx]) + @@ -1845,34 +1921,22 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, - &k_smem_offset_r, s_frag, logits_soft_cap); - - if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { - // TODO(Zihao): handle the case that q_offset is specified - apply_alibi_bias( - qo_packed_idx_base, - chunk_start + (iter * num_warps_z + get_warp_idx_z()) * - num_frags_z * 16, - int(kv_len) - int(qo_len), group_size, alibi_slopes, s_frag); - } + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + + logits_transform( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + + (iter * num_warps_z + get_warp_idx_z()) * num_frags_z * 16, + qo_len, kv_len, group_size, s_frag); + // apply mask - if constexpr (mask_mode == MaskMode::kCustom) { - mask_s( - qo_packed_idx_base, + if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { + logits_mask( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, chunk_start + (iter * num_warps_z + get_warp_idx_z()) * num_frags_z * 16, - qo_len, kv_len, window_left, chunk_end, group_size, - custom_mask + qk_indptr[request_idx], s_frag); - } else { - if (iter >= mask_iteration || iter < window_iteration) { - mask_s( - qo_packed_idx_base, - chunk_start + (iter * num_warps_z + get_warp_idx_z()) * - num_frags_z * 16, - qo_len, kv_len, window_left, chunk_end, group_size, nullptr, s_frag); - } + qo_len, kv_len, chunk_end, group_size, s_frag); } // compute m,d states in online softmax @@ -1944,199 +2008,21 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage #endif } -template -cudaError_t SinglePrefillWithKVCacheDispatched( - DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp, float* lse, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, - uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, uint32_t kv_stride_h, - int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream) { - const float log2_rope_rcp_scale = -std::log2f(rope_scale); - const float log2_rope_rcp_theta = -std::log2f(rope_theta); - if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { - std::ostringstream err_msg; - err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal " - "to qo_len, got kv_len" - << kv_len << " and qo_len " << qo_len; - throw std::invalid_argument(err_msg.str()); - } - - const uint32_t group_size = num_qo_heads / num_kv_heads; - const uint_fastdiv group_size_fastdiv(group_size); - constexpr uint32_t num_frags_y = HEAD_DIM / 16; - WarpLayout warp_layout; - int64_t unpacked_qo_len = qo_len * group_size; - if (unpacked_qo_len > 64 && HEAD_DIM < 256) { - warp_layout = WarpLayout::k4x1x2; - } else { - auto compute_capacity = GetCudaComputeCapability(); - if (compute_capacity.first >= 8) { - // Ampere or newer - if (unpacked_qo_len > 16) { - warp_layout = WarpLayout::k4x1x1; - } else { - warp_layout = WarpLayout::k1x4x1; - } - } else { - // NOTE(Zihao): not enough shared memory on Turing for 1x4x1 layout - warp_layout = WarpLayout::k4x1x1; - } - } - - DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { - constexpr uint32_t num_frags_x = get_num_frags_x(); - using DTypeQKAccum = - typename std::conditional::value, - half, float>::type; - - int dev_id = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - int max_smem_per_sm = 0; - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( - &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); - // we expect each sm execute two threadblocks - // TODO(Zihao): fix the following computation - const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeQ) * 16) ? 2 : 1; - const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; - - constexpr uint32_t num_warps_x = get_num_warps_x(); - constexpr uint32_t num_warps_z = get_num_warps_z(); - const uint32_t max_num_frags_z_reg = - (HEAD_DIM >= 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama && - !ALLOW_FP16_QK_REDUCTION) - ? 2 - : (8 / num_frags_x); - // TODO(Zihao): fix the following computation - const uint32_t max_num_frags_z_smem = - (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - num_frags_x * num_warps_x) / - (2 * num_warps_z); - - // control num_frags_z for maximum warp occupancy - DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - if constexpr (is_invalid_configuration( - num_frags_x, num_frags_y, num_frags_z, num_warps_x, num_warps_z)) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg << "FlashInfer Internal Error: Invalid configuration : num_frags_x=" << num_frags_x - << " num_frags_y=" << num_frags_y << " num_frags_z=" << num_frags_z - << " num_warps_x=" << num_warps_x << " num_warps_z=" << num_warps_z - << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" - " and report the issue to the developers."; - throw std::invalid_argument(err_msg.str()); - } else { - constexpr uint32_t num_threads = (num_warps_x * num_warps_z) * warp_size; - constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; - auto kernel = - SinglePrefillWithKVCacheKernel; - // TODO(Zihao): fix the following computation - uint32_t smem_size = (num_frags_x * num_warps_x * sizeof(DTypeQ) + - num_frags_z * num_warps_z * 2 * sizeof(DTypeQ)) * - 16 * HEAD_DIM; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - int num_blocks_per_sm = 0; - int num_sm = 0; - FLASHINFER_CUDA_CALL( - cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads, smem_size)); - uint32_t max_num_kv_chunks = - (num_blocks_per_sm * num_sm) / - (num_kv_heads * ceil_div(qo_len * group_size, num_rows_per_cta)); - uint32_t num_chunks; - if (max_num_kv_chunks > 0) { - uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); - num_chunks = ceil_div(kv_len, chunk_size); - } else { - num_chunks = 0; - } - - if (num_chunks <= 1 || tmp == nullptr) { - // Enough parallelism, do not split-kv - bool partition_kv = false; - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&custom_mask, - (void*)&o, - (void*)&lse, - (void*)&qo_len, - (void*)&kv_len, - (void*)&partition_kv, - (void*)&group_size_fastdiv, - (void*)&q_stride_n, - (void*)&q_stride_h, - (void*)&kv_stride_n, - (void*)&kv_stride_h, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&log2_rope_rcp_scale, - (void*)&log2_rope_rcp_theta}; - dim3 nblks(ceil_div(qo_len * group_size, num_rows_per_cta), 1, num_kv_heads); - dim3 nthrs(32, num_warps_x, num_warps_z); - - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - } else { - // Use cooperative groups to increase occupancy - bool partition_kv = true; - float* tmp_lse = (float*)(tmp + num_chunks * qo_len * num_qo_heads * HEAD_DIM); - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&custom_mask, - (void*)&tmp, - (void*)&tmp_lse, - (void*)&qo_len, - (void*)&kv_len, - (void*)&partition_kv, - (void*)&group_size_fastdiv, - (void*)&q_stride_n, - (void*)&q_stride_h, - (void*)&kv_stride_n, - (void*)&kv_stride_h, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&log2_rope_rcp_scale, - (void*)&log2_rope_rcp_theta}; - dim3 nblks(ceil_div(qo_len * group_size, num_rows_per_cta), num_chunks, num_kv_heads); - dim3 nthrs(32, num_warps_x, num_warps_z); - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - FLASHINFER_CUDA_CALL(MergeStates(tmp, tmp_lse, o, lse, num_chunks, qo_len, num_qo_heads, - HEAD_DIM, stream)); - } - } - }) - }); - return cudaSuccess; -} - -template -cudaError_t BatchPrefillWithRaggedKVCacheDispatched( - DTypeQ* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, - IdType* q_indptr, DTypeKV* k, DTypeKV* v, IdType* kv_indptr, uint8_t* custom_mask, - IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, IdType* o_indptr, DTypeOut* o, - DTypeOut* tmp_v, float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask, - IdType* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, - uint32_t padded_batch_size, uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h, - uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left, float logits_soft_cap, - float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream = nullptr) { - const float log2_rope_rcp_scale = -std::log2f(rope_scale); - const float log2_rope_rcp_theta = -std::log2f(rope_theta); +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream) { + using DTypeQ = typename AttentionVariant::DTypeQ; + using DTypeKV = typename AttentionVariant::DTypeKV; + const uint32_t padded_batch_size = params.padded_batch_size; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint_fastdiv group_size_fastdiv(num_qo_heads / num_kv_heads); + const uint32_t total_num_rows = params.total_num_rows; constexpr uint32_t num_frags_x = get_num_frags_x(); constexpr uint32_t num_warps_x = get_num_warps_x(); constexpr uint32_t num_warps_z = get_num_warps_z(); - const uint32_t group_size = num_qo_heads / num_kv_heads; - const uint_fastdiv group_size_fastdiv(group_size); if (padded_batch_size == 0) { // No request, skip @@ -2162,7 +2048,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; const uint32_t max_num_frags_z_reg = - (HEAD_DIM >= 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama && + (HEAD_DIM >= 128 && num_frags_x == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !ALLOW_FP16_QK_REDUCTION) ? 2 : (8 / num_frags_x); @@ -2172,7 +2058,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( (2 * num_warps_z); DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - if constexpr (is_invalid_configuration( + if constexpr (is_invalid_configuration( num_frags_x, num_frags_y, num_frags_z, num_warps_x, num_warps_z)) { // Invalid configuration, skip std::ostringstream err_msg; @@ -2187,106 +2073,52 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( uint32_t smem_size = (num_frags_x * num_warps_x * sizeof(DTypeQ) + num_frags_z * num_warps_z * 2 * sizeof(DTypeQ)) * 16 * HEAD_DIM; - auto kernel = BatchPrefillWithRaggedKVCacheKernel< - LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z, - num_warps_x, num_warps_z, DTypeQ, DTypeKV, DTypeQKAccum, DTypeOut, IdType>; + auto kernel = + BatchPrefillWithRaggedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); if (tmp_v == nullptr) { // do not partition kv - bool partition_kv = false; - - void* args[] = {(void*)&q, - (void*)&request_indices, - (void*)&q_tile_indices, - (void*)&kv_tile_indices, - (void*)&q_indptr, - (void*)&k, - (void*)&v, - (void*)&kv_indptr, - (void*)&custom_mask, - (void*)&qk_indptr, - (void*)&q_offset, - (void*)&k_rope_pos_offset, - (void*)&o_indptr, - (void*)&o, - (void*)&lse, - (void*)&block_valid_mask, - (void*)&kv_chunk_size_ptr, - (void*)&partition_kv, - (void*)&group_size_fastdiv, - (void*)&q_stride_n, - (void*)&q_stride_h, - (void*)&kv_stride_n, - (void*)&kv_stride_h, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&log2_rope_rcp_scale, - (void*)&log2_rope_rcp_theta}; + params.partition_kv = false; + void* args[] = {(void*)&group_size_fastdiv, (void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { // partition kv - bool partition_kv = true; - void* args[] = {(void*)&q, - (void*)&request_indices, - (void*)&q_tile_indices, - (void*)&kv_tile_indices, - (void*)&q_indptr, - (void*)&k, - (void*)&v, - (void*)&kv_indptr, - (void*)&custom_mask, - (void*)&qk_indptr, - (void*)&q_offset, - (void*)&k_rope_pos_offset, - (void*)&o_indptr, - (void*)&tmp_v, - (void*)&tmp_s, - (void*)&block_valid_mask, - (void*)&kv_chunk_size_ptr, - (void*)&partition_kv, - (void*)&group_size_fastdiv, - (void*)&q_stride_n, - (void*)&q_stride_h, - (void*)&kv_stride_n, - (void*)&kv_stride_h, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&log2_rope_rcp_scale, - (void*)&log2_rope_rcp_theta}; + params.partition_kv = true; + auto o = params.o; + auto lse = params.lse; + params.o = tmp_v; + params.lse = tmp_s; + void* args[] = {(void*)&group_size_fastdiv, (void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - FLASHINFER_CUDA_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, merge_indptr, o, lse, total_num_rows, num_qo_heads, HEAD_DIM, stream)); + FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse, + total_num_rows, num_qo_heads, HEAD_DIM, + stream)); } } }); return cudaSuccess; } -template -cudaError_t BatchPrefillWithPagedKVCacheDispatched( - DTypeQ* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, - IdType* q_indptr, IdType* q_offset, paged_kv_t paged_kv, - uint8_t* custom_mask, IdType* qk_indptr, IdType* o_indptr, DTypeOut* o, DTypeOut* tmp_v, - float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask, - IdType* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, - uint32_t padded_batch_size, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, cudaStream_t stream) { - const float log2_rope_rcp_scale = -std::log2f(rope_scale); - const float log2_rope_rcp_theta = -std::log2f(rope_theta); +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream) { + using DTypeQ = typename AttentionVariant::DTypeQ; + using DTypeKV = typename AttentionVariant::DTypeKV; + const uint32_t padded_batch_size = params.padded_batch_size; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.paged_kv.num_heads; + const uint_fastdiv group_size_fastdiv(num_qo_heads / num_kv_heads); + const uint32_t total_num_rows = params.total_num_rows; constexpr uint32_t num_frags_x = get_num_frags_x(); constexpr uint32_t num_warps_x = get_num_warps_x(); constexpr uint32_t num_warps_z = get_num_warps_z(); - const uint32_t num_kv_heads = paged_kv.num_heads; - const uint32_t group_size = num_qo_heads / num_kv_heads; - const uint_fastdiv group_size_fastdiv(group_size); if (padded_batch_size == 0) { // No request, skip @@ -2313,7 +2145,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; const uint32_t max_num_frags_z_reg = - (HEAD_DIM >= 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama && + (HEAD_DIM >= 128 && num_frags_x == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !ALLOW_FP16_QK_REDUCTION) ? 2 : (8 / num_frags_x); @@ -2323,7 +2155,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( (2 * num_warps_z); DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { - if constexpr (is_invalid_configuration( + if constexpr (is_invalid_configuration( num_frags_x, num_frags_y, num_frags_z, num_warps_x, num_warps_z)) { // Invalid configuration, skip std::ostringstream err_msg; @@ -2338,64 +2170,29 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( uint32_t smem_size = (num_frags_x * num_warps_x * sizeof(DTypeQ) + num_frags_z * num_warps_z * 2 * sizeof(DTypeQ)) * 16 * HEAD_DIM; - auto kernel = BatchPrefillWithPagedKVCacheKernel< - LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z, - num_warps_x, num_warps_z, page_storage, DTypeQ, DTypeKV, DTypeQKAccum, DTypeOut, IdType>; + auto kernel = BatchPrefillWithPagedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); if (tmp_v == nullptr) { // do not partition kv - bool partition_kv = false; - void* args[] = {(void*)&request_indices, - (void*)&q_tile_indices, - (void*)&kv_tile_indices, - (void*)&q, - (void*)&paged_kv, - (void*)&q_indptr, - (void*)&custom_mask, - (void*)&qk_indptr, - (void*)&q_offset, - (void*)&o_indptr, - (void*)&o, - (void*)&lse, - (void*)&block_valid_mask, - (void*)&kv_chunk_size_ptr, - (void*)&partition_kv, - (void*)&group_size_fastdiv, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&log2_rope_rcp_scale, - (void*)&log2_rope_rcp_theta}; + params.partition_kv = false; + void* args[] = {(void*)&group_size_fastdiv, (void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { - bool partition_kv = true; - void* args[] = {(void*)&request_indices, - (void*)&q_tile_indices, - (void*)&kv_tile_indices, - (void*)&q, - (void*)&paged_kv, - (void*)&q_indptr, - (void*)&custom_mask, - (void*)&qk_indptr, - (void*)&q_offset, - (void*)&o_indptr, - (void*)&tmp_v, - (void*)&tmp_s, - (void*)&block_valid_mask, - (void*)&kv_chunk_size_ptr, - (void*)&partition_kv, - (void*)&group_size_fastdiv, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&log2_rope_rcp_scale, - (void*)&log2_rope_rcp_theta}; + params.partition_kv = true; + auto o = params.o; + auto lse = params.lse; + params.o = tmp_v; + params.lse = tmp_s; + void* args[] = {(void*)&group_size_fastdiv, (void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - FLASHINFER_CUDA_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, merge_indptr, o, lse, total_num_rows, num_qo_heads, HEAD_DIM, stream)); + FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse, + total_num_rows, num_qo_heads, HEAD_DIM, + stream)); } } }); diff --git a/include/flashinfer/attention/prefill_params.cuh b/include/flashinfer/attention/prefill_params.cuh new file mode 100644 index 000000000..75a8f17bf --- /dev/null +++ b/include/flashinfer/attention/prefill_params.cuh @@ -0,0 +1,264 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_PREFILL_PARAMS_CUH_ +#define FLASHINFER_PREFILL_PARAMS_CUH_ + +#include + +#include +#include + +#include "../fastdiv.cuh" +#include "../layout.cuh" +#include "../page.cuh" + +namespace flashinfer { + +template +struct PrefillParamsBase { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + DTypeQ* q; + uint8_t* custom_mask; + DTypeO* o; + float* lse; + float sm_scale; +}; + +template +struct SinglePrefillParams : public PrefillParamsBase { + using IdType = int32_t; + DTypeKV* k; + DTypeKV* v; + float* alibi_slopes; + uint32_t qo_len; + uint32_t kv_len; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t kv_stride_n; + uint32_t kv_stride_h; + uint32_t head_dim; + int32_t window_left; + float logits_soft_cap; + float log2_rope_rcp_scale; + float log2_rope_rcp_theta; + + bool partition_kv; + + __host__ SinglePrefillParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* custom_mask, DTypeO* o, + float* lse, float* alibi_slopes, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, + uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, + uint32_t kv_stride_h, uint32_t head_dim, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta) + : PrefillParamsBase{q, custom_mask, o, lse, sm_scale}, + k(k), + v(v), + alibi_slopes(alibi_slopes), + num_qo_heads(num_qo_heads), + num_kv_heads(num_kv_heads), + // group_size_fastdiv(num_qo_heads / num_kv_heads), + qo_len(qo_len), + kv_len(kv_len), + q_stride_n(q_stride_n), + q_stride_h(q_stride_h), + kv_stride_n(kv_stride_n), + kv_stride_h(kv_stride_h), + head_dim(head_dim), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + log2_rope_rcp_scale(-std::log2f(rope_scale)), + log2_rope_rcp_theta(-std::log2f(rope_theta)), + partition_kv(false) {} + + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return qo_len; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_len; + } + + __host__ __device__ __forceinline__ uint8_t* get_batch_local_mask_ptr(uint32_t batch_idx) const { + return this->custom_mask; + } +}; + +template +struct BatchPrefillRaggedParams : public PrefillParamsBase { + using IdType = IdType_; + + DTypeKV* k; + DTypeKV* v; + IdType* q_indptr; + IdType* kv_indptr; + IdType* qk_indptr; + IdType* q_offset; // q_offset is only used for fused-rope attention + IdType* k_rope_pos_offset; // k_rope_pos_offset is only used for fused-rope attention + float* alibi_slopes; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t kv_stride_n; + uint32_t kv_stride_h; + int32_t window_left; + float logits_soft_cap; + float log2_rope_rcp_scale; + float log2_rope_rcp_theta; + + IdType* request_indices; + IdType* qo_tile_indices; + IdType* kv_tile_indices; + IdType* merge_indptr; + IdType* o_indptr; + IdType* kv_chunk_size_ptr; + bool* block_valid_mask; + uint32_t total_num_rows; + uint32_t padded_batch_size; + bool partition_kv; + + __host__ BatchPrefillRaggedParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* custom_mask, + IdType* q_indptr, IdType* kv_indptr, IdType* qk_indptr, + IdType* q_offset, IdType* k_rope_pos_offset, DTypeO* o, + float* lse, float* alibi_slopes, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h, + uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta) + : PrefillParamsBase{q, custom_mask, o, lse, sm_scale}, + k(k), + v(v), + q_indptr(q_indptr), + kv_indptr(kv_indptr), + qk_indptr(qk_indptr), + q_offset(q_offset), + k_rope_pos_offset(k_rope_pos_offset), + alibi_slopes(alibi_slopes), + num_qo_heads(num_qo_heads), + num_kv_heads(num_kv_heads), + q_stride_n(q_stride_n), + q_stride_h(q_stride_h), + kv_stride_n(kv_stride_n), + kv_stride_h(kv_stride_h), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + log2_rope_rcp_scale(-std::log2f(rope_scale)), + log2_rope_rcp_theta(-std::log2f(rope_theta)), + request_indices(nullptr), + qo_tile_indices(nullptr), + kv_tile_indices(nullptr), + merge_indptr(nullptr), + o_indptr(nullptr), + kv_chunk_size_ptr(nullptr), + block_valid_mask(nullptr), + total_num_rows(0), + padded_batch_size(0), + partition_kv(false) {} + + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_indptr[batch_idx + 1] - kv_indptr[batch_idx]; + } + + __host__ __device__ __forceinline__ uint8_t* get_batch_local_mask_ptr(uint32_t batch_idx) const { + return this->custom_mask + qk_indptr[batch_idx]; + } +}; + +template +struct BatchPrefillPagedParams : public PrefillParamsBase { + using IdType = IdType_; + + paged_kv_t paged_kv; + IdType* q_indptr; + IdType* qk_indptr; + IdType* q_offset; // q_offset is only used for fused-rope attention + float* alibi_slopes; + uint32_t num_qo_heads; + int32_t window_left; + float logits_soft_cap; + float log2_rope_rcp_scale; + float log2_rope_rcp_theta; + + IdType* request_indices; + IdType* qo_tile_indices; + IdType* kv_tile_indices; + IdType* merge_indptr; + IdType* o_indptr; + bool* block_valid_mask; + IdType* kv_chunk_size_ptr; + uint32_t total_num_rows; + uint32_t padded_batch_size; + bool partition_kv; + + __host__ BatchPrefillPagedParams(DTypeQ* q, paged_kv_t paged_kv, + uint8_t* custom_mask, IdType* q_indptr, IdType* qk_indptr, + IdType* q_offset, DTypeO* o, float* lse, float* alibi_slopes, + uint32_t num_qo_heads, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta) + : PrefillParamsBase{q, custom_mask, o, lse, sm_scale}, + paged_kv(paged_kv), + q_indptr(q_indptr), + qk_indptr(qk_indptr), + q_offset(q_offset), + alibi_slopes(alibi_slopes), + num_qo_heads(num_qo_heads), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + log2_rope_rcp_scale(-std::log2f(rope_scale)), + log2_rope_rcp_theta(-std::log2f(rope_theta)), + request_indices(nullptr), + qo_tile_indices(nullptr), + kv_tile_indices(nullptr), + merge_indptr(nullptr), + o_indptr(nullptr), + block_valid_mask(nullptr), + kv_chunk_size_ptr(nullptr), + total_num_rows(0), + padded_batch_size(0), + partition_kv(false) {} + + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return paged_kv.get_length(batch_idx); + } + + __host__ __device__ __forceinline__ uint8_t* get_batch_local_mask_ptr(uint32_t batch_idx) const { + return this->custom_mask + qk_indptr[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t get_mask_offset(uint32_t batch_idx, uint32_t qo_idx, + uint32_t kv_idx, + uint32_t kv_len) const { + return qk_indptr[batch_idx] * 8 + qo_idx * kv_len + kv_idx; + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_DECODE_PARAMS_CUH_ \ No newline at end of file diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh new file mode 100644 index 000000000..7b83b074c --- /dev/null +++ b/include/flashinfer/attention/scheduler.cuh @@ -0,0 +1,615 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_SCHEDULER_CUH_ +#define FLASHINFER_ATTENTION_SCHEDULER_CUH_ + +#include + +#include +#include +#include +#include +#include +#include + +#include "../allocator.h" +#include "../pos_enc.cuh" +#include "../utils.cuh" +#include "warp_layout.cuh" + +namespace flashinfer { + +template +__global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ + typename AttentionVariant::ParamsT params); + +/*! + * \brief Compute the maximum number of pages per batch and the new batch size + * after we partition Paged KV-Cache into multiple chunks on KV sequence length + * dimension. + * \tparam IdType A template type indicates the index data type + * \param max_grid_size The maximum grid size of the kernel + * \param num_kv_heads The number of KV heads + * \param num_pages The number of pages per request in the batch + * \param max_num_pages_per_batch_lb The pre-set lower bound of maximum number of + * pages per batch, default to 1 + * \return (max_num_pages_per_batch, new_batch_size) The number of pages per batch and + * the new batch size after the partition. + */ +template +std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( + const uint32_t max_grid_size, const uint32_t num_kv_heads, const std::vector& num_pages, + const uint32_t min_num_pages_per_batch = 1) { + uint32_t low = min_num_pages_per_batch, high = 0; + for (const IdType& elem : num_pages) { + high = max(high, elem); + } + uint32_t new_batch_size; + while (low < high) { + uint32_t mid = (low + high) / 2; + new_batch_size = 0; + for (const IdType& elem : num_pages) { + new_batch_size += ceil_div(elem, mid); + } + if (new_batch_size * num_kv_heads > max_grid_size) { + low = mid + 1; + } else { + high = mid; + } + } + new_batch_size = 0; + for (const IdType& elem : num_pages) { + new_batch_size += ceil_div(std::max(elem, 1), low); + } + return {low, new_batch_size}; +} + +inline std::tuple PrefillBinarySearchKVChunkSize( + const uint32_t max_grid_size, const uint32_t num_kv_heads, + const std::vector& packed_qo_len_arr, const std::vector& kv_len_arr, + const uint32_t qo_chunk_size, const uint32_t min_kv_chunk_size = 1) { + int64_t low = min_kv_chunk_size, high = 0; + int64_t batch_size = packed_qo_len_arr.size(); + int64_t max_kv_len = 0; + for (const int64_t& kv_len : kv_len_arr) { + max_kv_len = std::max(max_kv_len, kv_len); + } + high = max_kv_len; + int64_t new_batch_size; + while (low < high) { + int64_t mid = (low + high) / 2; + new_batch_size = 0; + for (uint32_t i = 0; i < batch_size; ++i) { + new_batch_size += + ceil_div(packed_qo_len_arr[i], qo_chunk_size) * ceil_div(kv_len_arr[i], mid); + } + if (new_batch_size * num_kv_heads > max_grid_size) { + low = mid + 1; + } else { + high = mid; + } + } + new_batch_size = 0; + for (uint32_t i = 0; i < batch_size; ++i) { + new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * + ceil_div(std::max(int(kv_len_arr[i]), 1), low); + } + return {low < max_kv_len, low, new_batch_size}; +} + +/*! + * \brief Estimate the temporary buffer size and the maximum grid size for the + * partition-kv BatchDecodeWithPagedKVCache kernel + * \tparam DTypeKV A template type indicates the key-value data type + * \tparam DTypeO A template type indicates the output data type + * \tparam IdType A template type indicates the index data type + * \param split_kv Whether to split the KV cache into multiple chunks + * \param max_grid_size The maximum grid size that can be used in a partiton-kv kernel + * \param max_num_pages_per_batch The maximum number of pages per batch + * \param new_batch_size The new batch size after the partition + * \param paged_kv The paged kv cache data structure + * \param num_qo_heads A integer indicates the number of heads of query and output + * \param pos_encoding_mode The positional encoding mode + * \param stream The cuda stream to launch the kernel + * \return status Indicates whether CUDA calls are successful + */ +template +cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( + bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, + uint32_t& new_batch_size, uint32_t batch_size, typename AttentionVariant::IdType* kv_indptr_h, + const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, + cudaStream_t stream) { + using DTypeKV = typename AttentionVariant::DTypeKV; + using IdType = typename AttentionVariant::IdType; + constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { + constexpr uint32_t bdx = HEAD_DIM / vec_size; + static_assert(bdx <= 32); + constexpr uint32_t bdy = GROUP_SIZE; + constexpr uint32_t num_threads = std::max(128U, bdx * bdy); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; + const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; + const uint32_t smem_size = + 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); + + auto kernel = + BatchDecodeWithPagedKVCacheKernel; + int num_blocks_per_sm = 0; + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); + max_grid_size = num_blocks_per_sm * num_sm; + if (batch_size * num_kv_heads >= max_grid_size) { + split_kv = false; + max_num_pages_per_batch = 1; + for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + max_num_pages_per_batch = std::max( + max_num_pages_per_batch, kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]); + } + new_batch_size = batch_size; + } else { + // compute max_num_pages_per_batch and new_batch_size + std::vector num_pages(batch_size); + for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; + } + std::tie(max_num_pages_per_batch, new_batch_size) = + PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( + max_grid_size, num_kv_heads, num_pages, std::max(128 / page_size, 1U)); + if (new_batch_size == batch_size && !enable_cuda_graph) { + // do not use partition-kv kernel for short sequence, when not using CUDAGraph + split_kv = false; + } else { + // when using CUDAGraph, we always use partition-kv kernel + split_kv = true; + } + } + return cudaSuccess; + }) +} + +/*! + * \brief Partition Paged KV-Cache into multiple chunks on KV sequence length + * \tparam IdType A template type indicates the index data type + * \param old_batch_size The batch size of the old Paged KV-Cache + * \param old_page_indptr_h The host-side page indptr of the old Paged KV-Cache + * \param max_num_pages_per_batch The maximum number of pages per batch + * \param new_paged_kv_d The device-side new Paged KV-Cache + * \param stream The cuda stream to launch the kernel + * \return status Indicates whether CUDA calls are successful + */ +template +std::tuple, std::vector, std::vector> DecodeSplitKVIndptr( + IdType* indptr_h, uint32_t batch_size, uint32_t kv_chunk_size) { + std::vector request_indices, kv_tile_indices, o_indptr; + o_indptr.push_back(0); + + for (uint32_t batch_idx = 0; batch_idx < batch_size; batch_idx++) { + uint32_t num_tiles_kv = ceil_div( + std::max(indptr_h[batch_idx + 1] - indptr_h[batch_idx], 1U), kv_chunk_size); + for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) { + request_indices.push_back(batch_idx); + kv_tile_indices.push_back(kv_tile_idx); + } + o_indptr.push_back(o_indptr.back() + num_tiles_kv); + } + + return {request_indices, kv_tile_indices, o_indptr}; +} + +struct DecodePlanInfo { + int64_t padded_batch_size; + int64_t v_offset; + int64_t s_offset; + int64_t request_indices_offset; + int64_t kv_tile_indices_offset; + int64_t o_indptr_offset; + int64_t block_valid_mask_offset; + int64_t kv_chunk_size_ptr_offset; + bool enable_cuda_graph; + bool split_kv; + + DecodePlanInfo() + : padded_batch_size(0), + v_offset(0), + s_offset(0), + request_indices_offset(0), + kv_tile_indices_offset(0), + o_indptr_offset(0), + block_valid_mask_offset(0), + kv_chunk_size_ptr_offset(0), + enable_cuda_graph(false), + split_kv(false) {} + + // convert DecodePlanInfo to std::vector + std::vector ToVector() const { + return {padded_batch_size, + v_offset, + s_offset, + request_indices_offset, + kv_tile_indices_offset, + o_indptr_offset, + block_valid_mask_offset, + kv_chunk_size_ptr_offset, + enable_cuda_graph, + split_kv}; + } + + // From std::vector to DecodePlanInfo + void FromVector(const std::vector& vec) { + if (vec.size() != 10) { + std::ostringstream err_msg; + err_msg << "DecodePlanInfo::FromVector: vec.size() should be 10, but got " << vec.size(); + throw std::invalid_argument(err_msg.str()); + } + padded_batch_size = vec[0]; + v_offset = vec[1]; + s_offset = vec[2]; + request_indices_offset = vec[3]; + kv_tile_indices_offset = vec[4]; + o_indptr_offset = vec[5]; + block_valid_mask_offset = vec[6]; + kv_chunk_size_ptr_offset = vec[7]; + enable_cuda_graph = vec[8]; + split_kv = vec[9]; + } +}; + +template +cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, + void* page_locked_int_buffer, size_t int_workspace_size_in_bytes, + DecodePlanInfo& plan_info, typename AttentionVariant::IdType* indptr_h, + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { + using DTypeO = typename AttentionVariant::DTypeO; + using IdType = typename AttentionVariant::IdType; + bool split_kv; + uint32_t max_grid_size, kv_chunk_size_in_pages, new_batch_size; + DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { + auto work_estimation_func = + BatchDecodeWithPagedKVCacheWorkEstimationDispatched; + FLASHINFER_CUDA_CALL(work_estimation_func(split_kv, max_grid_size, kv_chunk_size_in_pages, + new_batch_size, batch_size, indptr_h, num_qo_heads, + page_size, enable_cuda_graph, stream)); + size_t padded_batch_size; + plan_info.enable_cuda_graph = enable_cuda_graph; + plan_info.split_kv = split_kv; + padded_batch_size = (enable_cuda_graph) ? (split_kv ? max_grid_size / num_kv_heads : batch_size) + : new_batch_size; + plan_info.padded_batch_size = padded_batch_size; + + auto [request_indices_vec, kv_tile_indices_vec, o_indptr_vec] = + DecodeSplitKVIndptr(indptr_h, batch_size, kv_chunk_size_in_pages); + + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + plan_info.request_indices_offset = int_allocator.aligned_alloc_offset( + padded_batch_size * sizeof(IdType), 16, "batch_decode_request_indices"); + plan_info.kv_tile_indices_offset = int_allocator.aligned_alloc_offset( + padded_batch_size * sizeof(IdType), 16, "batch_decode_kv_tile_indices"); + plan_info.o_indptr_offset = int_allocator.aligned_alloc_offset( + (padded_batch_size + 1) * sizeof(IdType), 16, "batch_decode_o_indptr"); + plan_info.kv_chunk_size_ptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "batch_decode_kv_chunk_size_ptr"); + IdType* request_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.request_indices_offset); + IdType* kv_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_tile_indices_offset); + IdType* o_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.o_indptr_offset); + IdType* kv_chunk_size_ptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset); + std::copy(request_indices_vec.begin(), request_indices_vec.end(), request_indices_h); + std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), kv_tile_indices_h); + std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), o_indptr_h); + kv_chunk_size_ptr_h[0] = kv_chunk_size_in_pages * page_size; + + if (split_kv) { + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + plan_info.v_offset = float_allocator.aligned_alloc_offset( + num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeO), 16, "batch_decode_tmp_v"); + plan_info.s_offset = float_allocator.aligned_alloc_offset( + num_qo_heads * padded_batch_size * sizeof(float), 16, "batch_decode_tmp_s"); + + plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset( + padded_batch_size * sizeof(bool), 16, "batch_decode_block_valid_mask"); + bool* block_valid_mask_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.block_valid_mask_offset); + for (uint32_t i = 0; i < padded_batch_size; ++i) { + block_valid_mask_h[i] = i < new_batch_size; + } + } + + size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); + + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, + cudaMemcpyHostToDevice, stream)); + }); + return cudaSuccess; +} + +template +std::tuple, + std::vector, std::vector, std::vector, std::vector> +PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size, uint32_t max_grid_size, uint32_t max_batch_size_if_split, + bool enable_cuda_graph) { + std::vector request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr; + merge_indptr.push_back(0); + o_indptr.push_back(0); + + const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; + uint32_t total_num_rows = qo_indptr_h[batch_size]; + + // step 1: compute qo_chunk_size + std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); + int64_t sum_packed_qo_len = 0; + for (uint32_t i = 0; i < batch_size; ++i) { + packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size); + kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]); + sum_packed_qo_len += packed_qo_len_arr[i]; + } + int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; + WarpLayout warp_layout; + if (avg_packed_qo_len > 64 && head_dim < 256) { + warp_layout = WarpLayout::k4x1x2; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 2) + } else { + auto compute_capacity = GetCudaComputeCapability(); + if (compute_capacity.first >= 8) { + // Ampere or newer + if (avg_packed_qo_len > 16) { + warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1) + } else { + // avg_packed_qo_len <= 16 + warp_layout = WarpLayout::k1x4x1; // (num_warps_x = 1, num_warps_z = 4, num_frags_x = 1) + } + } else { + // NOTE(Zihao): not enough shared memory on Turing for 1x4x1 layout + warp_layout = WarpLayout::k4x1x1; + } + } + const uint32_t qo_chunk_size = get_num_rows_per_cta(warp_layout); + + // step 2: determine kv_chunk_size + auto [split_kv, kv_chunk_size, new_batch_size] = PrefillBinarySearchKVChunkSize( + max_grid_size, num_kv_heads, packed_qo_len_arr, kv_len_arr, qo_chunk_size, + /*min_kv_chunk_size=*/std::max((128 / page_size), 1U)); + + // step 3: split qo_indptr and kv_indptr + uint32_t total_num_tiles_q = 0; + for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { + int64_t packed_qo_len = packed_qo_len_arr[request_idx], + kv_len = std::max(int(kv_len_arr[request_idx]), 1); + int64_t num_tiles_q = ceil_div(packed_qo_len, qo_chunk_size), + num_tiles_kv = ceil_div(kv_len, kv_chunk_size); + total_num_tiles_q += num_tiles_q; + for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) { + for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) { + request_indices.push_back(request_idx); + qo_tile_indices.push_back(q_tile_idx); + kv_tile_indices.push_back(kv_tile_idx); + } + } + + int64_t qo_len = packed_qo_len / gqa_group_size; + for (uint32_t row = 0; row < qo_len; ++row) { + merge_indptr.push_back(merge_indptr.back() + num_tiles_kv); + } + o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv); + } + + // step 4: reset split_kv if enable_cuda_graph is true + if (enable_cuda_graph) { + split_kv = total_num_tiles_q < max_batch_size_if_split; + } + + // step 5: multiply kv_chunk_size by page_size + kv_chunk_size *= page_size; + + return {split_kv, + total_num_tiles_q, + new_batch_size, + warp_layout, + kv_chunk_size, + total_num_rows, + std::move(request_indices), + std::move(qo_tile_indices), + std::move(kv_tile_indices), + std::move(merge_indptr), + std::move(o_indptr)}; +} + +struct PrefillPlanInfo { + int64_t padded_batch_size; + int64_t total_num_rows; + int64_t warp_layout_code; + int64_t request_indices_offset; + int64_t qo_tile_indices_offset; + int64_t kv_tile_indices_offset; + int64_t merge_indptr_offset; + int64_t o_indptr_offset; + int64_t kv_chunk_size_ptr_offset; + int64_t v_offset; + int64_t s_offset; + int64_t block_valid_mask_offset; + bool enable_cuda_graph; + bool split_kv; + + PrefillPlanInfo() + : padded_batch_size(0), + total_num_rows(0), + warp_layout_code(0), + request_indices_offset(0), + qo_tile_indices_offset(0), + kv_tile_indices_offset(0), + merge_indptr_offset(0), + o_indptr_offset(0), + kv_chunk_size_ptr_offset(0), + v_offset(0), + s_offset(0), + block_valid_mask_offset(0), + enable_cuda_graph(false), + split_kv(false) {} + + // convert PrefillPlanInfo to std::vector + std::vector ToVector() const { + return {padded_batch_size, + total_num_rows, + warp_layout_code, + request_indices_offset, + qo_tile_indices_offset, + kv_tile_indices_offset, + merge_indptr_offset, + o_indptr_offset, + kv_chunk_size_ptr_offset, + v_offset, + s_offset, + block_valid_mask_offset, + enable_cuda_graph, + split_kv}; + } + + // From std::vector to PrefillPlanInfo + void FromVector(const std::vector& vec) { + if (vec.size() != 14) { + std::ostringstream err_msg; + err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 14, but got " << vec.size(); + throw std::invalid_argument(err_msg.str()); + } + padded_batch_size = vec[0]; + total_num_rows = vec[1]; + warp_layout_code = vec[2]; + request_indices_offset = vec[3]; + qo_tile_indices_offset = vec[4]; + kv_tile_indices_offset = vec[5]; + merge_indptr_offset = vec[6]; + o_indptr_offset = vec[7]; + kv_chunk_size_ptr_offset = vec[8]; + v_offset = vec[9]; + s_offset = vec[10]; + block_valid_mask_offset = vec[11]; + enable_cuda_graph = vec[12]; + split_kv = vec[13]; + } +}; + +template +cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, + void* page_locked_int_buffer, size_t int_workspace_size_in_bytes, + PrefillPlanInfo& plan_info, IdType* qo_indptr_h, IdType* kv_indptr_h, + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, uint32_t page_size, bool enable_cuda_graph, + uint32_t sizeof_dtype_o, cudaStream_t stream) { + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " + << num_kv_heads; + throw std::invalid_argument(err_msg.str()); + } + + // step 0: get the number of SMs + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + int num_blocks_per_sm = 2; + int max_grid_size = num_blocks_per_sm * num_sm; + uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; + + // step 2: determine kv_chunk_size + auto [split_kv, total_num_tiles_q, new_batch_size, warp_layout, kv_chunk_size, total_num_rows, + request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, + o_indptr_vec] = PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, batch_size, num_qo_heads, + num_kv_heads, head_dim, page_size, max_grid_size, + max_batch_size_if_split, enable_cuda_graph); + const uint32_t qo_tile_size = get_num_rows_per_cta(warp_layout); + plan_info.warp_layout_code = static_cast(warp_layout); + plan_info.total_num_rows = total_num_rows; + + plan_info.enable_cuda_graph = enable_cuda_graph; + size_t padded_batch_size = + enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size; + plan_info.padded_batch_size = padded_batch_size; + plan_info.split_kv = split_kv; + + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + plan_info.request_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "batch_prefill_request_indices"); + plan_info.qo_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "batch_prefill_qo_tile_indices"); + plan_info.kv_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "batch_prefill_kv_tile_indices"); + plan_info.o_indptr_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * (batch_size + 1), + 16, "batch_prefill_o_indptr"); + plan_info.kv_chunk_size_ptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); + + IdType* request_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.request_indices_offset); + IdType* qo_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_tile_indices_offset); + IdType* kv_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_tile_indices_offset); + IdType* o_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.o_indptr_offset); + IdType* kv_chunk_size_ptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset); + std::copy(request_indices_vec.begin(), request_indices_vec.end(), request_indices_h); + std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), qo_tile_indices_h); + std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), kv_tile_indices_h); + std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), o_indptr_h); + kv_chunk_size_ptr_h[0] = kv_chunk_size; + + if (split_kv) { + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + plan_info.v_offset = float_allocator.aligned_alloc_offset( + num_qo_heads * padded_batch_size * qo_tile_size * head_dim * sizeof_dtype_o, 16, + "batch_prefill_tmp_v"); + plan_info.s_offset = float_allocator.aligned_alloc_offset( + num_qo_heads * padded_batch_size * qo_tile_size * sizeof(float), 16, "batch_prefill_tmp_s"); + plan_info.merge_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * (plan_info.total_num_rows + 1), 16, "batch_prefill_merge_indptr"); + plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset( + sizeof(bool) * padded_batch_size, 16, "batch_prefill_block_valid_mask"); + IdType* merge_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.merge_indptr_offset); + bool* block_valid_mask_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.block_valid_mask_offset); + std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), merge_indptr_h); + for (uint32_t i = 0; i < padded_batch_size; ++i) { + block_valid_mask_h[i] = i < new_batch_size; + } + } + + size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, + cudaMemcpyHostToDevice, stream)); + + return cudaSuccess; +} + +} // namespace flashinfer +#endif // FLASHINFER_ATTENTION_SCHEDULER_CUH_ diff --git a/include/flashinfer/attention/state.cuh b/include/flashinfer/attention/state.cuh index b748ee68c..107898101 100644 --- a/include/flashinfer/attention/state.cuh +++ b/include/flashinfer/attention/state.cuh @@ -36,7 +36,7 @@ struct state_t { __device__ __forceinline__ void init() { o.fill(0.f); - m = -5e4; + m = -math::inf; d = 1.f; } diff --git a/include/flashinfer/attention/variants.cuh b/include/flashinfer/attention/variants.cuh new file mode 100644 index 000000000..52ba8eda2 --- /dev/null +++ b/include/flashinfer/attention/variants.cuh @@ -0,0 +1,270 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_VARIANTS_CUH_ +#define FLASHINFER_ATTENTION_VARIANTS_CUH_ +#include + +#include +#include + +#include "../math.cuh" + +namespace flashinfer { + +// Query Transform function that multiplies the query matrix by sm_scale +template +struct StandardAttention { + using ParamsT = ParamsT_; + using DTypeQ = typename ParamsT::DTypeQ; + using DTypeKV = typename ParamsT::DTypeKV; + using DTypeO = typename ParamsT::DTypeO; + using IdType = typename ParamsT::IdType; + + // Create closure + __device__ __host__ StandardAttention(const ParamsT& params, uint32_t batch_idx, + uint8_t* smem_ptr) {} + + template + __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) { + return float(q) * params.sm_scale * math::log2e; + } + + template + __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, + uint32_t qo_head_idx, uint32_t kv_head_idx) { + return logits; + } + + __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, + uint32_t kv_head_idx) { + return true; + } +}; + +template +struct CustomMaskAttention { + using ParamsT = ParamsT_; + using DTypeQ = typename ParamsT::DTypeQ; + using DTypeKV = typename ParamsT::DTypeKV; + using DTypeO = typename ParamsT::DTypeO; + + uint8_t* custom_mask_ptr; + uint32_t qo_len, kv_len; + + // Create closure + __device__ __host__ CustomMaskAttention(const ParamsT& params, uint32_t batch_idx, + uint8_t* smem_ptr) { + custom_mask_ptr = params.get_batch_local_mask_ptr(batch_idx); + qo_len = params.get_qo_len(batch_idx); + kv_len = params.get_kv_len(batch_idx); + } + + template + __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) { + return float(q) * params.sm_scale * math::log2e; + } + + template + __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, + uint32_t qo_head_idx, uint32_t kv_head_idx) { + return logits; + } + + __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, + uint32_t kv_head_idx) { + const uint32_t offset = qo_idx * kv_len + kv_idx; + return ((custom_mask_ptr[offset / 8] >> (offset % 8)) & 1); + } +}; + +template +struct SlidingWindowAttention { + using ParamsT = ParamsT_; + using DTypeQ = typename ParamsT::DTypeQ; + using DTypeKV = typename ParamsT::DTypeKV; + using DTypeO = typename ParamsT::DTypeO; + using IdType = typename ParamsT::IdType; + + uint32_t window_left, qo_len, kv_len; + + // Create closure + __device__ __host__ __forceinline__ SlidingWindowAttention(const ParamsT& params, + uint32_t batch_idx, + uint8_t* smem_ptr) { + qo_len = params.get_qo_len(batch_idx); + kv_len = params.get_kv_len(batch_idx); + window_left = (params.window_left >= 0) ? params.window_left : kv_len; + } + + template + __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) { + return float(q) * params.sm_scale * math::log2e; + } + + template + __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, + uint32_t qo_head_idx, uint32_t kv_head_idx) { + return logits; + } + + __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, + uint32_t kv_head_idx) { + return (kv_idx + qo_len + window_left >= kv_len + qo_idx); + } +}; + +template +struct LogitsSoftCap { + using DTypeQ = typename ParamsT::DTypeQ; + using DTypeKV = typename ParamsT::DTypeKV; + using DTypeO = typename ParamsT::DTypeO; + + __device__ __host__ LogitsSoftCap(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) {} + + template + __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) { + return float(q) * params.sm_scale * math::ptx_rcp(params.logits_soft_cap); + } + + template + __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, + uint32_t qo_head_idx, uint32_t kv_head_idx) { + return params.logits_soft_cap * math::log2e * float(math::tanh(logits)); + } + + __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, + uint32_t kv_head_idx) { + return true; + } +}; + +template +struct ALIBIAttention { + using DTypeQ = typename ParamsT::DTypeQ; + using DTypeKV = typename ParamsT::DTypeKV; + using DTypeO = typename ParamsT::DTypeO; + using IdType = typename ParamsT::IdType; + + __device__ __host__ ALIBIAttention(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) { + } + + template + __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) { + return float(q) * params.sm_scale * math::log2e; + } + + template + __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, + uint32_t qo_head_idx, uint32_t kv_head_idx) { + return logits + params.alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx)); + } + + __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, + uint32_t kv_head_idx) { + return true; + } +}; + +constexpr uint32_t CUSTOM_MASK = 1U; +constexpr uint32_t SLIDING_WINDOW = 2U; +constexpr uint32_t LOGITS_SOFT_CAP = 4U; +constexpr uint32_t ALIBI = 8U; + +constexpr uint32_t get_variant_code(bool use_custom_mask, bool use_sliding_window, + bool use_logits_soft_cap, bool use_alibi) { + return (use_custom_mask ? CUSTOM_MASK : 0U) | (use_sliding_window ? SLIDING_WINDOW : 0U) | + (use_logits_soft_cap ? LOGITS_SOFT_CAP : 0U) | (use_alibi ? ALIBI : 0U); +} + +template +struct ComposedAttention { + using ParamsT = ParamsT_; + using DTypeQ = typename ParamsT::DTypeQ; + using DTypeKV = typename ParamsT::DTypeKV; + using DTypeO = typename ParamsT::DTypeO; + using IdType = typename ParamsT::IdType; + static constexpr bool use_custom_mask = (VARIANT_CODE & CUSTOM_MASK) != 0; + static constexpr bool use_sliding_window = (VARIANT_CODE & SLIDING_WINDOW) != 0; + static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0; + static constexpr bool use_alibi = (VARIANT_CODE & ALIBI) != 0; + + uint32_t qo_len, kv_len; + uint8_t* custom_mask_ptr; + uint32_t window_left; + + // Create closure + __device__ __host__ ComposedAttention(const ParamsT& params, uint32_t batch_idx, + uint8_t* smem_ptr) { + qo_len = params.get_qo_len(batch_idx); + kv_len = params.get_kv_len(batch_idx); + if constexpr (use_custom_mask) { + custom_mask_ptr = params.get_batch_local_mask_ptr(batch_idx); + } + if constexpr (use_sliding_window) { + window_left = (params.window_left >= 0) ? params.window_left : kv_len; + } + } + + template + __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) { + if constexpr (use_logits_soft_cap) { + return float(q) * params.sm_scale * math::ptx_rcp(params.logits_soft_cap); + } else { + return float(q) * params.sm_scale * math::log2e; + } + } + + template + __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, + uint32_t qo_head_idx, uint32_t kv_head_idx) { + if constexpr (use_alibi) { + logits = logits + params.alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx)); + } + if constexpr (use_logits_soft_cap) { + logits = params.logits_soft_cap * math::log2e * float(math::tanh(logits)); + } + return logits; + } + + __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx, + uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, + uint32_t kv_head_idx) { + bool mask = true; + if constexpr (use_custom_mask) { + const uint32_t offset = qo_idx * kv_len + kv_idx; + mask &= ((custom_mask_ptr[offset / 8] >> (offset % 8)) & 1); + } + if constexpr (use_sliding_window) { + mask &= (kv_idx + qo_len + window_left >= kv_len + qo_idx); + } + return mask; + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_VARIANTS_CUH_ \ No newline at end of file diff --git a/include/flashinfer/attention_impl.cuh b/include/flashinfer/attention_impl.cuh index b375bcd01..db5c545ec 100644 --- a/include/flashinfer/attention_impl.cuh +++ b/include/flashinfer/attention_impl.cuh @@ -18,6 +18,9 @@ #include "attention/cascade.cuh" #include "attention/decode.cuh" +#include "attention/decode_params.cuh" #include "attention/prefill.cuh" +#include "attention/prefill_params.cuh" +#include "attention/variants.cuh" #endif // FLASHINFER_ATTENTION_IMPL_CUH_ diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh deleted file mode 100644 index 6f9ccf6f6..000000000 --- a/include/flashinfer/decode_attention_decl.cuh +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright (c) 2024 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef FLASHINFER_DECODE_ATTENTION_DECL_CUH_ -#define FLASHINFER_DECODE_ATTENTION_DECL_CUH_ - -#include - -#include "attention/handler.cuh" -#include "attention/logits_post_hook.cuh" -#include "layout.cuh" -#include "page.cuh" -#include "pos_enc.cuh" -#include "utils.cuh" - -namespace flashinfer { - -template -cudaError_t SingleDecodeWithKVCacheDispatched( - DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t seq_len, QKVLayout kv_layout, int32_t window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); - -template -cudaError_t BatchDecodeWithPagedKVCacheDispatched( - DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, - kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, - float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads, - int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream); - -template -cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( - BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, - uint32_t num_qo_heads, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, cudaStream_t stream) { - paged_kv_t new_paged_kv = paged_kv; - kv_partition_info_t kv_partition_info; - DTypeOut* tmp_v = handler->GetTempV(); - float* tmp_s = handler->GetTempS(); - - if (tmp_v != nullptr) { - // create auxiliary information for cooperative kernels - new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition(); - new_paged_kv.indptr = handler->GetNewIndPtr(); - new_paged_kv.last_page_len = handler->GetNewLastPageLen(); - kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition(); - kv_partition_info.chunk_indptr = handler->GetChunkIndPtr(); - kv_partition_info.batch_idx_map = handler->GetBatchIdxMap(); - kv_partition_info.chunk_start_pos = handler->GetChunkStartPos(); - kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition(); - } - - return BatchDecodeWithPagedKVCacheDispatched( - q, q_offset, new_paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse, - handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), num_qo_heads, window_left, - logits_soft_cap, sm_scale, rope_scale, rope_theta, stream); -} - -} // namespace flashinfer - -#endif // FLASHINFER_DECODE_ATTENTION_DECL_CUH_ diff --git a/include/flashinfer/bmm_fp8.cuh b/include/flashinfer/gemm/bmm_fp8.cuh similarity index 96% rename from include/flashinfer/bmm_fp8.cuh rename to include/flashinfer/gemm/bmm_fp8.cuh index 98d61baec..90be4da23 100644 --- a/include/flashinfer/bmm_fp8.cuh +++ b/include/flashinfer/gemm/bmm_fp8.cuh @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef FLASHINFER_BMM_FP8_CUH_ -#define FLASHINFER_BMM_FP8_CUH_ +#ifndef FLASHINFER_GEMM_BMM_FP8_CUH_ +#define FLASHINFER_GEMM_BMM_FP8_CUH_ +// NOTE(Zihao): we should leave pytorch related includes outside of the header files. #include #include #include @@ -170,6 +171,7 @@ void bmm_fp8_internal_cublaslt(const AT* A, const BT* B, DT* D, int batch_size, TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, at::cuda::blas::_cublasGetErrorEnum(status)); } +// NOTE(Zihao): templates should not be initialized in the header files! template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>( const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n, int k, const float* A_scale, const float* B_scale); @@ -197,4 +199,4 @@ template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, half>( } // namespace bmm_fp8 } // namespace flashinfer -#endif // FLASHINFER_BMM_FP8_CUH_ +#endif // FLASHINFER_GEMM_BMM_FP8_CUH_ diff --git a/include/flashinfer/group_gemm/wrapper.cuh b/include/flashinfer/gemm/group_gemm.cuh similarity index 88% rename from include/flashinfer/group_gemm/wrapper.cuh rename to include/flashinfer/gemm/group_gemm.cuh index 03f0b3eb5..968662f97 100644 --- a/include/flashinfer/group_gemm/wrapper.cuh +++ b/include/flashinfer/gemm/group_gemm.cuh @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef FLASHINFER_GROUP_GEMM_WRAPPER_CUH_ -#define FLASHINFER_GROUP_GEMM_WRAPPER_CUH_ +#ifndef FLASHINFER_GEMM_GROUP_GEMM_CUH_ +#define FLASHINFER_GEMM_GROUP_GEMM_CUH_ #include #include "../allocator.h" -#include "handler.cuh" +#include "group_gemm_cutlass.cuh" namespace flashinfer { @@ -35,12 +35,12 @@ namespace group_gemm { } template -cudaError_t CutlassSegmentGEMMWrapper(CutlassSegmentGEMMHandler* handler, DType* x, DType* w, - DType* y, int64_t* xy_indptr_d, int64_t* w_indices_d, - unsigned int batch_size, unsigned int d_in, - unsigned int d_out, bool weight_column_major, - cudaStream_t stream) { - AlignedAllocator allocator(handler->GetWorkspace(), handler->GetWorkspaceSizeInBytes()); +cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffer_size_in_bytes, + DType* x, DType* w, DType* y, int64_t* xy_indptr_d, + int64_t* w_indices_d, unsigned int batch_size, unsigned int d_in, + unsigned int d_out, bool weight_column_major, + cudaStream_t stream) { + AlignedAllocator allocator(workspace_buffer, workspace_buffer_size_in_bytes); cutlass::gemm::GemmCoord* problem_sizes_device = allocator.aligned_alloc( batch_size * sizeof(cutlass::gemm::GemmCoord), 16, "problem_sizes_device"); @@ -116,4 +116,4 @@ cudaError_t CutlassSegmentGEMMWrapper(CutlassSegmentGEMMHandler* handler, DType* } // namespace flashinfer -#endif // FLASHINFER_GROUP_GEMM_WRAPPER_CUH_ +#endif // FLASHINFER_GEMM_GROUP_GEMM_CUH_ diff --git a/include/flashinfer/group_gemm/group_gemm_cutlass.cuh b/include/flashinfer/gemm/group_gemm_cutlass.cuh similarity index 100% rename from include/flashinfer/group_gemm/group_gemm_cutlass.cuh rename to include/flashinfer/gemm/group_gemm_cutlass.cuh diff --git a/include/flashinfer/group_gemm/group_gemm_lora.cuh b/include/flashinfer/gemm/group_gemm_lora.cuh similarity index 100% rename from include/flashinfer/group_gemm/group_gemm_lora.cuh rename to include/flashinfer/gemm/group_gemm_lora.cuh diff --git a/include/flashinfer/group_gemm/group_gemv.cuh b/include/flashinfer/gemm/group_gemv.cuh similarity index 100% rename from include/flashinfer/group_gemm/group_gemv.cuh rename to include/flashinfer/gemm/group_gemv.cuh diff --git a/include/flashinfer/group_gemm/handler.cuh b/include/flashinfer/group_gemm/handler.cuh deleted file mode 100644 index 39ef0f783..000000000 --- a/include/flashinfer/group_gemm/handler.cuh +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright (c) 2024 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef FLASHINFER_GROUP_GEMM_HANDLER_CUH_ -#define FLASHINFER_GROUP_GEMM_HANDLER_CUH_ - -#include - -#include "../allocator.h" -#include "../utils.cuh" -#include "group_gemm_cutlass.cuh" -#include "group_gemm_lora.cuh" -#include "group_gemv.cuh" - -namespace flashinfer { - -namespace group_gemm { - -enum class GroupGEMMKernelConfig { - kGeneral, // large d_in, d_out - kShrink, // large d_in, small d_out - kExpand, // small d_in, large d_out -}; - -class CutlassSegmentGEMMHandler { - public: - void RegisterWorkspace(void* buffer, size_t size) { - buffer_ = buffer; - workspace_size_in_bytes_ = size; - } - - void* GetWorkspace() const { return buffer_; } - - size_t GetWorkspaceSizeInBytes() const { return workspace_size_in_bytes_; } - - cudaStream_t GetCUDAStream() const { return stream_; } - - void SetCUDAStream(cudaStream_t stream) { stream_ = stream; } - - CutlassSegmentGEMMHandler() {} - - ~CutlassSegmentGEMMHandler() {} - - private: - void* buffer_; - size_t workspace_size_in_bytes_; - cudaStream_t stream_; -}; - -} // namespace group_gemm - -} // namespace flashinfer - -#endif // FLASHINFER_GROUP_GEMM_HANDLER_CUH_ diff --git a/include/flashinfer/math.cuh b/include/flashinfer/math.cuh index c2401c7e1..c752a49fa 100644 --- a/include/flashinfer/math.cuh +++ b/include/flashinfer/math.cuh @@ -25,6 +25,8 @@ namespace math { // log2(e) constexpr float log2e = 1.44269504088896340736f; +constexpr float inf = 5e4; + __forceinline__ __device__ half2 uint32_as_half2(uint32_t x) { return *(half2*)&x; } __forceinline__ __device__ uint32_t half2_as_uint32(half2 x) { return *(uint32_t*)&x; } diff --git a/include/flashinfer/page.cuh b/include/flashinfer/page.cuh index d79a5ff00..18c1af20b 100644 --- a/include/flashinfer/page.cuh +++ b/include/flashinfer/page.cuh @@ -25,50 +25,13 @@ namespace flashinfer { -enum class PageStorage { - kIndices = 0U, // Store the pointer to the buffer allocated for paged kv-cache, and indices of - // each active offset. - kPointer = 1U, // Store the pointers to each active page. -}; - -/*! - * \brief The auxiliary information about kv sequence partitioning - */ -template -struct kv_partition_info_t { - uint32_t batch_size_before_partition; - IdType* chunk_indptr; - IdType* batch_idx_map; - IdType* chunk_start_pos; - IdType* seq_lens_before_partition; - - __host__ __device__ __forceinline__ kv_partition_info_t(uint32_t batch_size_before_partition, - IdType* chunk_indptr, - IdType* batch_idx_map, - IdType* chunk_start_pos, - IdType* seq_lens_before_partition) - : batch_size_before_partition(batch_size_before_partition), - chunk_indptr(chunk_indptr), - batch_idx_map(batch_idx_map), - chunk_start_pos(chunk_start_pos), - seq_lens_before_partition(seq_lens_before_partition) {} - - __host__ __device__ __forceinline__ kv_partition_info_t() - : batch_size_before_partition(0), - chunk_indptr(nullptr), - batch_idx_map(nullptr), - chunk_start_pos(nullptr), - seq_lens_before_partition(nullptr) {} -}; - /*! * \brief Paged key-value cache - * \tparam page_storage Whether to store indices or pointers of each active page * \tparam layout The layout of last 3 dimensions in KV-Cache. * \tparam DType The data type of the key-value cache * \tparam IdType The index data type of the kv-cache */ -template +template struct paged_kv_t { uint_fastdiv page_size; uint32_t num_heads; @@ -78,16 +41,12 @@ struct paged_kv_t { uint32_t stride_n; uint32_t stride_h; - // The flattened key-value cache, used when page_storage == kIndices // Internal layout: // [max_num_pages, num_heads, page_size, head_dim] if layout == HND // [max_num_pages, page_size, num_heads, head_dim] if layout == NHD DType* k_data; DType* v_data; - // [nnz_pages] The page indices array, used when page_storage == kIndices IdType* indices; - // [nnz_pages] The page pointers array, used when page_storage == kPointer - DType** kv_ptrs; // [batch_size + 1] The page indptr array, with the first element 0, the last element nnz_pages IdType* indptr; @@ -110,7 +69,6 @@ struct paged_kv_t { k_data(nullptr), v_data(nullptr), indices(nullptr), - kv_ptrs(nullptr), indptr(nullptr), last_page_len(nullptr), rope_pos_offset(nullptr) {} @@ -129,7 +87,6 @@ struct paged_kv_t { * \param indptr The page indptr array * \param last_page_len The offset of the last page for each request in the batch * \param rope_pos_offset The start position of each request in the batch. - * \note This constructor should only be used when page_storage == kIndices */ __host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, uint32_t batch_size, QKVLayout layout, DType* kv_data, @@ -170,7 +127,6 @@ struct paged_kv_t { * \param indptr The page indptr array * \param last_page_len The offset of the last page for each request in the batch * \param rope_pos_offset The start position of each request in the batch. - * \note This constructor should only be used when page_storage == kIndices */ __host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, uint32_t batch_size, QKVLayout layout, DType* k_data, @@ -203,7 +159,6 @@ struct paged_kv_t { * \param indptr The page indptr array * \param last_page_len The offset of the last page for each request in the batch * \param rope_pos_offset The start position of each request in the batch. - * \note This constructor should only be used when page_storage == kIndices */ __host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, uint32_t batch_size, QKVLayout layout, DType* kv_data, @@ -224,40 +179,15 @@ struct paged_kv_t { stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim; } - /*! - * \brief Construct a paged key-value cache - * \param num_heads The number of heads - * \param page_size The size of each page - * \param head_dim The dimension of each head - * \param batch_size The batch size - * \param layout The layout of last 3 dimensions in KV-Cache. - * \param kv_ptrs The array of pointers to each active kv page - * \param indptr The page indptr array - * \param last_page_len The offset of the last page for each request in the batch - * \param rope_pos_offset The start position of each request in the batch. - * \note This constructor should only be used when page_storage == kIndices - */ - __host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, - uint32_t batch_size, QKVLayout layout, DType** kv_ptrs, - IdType* indptr, IdType* last_page_len, - IdType* rope_pos_offset = nullptr) - : num_heads(num_heads), - page_size(page_size), - head_dim(head_dim), - batch_size(batch_size), - kv_ptrs(kv_ptrs), - indptr(indptr), - last_page_len(last_page_len), - rope_pos_offset(rope_pos_offset) { - stride_page = 2 * num_heads * page_size * head_dim; - stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim; - stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim; + __host__ __device__ __forceinline__ int64_t kv_ptr_delta() const { + return (int64_t(v_data) - int64_t(k_data)) / sizeof(DType); } - __host__ __device__ __forceinline__ int64_t kv_ptr_delta() const { - return page_storage == PageStorage::kPointer - ? num_heads * page_size * head_dim - : (int64_t(v_data) - int64_t(k_data)) / sizeof(DType); + __host__ __device__ __forceinline__ uint32_t get_length(uint32_t batch_idx) const { + if (indptr[batch_idx + 1] == indptr[batch_idx]) { + return 0; + } + return (indptr[batch_idx + 1] - indptr[batch_idx] - 1) * page_size + last_page_len[batch_idx]; } /*! @@ -266,7 +196,6 @@ struct paged_kv_t { * \param head_idx The head index * \param entry_idx The page entry index * \param feat_idx The feature index - * \note This function should only be used when page_storage == kIndices */ __host__ __device__ __forceinline__ size_t get_elem_offset(size_t page_idx, size_t head_idx, size_t entry_idx, @@ -288,57 +217,31 @@ struct paged_kv_t { __device__ __forceinline__ DType* get_k_ptr(IdType page_iter, uint32_t head_idx, uint32_t entry_idx, uint32_t feat_idx) const { - if constexpr (page_storage == PageStorage::kIndices) { - return k_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); - } else { - return kv_ptrs[page_iter] + get_elem_offset_in_page(head_idx, entry_idx, feat_idx); - } + return k_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); } __device__ __forceinline__ DType* protective_get_k_ptr(IdType page_iter, uint32_t head_idx, uint32_t entry_idx, uint32_t feat_idx, IdType last_indptr) const { - if constexpr (page_storage == PageStorage::kIndices) { - if (page_iter < last_indptr) { - return k_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); - } else { - return k_data; - } + if (page_iter < last_indptr) { + return k_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); } else { - if (page_iter < last_indptr) { - return kv_ptrs[page_iter] + get_elem_offset_in_page(head_idx, entry_idx, feat_idx); - } else { - return *kv_ptrs; - } + return k_data; } } __device__ __forceinline__ DType* get_v_ptr(IdType page_iter, uint32_t head_idx, uint32_t entry_idx, uint32_t feat_idx) const { - if constexpr (page_storage == PageStorage::kIndices) { - return v_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); - } else { - return (kv_ptrs[page_iter] + kv_ptr_delta()) + - get_elem_offset_in_page(head_idx, entry_idx, feat_idx); - } + return v_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); } __device__ __forceinline__ DType* protective_get_v_ptr(IdType page_iter, uint32_t head_idx, uint32_t entry_idx, uint32_t feat_idx, IdType last_indptr) const { - if constexpr (page_storage == PageStorage::kIndices) { - if (page_iter < last_indptr) { - return v_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); - } else { - return v_data; - } + if (page_iter < last_indptr) { + return v_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); } else { - if (page_iter < last_indptr) { - return (kv_ptrs[page_iter] + kv_ptr_delta()) + - get_elem_offset_in_page(head_idx, entry_idx, feat_idx); - } else { - return *kv_ptrs; - } + return v_data; } } }; @@ -347,16 +250,14 @@ struct paged_kv_t { * \brief CUDA kernel to append new keys/values to the paged key-value cache in the decode phase * \tparam head_dim The dimension of each head * \tparam vec_size The vector size used in the kernel - * \tparam page_storage Whether to store indices or pointers of each active page * \tparam DType The data type of the key-value cache * \tparam IdType The index data type of the kv-cache * \param paged_kv The paged key-value cache * \param key The key to be appended * \param value The value to be appended */ -template -__global__ void AppendPagedKVCacheDecodeKernel(paged_kv_t paged_kv, +template +__global__ void AppendPagedKVCacheDecodeKernel(paged_kv_t paged_kv, DType* __restrict__ key, DType* __restrict__ value) { uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t num_heads = paged_kv.num_heads; @@ -383,7 +284,6 @@ __global__ void AppendPagedKVCacheDecodeKernel(paged_kv_t -__global__ void AppendPagedKVCachePrefillKernel(paged_kv_t paged_kv, +template +__global__ void AppendPagedKVCachePrefillKernel(paged_kv_t paged_kv, DType* __restrict__ key, DType* __restrict__ value, IdType* __restrict__ append_indptr) { uint32_t tx = threadIdx.x, ty = threadIdx.y; @@ -427,7 +326,6 @@ __global__ void AppendPagedKVCachePrefillKernel(paged_kv_t -cudaError_t AppendPagedKVCacheDecode(paged_kv_t paged_kv, DType* key, - DType* value, cudaStream_t stream = nullptr) { +template +cudaError_t AppendPagedKVCacheDecode(paged_kv_t paged_kv, DType* key, DType* value, + cudaStream_t stream = nullptr) { uint32_t head_dim = paged_kv.head_dim; uint32_t batch_size = paged_kv.batch_size; uint32_t num_heads = paged_kv.num_heads; @@ -449,7 +347,7 @@ cudaError_t AppendPagedKVCacheDecode(paged_kv_t pag // NOTE(Zihao): could be slow for small batch size, will optimize later dim3 nblks(batch_size); dim3 nthrs(bdx, bdy); - auto kernel = AppendPagedKVCacheDecodeKernel; + auto kernel = AppendPagedKVCacheDecodeKernel; void* args[] = {(void*)&paged_kv, (void*)&key, (void*)&value}; FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); }); @@ -458,7 +356,6 @@ cudaError_t AppendPagedKVCacheDecode(paged_kv_t pag /*! * \brief Append new keys/values to the paged key-value cache - * \tparam page_storage Whether to store indices or pointers of each active page * \tparam layout The layout of last 3 dimension in KV-Cache * \tparam DType The data type of the key-value cache * \tparam IdType The index data type of the kv-cache @@ -469,9 +366,9 @@ cudaError_t AppendPagedKVCacheDecode(paged_kv_t pag * \param stream The CUDA stream to execute kernels. * \return status Indicates whether CUDA calls are successful */ -template -cudaError_t AppendPagedKVCache(paged_kv_t paged_kv, DType* key, - DType* value, IdType* append_indptr, cudaStream_t stream = nullptr) { +template +cudaError_t AppendPagedKVCache(paged_kv_t paged_kv, DType* key, DType* value, + IdType* append_indptr, cudaStream_t stream = nullptr) { uint32_t head_dim = paged_kv.head_dim; uint32_t batch_size = paged_kv.batch_size; uint32_t num_heads = paged_kv.num_heads; @@ -482,7 +379,7 @@ cudaError_t AppendPagedKVCache(paged_kv_t paged_kv, // NOTE(Zihao): could be slow for small batch size, will optimize later dim3 nblks(batch_size); dim3 nthrs(bdx, bdy); - auto kernel = AppendPagedKVCachePrefillKernel; + auto kernel = AppendPagedKVCachePrefillKernel; void* args[] = {(void*)&paged_kv, (void*)&key, (void*)&value, (void*)&append_indptr}; FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); }); diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 15b4a8d94..d6f96e4c1 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -265,7 +265,7 @@ __global__ void BatchQKApplyRotaryKernel(DType* __restrict__ q, DType* __restric } else { k_vec = vec_apply_llama_rope(k_ptr, freq, offset + i * bdy + ty); } - k_vec.cast_store(k_rope_ptr + +tx * vec_size); + k_vec.cast_store(k_rope_ptr + tx * vec_size); } } } diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh deleted file mode 100644 index 46b152097..000000000 --- a/include/flashinfer/prefill_attention_decl.cuh +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Copyright (c) 2024 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef FLASHINFER_PREFILL_ATTENTION_DECL_CUH_ -#define FLASHINFER_PREFILL_ATTENTION_DECL_CUH_ - -#include - -#include "attention/handler.cuh" -#include "attention/logits_post_hook.cuh" -#include "attention/mask.cuh" -#include "flashinfer/attention/warp_layout.cuh" -#include "layout.cuh" -#include "page.cuh" -#include "pos_enc.cuh" -#include "utils.cuh" - -namespace flashinfer { - -template -cudaError_t SinglePrefillWithKVCacheDispatched( - DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp, float* lse, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, - uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, uint32_t kv_stride_h, - int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream); - -template -cudaError_t BatchPrefillWithRaggedKVCacheDispatched( - DTypeQ* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, - IdType* q_indptr, DTypeKV* k, DTypeKV* v, IdType* kv_indptr, uint8_t* custom_mask, - IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, IdType* o_indptr, DTypeOut* o, - DTypeOut* tmp_v, float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask, - IdType* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, - uint32_t padded_batch_size, uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h, - uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left, float logits_soft_cap, - float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream = nullptr); - -template -cudaError_t BatchPrefillWithPagedKVCacheDispatched( - DTypeQ* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices, - IdType* q_indptr, IdType* q_offset, paged_kv_t paged_kv, - uint8_t* custom_mask, IdType* qk_indptr, IdType* o_indptr, DTypeOut* o, DTypeOut* tmp_v, - float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask, - IdType* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, - uint32_t padded_batch_size, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, cudaStream_t stream); - -template -cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( - BatchPrefillHandler* handler, DTypeQ* q, IdType* q_indptr, IdType* q_offset, - paged_kv_t paged_kv, uint8_t* custom_mask, IdType* qk_indptr, - DTypeOut* o, float* lse, uint32_t num_qo_heads, int32_t window_left, float logits_soft_cap, - float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { - DTypeOut* tmp_v = nullptr; - float* tmp_s = nullptr; - IdType *request_indices = nullptr, *qo_tile_indices = nullptr, *kv_tile_indices = nullptr, - *o_indptr = nullptr, *merge_indptr = nullptr, *kv_chunk_size_ptr = nullptr; - bool* block_valid_mask = nullptr; - WarpLayout warp_layout; - uint32_t padded_batch_size = 0U; - uint32_t total_num_rows = 0U; - tmp_v = handler->GetTempV(); - tmp_s = handler->GetTempS(); - request_indices = handler->GetRequestIndices(); - qo_tile_indices = handler->GetQOTileIndices(); - kv_tile_indices = handler->GetKVTileIndices(); - block_valid_mask = handler->GetBlockValidMask(); - o_indptr = handler->GetOIndptr(); - merge_indptr = handler->GetMergeIndptr(); - kv_chunk_size_ptr = handler->GetKVChunkSizePtr(); - warp_layout = handler->GetWarpLayout(); - padded_batch_size = handler->GetPaddedBatchSize(); - total_num_rows = handler->GetTotalNumRows(); - - DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { - return BatchPrefillWithPagedKVCacheDispatched< - PAGE_STORAGE, WARP_LAYOUT, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( - q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, q_offset, paged_kv, - custom_mask, qk_indptr, o_indptr, o, tmp_v, tmp_s, lse, merge_indptr, block_valid_mask, - kv_chunk_size_ptr, total_num_rows, num_qo_heads, padded_batch_size, window_left, - logits_soft_cap, sm_scale, rope_scale, rope_theta, stream); - }); - return cudaSuccess; -} - -template -cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( - BatchPrefillHandler* handler, DTypeQ* q, IdType* q_indptr, DTypeKV* k, DTypeKV* v, - IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset, - IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, - uint32_t kv_stride_h, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, cudaStream_t stream) { - DTypeOut* tmp_v = nullptr; - float* tmp_s = nullptr; - IdType *request_indices = nullptr, *qo_tile_indices = nullptr, *kv_tile_indices = nullptr, - *o_indptr = nullptr, *merge_indptr = nullptr, *kv_chunk_size_ptr = nullptr; - bool* block_valid_mask = nullptr; - WarpLayout warp_layout; - uint32_t padded_batch_size = 0U; - uint32_t total_num_rows = 0U; - tmp_v = handler->GetTempV(); - tmp_s = handler->GetTempS(); - request_indices = handler->GetRequestIndices(); - qo_tile_indices = handler->GetQOTileIndices(); - kv_tile_indices = handler->GetKVTileIndices(); - block_valid_mask = handler->GetBlockValidMask(); - o_indptr = handler->GetOIndptr(); - merge_indptr = handler->GetMergeIndptr(); - kv_chunk_size_ptr = handler->GetKVChunkSizePtr(); - warp_layout = handler->GetWarpLayout(); - padded_batch_size = handler->GetPaddedBatchSize(); - total_num_rows = handler->GetTotalNumRows(); - - DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { - return BatchPrefillWithRaggedKVCacheDispatched( - q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, k, v, kv_indptr, - custom_mask, qk_indptr, q_offset, k_rope_pos_offset, o_indptr, o, tmp_v, tmp_s, lse, - merge_indptr, block_valid_mask, kv_chunk_size_ptr, total_num_rows, num_qo_heads, - padded_batch_size, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, - window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, stream); - }); - return cudaSuccess; -} - -} // namespace flashinfer - -#endif // FLASHINFER_PREFILL_ATTENTION_DECL_CUH_ diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 4df2a006b..d7b9b02bd 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -19,6 +19,7 @@ #include #include #include +#include #include #include "math.cuh" @@ -939,7 +940,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; - float pivot = -std::numeric_limits::infinity(); + float pivot = -cuda::std::numeric_limits::infinity(); vec_t logits_vec; if (k < d) { extern __shared__ __align__(alignof(RenormTempStorage)) @@ -948,8 +949,8 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType reinterpret_cast&>(smem_renorm); DType logits_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0 - DType threadlocal_max_val = DType(-std::numeric_limits::infinity()), - threadlocal_min_val = DType(std::numeric_limits::infinity()); + DType threadlocal_max_val = DType(-cuda::std::numeric_limits::infinity()), + threadlocal_min_val = DType(cuda::std::numeric_limits::infinity()); for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { logits_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -1047,8 +1048,9 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - logits_vec[j] = - (logits_vec[j] > pivot) ? logits_vec[j] : DType(-std::numeric_limits::infinity()); + logits_vec[j] = (logits_vec[j] > pivot) + ? logits_vec[j] + : DType(-cuda::std::numeric_limits::infinity()); } if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { logits_vec.store(masked_logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); @@ -1063,7 +1065,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; - float pivot = -std::numeric_limits::infinity(), normalizer = 1; + float pivot = -cuda::std::numeric_limits::infinity(), normalizer = 1; vec_t probs_vec; if (k < d) { extern __shared__ __align__(alignof(RenormTempStorage)) diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 856b53255..46a925a32 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -138,19 +138,6 @@ } \ } -#define DISPATCH_LOGITS_POST_HOOK(logits_soft_cap, LOGITS_POST_HOOK, ...) \ - if (logits_soft_cap > 0.f) { \ - constexpr LogitsPostHook LOGITS_POST_HOOK = LogitsPostHook::kSoftCap; \ - __VA_ARGS__ \ - } else if (logits_soft_cap == 0.f) { \ - constexpr LogitsPostHook LOGITS_POST_HOOK = LogitsPostHook::kNone; \ - __VA_ARGS__ \ - } else { \ - std::ostringstream err_msg; \ - err_msg << "Invalid logits_soft_cap (should be >= 0): " << logits_soft_cap; \ - throw std::invalid_argument(err_msg.str()); \ - } - #define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ switch (head_dim) { \ case 64: { \ diff --git a/python/MANIFEST.in b/python/MANIFEST.in index d7ad61771..b20747fef 100644 --- a/python/MANIFEST.in +++ b/python/MANIFEST.in @@ -1,19 +1,11 @@ # sdist & wheel include version.txt -include generate_batch_paged_decode_inst.py -include generate_batch_paged_prefill_inst.py -include generate_batch_ragged_prefill_inst.py -include generate_dispatch_inc.py -include generate_single_decode_inst.py -include generate_single_prefill_inst.py -include literal_map.py recursive-include include * recursive-include csrc * recursive-include 3rdparty/cutlass * # wheel-only exclude flashinfer/_build_meta.py -exclude tests/ # Unneeded files prune */__pycache__ diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu deleted file mode 100644 index 88bbf5ab9..000000000 --- a/python/csrc/batch_decode.cu +++ /dev/null @@ -1,277 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "flashinfer_ops_decode.h" -#include "pytorch_extension_utils.h" - -using namespace flashinfer; - -void BatchDecodeWithPagedKVCachePyTorchWrapper::Plan( - torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor indptr, - torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, - unsigned int pos_encoding_mode, float logits_soft_cap, torch::Tensor empty_q_data, - torch::Tensor empty_kv_data) { - CHECK_INPUT(float_workspace_buffer); - CHECK_INPUT(int_workspace_buffer); - // NOTE(zihao): not necessary to be CUDA tensor - CHECK_CONTIGUOUS(indptr); - CHECK_CONTIGUOUS(last_page_len); - CHECK_DIM(1, indptr); - CHECK_DIM(1, last_page_len); - CHECK_DIM(1, float_workspace_buffer); - CHECK_DIM(1, int_workspace_buffer); - CHECK_EQ(indptr.scalar_type(), torch::kInt32); - CHECK_EQ(indptr.scalar_type(), torch::kInt32); - CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - size_t float_workspace_size_in_bytes = - float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); - size_t int_workspace_size_in_bytes = - int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); - auto device = float_workspace_buffer.device(); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - handler_->SetCUDAStream(torch_current_stream); - indptr = indptr.to(torch::kCPU); - last_page_len = last_page_len.to(torch::kCPU); - - TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); - const LogitsPostHook logits_post_hook = - logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - - auto q_scalar_type = empty_q_data.scalar_type(); - auto kv_scalar_type = empty_kv_data.scalar_type(); - - if (q_scalar_type == kv_scalar_type) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q_scalar_type, qkv_type, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - handler_ - ->PlanDispatched( - static_cast(float_workspace_buffer.data_ptr()), - float_workspace_size_in_bytes, - static_cast(int_workspace_buffer.data_ptr()), - int_workspace_size_in_bytes, static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, - num_qo_heads, num_kv_heads, page_size); - TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); - }); - }); - }); - } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(kv_scalar_type, kv_type, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - handler_->PlanDispatched( - static_cast(float_workspace_buffer.data_ptr()), - float_workspace_size_in_bytes, - static_cast(int_workspace_buffer.data_ptr()), - int_workspace_size_in_bytes, static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, - num_kv_heads, page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); - }); - }); - }); - }); - } -} - -void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( - unsigned int int_workspace_size_in_bytes) { - handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes); -} - -std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Run( - torch::Tensor q, std::optional paged_kv_cache, - std::optional paged_k_cache, std::optional paged_v_cache, - torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, unsigned int pos_encoding_mode, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { - CHECK_INPUT(q); - bool paged_kv_defined = paged_kv_cache.has_value(); - if (paged_kv_defined) { - CHECK_INPUT(paged_kv_cache.value()); - } else { - CHECK_INPUT(paged_k_cache.value()); - CHECK_INPUT(paged_v_cache.value()); - } - CHECK_INPUT(paged_kv_indptr); - CHECK_INPUT(paged_kv_indices); - CHECK_INPUT(paged_kv_last_page_len); - auto device = q.device(); - if (paged_kv_defined) { - CHECK_EQ(paged_kv_cache->device(), device); - } else { - CHECK_EQ(paged_k_cache->device(), device); - CHECK_EQ(paged_v_cache->device(), device); - } - CHECK_EQ(paged_kv_indices.device(), device); - CHECK_EQ(paged_kv_indptr.device(), device); - CHECK_EQ(paged_kv_last_page_len.device(), device); - CHECK_DIM(3, q); // (B, H_qo, D) - CHECK_DIM(1, paged_kv_last_page_len); // (B,) - CHECK_DIM(1, paged_kv_indptr); // (B+1,) - CHECK_DIM(1, paged_kv_indices); // (nnz,) - if (paged_kv_defined) { - // (num_max_pages, 2, H_kv, page_size, head_dim) for HND - // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD - CHECK_DIM(5, paged_kv_cache.value()); - } else { - // (num_max_pages, H_kv, page_size, head_dim) for HND - // (num_max_pages, page_size, H_kv, head_dim) for NHD - CHECK_DIM(4, paged_k_cache.value()); - CHECK_DIM(4, paged_v_cache.value()); - } - int64_t batch_size = q.size(0); - int64_t num_qo_heads = q.size(1); - int64_t head_dim = q.size(2); - int64_t num_kv_heads, page_size; - if (paged_kv_defined) { - CHECK_EQ(paged_kv_cache->size(1), 2); - CHECK_EQ(paged_kv_cache->size(4), head_dim); - if (kv_layout_ == QKVLayout::kHND) { - num_kv_heads = paged_kv_cache->size(2); - page_size = paged_kv_cache->size(3); - } else { - page_size = paged_kv_cache->size(2); - num_kv_heads = paged_kv_cache->size(3); - } - } else { - CHECK_EQ(paged_k_cache->size(3), head_dim); - CHECK_EQ(paged_v_cache->size(3), head_dim); - if (kv_layout_ == QKVLayout::kHND) { - num_kv_heads = paged_k_cache->size(1); - page_size = paged_k_cache->size(2); - } else { - page_size = paged_k_cache->size(1); - num_kv_heads = paged_k_cache->size(2); - } - } - CHECK_GE(paged_kv_indptr.size(0), batch_size + 1); - CHECK_GE(paged_kv_last_page_len.size(0), batch_size); - // TODO(Zihao): support dispatching to different data types - CHECK_EQ(paged_kv_indptr.scalar_type(), torch::kInt32); - CHECK_EQ(paged_kv_indices.scalar_type(), torch::kInt32); - CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32); - CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - torch::Tensor o = torch::empty_like(q); - torch::Tensor lse; - if (return_lse) { - lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32))); - } - - TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); - const LogitsPostHook logits_post_hook = - logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - - auto q_scalar_type = q.scalar_type(); - auto kv_scalar_type = - paged_kv_defined ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); - - if (q_scalar_type == kv_scalar_type) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q_scalar_type, qkv_type, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, kv_layout_, - static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() - : nullptr), - static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() - : nullptr), - static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() - : nullptr), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, qkv_type, - qkv_type, qkv_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); - }); - }); - }); - } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(kv_scalar_type, kv_type, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, kv_layout_, - static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() - : nullptr), - static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() - : nullptr), - static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() - : nullptr), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, q_type, - kv_type, q_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); - }); - }); - }); - }); - } - - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } -} diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu deleted file mode 100644 index 776c7c636..000000000 --- a/python/csrc/batch_prefill.cu +++ /dev/null @@ -1,714 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "flashinfer_ops_prefill.h" -#include "pytorch_extension_utils.h" - -using namespace flashinfer; - -void BatchPrefillWithPagedKVCachePyTorchWrapper::Plan( - torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, - unsigned int page_size, torch::Tensor empty_q_data) { - CHECK_INPUT(float_workspace_buffer); - CHECK_INPUT(int_workspace_buffer); - // NOTE(Zihao): not necessary to be a CUDA tensor - CHECK_CONTIGUOUS(qo_indptr); - CHECK_CONTIGUOUS(paged_kv_indptr); - CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - CHECK_DIM(1, qo_indptr); - CHECK_DIM(1, paged_kv_indptr); - CHECK_DIM(1, float_workspace_buffer); - CHECK_DIM(1, int_workspace_buffer); - CHECK_EQ(qo_indptr.size(0), batch_size + 1); - CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1); - qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); - paged_kv_indptr = paged_kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); - auto device = float_workspace_buffer.device(); - size_t float_workspace_size_in_bytes = - float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); - size_t int_workspace_size_in_bytes = - int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - handler_->SetCUDAStream(torch_current_stream); - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] { - cudaError_t status = handler_->Plan( - static_cast(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, - static_cast(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes, - static_cast(qo_indptr.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size); - TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); -} - -void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( - unsigned int int_workspace_size_in_bytes) { - handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes); -} - -std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Run( - torch::Tensor q, torch::Tensor qo_indptr, std::optional paged_kv_cache, - std::optional paged_k_cache, std::optional paged_v_cache, - torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, bool causal, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse) { - bool paged_kv_defined = paged_kv_cache.has_value(); - CHECK_INPUT(q); - CHECK_INPUT(qo_indptr); - if (paged_kv_defined) { - CHECK_INPUT(paged_kv_cache.value()); - } else { - CHECK_INPUT(paged_k_cache.value()); - CHECK_INPUT(paged_v_cache.value()); - } - CHECK_INPUT(paged_kv_indptr); - CHECK_INPUT(paged_kv_indices); - CHECK_INPUT(paged_kv_last_page_len); - auto device = q.device(); - CHECK_EQ(device, qo_indptr.device()); - if (paged_kv_defined) { - CHECK_EQ(device, paged_kv_cache->device()); - } else { - CHECK_EQ(device, paged_k_cache->device()); - CHECK_EQ(device, paged_v_cache->device()); - } - CHECK_EQ(device, paged_kv_indptr.device()); - CHECK_EQ(device, paged_kv_indices.device()); - CHECK_EQ(device, paged_kv_last_page_len.device()); - CHECK_DIM(3, q); // (nnz_qo, H_qo, D) - CHECK_DIM(1, qo_indptr); // (B + 1,) - - if (paged_kv_defined) { - // [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND - // [max_num_pages, 2, page_size, num_kv_heads, head_dim] for HND - CHECK_DIM(5, paged_kv_cache.value()); - } else { - // [max_num_pages, num_kv_heads, page_size, head_dim] for HND - // [max_num_pages, page_size, num_kv_heads, head_dim] for HND - CHECK_DIM(4, paged_k_cache.value()); - CHECK_DIM(4, paged_v_cache.value()); - } - - CHECK_DIM(1, paged_kv_indptr); // (B + 1,) - CHECK_DIM(1, paged_kv_indices); // (nnz_kv,) - CHECK_DIM(1, paged_kv_last_page_len); // (B,) - int64_t batch_size = qo_indptr.size(0) - 1; - int64_t nnz_qo = q.size(0); - int64_t num_qo_heads = q.size(1); - int64_t head_dim = q.size(2); - int64_t num_kv_heads, page_size; - - if (paged_kv_defined) { - CHECK_EQ(paged_kv_cache->size(1), 2); - CHECK_EQ(paged_kv_cache->size(4), head_dim); - if (kv_layout_ == QKVLayout::kHND) { - num_kv_heads = paged_kv_cache->size(2); - page_size = paged_kv_cache->size(3); - } else { - page_size = paged_kv_cache->size(2); - num_kv_heads = paged_kv_cache->size(3); - } - } else { - CHECK_EQ(paged_k_cache->size(3), head_dim); - CHECK_EQ(paged_v_cache->size(3), head_dim); - if (kv_layout_ == QKVLayout::kHND) { - num_kv_heads = paged_k_cache->size(1); - page_size = paged_k_cache->size(2); - } else { - page_size = paged_k_cache->size(1); - num_kv_heads = paged_k_cache->size(2); - } - } - CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - CHECK_GE(qo_indptr.size(0), batch_size + 1); - CHECK_GE(paged_kv_indptr.size(0), batch_size + 1); - CHECK_GE(paged_kv_last_page_len.size(0), batch_size); - qo_indptr = qo_indptr.to(torch::kInt32); - paged_kv_indptr = paged_kv_indptr.to(torch::kInt32); - paged_kv_indices = paged_kv_indices.to(torch::kInt32); - paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32); - - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - torch::Tensor o = torch::empty_like(q, q.options()); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); - } - MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; - TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); - const LogitsPostHook logits_post_hook = - logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - - auto q_scalar_type = q.scalar_type(); - auto kv_scalar_type = - paged_kv_defined ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); - - if (q_scalar_type == kv_scalar_type) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, c_type, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, kv_layout_, - static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() : nullptr), - static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), - static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - /*q_offset=*/nullptr, paged_kv, - /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, - rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithPagedKVCache failed with error code ", - cudaGetErrorString(status)); - return true; - }); - }); - }); - }); - }); - }); - } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(kv_scalar_type, kv_type, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, kv_layout_, - static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() - : nullptr), - static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() - : nullptr), - static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() - : nullptr), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, q_type, kv_type, q_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - /*q_offset=*/nullptr, paged_kv, - /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, - rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithPagedKVCache failed with error code ", - cudaGetErrorString(status)); - return true; - }); - }); - }); - }); - }); - }); - }); - } - - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } -} - -std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::RunCustomMask( - torch::Tensor q, torch::Tensor qo_indptr, std::optional paged_kv_cache, - std::optional paged_k_cache, std::optional paged_v_cache, - torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, torch::Tensor custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { - bool paged_kv_defined = paged_kv_cache.has_value(); - CHECK_INPUT(q); - CHECK_INPUT(qo_indptr); - if (paged_kv_defined) { - CHECK_INPUT(paged_kv_cache.value()); - } else { - CHECK_INPUT(paged_k_cache.value()); - CHECK_INPUT(paged_v_cache.value()); - } - CHECK_INPUT(paged_kv_indptr); - CHECK_INPUT(paged_kv_indices); - CHECK_INPUT(paged_kv_last_page_len); - CHECK_INPUT(custom_mask); - CHECK_INPUT(qk_indptr); - auto device = q.device(); - CHECK_EQ(device, qo_indptr.device()); - if (paged_kv_defined) { - CHECK_EQ(device, paged_kv_cache->device()); - } else { - CHECK_EQ(device, paged_k_cache->device()); - CHECK_EQ(device, paged_v_cache->device()); - } - CHECK_EQ(device, paged_kv_indptr.device()); - CHECK_EQ(device, paged_kv_indices.device()); - CHECK_EQ(device, paged_kv_last_page_len.device()); - CHECK_EQ(device, custom_mask.device()); - CHECK_EQ(device, qk_indptr.device()); - CHECK_DIM(3, q); // (nnz_qo, H_qo, D) - CHECK_DIM(1, qo_indptr); // (B + 1,) - - if (paged_kv_defined) { - // [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND - // [max_num_pages, 2, page_size, num_kv_heads, head_dim] for NHD - CHECK_DIM(5, paged_kv_cache.value()); - } else { - // [max_num_pages, num_kv_heads, page_size, head_dim] for HND - // [max_num_pages, page_size, num_kv_heads, head_dim] for NHD - CHECK_DIM(4, paged_k_cache.value()); - CHECK_DIM(4, paged_v_cache.value()); - } - CHECK_DIM(1, paged_kv_indptr); // (B + 1,) - CHECK_DIM(1, paged_kv_indices); // (nnz_kv,) - CHECK_DIM(1, paged_kv_last_page_len); // (B,) - CHECK_DIM(1, custom_mask); // (nnz_qk,) - CHECK_DIM(1, qk_indptr); // (B + 1,) - int64_t batch_size = qo_indptr.size(0) - 1; - int64_t nnz_qo = q.size(0); - int64_t num_qo_heads = q.size(1); - int64_t head_dim = q.size(2); - int64_t num_kv_heads, page_size; - - if (paged_kv_defined) { - CHECK_EQ(paged_kv_cache->size(1), 2); - CHECK_EQ(paged_kv_cache->size(4), head_dim); - if (kv_layout_ == QKVLayout::kHND) { - num_kv_heads = paged_kv_cache->size(2); - page_size = paged_kv_cache->size(3); - } else { - page_size = paged_kv_cache->size(2); - num_kv_heads = paged_kv_cache->size(3); - } - } else { - CHECK_EQ(paged_k_cache->size(3), head_dim); - CHECK_EQ(paged_v_cache->size(3), head_dim); - if (kv_layout_ == QKVLayout::kHND) { - num_kv_heads = paged_k_cache->size(1); - page_size = paged_k_cache->size(2); - } else { - page_size = paged_k_cache->size(1); - num_kv_heads = paged_k_cache->size(2); - } - } - CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - CHECK_GE(qo_indptr.size(0), batch_size + 1); - CHECK_GE(paged_kv_indptr.size(0), batch_size + 1); - CHECK_GE(paged_kv_last_page_len.size(0), batch_size); - CHECK_GE(qk_indptr.size(0), batch_size + 1); - qo_indptr = qo_indptr.to(torch::kInt32); - paged_kv_indptr = paged_kv_indptr.to(torch::kInt32); - paged_kv_indices = paged_kv_indices.to(torch::kInt32); - paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32); - qk_indptr = qk_indptr.to(torch::kInt32); - - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - torch::Tensor o = torch::empty_like(q, q.options()); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); - } - constexpr MaskMode MASK_MODE = MaskMode::kCustom; - TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); - const LogitsPostHook logits_post_hook = - logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - - auto q_scalar_type = q.scalar_type(); - auto kv_scalar_type = - paged_kv_defined ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); - - if (q_scalar_type == kv_scalar_type) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, c_type, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, kv_layout_, - static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() : nullptr), - static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), - static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - /*q_offset=*/nullptr, paged_kv, - static_cast(custom_mask.data_ptr()), - static_cast(qk_indptr.data_ptr()), - static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, - rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithPagedKVCache failed with error code ", - cudaGetErrorString(status)); - return true; - }); - }); - }); - }); - }); - } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(kv_scalar_type, kv_type, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, kv_layout_, - static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() - : nullptr), - static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() - : nullptr), - static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() - : nullptr), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, q_type, kv_type, q_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - /*q_offset=*/nullptr, paged_kv, - static_cast(custom_mask.data_ptr()), - static_cast(qk_indptr.data_ptr()), - static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, - rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithPagedKVCache failed with error code ", - cudaGetErrorString(status)); - return true; - }); - }); - }); - }); - }); - }); - } - - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } -} - -void BatchPrefillWithRaggedKVCachePyTorchWrapper::Plan( - torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - torch::Tensor qo_indptr, torch::Tensor kv_indptr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, - torch::Tensor empty_q_data) { - CHECK_INPUT(float_workspace_buffer); - CHECK_INPUT(int_workspace_buffer); - // NOTE(Zihao): not necessary to be a CUDA tensor - CHECK_CONTIGUOUS(qo_indptr); - CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - CHECK_DIM(1, qo_indptr); - CHECK_DIM(1, kv_indptr); - CHECK_DIM(1, float_workspace_buffer); - CHECK_DIM(1, int_workspace_buffer); - CHECK_EQ(qo_indptr.size(0), batch_size + 1); - CHECK_EQ(kv_indptr.size(0), batch_size + 1); - qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); - kv_indptr = kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); - size_t float_workspace_size_in_bytes = - float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); - size_t int_workspace_size_in_bytes = - int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); - auto device = float_workspace_buffer.device(); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - handler_->SetCUDAStream(torch_current_stream); - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] { - cudaError_t status = handler_->Plan( - static_cast(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, - static_cast(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes, - static_cast(qo_indptr.data_ptr()), static_cast(kv_indptr.data_ptr()), - batch_size, num_qo_heads, num_kv_heads, head_dim, - /*page_size=*/1); - TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); -} - -void BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( - unsigned int int_workspace_size_in_bytes) { - handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes); -} - -std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Run( - torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, - torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse) { - CHECK_INPUT(qo_indptr); - CHECK_CUDA(q); - CHECK_CUDA(k); - CHECK_CUDA(v); - CHECK_INPUT(kv_indptr); - auto device = q.device(); - CHECK_EQ(device, qo_indptr.device()); - CHECK_EQ(device, k.device()); - CHECK_EQ(device, v.device()); - CHECK_EQ(device, kv_indptr.device()); - CHECK_DIM(3, q); // (nnz_qo, H_qo, D) - CHECK_DIM(1, qo_indptr); // (B + 1,) - CHECK_DIM(3, k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) - CHECK_DIM(3, v); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) - CHECK_DIM(1, kv_indptr); // (B + 1,) - CHECK_EQ(q.scalar_type(), k.scalar_type()); - CHECK_EQ(q.scalar_type(), v.scalar_type()); - int64_t batch_size = qo_indptr.size(0) - 1; - int64_t nnz_qo = q.size(0); - int64_t num_qo_heads = q.size(1); - int64_t head_dim = q.size(2); - CHECK_GE(kv_indptr.size(0), batch_size + 1); - int64_t num_kv_heads = (kv_layout_ == QKVLayout::kNHD) ? k.size(1) : k.size(0); - CHECK_EQ(q.stride(2), 1); - CHECK_EQ(k.stride(2), 1); - CHECK_EQ(v.stride(2), 1); - CHECK_EQ(k.size(0), v.size(0)); - CHECK_EQ(k.size(1), v.size(1)); - CHECK_EQ(k.size(2), v.size(2)); - CHECK_EQ(k.size(2), head_dim); - CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; - if (kv_layout_ == QKVLayout::kNHD) { - kv_stride_n = k.stride(0); - kv_stride_h = k.stride(1); - } else { - kv_stride_h = k.stride(0); - kv_stride_n = k.stride(1); - } - qo_indptr = qo_indptr.to(torch::kInt32); - kv_indptr = kv_indptr.to(torch::kInt32); - - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - torch::Tensor o = torch::empty_like(q, q.options()); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); - } - - MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; - TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); - const LogitsPostHook logits_post_hook = - logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - - auto q_scalar_type = q.scalar_type(); - auto kv_scalar_type = k.scalar_type(); - - TORCH_CHECK(q_scalar_type == kv_scalar_type, - "q and k must have the same scalar type, but got q: ", q_scalar_type, - " and k: ", kv_scalar_type); - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, c_type, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< - HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, - MASK_MODE, c_type, c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - static_cast(k.data_ptr()), static_cast(v.data_ptr()), - static_cast(kv_indptr.data_ptr()), - /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, - /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, - static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, - kv_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, - rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithRaggedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); - }); - }); - }); - }); - }); - - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } -} - -std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::RunCustomMask( - torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, - torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { - CHECK_INPUT(qo_indptr); - CHECK_CUDA(q); - CHECK_CUDA(k); - CHECK_CUDA(v); - CHECK_INPUT(kv_indptr); - CHECK_INPUT(custom_mask); - CHECK_INPUT(qk_indptr); - auto device = q.device(); - CHECK_EQ(device, qo_indptr.device()); - CHECK_EQ(device, k.device()); - CHECK_EQ(device, v.device()); - CHECK_EQ(device, kv_indptr.device()); - CHECK_EQ(device, custom_mask.device()); - CHECK_EQ(device, qk_indptr.device()); - CHECK_DIM(3, q); // (nnz_qo, H_qo, D) - CHECK_DIM(1, qo_indptr); // (B + 1,) - CHECK_DIM(3, k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) - CHECK_DIM(3, v); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D) - CHECK_DIM(1, kv_indptr); // (B + 1,) - CHECK_DIM(1, custom_mask); // (nnz_qk,) - CHECK_DIM(1, qk_indptr); // (B + 1,) - CHECK_EQ(q.scalar_type(), k.scalar_type()); - CHECK_EQ(q.scalar_type(), v.scalar_type()); - int64_t batch_size = qo_indptr.size(0) - 1; - int64_t nnz_qo = q.size(0); - int64_t num_qo_heads = q.size(1); - int64_t head_dim = q.size(2); - CHECK_GE(kv_indptr.size(0), batch_size + 1); - CHECK_GE(qk_indptr.size(0), batch_size + 1); - int64_t num_kv_heads = (kv_layout_ == QKVLayout::kNHD) ? k.size(1) : k.size(0); - CHECK_EQ(q.stride(2), 1); - CHECK_EQ(k.stride(2), 1); - CHECK_EQ(v.stride(2), 1); - CHECK_EQ(k.size(0), v.size(0)); - CHECK_EQ(k.size(1), v.size(1)); - CHECK_EQ(k.size(2), v.size(2)); - CHECK_EQ(k.size(2), head_dim); - CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; - if (kv_layout_ == QKVLayout::kNHD) { - kv_stride_n = k.stride(0); - kv_stride_h = k.stride(1); - } else { - kv_stride_h = k.stride(0); - kv_stride_n = k.stride(1); - } - qo_indptr = qo_indptr.to(torch::kInt32); - kv_indptr = kv_indptr.to(torch::kInt32); - qk_indptr = qk_indptr.to(torch::kInt32); - - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - torch::Tensor o = torch::empty_like(q, q.options()); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype((torch::kFloat32))); - } - - constexpr MaskMode MASK_MODE = MaskMode::kCustom; - TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); - const LogitsPostHook logits_post_hook = - logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - - auto q_scalar_type = q.scalar_type(); - auto kv_scalar_type = k.scalar_type(); - TORCH_CHECK(q_scalar_type == kv_scalar_type, - "q and k must have the same scalar type, but got q: ", q_scalar_type, - " and k: ", kv_scalar_type); - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< - HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, - MASK_MODE, c_type, c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - static_cast(k.data_ptr()), static_cast(v.data_ptr()), - static_cast(kv_indptr.data_ptr()), - static_cast(custom_mask.data_ptr()), - static_cast(qk_indptr.data_ptr()), - /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, - static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, - kv_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithRaggedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); - }); - }); - }); - }); - - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } -} diff --git a/python/csrc/bmm_fp8.cu b/python/csrc/bmm_fp8.cu index 0f7da2129..6d27a1088 100644 --- a/python/csrc/bmm_fp8.cu +++ b/python/csrc/bmm_fp8.cu @@ -17,9 +17,8 @@ #include #include -#include +#include -#include "flashinfer_ops.h" #include "pytorch_extension_utils.h" using namespace flashinfer; diff --git a/python/csrc/cascade.cu b/python/csrc/cascade.cu index e5198f898..a9309e437 100644 --- a/python/csrc/cascade.cu +++ b/python/csrc/cascade.cu @@ -15,7 +15,6 @@ */ #include -#include "flashinfer_ops.h" #include "pytorch_extension_utils.h" using namespace flashinfer; diff --git a/python/csrc/flashinfer_cascade_ops.cu b/python/csrc/flashinfer_cascade_ops.cu new file mode 100644 index 000000000..2b3270381 --- /dev/null +++ b/python/csrc/flashinfer_cascade_ops.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, torch::Tensor v_b, + torch::Tensor s_b); + +void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_other, + torch::Tensor s_other, std::optional mask = std::nullopt); + +std::vector merge_states(torch::Tensor v, torch::Tensor s); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("merge_state", &merge_state, "Merge two self-attention states"); + m.def("merge_state_in_place", &merge_state_in_place, + "Merge another self-attention state in-place."); + m.def("merge_states", &merge_states, "Merge multiple self-attention states"); +} diff --git a/python/csrc/flashinfer_gemm_ops.cu b/python/csrc/flashinfer_gemm_ops.cu new file mode 100644 index 000000000..69bd47c46 --- /dev/null +++ b/python/csrc/flashinfer_gemm_ops.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, + torch::Tensor& A_scale, torch::Tensor& B_scale); + +torch::Tensor CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM"); + m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); +} diff --git a/python/csrc/flashinfer_norm_ops.cu b/python/csrc/flashinfer_norm_ops.cu new file mode 100644 index 000000000..8c3f33850 --- /dev/null +++ b/python/csrc/flashinfer_norm_ops.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); + +void fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, + double eps); + +void gemma_rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); + +void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, + double eps); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); + m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); + m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); + m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, + "Gemma Fused add root mean square normalization"); +} diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu deleted file mode 100644 index cb71b0831..000000000 --- a/python/csrc/flashinfer_ops.cu +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "flashinfer_ops.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); - m.def("merge_state", &merge_state, "Merge two self-attention states"); - m.def("merge_state_in_place", &merge_state_in_place, - "Merge another self-attention state in-place."); - m.def("merge_states", &merge_states, "Merge multiple self-attention states"); - m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); - m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, - "Top-k sampling from probabilities"); - m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, - "Min-p sampling from probabilities"); - m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, - "Top-p sampling from probabilities"); - m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, - "Top-k and top-p sampling from probabilities"); - m.def("top_k_renorm_probs", &top_k_renorm_probs, "Renormalize probabilities by top-k mask"); - m.def("top_p_renorm_probs", &top_p_renorm_probs, "Renormalize probabilities by top-p mask"); - m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); - m.def("chain_speculative_sampling", &chain_speculative_sampling, - "Speculative sampling from sequence of probabilities"); - m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); - m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); - m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); - m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, - "Gemma Fused add root mean square normalization"); - m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); - m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); - m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); - m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); - m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, - "Apply Llama 3.1 style RoPE in-place"); - m.def("apply_rope", &apply_rope, "Apply RoPE"); - m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); - m.def("packbits", &packbits, "GPU packbits operator"); - m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); - m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); - py::class_(m, "CutlassSegmentGEMMPyTorchWrapper") - .def(py::init()) - .def("register_workspace", &CutlassSegmentGEMMPyTorchWrapper::RegisterWorkspaceBuffer) - .def("run", &CutlassSegmentGEMMPyTorchWrapper::Run); -} diff --git a/python/csrc/flashinfer_ops_decode.cu b/python/csrc/flashinfer_ops_decode.cu deleted file mode 100644 index 7bcde40b0..000000000 --- a/python/csrc/flashinfer_ops_decode.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "flashinfer_ops_decode.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, - "Single-request decode with KV-Cache operator"); - py::class_(m, - "BatchDecodeWithPagedKVCachePyTorchWrapper") - .def(py::init()) - .def("plan", &BatchDecodeWithPagedKVCachePyTorchWrapper::Plan) - .def("is_cuda_graph_enabled", &BatchDecodeWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled) - .def("update_page_locked_buffer_size", - &BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) - .def("run", &BatchDecodeWithPagedKVCachePyTorchWrapper::Run); -} diff --git a/python/csrc/flashinfer_ops_decode.h b/python/csrc/flashinfer_ops_decode.h deleted file mode 100644 index 48c68f7e1..000000000 --- a/python/csrc/flashinfer_ops_decode.h +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include - -#include -#include -#include - -torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, unsigned int pos_encoding_mode, - unsigned int layout, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta); - -class BatchDecodeWithPagedKVCachePyTorchWrapper { - public: - void Plan(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, - unsigned int page_size, unsigned int pos_encoding_mode, float logits_soft_cap, - torch::Tensor empty_q_data, torch::Tensor empty_kv_data); - void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes); - bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } - std::vector Run(torch::Tensor q, std::optional paged_kv_cache, - std::optional paged_k_cache, - std::optional paged_v_cache, - torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, - unsigned int pos_encoding_mode, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, bool return_lse); - BatchDecodeWithPagedKVCachePyTorchWrapper( - std::shared_ptr handler_ptr, flashinfer::QKVLayout kv_layout) - : handler_(handler_ptr), kv_layout_(kv_layout) {} - BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph, - unsigned int fixed_batch_size) - : kv_layout_(flashinfer::QKVLayout(layout)), - handler_(std::make_shared(enable_cuda_graph, - fixed_batch_size)) {} - - protected: - std::shared_ptr handler_; - flashinfer::QKVLayout kv_layout_; -}; diff --git a/python/csrc/flashinfer_ops_prefill.cu b/python/csrc/flashinfer_ops_prefill.cu deleted file mode 100644 index 75906e069..000000000 --- a/python/csrc/flashinfer_ops_prefill.cu +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "flashinfer_ops_prefill.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, - "Single-request prefill with KV-Cache operator, return logsumexp"); - m.def( - "single_prefill_with_kv_cache_custom_mask", &single_prefill_with_kv_cache_custom_mask, - "Single-request prefill with KV-Cache operator, user defined custom mask, return logsumexp"); - py::class_( - m, "BatchPrefillWithPagedKVCachePyTorchWrapper") - .def(py::init()) - .def("plan", &BatchPrefillWithPagedKVCachePyTorchWrapper::Plan) - .def("is_cuda_graph_enabled", &BatchPrefillWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled) - .def("update_page_locked_buffer_size", - &BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) - .def("run", &BatchPrefillWithPagedKVCachePyTorchWrapper::Run) - .def("run_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::RunCustomMask); - py::class_( - m, "BatchPrefillWithRaggedKVCachePyTorchWrapper") - .def(py::init()) - .def("plan", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Plan) - .def("is_cuda_graph_enabled", - &BatchPrefillWithRaggedKVCachePyTorchWrapper::IsCUDAGraphEnabled) - .def("update_page_locked_buffer_size", - &BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) - .def("run", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Run) - .def("run_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::RunCustomMask); -} diff --git a/python/csrc/flashinfer_ops_prefill.h b/python/csrc/flashinfer_ops_prefill.h deleted file mode 100644 index 27a0edcfc..000000000 --- a/python/csrc/flashinfer_ops_prefill.h +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include - -#include -#include - -std::vector single_prefill_with_kv_cache( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, - unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - bool return_lse); - -std::vector single_prefill_with_kv_cache_custom_mask( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor packed_custom_mask, - torch::Tensor tmp, unsigned int layout, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse); - -class BatchPrefillWithPagedKVCachePyTorchWrapper { - public: - void Plan(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - torch::Tensor qo_indptr, torch::Tensor page_kv_indptr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, - unsigned page_size, torch::Tensor empty_q_data); - bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } - void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes); - std::vector Run(torch::Tensor q, torch::Tensor qo_indptr, - std::optional paged_kv_cache, - std::optional paged_k_cache, - std::optional paged_v_cache, - torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, bool causal, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse); - std::vector RunCustomMask( - torch::Tensor q, torch::Tensor qo_indptr, std::optional paged_kv_cache, - std::optional paged_k_cache, std::optional paged_v_cache, - torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, torch::Tensor packed_custom_mask, - torch::Tensor qk_indptr, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - bool return_lse); - BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph) - : kv_layout_(flashinfer::QKVLayout(layout)), - handler_(std::make_shared(enable_cuda_graph)) {} - - private: - std::shared_ptr handler_; - flashinfer::QKVLayout kv_layout_; -}; - -class BatchPrefillWithRaggedKVCachePyTorchWrapper { - public: - void Plan(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - torch::Tensor qo_indptr, torch::Tensor kv_indptr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, - torch::Tensor empty_q_data); - bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } - void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes); - std::vector Run(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, - torch::Tensor v, torch::Tensor kv_indptr, bool causal, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - int window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse); - std::vector RunCustomMask( - torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, - torch::Tensor kv_indptr, torch::Tensor packed_custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse); - BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph) - : kv_layout_(flashinfer::QKVLayout(layout)), - handler_(std::make_shared(enable_cuda_graph)) {} - - private: - std::shared_ptr handler_; - flashinfer::QKVLayout kv_layout_; -}; diff --git a/python/csrc/flashinfer_page_ops.cu b/python/csrc/flashinfer_page_ops.cu new file mode 100644 index 000000000..39caf24fa --- /dev/null +++ b/python/csrc/flashinfer_page_ops.cu @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, + torch::Tensor append_indptr, std::optional paged_kv_cache, + std::optional paged_k_cache, + std::optional paged_v_cache, torch::Tensor kv_indices, + torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, + unsigned int layout); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); +} diff --git a/python/csrc/flashinfer_quantization_ops.cu b/python/csrc/flashinfer_quantization_ops.cu new file mode 100644 index 000000000..7f2886091 --- /dev/null +++ b/python/csrc/flashinfer_quantization_ops.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); + +torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, + torch::Tensor output_indptr, const std::string& bitorder); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("packbits", &packbits, "GPU packbits operator"); + m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); +} diff --git a/python/csrc/flashinfer_rope_ops.cu b/python/csrc/flashinfer_rope_ops.cu new file mode 100644 index 000000000..4075930b5 --- /dev/null +++ b/python/csrc/flashinfer_rope_ops.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); + +void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta, float low_freq_factor, float high_freq_factor, + float old_context_length); + +std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta); + +std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, + torch::Tensor indptr, torch::Tensor offsets, + bool interleave, float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); + m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, + "Apply Llama 3.1 style RoPE in-place"); + m.def("apply_rope", &apply_rope, "Apply RoPE"); + m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); +} diff --git a/python/csrc/flashinfer_sampling_ops.cu b/python/csrc/flashinfer_sampling_ops.cu new file mode 100644 index 000000000..0ab59fc9c --- /dev/null +++ b/python/csrc/flashinfer_sampling_ops.cu @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, + bool deterministic); + +std::vector top_p_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + std::optional maybe_top_p_arr, + double top_p_val, bool deterministic); + +std::vector top_k_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + std::optional maybe_top_k_arr, + unsigned int top_k_val, bool deterministic); + +std::vector min_p_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + std::optional maybe_min_p_arr, + double min_p_val, bool deterministic); + +std::vector top_k_top_p_sampling_from_probs( + torch::Tensor probs, torch::Tensor uniform_samples, + std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic); + +torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional maybe_top_p_arr, + double top_p_val); + +torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional maybe_top_k_arr, + unsigned int top_k_val); + +torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional maybe_top_k_arr, + unsigned int top_k_val); + +std::vector chain_speculative_sampling( + torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, + torch::Tensor target_probs, std::optional maybe_output_accepted_token_num, + std::optional maybe_output_emitted_token_num, bool deterministic); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); + m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, + "Top-k sampling from probabilities"); + m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, + "Min-p sampling from probabilities"); + m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, + "Top-p sampling from probabilities"); + m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, + "Top-k and top-p sampling from probabilities"); + m.def("top_k_renorm_probs", &top_k_renorm_probs, "Renormalize probabilities by top-k mask"); + m.def("top_p_renorm_probs", &top_p_renorm_probs, "Renormalize probabilities by top-p mask"); + m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); + m.def("chain_speculative_sampling", &chain_speculative_sampling, + "Speculative sampling from sequence of probabilities"); +} diff --git a/python/csrc/group_gemm.cu b/python/csrc/group_gemm.cu index e3ed472b1..7954f5336 100644 --- a/python/csrc/group_gemm.cu +++ b/python/csrc/group_gemm.cu @@ -13,22 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include -#include "flashinfer_ops.h" #include "pytorch_extension_utils.h" using namespace flashinfer::group_gemm; -void CutlassSegmentGEMMPyTorchWrapper::RegisterWorkspaceBuffer(torch::Tensor workspace_buffer) { - handler_->RegisterWorkspace(static_cast(workspace_buffer.data_ptr()), - workspace_buffer.size(0) * workspace_buffer.element_size()); -} - -torch::Tensor CutlassSegmentGEMMPyTorchWrapper::Run(torch::Tensor seg_indptr, - torch::Tensor weight_indices, torch::Tensor x, - torch::Tensor weight, unsigned int batch_size, - bool weight_column_major) { +torch::Tensor CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major) { // TODO(Zihao): Add more checks here CHECK_INPUT(seg_indptr); CHECK_INPUT(x); @@ -56,10 +50,10 @@ torch::Tensor CutlassSegmentGEMMPyTorchWrapper::Run(torch::Tensor seg_indptr, DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(x.scalar_type(), c_type, [&] { using cutlass_t = typename cutlass_dtype::type; - auto status = CutlassSegmentGEMMWrapper( - handler_.get(), static_cast(x.data_ptr()), - static_cast(weight.data_ptr()), static_cast(y.data_ptr()), - static_cast(seg_indptr.data_ptr()), + auto status = CutlassSegmentGEMMRun( + workspace_buffer.data_ptr(), workspace_buffer.element_size() * workspace_buffer.size(0), + static_cast(x.data_ptr()), static_cast(weight.data_ptr()), + static_cast(y.data_ptr()), static_cast(seg_indptr.data_ptr()), weight_indices_defined ? static_cast(weight_indices.data_ptr()) : nullptr, batch_size, d_in, d_out, weight_column_major, torch_current_stream); TORCH_CHECK(status == cudaSuccess, diff --git a/python/csrc/norm.cu b/python/csrc/norm.cu index 855050756..3de264656 100644 --- a/python/csrc/norm.cu +++ b/python/csrc/norm.cu @@ -15,7 +15,6 @@ */ #include -#include "flashinfer_ops.h" #include "pytorch_extension_utils.h" using namespace flashinfer; diff --git a/python/csrc/page.cu b/python/csrc/page.cu index 12461c827..787aa1aa6 100644 --- a/python/csrc/page.cu +++ b/python/csrc/page.cu @@ -15,7 +15,6 @@ */ #include -#include "flashinfer_ops.h" #include "pytorch_extension_utils.h" using namespace flashinfer; @@ -71,7 +70,6 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, CHECK_EQ(kv_indptr.device(), device); CHECK_EQ(kv_last_page_len.device(), device); - constexpr PageStorage page_storage = PageStorage::kIndices; QKVLayout kv_layout = QKVLayout(layout); unsigned int num_heads, page_size, head_dim; @@ -106,7 +104,7 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(kv_scalar_dtype, c_type, [&] { - paged_kv_t paged_kv( + paged_kv_t paged_kv( num_heads, page_size, head_dim, batch_size, kv_layout, static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() : nullptr), static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index d6895041c..2526c9407 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -20,11 +20,6 @@ #include #include -#include -#include - -#include "generated/dispatch.inc" - using namespace flashinfer; #ifdef FLASHINFER_ENABLE_BF16 @@ -196,24 +191,6 @@ using namespace flashinfer; return __VA_ARGS__(); \ } -#define DISPATCH_head_dim(expr, const_expr, ...) \ - _DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__)) - -#define DISPATCH_logits_post_hook(expr, const_expr, ...) \ - _DISPATCH_SWITCH("logits post hook", expr, \ - _DISPATCH_CASES_logits_post_hook(const_expr, __VA_ARGS__)) - -#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \ - _DISPATCH_SWITCH("positional encoding mode", expr, \ - _DISPATCH_CASES_pos_encoding_mode(const_expr, __VA_ARGS__)) - -#define DISPATCH_allow_fp16_qk_reduction(expr, const_expr, ...) \ - _DISPATCH_SWITCH("allow_fp16_qk_reduction", expr, \ - _DISPATCH_CASES_allow_fp16_qk_reduction(const_expr, __VA_ARGS__)) - -#define DISPATCH_mask_mode(expr, const_expr, ...) \ - _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) - inline void check_shape(const torch::Tensor& a, const torch::Tensor& b, const char* a_name, const char* b_name) { TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", diff --git a/python/csrc/quantization.cu b/python/csrc/quantization.cu index 9e358421c..7832340f9 100644 --- a/python/csrc/quantization.cu +++ b/python/csrc/quantization.cu @@ -15,7 +15,6 @@ */ #include -#include "flashinfer_ops.h" #include "pytorch_extension_utils.h" using namespace flashinfer; diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu index 572d7e9c5..bb8d5a196 100644 --- a/python/csrc/rope.cu +++ b/python/csrc/rope.cu @@ -15,7 +15,6 @@ */ #include -#include "flashinfer_ops.h" #include "pytorch_extension_utils.h" using namespace flashinfer; diff --git a/python/csrc/sampling.cu b/python/csrc/sampling.cu index 190ad6bf8..db4c0a5c5 100644 --- a/python/csrc/sampling.cu +++ b/python/csrc/sampling.cu @@ -15,7 +15,6 @@ */ #include -#include "flashinfer_ops.h" #include "pytorch_extension_utils.h" using namespace flashinfer; diff --git a/python/csrc/single_decode.cu b/python/csrc/single_decode.cu deleted file mode 100644 index abbe81dcb..000000000 --- a/python/csrc/single_decode.cu +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "flashinfer_ops_decode.h" -#include "pytorch_extension_utils.h" - -using namespace flashinfer; - -torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, unsigned int pos_encoding_mode, - unsigned int layout, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta) { - CHECK_INPUT(q); - CHECK_INPUT(k); - CHECK_INPUT(v); - CHECK_INPUT(tmp); - auto device = q.device(); - CHECK_EQ(k.device(), device); - CHECK_EQ(v.device(), device); - CHECK_EQ(tmp.device(), device); - CHECK_DIM(2, q); - CHECK_DIM(3, k); - CHECK_DIM(3, v); - CHECK_SHAPE(k, v); - CHECK_EQ(q.size(1), k.size(2)); - CHECK_EQ(v.scalar_type(), k.scalar_type()); - unsigned int num_qo_heads = q.size(0); - unsigned int head_dim = q.size(1); - unsigned int kv_len, num_kv_heads; - QKVLayout kv_layout = static_cast(layout); - if (kv_layout == QKVLayout::kNHD) { - kv_len = k.size(0); - num_kv_heads = k.size(1); - } else { - num_kv_heads = k.size(0); - kv_len = k.size(1); - } - CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto o = torch::empty_like(q); - - TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); - const LogitsPostHook logits_post_hook = - logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - - auto q_scalar_type = q.scalar_type(); - auto kv_scalar_type = k.scalar_type(); - - if (q_scalar_type == kv_scalar_type) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q_scalar_type, qkv_type, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = SingleDecodeWithKVCacheDispatched( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, - kv_layout, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SingleDecodeWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); - }); - }); - }); - } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(kv_scalar_type, kv_type, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = SingleDecodeWithKVCacheDispatched( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, - kv_layout, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SingleDecodeWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); - }); - }); - }); - }); - } - - return o; -} diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu deleted file mode 100644 index 320d2c353..000000000 --- a/python/csrc/single_prefill.cu +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "flashinfer_ops_prefill.h" -#include "pytorch_extension_utils.h" - -using namespace flashinfer; - -std::vector single_prefill_with_kv_cache( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, - unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - bool return_lse) { - CHECK_CUDA(q); - CHECK_CUDA(k); - CHECK_CUDA(v); - CHECK_INPUT(tmp); - auto device = q.device(); - CHECK_EQ(k.device(), device); - CHECK_EQ(v.device(), device); - CHECK_EQ(tmp.device(), device); - CHECK_DIM(3, q); - CHECK_DIM(3, k); - CHECK_DIM(3, v); - CHECK_SHAPE(k, v); - CHECK_EQ(q.stride(2), 1); - CHECK_EQ(k.stride(2), 1); - CHECK_EQ(v.stride(2), 1); - CHECK_EQ(q.size(2), k.size(2)); - CHECK_EQ(q.scalar_type(), k.scalar_type()); - CHECK_EQ(q.scalar_type(), v.scalar_type()); - unsigned int head_dim = q.size(2); - unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; - QKVLayout kv_layout = static_cast(layout); - qo_len = q.size(0); - num_qo_heads = q.size(1); - uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; - if (kv_layout == QKVLayout::kNHD) { - kv_len = k.size(0); - num_kv_heads = k.size(1); - kv_stride_n = k.stride(0); - kv_stride_h = k.stride(1); - } else { // QKVLayout::kHND - kv_len = k.size(1); - num_kv_heads = k.size(0); - kv_stride_h = k.stride(0); - kv_stride_n = k.stride(1); - } - CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto o = torch::empty_like(q, q.options()); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32)); - } - - const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; - TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); - const LogitsPostHook logits_post_hook = - logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - - auto q_scalar_type = q.scalar_type(); - auto kv_scalar_type = k.scalar_type(); - TORCH_CHECK(q_scalar_type == kv_scalar_type, - "q and k must have the same scalar type, but got q: ", q_scalar_type, - " and k: ", kv_scalar_type); - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, c_type, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - SinglePrefillWithKVCacheDispatched( - static_cast(q.data_ptr()), - static_cast(k.data_ptr()), - static_cast(v.data_ptr()), - /*custom_mask=*/nullptr, static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h, - kv_stride_n, kv_stride_h, window_left, logits_soft_cap, sm_scale, - rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SinglePrefillWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); - }); - }); - }); - }); - }); - - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } -} - -std::vector single_prefill_with_kv_cache_custom_mask( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor packed_custom_mask, - torch::Tensor tmp, unsigned int layout, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, bool return_lse) { - CHECK_CUDA(q); - CHECK_CUDA(k); - CHECK_CUDA(v); - CHECK_INPUT(packed_custom_mask); - auto device = q.device(); - CHECK_EQ(k.device(), device); - CHECK_EQ(v.device(), device); - CHECK_EQ(packed_custom_mask.device(), device); - CHECK_DIM(3, q); - CHECK_DIM(3, k); - CHECK_DIM(3, v); - CHECK_DIM(1, packed_custom_mask); - CHECK_SHAPE(k, v); - CHECK_EQ(q.stride(2), 1); - CHECK_EQ(k.stride(2), 1); - CHECK_EQ(v.stride(2), 1); - CHECK_EQ(q.size(2), k.size(2)); - // packed_custom_mask must be uint8 - TORCH_CHECK(packed_custom_mask.scalar_type() == torch::kUInt8, - "packed_custom_mask must be uint8"); - unsigned int head_dim = q.size(2); - unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; - QKVLayout kv_layout = static_cast(layout); - qo_len = q.size(0); - num_qo_heads = q.size(1); - uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; - if (kv_layout == QKVLayout::kNHD) { - kv_len = k.size(0); - num_kv_heads = k.size(1); - kv_stride_n = k.stride(0); - kv_stride_h = k.stride(1); - } else { - kv_len = k.size(1); - num_kv_heads = k.size(0); - kv_stride_h = k.stride(0); - kv_stride_n = k.stride(1); - } - CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto o = torch::empty_like(q, q.options()); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32)); - } - - constexpr MaskMode MASK_MODE = MaskMode::kCustom; - TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); - const LogitsPostHook logits_post_hook = - logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - - auto q_scalar_type = q.scalar_type(); - auto kv_scalar_type = k.scalar_type(); - TORCH_CHECK(q_scalar_type == kv_scalar_type, - "q and k must have the same scalar type, but got q: ", q_scalar_type, - " and k: ", kv_scalar_type); - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - SinglePrefillWithKVCacheDispatched( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), - static_cast(packed_custom_mask.data_ptr()), - static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h, - kv_stride_n, kv_stride_h, window_left, logits_soft_cap, sm_scale, - rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SinglePrefillWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); - }); - }); - }); - }); - - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } -} diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 9084503b6..2f0d8b488 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -63,3 +63,8 @@ except ImportError: with open("version.txt") as f: __version__ = f.read().strip() + +try: + import aot_config +except ImportError: + aot_config = None diff --git a/python/flashinfer/activation.py b/python/flashinfer/activation.py index eb9732b2d..0b5c18e50 100644 --- a/python/flashinfer/activation.py +++ b/python/flashinfer/activation.py @@ -15,21 +15,70 @@ """ from typing import Optional +from .jit import ( + load_cuda_ops, + FLASHINFER_GEN_SRC_DIR, + gen_act_and_mul_cu, + has_prebuilt_ops, +) import torch -# mypy: disable-error-code="attr-defined" -try: - from . import _kernels -except ImportError as e: - import logging - import os - if os.environ.get("BUILD_DOC", "0") == "1": - _kernels = None - logging.warning("Kernels are not loaded in documentation build mode.") - else: - raise e +silu_def_cu_str = r""" +__device__ __forceinline__ float silu(const float& val) { + return val / (1.0f + __expf(-val)); +} +""" + +gelu_def_cu_str = r""" +__device__ __forceinline__ float gelu(const float& val) { + constexpr float kAlpha = M_SQRT1_2; + return val * 0.5f * (1.0f + ::erf(val * kAlpha)); +} +""" + +gelu_def_tanh_cu_str = r""" +__device__ __forceinline__ float gelu_tanh(const float& val) { + const float cdf = + 0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val)))); + return val * cdf; +} +""" + +act_func_def_str = { + "silu": silu_def_cu_str, + "gelu": gelu_def_cu_str, + "gelu_tanh": gelu_def_tanh_cu_str, +} + + +def compile_act_and_mul_module(name: str, act_func_def: str, verbose: bool = False): + gen_act_and_mul_cu(name, act_func_def) + return load_cuda_ops( + f"{name}_and_mul", + [ + FLASHINFER_GEN_SRC_DIR / f"{name}_and_mul.cu", + ], + verbose=verbose, + ) + + +_jit_modules = {} + + +def get_act_and_mul_module(act_func_name: str): + global _jit_modules + if act_func_name not in _jit_modules: + if has_prebuilt_ops: + from . import _kernels + + _jit_modules[act_func_name] = _kernels + else: + _jit_modules[act_func_name] = compile_act_and_mul_module( + act_func_name, act_func_def_str[act_func_name] + ) + return _jit_modules[act_func_name] def _check_shape(input: torch.Tensor, output: torch.Tensor): @@ -68,7 +117,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - _kernels.silu_and_mul(out, input) + get_act_and_mul_module("silu").silu_and_mul(out, input) return out @@ -98,7 +147,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te device=input.device, dtype=input.dtype, ) - _kernels.gelu_tanh_and_mul(out, input) + get_act_and_mul_module("gelu_tanh").gelu_tanh_and_mul(out, input) return out @@ -128,5 +177,5 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - _kernels.gelu_and_mul(out, input) + get_act_and_mul_module("gelu").gelu_and_mul(out, input) return out diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 86fb4a5f9..7c10e5cbf 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -16,20 +16,30 @@ import math from typing import Optional, Tuple, List +from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops import torch -# mypy: disable-error-code="attr-defined" -try: - from . import _kernels -except ImportError as e: - import os - import logging - if os.environ.get("BUILD_DOC", "0") == "1": - _kernels = None - logging.warning("Kernels are not loaded in documentation build mode.") - else: - raise e +_cascade_module = None + + +def get_cascade_module(): + global _cascade_module + if _cascade_module is None: + if has_prebuilt_ops: + from . import _kernels + + _cascade_module = _kernels + else: + _cascade_module = load_cuda_ops( + "cascade", + [ + FLASHINFER_CSRC_DIR / "cascade.cu", + FLASHINFER_CSRC_DIR / "flashinfer_cascade_ops.cu", + ], + ) + return _cascade_module + from .decode import ( BatchDecodeWithPagedKVCacheWrapper, @@ -88,7 +98,7 @@ def merge_state( >>> s_merged.shape torch.Size([2048, 32]) """ - return _kernels.merge_state(v_a, s_a, v_b, s_b) + return get_cascade_module().merge_state(v_a, s_a, v_b, s_b) def merge_state_in_place( @@ -134,7 +144,7 @@ def merge_state_in_place( >>> s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") >>> flashinfer.merge_state_in_place(v, s, v_other, s_other) """ - _kernels.merge_state_in_place(v, s, v_other, s_other, mask) + get_cascade_module().merge_state_in_place(v, s, v_other, s_other, mask) def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -174,7 +184,7 @@ def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch. >>> s_merged.shape torch.Size([2048, 32]) """ - return _kernels.merge_states(v, s) + return get_cascade_module().merge_states(v, s) class MultiLevelCascadeAttentionWrapper: diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index afd1a040f..316940309 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -16,48 +16,106 @@ import math from typing import Optional, Union, Dict, Tuple +from types import SimpleNamespace import torch - -# mypy: disable-error-code="attr-defined" -try: - from . import _decode - from . import _prefill -except ImportError as e: - import os - import logging - - if os.environ.get("BUILD_DOC", "0") == "1": - _decode = None - _prefill = None - logging.warning("Kernels are not loaded in documentation build mode.") - else: - raise e - +import functools + +from .jit import ( + load_cuda_ops, + FLASHINFER_GEN_SRC_DIR, + gen_single_decode_cu, + get_single_decode_uri, + gen_batch_decode_cu, + get_batch_decode_uri, + has_prebuilt_ops, + prebuilt_ops_uri, +) +from .prefill import get_single_prefill_module, get_batch_prefill_module from .utils import ( PosEncodingMode, TensorLayout, + MaskMode, + canonicalize_torch_dtype, _check_pos_encoding_mode, _check_kv_layout, _unpack_paged_kv_cache, + _get_cache_buf, + _get_cache_alibi_slopes_buf, + _get_range_buf, ) -_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {} - - -def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: - key = (name, device) - buf = _cache_buf.get(key) - if buf is None: - buf = torch.empty(bytes, dtype=torch.uint8, device=device) - _cache_buf[key] = buf - return buf - -def _grouped_size_compiled_for_decode_kernels( - num_qo_heads: int, num_kv_heads: int -) -> bool: - return (num_qo_heads // num_kv_heads) in [1, 2, 4, 8] +def compile_single_decode_module( + *args, + verbose: bool = False, +): + gen_single_decode_cu(*args) + uri = get_single_decode_uri(*args) + return load_cuda_ops( + uri, + [FLASHINFER_GEN_SRC_DIR / f"{uri}.cu"], + verbose=verbose, + ) + + +def compile_batch_decode_module( + *args, + verbose: bool = False, +): + gen_batch_decode_cu(*args) + uri = get_batch_decode_uri(*args) + return load_cuda_ops( + uri, + [FLASHINFER_GEN_SRC_DIR / f"{uri}.cu"], + verbose=verbose, + ) + + +_single_decode_modules = {} +_batch_decode_modules = {} + + +def get_single_decode_module(*args): + global _single_decode_modules + if args not in _single_decode_modules: + if has_prebuilt_ops and get_single_decode_uri(*args) in prebuilt_ops_uri: + from . import _decode_kernels + + _single_decode_modules[args] = SimpleNamespace( + run=_decode_kernels.single_decode_with_kv_cache, + ) + else: + _single_decode_modules[args] = compile_single_decode_module(*args) + return _single_decode_modules[args] + + +def get_batch_decode_module(*args): + global _batch_decode_modules + if args not in _batch_decode_modules: + if has_prebuilt_ops and get_batch_decode_uri(*args) in prebuilt_ops_uri: + from . import _decode_kernels + + # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later + dtype_q = args[0] + dtype_kv = args[1] + head_dim = args[4] + use_logits_cap = args[7] + plan_func = lambda *plan_args: _decode_kernels.batch_decode_with_paged_kv_cache_plan( + use_logits_cap, + head_dim, + torch.empty(0, dtype=dtype_q), + torch.empty(0, dtype=dtype_kv), + *plan_args, + ) + run_func = _decode_kernels.batch_decode_with_paged_kv_cache_run + _batch_decode_modules[args] = SimpleNamespace( + plan=plan_func, + run=run_func, + ) + else: + _batch_decode_modules[args] = compile_batch_decode_module(*args) + return _batch_decode_modules[args] def single_decode_with_kv_cache( @@ -165,38 +223,52 @@ def single_decode_with_kv_cache( if rope_theta is None: rope_theta = 1e4 num_qo_heads = q.shape[0] - num_kv_heads = k.shape[1] if kv_layout == "NHD" else k.shape[0] - if not _grouped_size_compiled_for_decode_kernels(num_qo_heads, num_kv_heads): - raise RuntimeError( - "Please set `use_tensor_cores=True` in single_decode_with_kv_cache for group size {}.".format( - num_qo_heads // num_kv_heads - ) - ) if use_tensor_cores: - out = _prefill.single_prefill_with_kv_cache( - q.unsqueeze(0), - k, - v, - tmp, - False, # causal - TensorLayout[kv_layout].value, - PosEncodingMode[pos_encoding_mode].value, - False, # allow_fp16_qk_reduction - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - False, # return_lse - )[0].squeeze(0) + out = ( + get_single_prefill_module( + q.dtype, + k.dtype, + q.dtype, + head_dim, + MaskMode.NON_CAUSAL.value, + PosEncodingMode[pos_encoding_mode].value, + window_left != -1, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + False, # allow_fp16_qk_reduction + ) + .run( + q.unsqueeze(0), + k, + v, + None, # packed_custom_mask + tmp, + _get_cache_alibi_slopes_buf(num_qo_heads, q.device), + TensorLayout[kv_layout].value, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + False, # return_lse + )[0] + .squeeze(0) + ) else: - out = _decode.single_decode_with_kv_cache( + out = get_single_decode_module( + q.dtype, + k.dtype, + q.dtype, + head_dim, + PosEncodingMode[pos_encoding_mode].value, + window_left != -1, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + ).run( q, k, v, tmp, - PosEncodingMode[pos_encoding_mode].value, + _get_cache_alibi_slopes_buf(num_qo_heads, q.device), TensorLayout[kv_layout].value, window_left, logits_soft_cap, @@ -204,6 +276,7 @@ def single_decode_with_kv_cache( rope_scale, rope_theta, ) + if v_scale is not None: out *= v_scale return out @@ -325,8 +398,12 @@ def __init__( _check_kv_layout(kv_layout) self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer + self.device = float_workspace_buffer.device self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) + self._pin_memory_int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, pin_memory=True ) if use_cuda_graph: @@ -353,26 +430,16 @@ def __init__( self._paged_kv_indptr_buf = paged_kv_indptr_buffer self._paged_kv_indices_buf = paged_kv_indices_buffer self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer + self._use_tensor_cores = use_tensor_cores + self._use_cuda_graph = use_cuda_graph if use_tensor_cores: - self._use_tensor_cores = True - self._wrapper = _prefill.BatchPrefillWithPagedKVCachePyTorchWrapper( - TensorLayout[kv_layout].value, - use_cuda_graph, - ) if use_cuda_graph: self._qo_indptr_buf = torch.arange( self._fixed_batch_size + 1, dtype=torch.int32, device=float_workspace_buffer.device, ) - else: - self._use_tensor_cores = False - self._wrapper = _decode.BatchDecodeWithPagedKVCachePyTorchWrapper( - TensorLayout[kv_layout].value, - use_cuda_graph, - self._fixed_batch_size, - ) @property def use_tensor_cores(self) -> bool: @@ -380,10 +447,10 @@ def use_tensor_cores(self) -> bool: @property def is_cuda_graph_enabled(self) -> bool: - return self._wrapper.is_cuda_graph_enabled() + return self._use_cuda_graph def reset_workspace_buffer( - self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor ) -> None: r"""Reset the workspace buffer. @@ -399,8 +466,10 @@ def reset_workspace_buffer( """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - self._wrapper.update_page_locked_buffer_size( - int_workspace_buffer.numel() * int_workspace_buffer.element_size() + self._pin_memory_int_workspace_buffer = torch.empty( + self._int_workspace_buffer.shape, + dtype=self._int_workspace_buffer.dtype, + pin_memory=True, ) def plan( @@ -473,6 +542,8 @@ def plan( if logits_soft_cap is None: logits_soft_cap = 0.0 + qo_indptr = _get_range_buf(batch_size + 1, indptr.device) + if self.is_cuda_graph_enabled: if batch_size != self._fixed_batch_size: raise ValueError( @@ -488,72 +559,66 @@ def plan( self._paged_kv_indptr_buf.copy_(indptr) self._paged_kv_indices_buf[: len(indices)] = indices self._paged_kv_last_page_len_buf.copy_(last_page_len) + if self.use_tensor_cores: + self._qo_indptr_buf.copy_(qo_indptr) else: - self._paged_kv_indptr_buf = indptr - self._paged_kv_indices_buf = indices - self._paged_kv_last_page_len_buf = last_page_len + self._paged_kv_indptr_buf = indptr.to(self.device) + self._paged_kv_indices_buf = indices.to(self.device) + self._paged_kv_last_page_len_buf = last_page_len.to(self.device) + if self.use_tensor_cores: + self._qo_indptr_buf = qo_indptr.to(self.device) - # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info + data_type = canonicalize_torch_dtype(data_type) if not q_data_type: q_data_type = data_type - empty_q_data = torch.empty( - 0, - dtype=( - getattr(torch, q_data_type) - if isinstance(q_data_type, str) - else q_data_type - ), - ) - empty_kv_cache = torch.empty( - 0, - dtype=( - getattr(torch, data_type) if isinstance(data_type, str) else data_type - ), - ) - - if not _grouped_size_compiled_for_decode_kernels(num_qo_heads, num_kv_heads): - if not self.use_tensor_cores: - # NOTE(Zihao): group size not compiled for decode (cuda cores) kernels, user should use prefill (tensor cores) kernels instead - raise RuntimeError( - "Please set `use_tensor_cores=True` in BatchDecodeWithPagedKVCacheWrapper for group size {}.".format( - num_qo_heads // num_kv_heads - ) - ) + q_data_type = canonicalize_torch_dtype(q_data_type) if self.use_tensor_cores: - if not self.is_cuda_graph_enabled: - # when not using cudagraph, we need to create the indptr buffer, otherwise - # the buffer is already created during initialization - self._qo_indptr_buf = torch.arange( - batch_size + 1, dtype=torch.int32, device=indptr.device - ) - self._wrapper.plan( + self._cached_module = get_batch_prefill_module( + q_data_type, + data_type, + q_data_type, + indptr.dtype, + head_dim, + PosEncodingMode[pos_encoding_mode].value, + MaskMode.NON_CAUSAL.value, + window_left != -1, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + False, # allow_fp16_qk_reduction + ) + self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, - self._qo_indptr_buf, + self._pin_memory_int_workspace_buffer, + qo_indptr, indptr, batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size, - empty_q_data, + self.is_cuda_graph_enabled, ) else: - self._wrapper.plan( + self._cached_module = get_batch_decode_module( + q_data_type, + data_type, + q_data_type, + indptr.dtype, + head_dim, + PosEncodingMode[pos_encoding_mode].value, + window_left != -1, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + ) + self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, indptr, - last_page_len, batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size, - PosEncodingMode[pos_encoding_mode].value, - logits_soft_cap, - empty_q_data, - empty_kv_cache, + self.is_cuda_graph_enabled, ) self._pos_encoding_mode = pos_encoding_mode @@ -597,7 +662,8 @@ def run( q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, - ) -> torch.Tensor: + return_lse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch decode attention between query and paged kv cache. Parameters @@ -624,11 +690,17 @@ def run( The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. v_scale : Optional[float] The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. + return_lse : bool + Whether to return the logsumexp of attention scores, defaults to ``False``. Returns ------- - torch.Tensor - The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + If :attr:`return_lse` is ``False``, the attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. + If :attr:`return_lse` is ``True``, a tuple of two tensors: + + * attention output, shape: ``[batch_size, num_qo_heads, head_dim]`` + * logsumexp of attention scores, shape: ``[batch_size, num_qo_heads]``. """ pos_encoding_mode = self._pos_encoding_mode window_left = self._window_left @@ -652,41 +724,50 @@ def run( rope_theta = 1e4 if self.use_tensor_cores: - out = self._wrapper.run( + out = self._cached_module.paged_run( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._plan_info, q, - self._qo_indptr_buf, *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), + None, # packed_custom_mask + _get_cache_alibi_slopes_buf(q.shape[1], q.device), + self._qo_indptr_buf, self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, - False, # causal - PosEncodingMode[pos_encoding_mode].value, - False, # allow_fp16_qk_reduction + None, # qk_indptr_buf + TensorLayout[self._kv_layout].value, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, - False, # return_lse - )[0] + return_lse, + ) else: - out = self._wrapper.run( + out = self._cached_module.run( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._plan_info, q, *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, - PosEncodingMode[pos_encoding_mode].value, + _get_cache_alibi_slopes_buf(q.shape[1], q.device), + TensorLayout[self._kv_layout].value, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, - False, # return_lse - )[0] + return_lse, + ) if v_scale is not None: - out *= v_scale - return out + out[0] *= v_scale + + return out if return_lse else out[0] def forward_return_lse( self, @@ -713,110 +794,7 @@ def forward_return_lse( q, paged_kv_cache, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale ) - def run_return_lse( - self, - q: torch.Tensor, - paged_kv_cache: torch.Tensor, - q_scale: Optional[float] = None, - k_scale: Optional[float] = None, - v_scale: Optional[float] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Compute batch decode attention with paged kv cache, return attention output - and logsumexp of attention scores. - - Parameters - ---------- - q : torch.Tensor - The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]`` - paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - The paged KV-Cache stored as a tuple of tensors or a single tensor: - - * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape: - ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, - and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. - - * a single 5-D tensor with shape: - ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if - :attr:`kv_layout` is ``NHD``, and - ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if - :attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and - ``paged_kv_cache[:, 1]`` is the value-cache. - - q_scale : Optional[float] - The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. - k_scale : Optional[float] - The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. - v_scale : Optional[float] - The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. - - Returns - ------- - V : torch.Tensor - The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. - S : torch.Tensor - The logsumexp of attention scores, Shape: ``[batch_size, num_qo_heads]``. - - Note - ---- - Please refer to the :ref:`tutorial ` for a detailed - explanation of the log-sum-exp function and attention states. - """ - pos_encoding_mode = self._pos_encoding_mode - window_left = self._window_left - logits_soft_cap = self._logits_soft_cap - sm_scale = self._sm_scale - rope_scale = self._rope_scale - rope_theta = self._rope_theta - _check_pos_encoding_mode(pos_encoding_mode) - if logits_soft_cap is None: - logits_soft_cap = 0.0 - if sm_scale is None: - head_dim = q.shape[-1] - sm_scale = 1.0 / math.sqrt(head_dim) - if q_scale is not None: - sm_scale *= q_scale - if k_scale is not None: - sm_scale *= k_scale - if rope_scale is None: - rope_scale = 1.0 - if rope_theta is None: - rope_theta = 1e4 - if self.use_tensor_cores: - V, s = self._wrapper.run( - q, - self._qo_indptr_buf, - *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len_buf, - False, # causal - PosEncodingMode[pos_encoding_mode].value, - False, # allow_fp16_qk_reduction - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - True, # return_lse - ) - else: - V, s = self._wrapper.run( - q, - *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len_buf, - PosEncodingMode[pos_encoding_mode].value, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - True, # return_lse - ) - if v_scale is not None: - V *= v_scale - return V, s + run_return_lse = functools.partialmethod(run, return_lse=True) def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect.""" diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index 0bd2b373b..e7e7a0515 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -19,20 +19,30 @@ import torch from .utils import get_indptr +from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops from typing import Optional -# mypy: disable-error-code="attr-defined" -try: - from . import _kernels -except ImportError as e: - import logging - import os - if os.environ.get("BUILD_DOC", "0") == "1": - _kernels = None - logging.warning("Kernels are not loaded in documentation build mode.") - else: - raise e +_gemm_module = None + + +def get_gemm_module(): + global _gemm_module + if _gemm_module is None: + if has_prebuilt_ops: + from . import _kernels + + _gemm_module = _kernels + else: + _gemm_module = load_cuda_ops( + "gemm", + [ + FLASHINFER_CSRC_DIR / "group_gemm.cu", + FLASHINFER_CSRC_DIR / "bmm_fp8.cu", + FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu", + ], + ) + return _gemm_module class SegmentGEMMWrapper: @@ -96,9 +106,6 @@ def __init__(self, workspace_buffer: torch.Tensor) -> None: size is proportional to the number of segments (batch size), 1MB workspace is enough for most cases. """ self._workspace_buffer = workspace_buffer - self._wrapper = _kernels.CutlassSegmentGEMMPyTorchWrapper( - self._workspace_buffer - ) def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: r"""Reset the workspace buffer. @@ -110,7 +117,6 @@ def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: be the same as the device of the input tensors. """ self._workspace_buffer = new_workspace_buffer - self._wrapper.register_workspace_buffer(new_workspace_buffer) def run( self, @@ -187,7 +193,8 @@ def run( if weight_indices is None: # create an empty CPU tensor as placeholder weight_indices = torch.empty(0, dtype=torch.int64) - return self._wrapper.run( + return get_gemm_module().cutlass_segment_gemm( + self._workspace_buffer, seg_indptr, weight_indices, x, @@ -264,5 +271,5 @@ def bmm_fp8( device=A.device, dtype=dtype, ) - _kernels.bmm_fp8(A, B, out, A_scale, B_scale) + get_gemm_module().bmm_fp8(A, B, out, A_scale, B_scale) return out diff --git a/python/flashinfer/jit/__init__.py b/python/flashinfer/jit/__init__.py new file mode 100644 index 000000000..bbd85729c --- /dev/null +++ b/python/flashinfer/jit/__init__.py @@ -0,0 +1,151 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import re +import logging +import torch.utils.cpp_extension as torch_cpp_ext +from typing import List +from .env import ( + FLASHINFER_WORKSPACE_DIR, + FLASHINFER_JIT_DIR, + FLASHINFER_GEN_SRC_DIR, + FLASHINFER_INCLUDE_DIR, + FLASHINFER_CSRC_DIR, + CUTLASS_INCLUDE_DIR, +) +from .activation import get_act_and_mul_cu_str, gen_act_and_mul_cu +from .attention import ( + gen_single_decode_cu, + get_single_decode_uri, + gen_batch_decode_cu, + get_batch_decode_uri, + gen_single_prefill_cu, + get_single_prefill_uri, + gen_batch_prefill_cu, + get_batch_prefill_uri, +) + +try: + from .aot_config import prebuilt_ops_uri + + has_prebuilt_ops = True +except ImportError as e: + prebuilt_ops_uri = set() + has_prebuilt_ops = False + +if not os.path.exists(FLASHINFER_WORKSPACE_DIR): + os.makedirs(FLASHINFER_WORKSPACE_DIR) + + +class FlashInferJITLogger(logging.Logger): + def __init__(self, name): + super().__init__(name) + self.setLevel(logging.INFO) + self.addHandler(logging.StreamHandler()) + log_path = FLASHINFER_WORKSPACE_DIR / "flashinfer_jit.log" + if not os.path.exists(log_path): + # create an empty file + with open(log_path, "w") as f: + pass + self.addHandler(logging.FileHandler(log_path)) + # set the format of the log + self.handlers[0].setFormatter( + logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + ) + self.handlers[1].setFormatter( + logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + ) + + def info(self, msg): + super().info("flashinfer.jit: " + msg) + + +logger = FlashInferJITLogger("flashinfer.jit") + + +def check_cuda_arch(): + # cuda arch check for fp8 at the moment. + for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): + arch = int(re.search("compute_\d+", cuda_arch_flags).group()[-2:]) + if arch < 75: + raise RuntimeError("FlashInfer requires sm75+") + + +def clear_cache_dir(): + if os.path.exists(FLASHINFER_JIT_DIR): + for file in os.listdir(FLASHINFER_JIT_DIR): + os.remove(os.path.join(FLASHINFER_JIT_DIR, file)) + + +def remove_unwanted_pytorch_nvcc_flags(): + REMOVE_NVCC_FLAGS = [ + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + for flag in REMOVE_NVCC_FLAGS: + try: + torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) + except ValueError: + pass + + +remove_unwanted_pytorch_nvcc_flags() + + +def load_cuda_ops( + name: str, + sources: List[str], + extra_cflags: List[str] = ["-O3", "-Wno-switch-bool"], + extra_cuda_cflags: List[str] = [ + "-O3", + "-std=c++17", + "--threads", + "4", + # "-Xfatbin", + # "-compress-all", + "-use_fast_math", + "-DFLASHINFER_ENABLE_BF16", + "-DFLASHINFER_ENABLE_FP8", + ], + extra_ldflags=None, + extra_include_paths=None, + verbose=False, +): + logger.info(f"Loading JIT ops: {name}") + check_cuda_arch() + build_directory = FLASHINFER_JIT_DIR / name + if not os.path.exists(build_directory): + os.makedirs(build_directory) + if extra_include_paths is None: + extra_include_paths = [ + FLASHINFER_INCLUDE_DIR, + CUTLASS_INCLUDE_DIR, + FLASHINFER_CSRC_DIR, + ] + return torch_cpp_ext.load( + name, + list(map(lambda _: str(_), sources)), + extra_cflags=extra_cflags, + extra_cuda_cflags=extra_cuda_cflags, + extra_ldflags=extra_ldflags, + extra_include_paths=list(map(lambda _: str(_), extra_include_paths)), + build_directory=build_directory, + verbose=verbose, + with_cuda=True, + ) diff --git a/python/flashinfer/jit/activation.py b/python/flashinfer/jit/activation.py new file mode 100644 index 000000000..b07a964ea --- /dev/null +++ b/python/flashinfer/jit/activation.py @@ -0,0 +1,71 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import jinja2 +import os +from .env import FLASHINFER_GEN_SRC_DIR +from .utils import write_if_different + +activation_templ = r""" +#include +#include +#include +#include +#include "pytorch_extension_utils.h" + +{% set func_name = act_func_name ~ '_and_mul' %} + +using namespace flashinfer; + +{{ act_func_def }} + +void {{ func_name }}(torch::Tensor& out, torch::Tensor& input) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(std::min(d / vec_size, 1024U)); + flashinfer::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), + static_cast(input.data_ptr()), d); + + return true; + }); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("{{ func_name }}", &{{ func_name }}, "Fused {{ act_func_name }} and Mul"); +} +""" + + +def get_act_and_mul_cu_str(act_func_name: str, act_func_def: str) -> str: + template = jinja2.Template(activation_templ) + return template.render(act_func_name=act_func_name, act_func_def=act_func_def) + + +def gen_act_and_mul_cu(act_func_name: str, act_func_def: str) -> None: + gen_directory = FLASHINFER_GEN_SRC_DIR + if not os.path.exists(gen_directory): + os.makedirs(gen_directory) + write_if_different( + gen_directory / f"{act_func_name}_and_mul.cu", + get_act_and_mul_cu_str(act_func_name, act_func_def), + ) diff --git a/python/flashinfer/jit/attention.py b/python/flashinfer/jit/attention.py new file mode 100644 index 000000000..6a5bb3941 --- /dev/null +++ b/python/flashinfer/jit/attention.py @@ -0,0 +1,263 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import jinja2 +import os +from .env import FLASHINFER_GEN_SRC_DIR +from .utils import ( + write_if_different, + dtype_map, + filename_safe_dtype_map, + pos_encoding_mode_literal, + mask_mode_literal, +) +from .single_decode_templ import single_decode_templ +from .batch_decode_templ import batch_decode_templ +from .single_prefill_templ import single_prefill_templ +from .batch_prefill_templ import batch_prefill_templ + + +def get_single_decode_cu_str( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + head_dim: int, + pos_encoding_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, +) -> str: + template = jinja2.Template(single_decode_templ) + return template.render( + dtype_q=dtype_map[dtype_q], + dtype_kv=dtype_map[dtype_kv], + dtype_o=dtype_map[dtype_o], + head_dim=head_dim, + pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], + use_sliding_window="true" if use_sliding_window else "false", + use_logits_soft_cap="true" if use_logits_soft_cap else "false", + ) + + +def get_single_decode_uri( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + head_dim: int, + pos_encoding_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, +) -> str: + return ( + f"single_decode_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" + f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" + f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}" + ) + + +def gen_single_decode_cu(*args) -> None: + gen_directory = FLASHINFER_GEN_SRC_DIR + if not os.path.exists(gen_directory): + os.makedirs(gen_directory) + file_name = f"{get_single_decode_uri(*args)}.cu" + write_if_different( + gen_directory / file_name, + get_single_decode_cu_str(*args), + ) + + +def get_batch_decode_cu_str( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim: int, + pos_encoding_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, +) -> str: + template = jinja2.Template(batch_decode_templ) + return template.render( + dtype_q=dtype_map[dtype_q], + dtype_kv=dtype_map[dtype_kv], + dtype_o=dtype_map[dtype_o], + dtype_idx=dtype_map[dtype_idx], + head_dim=head_dim, + pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], + use_sliding_window="true" if use_sliding_window else "false", + use_logits_soft_cap="true" if use_logits_soft_cap else "false", + ) + + +def get_batch_decode_uri( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim: int, + pos_encoding_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, +) -> str: + return ( + f"batch_decode_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" + f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" + f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" + f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}" + ) + + +def gen_batch_decode_cu(*args) -> None: + gen_directory = FLASHINFER_GEN_SRC_DIR + if not os.path.exists(gen_directory): + os.makedirs(gen_directory) + file_name = f"{get_batch_decode_uri(*args)}.cu" + write_if_different( + gen_directory / file_name, + get_batch_decode_cu_str(*args), + ) + + +def get_single_prefill_cu_str( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + head_dim: int, + pos_encoding_mode: int, + mask_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, + use_fp16_qk_reduction: bool, +) -> str: + template = jinja2.Template(single_prefill_templ) + return template.render( + dtype_q=dtype_map[dtype_q], + dtype_kv=dtype_map[dtype_kv], + dtype_o=dtype_map[dtype_o], + head_dim=head_dim, + pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], + mask_mode=mask_mode_literal[mask_mode], + use_sliding_window="true" if use_sliding_window else "false", + use_logits_soft_cap="true" if use_logits_soft_cap else "false", + use_fp16_qk_reduction="true" if use_fp16_qk_reduction else "false", + ) + + +def get_single_prefill_uri( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + head_dim: int, + pos_encoding_mode: int, + mask_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, + use_fp16_qk_reduction: bool, +) -> str: + return ( + f"single_prefill_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" + f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" + f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"mask_{mask_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}_" + f"f16qk_{use_fp16_qk_reduction}" + ) + + +def gen_single_prefill_cu(*args) -> None: + gen_directory = FLASHINFER_GEN_SRC_DIR + if not os.path.exists(gen_directory): + os.makedirs(gen_directory) + file_name = f"{get_single_prefill_uri(*args)}.cu" + write_if_different( + gen_directory / file_name, + get_single_prefill_cu_str(*args), + ) + + +def get_batch_prefill_cu_str( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim: int, + pos_encoding_mode: int, + mask_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, + use_fp16_qk_reduction: bool, +) -> str: + template = jinja2.Template(batch_prefill_templ) + return template.render( + dtype_q=dtype_map[dtype_q], + dtype_kv=dtype_map[dtype_kv], + dtype_o=dtype_map[dtype_o], + dtype_idx=dtype_map[dtype_idx], + head_dim=head_dim, + pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], + mask_mode=mask_mode_literal[mask_mode], + use_sliding_window="true" if use_sliding_window else "false", + use_logits_soft_cap="true" if use_logits_soft_cap else "false", + use_fp16_qk_reduction="true" if use_fp16_qk_reduction else "false", + ) + + +def get_batch_prefill_uri( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim: int, + pos_encoding_mode: int, + mask_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, + use_fp16_qk_reduction: bool, +) -> str: + return ( + f"batch_prefill_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" + f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" + f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" + f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" + f"head_dim_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"mask_{mask_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}_" + f"f16qk_{use_fp16_qk_reduction}" + ) + + +def gen_batch_prefill_cu(*args) -> None: + gen_directory = FLASHINFER_GEN_SRC_DIR + if not os.path.exists(gen_directory): + os.makedirs(gen_directory) + file_name = f"{get_batch_prefill_uri(*args)}.cu" + write_if_different( + gen_directory / file_name, + get_batch_prefill_cu_str(*args), + ) diff --git a/python/flashinfer/jit/batch_decode_templ.py b/python/flashinfer/jit/batch_decode_templ.py new file mode 100644 index 000000000..5ae587370 --- /dev/null +++ b/python/flashinfer/jit/batch_decode_templ.py @@ -0,0 +1,165 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +batch_decode_templ = r""" +#include +#include +#include +#include +#include +#include +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} +using ParamsT = BatchDecodeParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; +using AttentionVariant = ComposedAttention; + +std::vector BatchDecodeWithPagedKVCachePlan( + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor page_locked_int_workspace_buffer, + torch::Tensor indptr, + unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int page_size, + bool enable_cuda_graph) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + auto device = float_workspace_buffer.device(); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + indptr = indptr.to(torch::kCPU); + + DecodePlanInfo plan_info; + + cudaError_t status = DecodePlan<{{ head_dim }}, {{ pos_encoding_mode }}, AttentionVariant>( + static_cast(float_workspace_buffer.data_ptr()), + float_workspace_size_in_bytes, + static_cast(int_workspace_buffer.data_ptr()), + static_cast(page_locked_int_workspace_buffer.data_ptr()), + int_workspace_size_in_bytes, + plan_info, + static_cast<{{ dtype_idx }}*>(indptr.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, page_size, enable_cuda_graph, /*stream=*/torch_current_stream); + + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + + return plan_info.ToVector(); +} + +std::vector BatchDecodeWithPagedKVCacheRun( + torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, + std::vector plan_info_vec, + torch::Tensor q, std::optional paged_kv_cache, + std::optional paged_k_cache, std::optional paged_v_cache, + torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, + std::optional alibi_slopes, + unsigned int kv_layout_code, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + DecodePlanInfo plan_info; + plan_info.FromVector(plan_info_vec); + QKVLayout kv_layout = static_cast(kv_layout_code); + bool paged_kv_defined = paged_kv_cache.has_value(); + auto device = q.device(); + int64_t batch_size = q.size(0); + int64_t num_qo_heads = q.size(1); + int64_t num_kv_heads, page_size; + if (paged_kv_defined) { + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_kv_cache->size(2); + page_size = paged_kv_cache->size(3); + } else { + page_size = paged_kv_cache->size(2); + num_kv_heads = paged_kv_cache->size(3); + } + } else { + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache->size(1); + page_size = paged_k_cache->size(2); + } else { + page_size = paged_k_cache->size(1); + num_kv_heads = paged_k_cache->size(2); + } + } + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + torch::Tensor o = torch::empty_like(q); + torch::Tensor lse; + if (return_lse) { + lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32))); + } + + TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); + + void* float_buffer = static_cast(float_workspace_buffer.data_ptr()); + void* int_buffer = static_cast(int_workspace_buffer.data_ptr()); + + paged_kv_t<{{ dtype_kv }}, {{ dtype_idx }}> paged_kv( + num_kv_heads, page_size, {{ head_dim }}, + batch_size, kv_layout, + static_cast<{{ dtype_kv }}*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() + : nullptr), + static_cast<{{ dtype_kv }} *>(paged_k_cache.has_value() ? paged_k_cache->data_ptr() + : nullptr), + static_cast<{{ dtype_kv }}*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr() + : nullptr), + static_cast<{{ dtype_idx }}*>(paged_kv_indices.data_ptr()), + static_cast<{{ dtype_idx }}*>(paged_kv_indptr.data_ptr()), + static_cast<{{ dtype_idx }}*>(paged_kv_last_page_len.data_ptr())); + ParamsT params( + static_cast<{{ dtype_q }}*>(q.data_ptr()), + /*q_offset=*/nullptr, paged_kv, static_cast<{{ dtype_o }}*>(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + {% if use_alibi == "true" %}static_cast(alibi_slopes->data_ptr()){% else %}nullptr{% endif %}, + num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); + + {{ dtype_o }}* tmp_v = nullptr; + float* tmp_s = nullptr; + params.request_indices = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer, plan_info.request_indices_offset); + params.kv_tile_indices = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer, plan_info.kv_tile_indices_offset); + params.o_indptr = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer, plan_info.o_indptr_offset); + params.kv_chunk_size_ptr = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer, plan_info.kv_chunk_size_ptr_offset); + if (plan_info.split_kv) { + tmp_v = GetPtrFromBaseOffset<{{ dtype_o }}>(float_buffer, plan_info.v_offset); + tmp_s = GetPtrFromBaseOffset(float_buffer, plan_info.s_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = GetPtrFromBaseOffset(int_buffer, plan_info.block_valid_mask_offset); + } + } + params.padded_batch_size = plan_info.padded_batch_size; + + cudaError_t status = BatchDecodeWithPagedKVCacheDispatched< + {{ head_dim }}, {{ pos_encoding_mode }}, AttentionVariant>( + params, tmp_v, tmp_s, /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("plan", &BatchDecodeWithPagedKVCachePlan); + m.def("run", &BatchDecodeWithPagedKVCacheRun); +} +""" diff --git a/python/flashinfer/jit/batch_prefill_templ.py b/python/flashinfer/jit/batch_prefill_templ.py new file mode 100644 index 000000000..a6b0c0ac6 --- /dev/null +++ b/python/flashinfer/jit/batch_prefill_templ.py @@ -0,0 +1,278 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +batch_prefill_templ = r""" +#include +#include +#include +#include +#include +#include +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +{% set use_custom_mask = "true" if mask_mode == "MaskMode::kCustom" else "false" %} +{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} +using RaggedParamsT = BatchPrefillRaggedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; +using RaggedAttentionVariant = ComposedAttention; +using PagedParamsT = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ dtype_idx }}>; +using PagedAttentionVariant = ComposedAttention; + +std::vector BatchPrefillWithKVCachePlan( + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor page_locked_int_workspace_buffer, + torch::Tensor qo_indptr, + torch::Tensor kv_indptr, + unsigned int batch_size, + unsigned int num_qo_heads, + unsigned int num_kv_heads, + unsigned int page_size, + bool enable_cuda_graph) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + + auto device = float_workspace_buffer.device(); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + qo_indptr = qo_indptr.to(torch::kCPU); + kv_indptr = kv_indptr.to(torch::kCPU); + + PrefillPlanInfo plan_info; + + cudaError_t status = PrefillPlan<{{ dtype_idx }}>( + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, + plan_info, qo_indptr.data_ptr<{{ dtype_idx }}>(), kv_indptr.data_ptr<{{ dtype_idx }}>(), + batch_size, num_qo_heads, num_kv_heads, {{ head_dim }}, page_size, enable_cuda_graph, + sizeof({{ dtype_o }}), torch_current_stream); + + TORCH_CHECK(status == cudaSuccess, "Failed to plan prefill with error: ", cudaGetErrorString(status)); + + return plan_info.ToVector(); +} + +std::vector BatchPrefillWithRaggedKVCacheRun( + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + std::vector plan_info_vec, + torch::Tensor q, torch::Tensor k, torch::Tensor v, + std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, + torch::Tensor qo_indptr, torch::Tensor kv_indptr, + std::optional maybe_qk_indptr, + unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse) { + PrefillPlanInfo plan_info; + plan_info.FromVector(plan_info_vec); + QKVLayout kv_layout = static_cast(layout); + + int64_t num_qo_heads = q.size(1); + int64_t head_dim = q.size(2); + int64_t num_kv_heads = (kv_layout == QKVLayout::kNHD) ? k.size(1) : k.size(0); + uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; + if (kv_layout == QKVLayout::kNHD) { + kv_stride_n = k.stride(0); + kv_stride_h = k.stride(1); + } else { + kv_stride_h = k.stride(0); + kv_stride_n = k.stride(1); + } + + auto device = float_workspace_buffer.device(); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + auto o = torch::empty_like(q, q.options()); + int64_t nnz_qo = q.size(0); + torch::Tensor lse = torch::empty({0}); + if (return_lse) { + lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); + } + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + RaggedParamsT params( + static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()), + static_cast<{{ dtype_kv }}*>(v.data_ptr()), + {% if mask_mode == "MaskMode::kCustom" %}static_cast(maybe_custom_mask->data_ptr()){% else %}nullptr{% endif %}, + static_cast<{{ dtype_idx }}*>(qo_indptr.data_ptr()), + static_cast<{{ dtype_idx }}*>(kv_indptr.data_ptr()), + {% if mask_mode == "MaskMode::kCustom" %}static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()){% else %}nullptr{% endif %}, + /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, + static_cast<{{ dtype_o }}*>(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + {% if use_alibi == "true" %}static_cast(maybe_alibi_slopes->data_ptr()){% else %}nullptr{% endif %}, + num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, + window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); + + {{ dtype_o }}* tmp_v = nullptr; + float* tmp_s = nullptr; + + params.request_indices = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer_ptr, plan_info.request_indices_offset); + params.qo_tile_indices = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.kv_tile_indices = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer_ptr, plan_info.kv_tile_indices_offset); + params.o_indptr = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer_ptr, plan_info.o_indptr_offset); + params.kv_chunk_size_ptr = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset); + if (plan_info.split_kv) { + params.merge_indptr = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer_ptr, plan_info.merge_indptr_offset); + tmp_v = GetPtrFromBaseOffset<{{ dtype_o }}>(float_buffer_ptr, plan_info.v_offset); + tmp_s = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.s_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); + } + } + params.total_num_rows = plan_info.total_num_rows; + params.padded_batch_size = plan_info.padded_batch_size; + + WarpLayout warp_layout = WarpLayout(plan_info.warp_layout_code); + cudaError_t status = cudaSuccess; + + DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { + status = BatchPrefillWithRaggedKVCacheDispatched< + WARP_LAYOUT, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, RaggedAttentionVariant>( + params, tmp_v, tmp_s, torch_current_stream); + }); + + TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", cudaGetErrorString(status)); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} + +std::vector BatchPrefillWithPagedKVCacheRun( + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + std::vector plan_info_vec, + torch::Tensor q, + std::optional paged_kv_cache, + std::optional paged_k_cache, + std::optional paged_v_cache, + std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, + torch::Tensor qo_indptr, + torch::Tensor paged_kv_indptr, + torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, + std::optional maybe_qk_indptr, + unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse) { + PrefillPlanInfo plan_info; + plan_info.FromVector(plan_info_vec); + QKVLayout kv_layout = static_cast(layout); + bool paged_kv_defined = paged_kv_cache.has_value(); + auto device = q.device(); + int64_t batch_size = paged_kv_indptr.size(0) - 1; + int64_t num_qo_heads = q.size(1); + int64_t num_kv_heads, page_size; + if (paged_kv_defined) { + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_kv_cache->size(2); + page_size = paged_kv_cache->size(3); + } else { + page_size = paged_kv_cache->size(2); + num_kv_heads = paged_kv_cache->size(3); + } + } else { + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache->size(1); + page_size = paged_k_cache->size(2); + } else { + page_size = paged_k_cache->size(1); + num_kv_heads = paged_k_cache->size(2); + } + } + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + auto o = torch::empty_like(q, q.options()); + int64_t nnz_qo = q.size(0); + torch::Tensor lse = torch::empty({0}); + if (return_lse) { + lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32)); + } + + void* float_buffer_ptr = static_cast(float_workspace_buffer.data_ptr()); + void* int_buffer_ptr = static_cast(int_workspace_buffer.data_ptr()); + + paged_kv_t<{{ dtype_kv }}, {{ dtype_idx }}> paged_kv( + num_kv_heads, page_size, {{ head_dim }}, + batch_size, kv_layout, + static_cast<{{ dtype_kv }}*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() + : nullptr), + static_cast<{{ dtype_kv }} *>(paged_k_cache.has_value() ? paged_k_cache->data_ptr() + : nullptr), + static_cast<{{ dtype_kv }}*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr() + : nullptr), + static_cast<{{ dtype_idx }}*>(paged_kv_indices.data_ptr()), + static_cast<{{ dtype_idx }}*>(paged_kv_indptr.data_ptr()), + static_cast<{{ dtype_idx }}*>(paged_kv_last_page_len.data_ptr())); + + PagedParamsT params( + static_cast<{{ dtype_q }}*>(q.data_ptr()), paged_kv, + {% if mask_mode == "MaskMode::kCustom" %}static_cast(maybe_custom_mask->data_ptr()){% else %}nullptr{% endif %}, + static_cast<{{ dtype_idx }}*>(qo_indptr.data_ptr()), + {% if mask_mode == "MaskMode::kCustom" %}static_cast<{{ dtype_idx }}*>(maybe_qk_indptr->data_ptr()){% else %}nullptr{% endif %}, + /*q_offset=*/nullptr, + static_cast<{{ dtype_o }}*>(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + {% if use_alibi == "true" %}static_cast(maybe_alibi_slopes->data_ptr()){% else %}nullptr{% endif %}, + num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); + + {{ dtype_o }}* tmp_v = nullptr; + float* tmp_s = nullptr; + + params.request_indices = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer_ptr, plan_info.request_indices_offset); + params.qo_tile_indices = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.kv_tile_indices = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer_ptr, plan_info.kv_tile_indices_offset); + params.o_indptr = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer_ptr, plan_info.o_indptr_offset); + params.kv_chunk_size_ptr = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset); + if (plan_info.split_kv) { + params.merge_indptr = GetPtrFromBaseOffset<{{ dtype_idx }}>(int_buffer_ptr, plan_info.merge_indptr_offset); + tmp_v = GetPtrFromBaseOffset<{{ dtype_o }}>(float_buffer_ptr, plan_info.v_offset); + tmp_s = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.s_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); + } + } + params.total_num_rows = plan_info.total_num_rows; + params.padded_batch_size = plan_info.padded_batch_size; + + WarpLayout warp_layout = WarpLayout(plan_info.warp_layout_code); + cudaError_t status = cudaSuccess; + + DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { + status = BatchPrefillWithPagedKVCacheDispatched< + WARP_LAYOUT, {{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, PagedAttentionVariant>( + params, tmp_v, tmp_s, torch_current_stream); + }); + + TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status)); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("plan", &BatchPrefillWithKVCachePlan); + m.def("ragged_run", &BatchPrefillWithRaggedKVCacheRun); + m.def("paged_run", &BatchPrefillWithPagedKVCacheRun); +} +""" diff --git a/python/flashinfer/jit/env.py b/python/flashinfer/jit/env.py new file mode 100644 index 000000000..65b47eed0 --- /dev/null +++ b/python/flashinfer/jit/env.py @@ -0,0 +1,26 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pathlib + +# use pathlib +FLASHINFER_WORKSPACE_DIR = pathlib.Path.home() / ".flashinfer" +FLASHINFER_JIT_DIR = FLASHINFER_WORKSPACE_DIR / "cached_ops" +FLASHINFER_GEN_SRC_DIR = FLASHINFER_WORKSPACE_DIR / "generated" +_project_root = pathlib.Path(__file__).resolve().parent.parent.parent +FLASHINFER_INCLUDE_DIR = _project_root / "include" +FLASHINFER_CSRC_DIR = _project_root / "csrc" +CUTLASS_INCLUDE_DIR = _project_root / "3rdparty" / "cutlass" / "include" diff --git a/python/flashinfer/jit/single_decode_templ.py b/python/flashinfer/jit/single_decode_templ.py new file mode 100644 index 000000000..79619be32 --- /dev/null +++ b/python/flashinfer/jit/single_decode_templ.py @@ -0,0 +1,68 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +single_decode_templ = r""" +#include +#include +#include +#include +#include +#include "pytorch_extension_utils.h" + +{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} +torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, + torch::Tensor tmp, std::optional alibi_slopes, + unsigned int layout, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta) { + auto device = q.device(); + unsigned int num_qo_heads = q.size(0); + unsigned int head_dim = q.size(1); + unsigned int kv_len, num_kv_heads; + QKVLayout kv_layout = static_cast(layout); + if (kv_layout == QKVLayout::kNHD) { + kv_len = k.size(0); + num_kv_heads = k.size(1); + } else { + num_kv_heads = k.size(0); + kv_len = k.size(1); + } + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + auto o = torch::empty_like(q); + + using ParamsT = SingleDecodeParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>; + using AttentionVariant = ComposedAttention; + ParamsT params( + static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()), + static_cast<{{ dtype_kv }}*>(v.data_ptr()), static_cast<{{ dtype_o }}*>(o.data_ptr()), + {% if use_alibi == "true" %}static_cast(alibi_slopes->data_ptr()){% else %}nullptr{% endif %}, + kv_len, num_qo_heads, num_kv_heads, kv_layout, head_dim, window_left, + logits_soft_cap, sm_scale, rope_scale, rope_theta); + + cudaError_t status = SingleDecodeWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, AttentionVariant>( + params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SingleDecodeWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + + return o; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", &single_decode_with_kv_cache, + "Single-request decode with KV-Cache operator"); +} +""" diff --git a/python/flashinfer/jit/single_prefill_templ.py b/python/flashinfer/jit/single_prefill_templ.py new file mode 100644 index 000000000..297395168 --- /dev/null +++ b/python/flashinfer/jit/single_prefill_templ.py @@ -0,0 +1,90 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +single_prefill_templ = r""" +#include +#include +#include +#include +#include +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +{% set use_custom_mask = "true" if mask_mode == "MaskMode::kCustom" else "false" %} +{% set use_alibi = "true" if pos_encoding_mode == "PosEncodingMode::kALiBi" else "false" %} +using ParamsT = SinglePrefillParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>; +using AttentionVariant = ComposedAttention; + +std::vector single_prefill_with_kv_cache( + torch::Tensor q, torch::Tensor k, torch::Tensor v, std::optional maybe_packed_custom_mask, + torch::Tensor tmp, std::optional maybe_alibi_slopes, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse) { + auto device = q.device(); + unsigned int head_dim = q.size(2); + unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; + QKVLayout kv_layout = static_cast(layout); + qo_len = q.size(0); + num_qo_heads = q.size(1); + uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; + if (kv_layout == QKVLayout::kNHD) { + kv_len = k.size(0); + num_kv_heads = k.size(1); + kv_stride_n = k.stride(0); + kv_stride_h = k.stride(1); + } else { + kv_len = k.size(1); + num_kv_heads = k.size(0); + kv_stride_h = k.stride(0); + kv_stride_n = k.stride(1); + } + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + auto o = torch::empty_like(q, q.options()); + torch::Tensor lse = torch::empty({0}); + if (return_lse) { + lse = torch::empty({qo_len, num_qo_heads}, q.options().dtype(torch::kFloat32)); + } + + ParamsT params( + static_cast<{{ dtype_q }}*>(q.data_ptr()), static_cast<{{ dtype_kv }}*>(k.data_ptr()), + static_cast<{{ dtype_kv }}*>(v.data_ptr()), + {% if mask_mode == "MaskMode::kCustom" %}static_cast(maybe_packed_custom_mask->data_ptr()){% else %}nullptr{% endif %}, + static_cast<{{ dtype_o }}*>(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + {% if use_alibi == "true" %}static_cast(maybe_alibi_slopes->data_ptr()){% else %}nullptr{% endif %}, + num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h, + kv_stride_n, kv_stride_h, head_dim, window_left, logits_soft_cap, sm_scale, + rope_scale, rope_theta); + + cudaError_t status = + SinglePrefillWithKVCacheDispatched<{{ head_dim }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, AttentionVariant>( + params, static_cast<{{ dtype_o }}*>(tmp.data_ptr()), torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SinglePrefillWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + + if (return_lse) { + return {o, lse}; + } else { + return {o}; + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", &single_prefill_with_kv_cache, + "Single-request prefill attention with KV-Cache operator"); +} +""" diff --git a/python/flashinfer/jit/utils.py b/python/flashinfer/jit/utils.py new file mode 100644 index 000000000..f8671702d --- /dev/null +++ b/python/flashinfer/jit/utils.py @@ -0,0 +1,66 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pathlib +import torch + + +def write_if_different(path: pathlib.Path, content: str) -> None: + if path.exists(): + with open(path, "r") as f: + if f.read() == content: + return + with open(path, "w") as f: + f.write(content) + + +dtype_map = { + torch.float16: "half", + torch.bfloat16: "nv_bfloat16", + torch.float8_e4m3fn: "__nv_fp8_e4m3", + torch.float8_e5m2: "__nv_fp8_e5m2", + torch.int8: "int8_t", + torch.uint8: "uint8_t", + torch.int32: "int32_t", + torch.uint32: "uint32_t", + torch.int64: "int64_t", + torch.uint64: "uint64_t", +} + +filename_safe_dtype_map = { + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float8_e4m3fn: "e4m3", + torch.float8_e5m2: "e5m2", + torch.int8: "i8", + torch.uint8: "u8", + torch.int32: "i32", + torch.uint32: "u32", + torch.int64: "i64", + torch.uint64: "u64", +} + +pos_encoding_mode_literal = { + 0: "PosEncodingMode::kNone", + 1: "PosEncodingMode::kRoPELlama", + 2: "PosEncodingMode::kALiBi", +} + +mask_mode_literal = { + 0: "MaskMode::kNone", + 1: "MaskMode::kCausal", + 2: "MaskMode::kCustom", +} diff --git a/python/flashinfer/norm.py b/python/flashinfer/norm.py index 6a504e030..742e44e51 100644 --- a/python/flashinfer/norm.py +++ b/python/flashinfer/norm.py @@ -16,18 +16,27 @@ import torch -# mypy: disable-error-code="attr-defined" -try: - from . import _kernels -except ImportError as e: - import logging - import os +from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops - if os.environ.get("BUILD_DOC", "0") == "1": - _kernels = None - logging.warning("Kernels are not loaded in documentation build mode.") - else: - raise e +_norm_module = None + + +def get_norm_module(): + global _norm_module + if _norm_module is None: + if has_prebuilt_ops: + from . import _kernels + + _norm_module = _kernels + else: + _norm_module = load_cuda_ops( + "norm", + [ + FLASHINFER_CSRC_DIR / "norm.cu", + FLASHINFER_CSRC_DIR / "flashinfer_norm_ops.cu", + ], + ) + return _norm_module def rmsnorm( @@ -56,7 +65,7 @@ def rmsnorm( """ if out is None: out = torch.empty_like(input) - _kernels.rmsnorm(out, input, weight, eps) + get_norm_module().rmsnorm(out, input, weight, eps) return out @@ -76,7 +85,7 @@ def fused_add_rmsnorm( eps: float Epsilon for numerical stability. """ - _kernels.fused_add_rmsnorm(input, residual, weight, eps) + get_norm_module().fused_add_rmsnorm(input, residual, weight, eps) def gemma_rmsnorm( @@ -105,7 +114,7 @@ def gemma_rmsnorm( """ if out is None: out = torch.empty_like(input) - _kernels.gemma_rmsnorm(out, input, weight, eps) + get_norm_module().gemma_rmsnorm(out, input, weight, eps) return out @@ -125,4 +134,4 @@ def gemma_fused_add_rmsnorm( eps: float Epsilon for numerical stability. """ - _kernels.gemma_fused_add_rmsnorm(input, residual, weight, eps) + get_norm_module().gemma_fused_add_rmsnorm(input, residual, weight, eps) diff --git a/python/flashinfer/page.py b/python/flashinfer/page.py index 101e18ad0..0a649f8b4 100644 --- a/python/flashinfer/page.py +++ b/python/flashinfer/page.py @@ -15,21 +15,29 @@ """ import torch +from .utils import TensorLayout, _check_kv_layout, _unpack_paged_kv_cache +from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops -# mypy: disable-error-code="attr-defined" -try: - from . import _kernels -except ImportError as e: - import os - import logging - if os.environ.get("BUILD_DOC", "0") == "1": - _kernels = None - logging.warning("Kernels are not loaded in documentation build mode.") - else: - raise e +_page_module = None -from .utils import _check_kv_layout, TensorLayout, _unpack_paged_kv_cache + +def get_page_module(): + global _page_module + if _page_module is None: + if has_prebuilt_ops: + from . import _kernels + + _page_module = _kernels + else: + _page_module = load_cuda_ops( + "page", + [ + FLASHINFER_CSRC_DIR / "page.cu", + FLASHINFER_CSRC_DIR / "flashinfer_page_ops.cu", + ], + ) + return _page_module def append_paged_kv_cache( @@ -127,7 +135,7 @@ def append_paged_kv_cache( incorporated appended k/v. """ _check_kv_layout(kv_layout) - _kernels.append_paged_kv_cache( + get_page_module().append_paged_kv_cache( append_key, append_value, append_indptr, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 180621fdb..e9d49cd20 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -16,43 +16,117 @@ import math from typing import Optional, Dict, Tuple, Union +from types import SimpleNamespace +import functools import torch import logging -# mypy: disable-error-code="attr-defined" -try: - from . import _prefill -except ImportError as e: - import os - import logging - - if os.environ.get("BUILD_DOC", "0") == "1": - _prefill = None - logging.warning("Kernels are not loaded in documentation build mode.") - else: - raise e - +from .jit import ( + load_cuda_ops, + FLASHINFER_GEN_SRC_DIR, + gen_single_prefill_cu, + get_single_prefill_uri, + gen_batch_prefill_cu, + get_batch_prefill_uri, + has_prebuilt_ops, + prebuilt_ops_uri, +) from .utils import ( PosEncodingMode, + MaskMode, TensorLayout, _check_pos_encoding_mode, _check_kv_layout, _unpack_paged_kv_cache, is_float8, + canonicalize_torch_dtype, + _get_cache_buf, + _get_cache_alibi_slopes_buf, ) from .quantization import packbits, segment_packbits -_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {} +if has_prebuilt_ops: + from . import _prefill_kernels + + +def compile_single_prefill_module( + *args, + verbose: bool = False, +): + gen_single_prefill_cu(*args) + uri = get_single_prefill_uri(*args) + return load_cuda_ops( + uri, + [FLASHINFER_GEN_SRC_DIR / f"{uri}.cu"], + verbose=verbose, + ) + + +def compile_batch_prefill_module( + *args, + verbose: bool = False, +): + gen_batch_prefill_cu(*args) + uri = get_batch_prefill_uri(*args) + return load_cuda_ops( + uri, + [FLASHINFER_GEN_SRC_DIR / f"{uri}.cu"], + verbose=verbose, + ) -def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: - key = (name, device) - buf = _cache_buf.get(key) - if buf is None: - buf = torch.empty(bytes, dtype=torch.uint8, device=device) - _cache_buf[key] = buf - return buf +_single_prefill_modules = {} +_batch_prefill_modules = {} + + +def get_single_prefill_module(*args): + global _single_prefill_modules + if args not in _single_prefill_modules: + if has_prebuilt_ops and get_single_prefill_uri(*args) in prebuilt_ops_uri: + # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later + mask_mode = args[5] + run_func = lambda *run_args: _prefill_kernels.single_prefill_with_kv_cache( + mask_mode, + *run_args, + ) + _single_prefill_modules[args] = SimpleNamespace( + run=run_func, + ) + else: + _single_prefill_modules[args] = compile_single_prefill_module(*args) + return _single_prefill_modules[args] + + +def get_batch_prefill_module(*args): + global _batch_prefill_modules + if args not in _batch_prefill_modules: + if has_prebuilt_ops and get_batch_prefill_uri(*args) in prebuilt_ops_uri: + # NOTE(Zihao): we should avoid hard-coded index like this, refactor it later + head_dim = args[4] + plan_func = ( + lambda *plan_args: _prefill_kernels.batch_prefill_with_kv_cache_plan( + head_dim, + *plan_args, + ) + ) + mask_mode = args[6] + ragged_run_func = lambda *run_args: _prefill_kernels.batch_prefill_with_ragged_kv_cache_run( + mask_mode, + *run_args, + ) + paged_run_func = lambda *run_args: _prefill_kernels.batch_prefill_with_paged_kv_cache_run( + mask_mode, + *run_args, + ) + _batch_prefill_modules[args] = SimpleNamespace( + plan=plan_func, + ragged_run=ragged_run_func, + paged_run=paged_run_func, + ) + else: + _batch_prefill_modules[args] = compile_batch_prefill_module(*args) + return _batch_prefill_modules[args] def single_prefill_with_kv_cache( @@ -70,7 +144,8 @@ def single_prefill_with_kv_cache( logits_soft_cap: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, -) -> torch.Tensor: + return_lse: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Prefill/Append attention with KV cache for single request, return the attention output. @@ -125,11 +200,17 @@ def single_prefill_with_kv_cache( The scale used in RoPE interpolation, if not provided, will be set to 1.0. rope_theta : Optional[float] The theta used in RoPE, if not provided, will be set to 1e4. + return_lse : bool + Whether to return the log sum exp value of the attention logits. Returns ------- - torch.Tensor - The attention output, shape: ``[qo_len, num_qo_heads, head_dim]``. + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_len, num_qo_heads, head_dim]``. + If :attr:`return_lse` is ``True``, a tuple of two tensors: + + * The attention output, shape: ``[qo_len, num_qo_heads, head_dim]``. + * The log sum exp value, shape: ``[qo_len, num_qo_heads]``. Examples -------- @@ -186,225 +267,47 @@ def single_prefill_with_kv_cache( packed_custom_mask = packbits( custom_mask.contiguous().view(-1), bitorder="little" ) + if packed_custom_mask is not None: - return _prefill.single_prefill_with_kv_cache_custom_mask( - q, - k, - v, - packed_custom_mask, - tmp, - TensorLayout[kv_layout].value, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - False, # return lse - )[0] + mask_mode = MaskMode.CUSTOM.value else: - return _prefill.single_prefill_with_kv_cache( - q, - k, - v, - tmp, - causal, - TensorLayout[kv_layout].value, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - False, # return lse - )[0] - - -def single_prefill_with_kv_cache_return_lse( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - custom_mask: Optional[torch.Tensor] = None, - packed_custom_mask: Optional[torch.Tensor] = None, - causal: bool = False, - kv_layout: str = "NHD", - pos_encoding_mode: str = "NONE", - allow_fp16_qk_reduction: bool = False, - window_left: int = -1, - logits_soft_cap: Optional[float] = None, - sm_scale: Optional[float] = None, - rope_scale: Optional[float] = None, - rope_theta: Optional[float] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Prefill/Append attention with KV cache for single request, return attention - output and logsumexp of attention scores. - - Parameters - ---------- - q : torch.Tensor - The query tensor, shape: ``[qo_len, num_qo_heads, head_dim]``. - k : torch.Tensor - The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` - is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is - ``HND``. - v : torch.Tensor - The key tensor, shape: ``[kv_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` - is ``NHD``, or ``[num_kv_heads, kv_len, head_dim]`` if :attr:`kv_layout` is - ``HND``. - custom_mask : Optional[torch.Tensor] - The custom bool mask tensor, shape: ``[qo_len, kv_len]``. - The elements in the mask tensor should be either ``True`` or ``False``, - where ``False`` means the corresponding element in the attention matrix will be - masked out. - - When :attr:`custom_mask` is provided, and :attr:`packed_custom_mask` is not, the - function will pack the custom mask tensor into a 1D packed mask tensor, which introduces - additional overhead. - packed_custom_mask : Optional[torch.Tensor] - The 1D packed mask tensor, if provided, the :attr:`custom_mask` will be ignored. - The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. - causal : bool - Whether to apply causal mask to the attention matrix. - This is only effective when :attr:`custom_mask` is not provided. - kv_layout : str - The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. - pos_encoding_mode : str - The position encoding applied inside attention kernels, could be - ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. - Default is ``NONE``. - allow_fp16_qk_reduction : bool - Whether to use f16 for qk reduction (faster at the cost of slight precision - loss). - window_left : int - The left (inclusive) window size for the attention window, when set to ``-1``, the window - size will be set to the full length of the sequence. Defaults to ``-1``. - logits_soft_cap : Optional[float] - The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not - provided, will be set to ``0``. If greater than 0, the logits will be capped according to - formula: - :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, - where :math:`x` is the input logits. - sm_scale : Optional[float] - The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. - rope_scale : Optional[float] - The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. - rope_theta : Optional[float] - The theta used in RoPE, if not provided, will be set to ``1e4``. - - Returns - ------- - V : torch.Tensor - The attention output, shape: ``[qo_len, num_qo_heads, head_dim]``. - S : torch.Tensor - The logsumexp value, shape: ``[qo_len, num_qo_heads]`` - - Examples - -------- + if causal: + mask_mode = MaskMode.CAUSAL.value + else: + mask_mode = MaskMode.NON_CAUSAL.value + + out = get_single_prefill_module( + q.dtype, + k.dtype, + q.dtype, + q.shape[-1], + PosEncodingMode[pos_encoding_mode].value, + mask_mode, + window_left >= 0, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + allow_fp16_qk_reduction, + ).run( + q, + k, + v, + packed_custom_mask, + tmp, + _get_cache_alibi_slopes_buf(q.shape[1], q.device), + TensorLayout[kv_layout].value, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + return_lse, + ) - >>> import torch - >>> import flashinfer - >>> qo_len = 128 - >>> kv_len = 4096 - >>> num_qo_heads = 32 - >>> num_kv_heads = 4 - >>> head_dim = 128 - >>> q = torch.randn(qo_len, num_qo_heads, head_dim).half().to("cuda:0") - >>> k = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") - >>> v = torch.randn(kv_len, num_kv_heads, head_dim).half().to("cuda:0") - >>> V, S = flashinfer.single_prefill_with_kv_cache_return_lse(q, k, v, causal=True) - >>> V.shape - torch.Size([128, 32, 128]) - >>> S.shape - torch.Size([128, 32]) - >>> mask = torch.tril( - >>> torch.full((qo_len, kv_len), True, device="cuda:0"), - >>> diagonal=(kv_len - qo_len), - >>> ) - >>> mask - tensor([[ True, True, True, ..., False, False, False], - [ True, True, True, ..., False, False, False], - [ True, True, True, ..., False, False, False], - ..., - [ True, True, True, ..., True, False, False], - [ True, True, True, ..., True, True, False], - [ True, True, True, ..., True, True, True]], device='cuda:0') - >>> V_custom, S_custom = flashinfer.single_prefill_with_kv_cache_return_lse(q, k, v, custom_mask=mask) - >>> torch.allclose(V, V_custom, rtol=1e-3, atol=1e-3) - True - >>> torch.allclose(S, S_custom, rtol=1e-3, atol=1e-3) - True + return out if return_lse else out[0] - Note - ---- - Please refer to the :ref:`tutorial ` for a detailed - explanation of the log-sum-exp function and attention states. - The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is - not equal to ``num_kv_heads``, the function will use - `grouped query attention `_. - """ - _check_pos_encoding_mode(pos_encoding_mode) - _check_kv_layout(kv_layout) - tmp = _get_cache_buf( - "single_prefill_with_kv_cache_return_lse_tmp", 8 * 1024 * 1024, q.device - ) - if logits_soft_cap is None: - logits_soft_cap = 0.0 - if sm_scale is None: - sm_scale = 1.0 / math.sqrt(q.size(-1)) - if rope_scale is None: - rope_scale = 1.0 - if rope_theta is None: - rope_theta = 1e4 - if is_float8(q): - logging.warning( - "Our current prefill kernel implementation needs f16 input, the f8 inputs " - " are casted to f16, which could result in performance degradation." - ) - q = q.to(torch.float16) - k = k.to(torch.float16) - v = v.to(torch.float16) - if custom_mask is not None and packed_custom_mask is None: - # convert custom mask to packed mask - packed_custom_mask = packbits( - custom_mask.contiguous().view(-1), bitorder="little" - ) - if packed_custom_mask is not None: - return _prefill.single_prefill_with_kv_cache_custom_mask( - q, - k, - v, - packed_custom_mask, - tmp, - TensorLayout[kv_layout].value, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - True, # return lse - ) - else: - return _prefill.single_prefill_with_kv_cache( - q, - k, - v, - tmp, - causal, - TensorLayout[kv_layout].value, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - True, # return lse - ) +single_prefill_with_kv_cache_return_lse = functools.partial( + single_prefill_with_kv_cache, return_lse=True +) def _compute_page_qk_indptr( @@ -597,13 +500,16 @@ def __init__( _check_kv_layout(kv_layout) self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer + self.device = float_workspace_buffer.device self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) - self._wrapper = _prefill.BatchPrefillWithPagedKVCachePyTorchWrapper( - TensorLayout[kv_layout].value, - use_cuda_graph, + self._pin_memory_int_workspace_buffer = torch.empty( + self._int_workspace_buffer.shape, + dtype=self._int_workspace_buffer.dtype, + pin_memory=True, ) + self._use_cuda_graph = use_cuda_graph if use_cuda_graph: if not torch.is_tensor(qo_indptr_buf): raise ValueError( @@ -643,7 +549,7 @@ def __init__( @property def is_cuda_graph_enabled(self) -> bool: - return self._wrapper.is_cuda_graph_enabled() + return self._use_cuda_graph def reset_workspace_buffer( self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor @@ -662,8 +568,10 @@ def reset_workspace_buffer( """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - self._wrapper.update_page_locked_buffer_size( - int_workspace_buffer.numel() * int_workspace_buffer.element_size() + self._pin_memory_int_workspace_buffer = torch.empty( + self._int_workspace_buffer.shape, + dtype=self._int_workspace_buffer.dtype, + pin_memory=True, ) def plan( @@ -686,7 +594,8 @@ def plan( logits_soft_cap: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - q_data_type: str = "float16", + q_data_type: Union[str, torch.dtype] = "float16", + kv_data_type: Optional[Union[str, torch.dtype]] = None, ) -> None: r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification. @@ -752,8 +661,10 @@ def plan( ``1.0``. rope_theta : Optional[float] The theta used in RoPE, if not provided, will be set to ``1e4``. - q_data_type : Optional[Union[str, torch.dtype]] - The data type of the query tensor. If None, will be set to torch.float16. + q_data_type : Union[str, torch.dtype] + The data type of the query tensor, defaults torch.float16. + kv_data_type : Optional[Union[str, torch.dtype]] + The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. Note ---- @@ -765,8 +676,15 @@ def plan( is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ - batch_size = len(qo_indptr) - 1 + q_data_type = canonicalize_torch_dtype(q_data_type) + if kv_data_type is None: + kv_data_type = q_data_type + kv_data_type = canonicalize_torch_dtype(kv_data_type) + if logits_soft_cap is None: + logits_soft_cap = 0.0 + + batch_size = len(qo_indptr) - 1 if custom_mask is not None or packed_custom_mask is not None: qk_indptr = _compute_page_qk_indptr( qo_indptr, @@ -814,32 +732,45 @@ def plan( # NOTE(Zihao): qk_indptr has the same length as qo_indptr self._qk_indptr_buf.copy_(qk_indptr) else: - self._qo_indptr_buf = qo_indptr - self._paged_kv_indptr_buf = paged_kv_indptr - self._paged_kv_indices_buf = paged_kv_indices - self._paged_kv_last_page_len_buf = paged_kv_last_page_len + self._qo_indptr_buf = qo_indptr.to(self.device) + self._paged_kv_indptr_buf = paged_kv_indptr.to(self.device) + self._paged_kv_indices_buf = paged_kv_indices.to(self.device) + self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to(self.device) if packed_custom_mask is not None: - self._custom_mask_buf = packed_custom_mask - self._qk_indptr_buf = qk_indptr - empty_q_data = torch.empty( - 0, - dtype=( - getattr(torch, q_data_type) - if isinstance(q_data_type, str) - else q_data_type - ), + self._custom_mask_buf = packed_custom_mask.to(self.device) + self._qk_indptr_buf = qk_indptr.to(self.device) + + if packed_custom_mask is not None: + mask_mode = MaskMode.CUSTOM.value + else: + if causal: + mask_mode = MaskMode.CAUSAL.value + else: + mask_mode = MaskMode.NON_CAUSAL.value + + self._cached_module = get_batch_prefill_module( + q_data_type, + kv_data_type, + q_data_type, + paged_kv_indptr.dtype, + head_dim, + PosEncodingMode[pos_encoding_mode].value, + mask_mode, + window_left >= 0, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + allow_fp16_qk_reduction, ) - self._wrapper.plan( + self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, qo_indptr, paged_kv_indptr, batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size, - empty_q_data, + self.is_cuda_graph_enabled, ) self._causal = causal self._pos_encoding_mode = pos_encoding_mode @@ -884,7 +815,8 @@ def run( paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], k_scale: Optional[float] = None, v_scale: Optional[float] = None, - ) -> torch.Tensor: + return_lse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch prefill/append attention between query and paged kv-cache. Parameters @@ -909,21 +841,23 @@ def run( The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. v_scale : Optional[float] The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. + return_lse : bool + Whether to return the logsumexp of attention output Returns ------- - torch.Tensor - The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + If :attr:`return_lse` is ``True``, a tuple of two tensors: + + * The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + * The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads]``. """ - causal = self._causal - pos_encoding_mode = self._pos_encoding_mode - allow_fp16_qk_reduction = self._allow_fp16_qk_reduction window_left = self._window_left logits_soft_cap = self._logits_soft_cap sm_scale = self._sm_scale rope_scale = self._rope_scale rope_theta = self._rope_theta - _check_pos_encoding_mode(pos_encoding_mode) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: @@ -935,46 +869,34 @@ def run( if rope_theta is None: rope_theta = 1e4 - if self._custom_mask_buf is None: - out = self._wrapper.run( - q, - self._qo_indptr_buf, - *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len_buf, - causal, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - False, # return LSE - )[0] - else: - out = self._wrapper.run_custom_mask( - q, - self._qo_indptr_buf, - *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len_buf, - self._custom_mask_buf, - self._qk_indptr_buf, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - False, # return LSE - )[0] + out = self._cached_module.paged_run( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._plan_info, + q, + *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), + self._custom_mask_buf, + _get_cache_alibi_slopes_buf(q.shape[1], q.device), + self._qo_indptr_buf, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len_buf, + self._qk_indptr_buf, + TensorLayout[self._kv_layout].value, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + return_lse, + ) + if v_scale is not None: - out *= v_scale - return out + out[0] *= v_scale + + return out if return_lse else out[0] + + run_return_lse = functools.partialmethod(run, return_lse=True) def forward_return_lse( self, @@ -1002,108 +924,6 @@ def forward_return_lse( self._rope_theta = rope_theta return self.run_return_lse(q, paged_kv_cache, k_scale, v_scale) - def run_return_lse( - self, - q: torch.Tensor, - paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - k_scale: Optional[float] = None, - v_scale: Optional[float] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Compute batch prefill/append attention paged kv-cache. - - Parameters - ---------- - q : torch.Tensor - The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]`` - paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - The paged KV-Cache stored as a tuple of tensors or a single tensor: - - * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape: - ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, - and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. - - * a single 5-D tensor with shape: - ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if - :attr:`kv_layout` is ``NHD``, and - ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if - :attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and - ``paged_kv_cache[:, 1]`` is the value-cache. - - k_scale : Optional[float] - The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. - v_scale : Optional[float] - The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. - - Returns - ------- - V : torch.Tensor - The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. - S : torch.Tensor - The logsumexp of attention output, shape: - ``[qo_indptr[-1], num_qo_heads, head_dim]``. - """ - causal = self._causal - pos_encoding_mode = self._pos_encoding_mode - allow_fp16_qk_reduction = self._allow_fp16_qk_reduction - window_left = self._window_left - logits_soft_cap = self._logits_soft_cap - sm_scale = self._sm_scale - rope_scale = self._rope_scale - rope_theta = self._rope_theta - _check_pos_encoding_mode(pos_encoding_mode) - if logits_soft_cap is None: - logits_soft_cap = 0.0 - if sm_scale is None: - sm_scale = 1.0 / math.sqrt(q.size(-1)) - if k_scale is not None: - sm_scale *= k_scale - if rope_scale is None: - rope_scale = 1.0 - if rope_theta is None: - rope_theta = 1e4 - - if self._custom_mask_buf is None: - out, lse = self._wrapper.run( - q, - self._qo_indptr_buf, - *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len_buf, - causal, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - True, # return LSE - ) - else: - out, lse = self._wrapper.run_custom_mask( - q, - self._qo_indptr_buf, - *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len_buf, - self._custom_mask_buf, - self._qk_indptr_buf, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - True, # return LSE - ) - - if v_scale is not None: - out *= v_scale - return out, lse - def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect.""" pass @@ -1265,13 +1085,14 @@ def __init__( _check_kv_layout(kv_layout) self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer + self.device = float_workspace_buffer.device self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) - self._wrapper = _prefill.BatchPrefillWithRaggedKVCachePyTorchWrapper( - TensorLayout[kv_layout].value, - use_cuda_graph, + self._pin_memory_int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, pin_memory=True ) + self._use_cuda_graph = use_cuda_graph if use_cuda_graph: if not torch.is_tensor(qo_indptr_buf): raise ValueError( @@ -1298,7 +1119,7 @@ def __init__( @property def is_cuda_graph_enabled(self) -> bool: - return self._wrapper.is_cuda_graph_enabled() + return self._use_cuda_graph def reset_workspace_buffer( self, float_workspace_buffer: torch.Tensor, int_workspace_buffer @@ -1317,8 +1138,10 @@ def reset_workspace_buffer( """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - self._wrapper.update_page_locked_buffer_size( - int_workspace_buffer.numel() * int_workspace_buffer.element_size() + self._pin_memory_int_workspace_buffer = torch.empty( + self._int_workspace_buffer.shape, + dtype=self._int_workspace_buffer.dtype, + pin_memory=True, ) def plan( @@ -1339,6 +1162,7 @@ def plan( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, q_data_type: str = "float16", + kv_data_type: Optional[str] = None, ) -> None: r"""Plan batch prefill/append attention on Ragged KV-Cache for given problem specification. @@ -1399,9 +1223,10 @@ def plan( ``1.0``. rope_theta : Optional[float] The theta used in RoPE, if not provided, will be set to ``1e4``. - - q_data_type : Optional[Union[str, torch.dtype]] - The data type of the query tensor. If None, will be set to torch.float16. + q_data_type : Union[str, torch.dtype] + The data type of the query tensor, defaults to torch.float16. + kv_data_type : Optional[Union[str, torch.dtype]] + The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. Note ---- @@ -1413,6 +1238,14 @@ def plan( is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ + q_data_type = canonicalize_torch_dtype(q_data_type) + if kv_data_type is None: + kv_data_type = q_data_type + kv_data_type = canonicalize_torch_dtype(kv_data_type) + + if logits_soft_cap is None: + logits_soft_cap = 0.0 + batch_size = len(qo_indptr) - 1 if len(kv_indptr) != batch_size + 1: raise ValueError( @@ -1450,29 +1283,43 @@ def plan( self._custom_mask_buf[: len(packed_custom_mask)] = packed_custom_mask self._qk_indptr_buf.copy_(qk_indptr) else: - self._qo_indptr_buf = qo_indptr - self._kv_indptr_buf = kv_indptr + self._qo_indptr_buf = qo_indptr.to(self.device) + self._kv_indptr_buf = kv_indptr.to(self.device) if packed_custom_mask is not None: - self._custom_mask_buf = packed_custom_mask - self._qk_indptr_buf = qk_indptr - empty_q_data = torch.empty( - 0, - dtype=( - getattr(torch, q_data_type) - if isinstance(q_data_type, str) - else q_data_type - ), + self._custom_mask_buf = packed_custom_mask.to(self.device) + self._qk_indptr_buf = qk_indptr.to(self.device) + + if packed_custom_mask is not None: + mask_mode = MaskMode.CUSTOM.value + else: + if causal: + mask_mode = MaskMode.CAUSAL.value + else: + mask_mode = MaskMode.NON_CAUSAL.value + + self._cached_module = get_batch_prefill_module( + q_data_type, + kv_data_type, + q_data_type, + kv_indptr.dtype, + head_dim, + PosEncodingMode[pos_encoding_mode].value, + mask_mode, + window_left >= 0, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + allow_fp16_qk_reduction, ) - self._wrapper.plan( + self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, qo_indptr, kv_indptr, batch_size, num_qo_heads, num_kv_heads, - head_dim, - empty_q_data, + 1, # page_size + self.is_cuda_graph_enabled, ) self._causal = causal self._pos_encoding_mode = pos_encoding_mode @@ -1515,7 +1362,8 @@ def run( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - ) -> torch.Tensor: + return_lse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch prefill/append attention between query and kv-cache stored as ragged tensor. @@ -1527,21 +1375,23 @@ def run( The key tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` v : torch.Tensor The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` + return_lse : bool + Whether to return the logsumexp of attention output Returns ------- - torch.Tensor - The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + If :attr:`return_lse` is ``True``, a tuple of two tensors: + + * The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + * The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads]``. """ - causal = self._causal - pos_encoding_mode = self._pos_encoding_mode - allow_fp16_qk_reduction = self._allow_fp16_qk_reduction window_left = self._window_left logits_soft_cap = self._logits_soft_cap sm_scale = self._sm_scale rope_scale = self._rope_scale rope_theta = self._rope_theta - _check_pos_encoding_mode(pos_encoding_mode) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: @@ -1558,41 +1408,31 @@ def run( q = q.to(torch.float16) k = k.to(torch.float16) v = v.to(torch.float16) - if self._custom_mask_buf is None: - return self._wrapper.run( - q, - self._qo_indptr_buf, - k, - v, - self._kv_indptr_buf, - causal, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - False, - )[0] - else: - return self._wrapper.run_custom_mask( - q, - self._qo_indptr_buf, - k, - v, - self._kv_indptr_buf, - self._custom_mask_buf, - self._qk_indptr_buf, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - False, - )[0] + + out = self._cached_module.ragged_run( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._plan_info, + q, + k, + v, + self._custom_mask_buf, + _get_cache_alibi_slopes_buf(q.shape[1], self.device), + self._qo_indptr_buf, + self._kv_indptr_buf, + self._qk_indptr_buf, + TensorLayout[self._kv_layout].value, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + return_lse, + ) + + return out if return_lse else out[0] + + run_return_lse = functools.partialmethod(run, return_lse=True) def forward_return_lse( self, @@ -1619,119 +1459,6 @@ def forward_return_lse( self._rope_theta = rope_theta return self.run_return_lse(q, k, v) - def run_return_lse( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Compute batch prefill/append attention between query and kv-cache stored as - ragged tensor. Return attention output and logsumexp of attention scores. - - Parameters - ---------- - q : torch.Tensor - The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]`` - k : torch.Tensor - The key tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` - v : torch.Tensor - The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` - causal : bool - Whether to apply causal mask to the attention matrix. - This argument is ignored if ``mask`` is provided in :meth:`plan`. - pos_encoding_mode : str - The position encoding applied inside attention kernels, could be - ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. - Default is ``NONE``. - allow_fp16_qk_reduction : bool - Whether to use f16 for qk reduction (faster at the cost of slight precision - loss). - window_left : int - The left (inclusive) window size for the attention window, when set to ``-1``, the window - size will be set to the full length of the sequence. Defaults to ``-1``. - logits_soft_cap : Optional[float] - The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not - provided, will be set to ``0``. If greater than 0, the logits will be capped according to - formula: - :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, - where :math:`x` is the input logits. - sm_scale : Optional[float] - The scale used in softmax, if not provided, will be set to - ``1.0 / sqrt(head_dim)``. - rope_scale : Optional[float] - The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. - rope_theta : Optional[float] - The theta used in RoPE, if not provided, will be set to ``1e4``. - - Returns - ------- - V : torch.Tensor - The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. - S : torch.Tensor - The logsumexp of attention output, shape: - ``[qo_indptr[-1], num_qo_heads, head_dim]``. - """ - causal = self._causal - pos_encoding_mode = self._pos_encoding_mode - allow_fp16_qk_reduction = self._allow_fp16_qk_reduction - window_left = self._window_left - logits_soft_cap = self._logits_soft_cap - sm_scale = self._sm_scale - rope_scale = self._rope_scale - rope_theta = self._rope_theta - _check_pos_encoding_mode(pos_encoding_mode) - if logits_soft_cap is None: - logits_soft_cap = 0.0 - if sm_scale is None: - sm_scale = 1.0 / math.sqrt(q.size(-1)) - if rope_scale is None: - rope_scale = 1.0 - if rope_theta is None: - rope_theta = 1e4 - if is_float8(q): - logging.warning( - "Our current prefill kernel implementation needs f16 input, the f8 inputs " - " are casted to f16, which could result in performance degradation." - ) - q = q.to(torch.float16) - k = k.to(torch.float16) - v = v.to(torch.float16) - if self._custom_mask_buf is None: - return self._wrapper.run( - q, - self._qo_indptr_buf, - k, - v, - self._kv_indptr_buf, - causal, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - True, - ) - else: - return self._wrapper.run_custom_mask( - q, - self._qo_indptr_buf, - k, - v, - self._kv_indptr_buf, - self._custom_mask_buf, - self._qk_indptr_buf, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - window_left, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - True, - ) - def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect.""" pass diff --git a/python/flashinfer/quantization.py b/python/flashinfer/quantization.py index 73c8840ab..919d51774 100644 --- a/python/flashinfer/quantization.py +++ b/python/flashinfer/quantization.py @@ -16,19 +16,28 @@ import torch from typing import Tuple +from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops -# mypy: disable-error-code="attr-defined" -try: - from . import _kernels -except ImportError as e: - import os - import logging - if os.environ.get("BUILD_DOC", "0") == "1": - _kernels = None - logging.warning("Kernels are not loaded in documentation build mode.") - else: - raise e +_quantization_module = None + + +def get_quantization_module(): + global _quantization_module + if _quantization_module is None: + if has_prebuilt_ops: + from . import _kernels + + _quantization_module = _kernels + else: + _quantization_module = load_cuda_ops( + "quantization", + [ + FLASHINFER_CSRC_DIR / "quantization.cu", + FLASHINFER_CSRC_DIR / "flashinfer_quantization_ops.cu", + ], + ) + return _quantization_module def packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor: @@ -62,7 +71,7 @@ def packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor: -------- segment_packbits """ - return _kernels.packbits(x, bitorder) + return get_quantization_module().packbits(x, bitorder) def segment_packbits( @@ -111,4 +120,7 @@ def segment_packbits( packed_len = (seglen + 7) // 8 indptr_new = torch.zeros(len(indptr), dtype=indptr.dtype, device=indptr.device) indptr_new[1:] = torch.cumsum(packed_len, 0) - return _kernels.segment_packbits(x, indptr, indptr_new, bitorder), indptr_new + return ( + get_quantization_module().segment_packbits(x, indptr, indptr_new, bitorder), + indptr_new, + ) diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index 626b162ac..3fb71435b 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -15,19 +15,28 @@ """ import torch +from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops -# mypy: disable-error-code="attr-defined" -try: - from . import _kernels -except ImportError as e: - import os - import logging - if os.environ.get("BUILD_DOC", "0") == "1": - _kernels = None - logging.warning("Kernels are not loaded in documentation build mode.") - else: - raise e +_rope_module = None + + +def get_rope_module(): + global _rope_module + if _rope_module is None: + if has_prebuilt_ops: + from . import _kernels + + _rope_module = _kernels + else: + _rope_module = load_cuda_ops( + "rope", + [ + FLASHINFER_CSRC_DIR / "rope.cu", + FLASHINFER_CSRC_DIR / "flashinfer_rope_ops.cu", + ], + ) + return _rope_module def apply_rope_inplace( @@ -105,7 +114,7 @@ def apply_rope_inplace( -------- apply_rope """ - return _kernels.apply_rope_inplace( + return get_rope_module().apply_rope_inplace( q, k, indptr, offsets, interleave, rope_scale, rope_theta ) @@ -195,7 +204,7 @@ def apply_llama31_rope_inplace( -------- apply_llama31_rope """ - return _kernels.apply_llama31_rope_inplace( + return get_rope_module().apply_llama31_rope_inplace( q, k, indptr, @@ -295,7 +304,7 @@ def apply_rope( -------- apply_rope_inplace """ - return _kernels.apply_rope( + return get_rope_module().apply_rope( q, k, indptr, offsets, interleave, rope_scale, rope_theta ) @@ -396,7 +405,7 @@ def apply_llama31_rope( -------- apply_llama31_rope_inplace """ - return _kernels.apply_llama31_rope( + return get_rope_module().apply_llama31_rope( q, k, indptr, diff --git a/python/flashinfer/sampling.py b/python/flashinfer/sampling.py index 23add4580..6ef0ddeb3 100644 --- a/python/flashinfer/sampling.py +++ b/python/flashinfer/sampling.py @@ -16,19 +16,28 @@ import torch from typing import Tuple, Union, Optional +from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops -# mypy: disable-error-code="attr-defined" -try: - from . import _kernels -except ImportError as e: - import os - import logging - if os.environ.get("BUILD_DOC", "0") == "1": - _kernels = None - logging.warning("Kernels are not loaded in documentation build mode.") - else: - raise e +_sampling_module = None + + +def get_sampling_module(): + global _sampling_module + if _sampling_module is None: + if has_prebuilt_ops: + from . import _kernels + + _sampling_module = _kernels + else: + _sampling_module = load_cuda_ops( + "sampling", + [ + FLASHINFER_CSRC_DIR / "sampling.cu", + FLASHINFER_CSRC_DIR / "flashinfer_sampling_ops.cu", + ], + ) + return _sampling_module def _to_tensor_scalar_tuple(x): @@ -90,7 +99,9 @@ def sampling_from_probs( if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") - return _kernels.sampling_from_probs(probs, uniform_samples, deterministic) + return get_sampling_module().sampling_from_probs( + probs, uniform_samples, deterministic + ) def top_p_sampling_from_probs( @@ -170,7 +181,7 @@ def top_p_sampling_from_probs( if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") - return _kernels.top_p_sampling_from_probs( + return get_sampling_module().top_p_sampling_from_probs( probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic ) @@ -252,7 +263,7 @@ def top_k_sampling_from_probs( if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") - return _kernels.top_k_sampling_from_probs( + return get_sampling_module().top_k_sampling_from_probs( probs, uniform_samples, *_to_tensor_scalar_tuple(top_k), deterministic ) @@ -330,7 +341,7 @@ def min_p_sampling_from_probs( if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") - return _kernels.min_p_sampling_from_probs( + return get_sampling_module().min_p_sampling_from_probs( probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic ) @@ -441,7 +452,7 @@ def top_k_top_p_sampling_from_logits( if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") - return _kernels.top_k_top_p_sampling_from_probs( + return get_sampling_module().top_k_top_p_sampling_from_probs( probs, uniform_samples, *_to_tensor_scalar_tuple(top_k), @@ -549,7 +560,7 @@ def top_k_top_p_sampling_from_probs( if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") - return _kernels.top_k_top_p_sampling_from_probs( + return get_sampling_module().top_k_top_p_sampling_from_probs( probs, uniform_samples, *_to_tensor_scalar_tuple(top_k), @@ -617,7 +628,9 @@ def top_p_renorm_probs( sampling_from_probs top_k_renorm_probs """ - return _kernels.top_p_renorm_probs(probs, *_to_tensor_scalar_tuple(top_p)) + return get_sampling_module().top_p_renorm_probs( + probs, *_to_tensor_scalar_tuple(top_p) + ) top_p_renorm_prob = top_p_renorm_probs @@ -679,7 +692,9 @@ def top_k_renorm_probs( sampling_from_probs top_p_renorm_probs """ - return _kernels.top_k_renorm_probs(probs, *_to_tensor_scalar_tuple(top_k)) + return get_sampling_module().top_k_renorm_probs( + probs, *_to_tensor_scalar_tuple(top_k) + ) top_k_renorm_prob = top_k_renorm_probs @@ -736,7 +751,9 @@ def top_k_mask_logits( -------- top_k_renorm_probs """ - return _kernels.top_k_mask_logits(logits, *_to_tensor_scalar_tuple(top_k)) + return get_sampling_module().top_k_mask_logits( + logits, *_to_tensor_scalar_tuple(top_k) + ) def chain_speculative_sampling( @@ -830,7 +847,7 @@ def chain_speculative_sampling( >>> output_emitted_token_num tensor([1], device='cuda:0') """ - return _kernels.chain_speculative_sampling( + return get_sampling_module().chain_speculative_sampling( draft_probs, draft_token_ids, uniform_samples, diff --git a/python/flashinfer/sparse.py b/python/flashinfer/sparse.py index b005d6afb..bcb28f3e4 100644 --- a/python/flashinfer/sparse.py +++ b/python/flashinfer/sparse.py @@ -15,29 +15,20 @@ """ import math -from typing import Optional +from typing import Optional, Union, Tuple import logging import torch -from .prefill import _compute_page_qk_indptr +from .prefill import _compute_page_qk_indptr, get_batch_prefill_module from .quantization import segment_packbits from .utils import ( _check_pos_encoding_mode, + _get_cache_alibi_slopes_buf, + canonicalize_torch_dtype, PosEncodingMode, + MaskMode, TensorLayout, ) -# mypy: disable-error-code="attr-defined" -try: - from . import _prefill -except ImportError as e: - import os - - if os.environ.get("BUILD_DOC", "0") == "1": - _prefill = None - logging.warning("Kernels are not loaded in documentation build mode.") - else: - raise e - def convert_bsr_mask_layout(mask: torch.Tensor, indptr: torch.Tensor) -> torch.Tensor: r"""Convert mask from BSR data layout to flashinfer's flattened mask layout. @@ -123,13 +114,17 @@ def __init__( buffer should be the same as the device of the input tensors. """ self._float_workspace_buffer = float_workspace_buffer + self.device = float_workspace_buffer.device self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device ) - self._wrapper = _prefill.BatchPrefillWithPagedKVCachePyTorchWrapper( - TensorLayout["NHD"].value, - False, # use_cuda_graph + self._pin_memory_int_workspace_buffer = torch.empty( + self._int_workspace_buffer.shape, + dtype=self._int_workspace_buffer.dtype, + pin_memory=True, ) + self._use_cuda_graph = False + self._kv_layout = "NHD" self._qo_indptr: Optional[torch.Tensor] = None self._paged_kv_indptr_buf: Optional[torch.Tensor] = None self._paged_kv_indices_buf: Optional[torch.Tensor] = None @@ -142,7 +137,7 @@ def __init__( self.N: Optional[int] = None def reset_workspace_buffer( - self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor ) -> None: r"""Reset the workspace buffer. @@ -158,8 +153,10 @@ def reset_workspace_buffer( """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - self._wrapper.update_page_locked_buffer_size( - int_workspace_buffer.numel() * int_workspace_buffer.element_size() + self._pin_memory_int_workspace_buffer = torch.empty( + self._int_workspace_buffer.shape, + dtype=self._int_workspace_buffer.dtype, + pin_memory=True, ) def plan( @@ -181,7 +178,8 @@ def plan( sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - q_data_type: str = "float16", + q_data_type: Union[str, torch.dtype] = "float16", + kv_data_type: Optional[Union[str, torch.dtype]] = None, ) -> None: r"""Create auxiliary data structures for block sparse attention. @@ -235,10 +233,10 @@ def plan( ``1.0``. rope_theta : Optional[float] The theta used in RoPE, if not provided, will be set to ``1e4``. - - q_data_type : str, optional The data type of the query tensor. + kv_data_type : Optional[Union[str, torch.dtype]] + The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. The :meth:`plan` method should be called before any :meth:`run` or :meth:`run_return_lse` calls, auxiliary data structures will be created @@ -248,10 +246,18 @@ def plan( is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ + q_data_type = canonicalize_torch_dtype(q_data_type) + if kv_data_type is None: + kv_data_type = q_data_type + kv_data_type = canonicalize_torch_dtype(kv_data_type) + + if logits_soft_cap is None: + logits_soft_cap = 0.0 + num_blocks_row = len(indptr) - 1 qo_indptr_host = R * torch.arange(num_blocks_row + 1, dtype=torch.int32) qo_indptr_host[-1] = M - self._qo_indptr = qo_indptr_host.to(indptr.device) + qo_indptr = qo_indptr_host.to(indptr.device) if indices.max().item() * C > N: raise ValueError("indices out of bound") last_block_len = torch.full( @@ -260,7 +266,7 @@ def plan( if mask is not None or packed_mask is not None: qk_indptr = _compute_page_qk_indptr( - self._qo_indptr, + qo_indptr, indptr, # paged_kv_indptr last_block_len, # paged_kv_last_page_len C, # page_size @@ -273,40 +279,48 @@ def plan( mask.contiguous().view(-1), qk_indptr, bitorder="little" ) - self._paged_kv_indptr_buf = indptr - self._paged_kv_indices_buf = indices - self._paged_kv_last_page_len = last_block_len + self._qo_indptr = qo_indptr.to(self.device) + self._paged_kv_indptr_buf = indptr.to(self.device) + self._paged_kv_indices_buf = indices.to(self.device) + self._paged_kv_last_page_len = last_block_len.to(self.device) if packed_mask is not None: - self._packed_mask_buf = packed_mask - self._qk_indptr_buf = qk_indptr + self._packed_mask_buf = packed_mask.to(self.device) + self._qk_indptr_buf = qk_indptr.to(self.device) + mask_mode = MaskMode.CUSTOM.value else: self._packed_mask_buf = None - - empty_q_data = torch.empty( - 0, - dtype=( - getattr(torch, q_data_type) - if isinstance(q_data_type, str) - else q_data_type - ), - ) + self._qk_indptr_buf = None + mask_mode = MaskMode.NON_CAUSAL.value self.M = M self.N = N self.R = R self.C = C - self._wrapper.plan( + self._cached_module = get_batch_prefill_module( + q_data_type, + kv_data_type, + q_data_type, + indptr.dtype, + head_dim, + PosEncodingMode[pos_encoding_mode].value, + mask_mode, + False, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + allow_fp16_qk_reduction, + ) + + self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, - self._qo_indptr, - self._paged_kv_indptr_buf, + self._pin_memory_int_workspace_buffer, + qo_indptr, + indptr, num_blocks_row, num_qo_heads, num_kv_heads, - head_dim, C, - empty_q_data, + False, # is_cuda_graph_enabled ) self._pos_encoding_mode = pos_encoding_mode @@ -344,7 +358,8 @@ def run( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - ) -> torch.Tensor: + return_lse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute block-sparse attention between Q/K/V tensors. Parameters @@ -355,11 +370,18 @@ def run( The key tensor with shape ``(N, num_kv_heads, head_dim)``. v : torch.Tensor The value tensor with shape ``(N, num_kv_heads, head_dim)``. + return_lse : bool + Whether to return the logsumexp of attention output + Returns ------- - torch.Tensor - The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + If :attr:`return_lse` is ``False``, the attention output, shape: ``[M, num_qo_heads, head_dim]``. + If :attr:`return_lse` is ``True``, a tuple of two tensors: + + * The attention output, shape: ``[M, num_qo_heads, head_dim]``. + * The logsumexp of attention output, shape: ``[M, num_qo_heads]``. """ pos_encoding_mode = self._pos_encoding_mode allow_fp16_qk_reduction = self._allow_fp16_qk_reduction @@ -379,47 +401,32 @@ def run( k = k.reshape(-1, self.C, *k.shape[-2:]).contiguous() v = v.reshape(-1, self.C, *v.shape[-2:]).contiguous() - if self._packed_mask_buf is None: - return self._wrapper.run( - q, - self._qo_indptr, - None, - k, - v, - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len, - False, # causal - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - -1, # window_left - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - False, # return LSE - )[0] - else: - return self._wrapper.run_custom_mask( - q, - self._qo_indptr, - None, - k, - v, - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len, - self._packed_mask_buf, - self._qk_indptr_buf, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - -1, # window_left - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - False, # return LSE - )[0] + + out = self._cached_module.paged_run( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._plan_info, + q, + None, + k, + v, + self._packed_mask_buf, + _get_cache_alibi_slopes_buf(q.shape[1], self.device), + self._qo_indptr, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len, + self._qk_indptr_buf, + TensorLayout[self._kv_layout].value, + -1, # window_left + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + return_lse, + ) + + return out if return_lse else out[0] def end_forward(self) -> None: r"""Warning: This method is deprecated and has no effect.""" diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index 83e3806ce..e8e40ee5c 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -15,8 +15,9 @@ """ import torch +import math from enum import Enum -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Dict class PosEncodingMode(Enum): @@ -25,11 +26,20 @@ class PosEncodingMode(Enum): ALIBI = 2 +class MaskMode(Enum): + NON_CAUSAL = 0 + CAUSAL = 1 + CUSTOM = 2 + + class TensorLayout(Enum): NHD = 0 HND = 1 +log2e = 1.44269504088896340736 + + def _expand_5d(x: torch.Tensor, kv_layout: str) -> torch.Tensor: if not x.ndim in [4, 5]: raise ValueError("x must be 4D or 5D") @@ -106,3 +116,63 @@ def _unpack_paged_kv_cache( type(paged_kv_cache) ) ) + + +def get_alibi_slopes(n_heads: int) -> torch.Tensor: + n = 2 ** math.floor(math.log2(n_heads)) + m_0 = 2.0 ** (-8.0 / n) + m = torch.pow(m_0, torch.arange(1, 1 + n)) + if n < n_heads: + m_hat_0 = 2.0 ** (-4.0 / n) + m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2)) + m = torch.cat([m, m_hat]) + return m.float() + + +_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {} + + +def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: + key = (name, device) + buf = _cache_buf.get(key) + if buf is None: + buf = torch.empty(bytes, dtype=torch.uint8, device=device) + _cache_buf[key] = buf + return buf + + +# find the least power of 2 that is greater than or equal to x +def _ceil_pow2(x: int) -> int: + return 1 << (x - 1).bit_length() + + +def _get_range_buf(seq_len: int, device: torch.device) -> torch.Tensor: + seq_len_pow2 = _ceil_pow2(seq_len) + key = (f"range_{seq_len_pow2}", device) + buf = _cache_buf.get(key) + if buf is None: + buf = torch.arange(seq_len_pow2, device=device, dtype=torch.int32) + _cache_buf[key] = buf + return buf[:seq_len] + + +def _get_cache_alibi_slopes_buf( + num_qo_heads: int, device: torch.device +) -> torch.Tensor: + key = (f"alibi_slopes_{num_qo_heads}", device) + buf = _cache_buf.get(key) + if buf is None: + buf = (get_alibi_slopes(num_qo_heads) * log2e).to(device) + _cache_buf[key] = buf + return buf + + +def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: + if isinstance(dtype, str): + return getattr(torch, dtype) + elif isinstance(dtype, torch.dtype): + return dtype + else: + raise TypeError( + "dtype must be a string or torch.dtype, got {}".format(type(dtype)) + ) diff --git a/python/generate_batch_paged_prefill_inst.py b/python/generate_batch_paged_prefill_inst.py deleted file mode 100644 index 62e66b90c..000000000 --- a/python/generate_batch_paged_prefill_inst.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Copyright (c) 2024 by FlashInfer team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import sys -import re -import itertools -from literal_map import ( - mask_mode_literal, - pos_encoding_mode_literal, - warp_layout_literal, - dtype_literal, - idtype_literal, - logits_hook_literal, -) -from pathlib import Path - - -def get_cu_file_str( - head_dim, - logits_hook, - pos_encoding_mode, - allow_fp16_qk_reduction, - mask_mode, - dtype_q, - dtype_kv, - dtype_out, - idtype, -): - warp_layout_choice = [0, 1, 2] - insts = "\n".join( - [ - """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( - {dtype_q}* q, {idtype}* request_indices, {idtype}* q_tile_indices, {idtype}* kv_tile_indices, - {idtype}* q_indptr, {idtype}* q_offset, - paged_kv_t paged_kv, uint8_t* custom_mask, - {idtype}* qk_indptr, {idtype}* o_indptr, {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, - {idtype}* merge_indptr, bool* block_valid_mask, {idtype}* kv_chunk_size_ptr, uint32_t max_num_rows, - uint32_t num_qo_heads, uint32_t padded_batch_size, int32_t window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); - """.format( - logits_hook=logits_hook_literal[int(logits_hook)], - warp_layout=warp_layout_literal[warp_layout], - head_dim=head_dim, - pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], - allow_fp16_qk_reduction=allow_fp16_qk_reduction, - mask_mode=mask_mode_literal[int(mask_mode)], - dtype_q=dtype_literal[dtype_q], - dtype_kv=dtype_literal[dtype_kv], - dtype_out=dtype_literal[dtype_out], - idtype=idtype_literal[idtype], - ) - for warp_layout in warp_layout_choice - ] - ) - - content = f"""#include - -namespace flashinfer {{ - -constexpr PageStorage page_storage = PageStorage::kIndices; - -{insts} - -}}""" - return content - - -if __name__ == "__main__": - pattern = ( - r"batch_paged_prefill_head_([0-9]+)_logitshook_([0-9]+)_posenc_([0-9]+)_" - r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" - ) - compiled_pattern = re.compile(pattern) - path = Path(sys.argv[1]) - fname = path.name - match = compiled_pattern.match(fname) - - with open(path, "w") as f: - f.write(get_cu_file_str(*match.groups())) diff --git a/python/setup.py b/python/setup.py index 22d2878af..ffb3debb2 100644 --- a/python/setup.py +++ b/python/setup.py @@ -18,237 +18,10 @@ import pathlib import os -import re -import itertools -import subprocess -import platform - import setuptools -import argparse -import torch -import torch.utils.cpp_extension as torch_cpp_ext - -import generate_single_decode_inst, generate_single_prefill_inst, generate_batch_paged_decode_inst, generate_batch_paged_prefill_inst, generate_batch_ragged_prefill_inst, generate_dispatch_inc root = pathlib.Path(__name__).parent -# cuda arch check for fp8 at the moment. -for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): - arch = int(re.search("compute_\d+", cuda_arch_flags).group()[-2:]) - if arch < 75: - raise RuntimeError("FlashInfer requires sm75+") - -enable_bf16 = os.environ.get("FLASHINFER_ENABLE_BF16", "1") == "1" -enable_fp8 = os.environ.get("FLASHINFER_ENABLE_FP8", "1") == "1" - -if enable_bf16: - torch_cpp_ext.COMMON_NVCC_FLAGS.append("-DFLASHINFER_ENABLE_BF16") -if enable_fp8: - torch_cpp_ext.COMMON_NVCC_FLAGS.append("-DFLASHINFER_ENABLE_FP8") - - -def write_if_different(path: pathlib.Path, content: str) -> None: - if path.exists(): - with open(path, "r") as f: - if f.read() == content: - return - with open(path, "w") as f: - f.write(content) - - -def get_instantiation_cu() -> Tuple[List[str], List[str]]: - prefix = "csrc/generated" - (root / prefix).mkdir(parents=True, exist_ok=True) - - logits_hooks = os.environ.get("FLASHINFER_LOGITS_POST_HOOKS", "0,1").split(",") - head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",") - pos_encoding_modes = os.environ.get("FLASHINFER_POS_ENCODING_MODES", "0,1,2").split( - "," - ) - allow_fp16_qk_reduction_options = os.environ.get( - "FLASHINFER_ALLOW_FP16_QK_REDUCTION_OPTIONS", "0,1" - ).split(",") - mask_modes = os.environ.get("FLASHINFER_MASK_MODES", "0,1,2").split(",") - # dispatch.inc - path = root / prefix / "dispatch.inc" - write_if_different( - path, - generate_dispatch_inc.get_dispatch_inc_str( - argparse.Namespace( - head_dims=map(int, head_dims), - logits_post_hooks=map(int, logits_hooks), - pos_encoding_modes=map(int, pos_encoding_modes), - allow_fp16_qk_reductions=map(int, allow_fp16_qk_reduction_options), - mask_modes=map(int, mask_modes), - ) - ), - ) - - idtypes = ["i32"] - prefill_dtypes = ["f16"] - decode_dtypes = ["f16"] - fp16_dtypes = ["f16"] - fp8_dtypes = ["e4m3", "e5m2"] - if enable_bf16: - prefill_dtypes.append("bf16") - decode_dtypes.append("bf16") - fp16_dtypes.append("bf16") - if enable_fp8: - decode_dtypes.extend(fp8_dtypes) - - files_decode = [] - files_prefill = [] - # single decode files - for ( - head_dim, - logits_hook, - pos_encoding_mode, - ) in itertools.product( - head_dims, - logits_hooks, - pos_encoding_modes, - ): - for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list( - itertools.product(fp16_dtypes, fp8_dtypes) - ): - dtype_out = dtype_q - fname = f"single_decode_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" - files_decode.append(prefix + "/" + fname) - content = generate_single_decode_inst.get_cu_file_str( - head_dim, - logits_hook, - pos_encoding_mode, - dtype_q, - dtype_kv, - dtype_out, - ) - write_if_different(root / prefix / fname, content) - - # batch decode files - for ( - head_dim, - logits_hook, - pos_encoding_mode, - ) in itertools.product( - head_dims, - logits_hooks, - pos_encoding_modes, - ): - for idtype in idtypes: - for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list( - itertools.product(fp16_dtypes, fp8_dtypes) - ): - dtype_out = dtype_q - fname = f"batch_paged_decode_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu" - files_decode.append(prefix + "/" + fname) - content = generate_batch_paged_decode_inst.get_cu_file_str( - head_dim, - logits_hook, - pos_encoding_mode, - dtype_q, - dtype_kv, - dtype_out, - idtype, - ) - write_if_different(root / prefix / fname, content) - - # single prefill files - for ( - head_dim, - logits_hook, - pos_encoding_mode, - allow_fp16_qk_reduction, - mask_mode, - ) in itertools.product( - head_dims, - logits_hooks, - pos_encoding_modes, - allow_fp16_qk_reduction_options, - mask_modes, - ): - for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): - fname = f"single_prefill_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu" - files_prefill.append(prefix + "/" + fname) - content = generate_single_prefill_inst.get_cu_file_str( - head_dim, - logits_hook, - pos_encoding_mode, - allow_fp16_qk_reduction, - mask_mode, - dtype_q, # dtype_q - dtype_kv, # dtype_kv - dtype_q, # dtype_out - ) - write_if_different(root / prefix / fname, content) - - # batch paged prefill files - for ( - head_dim, - logits_hook, - pos_encoding_mode, - allow_fp16_qk_reduction, - mask_mode, - idtype, - ) in itertools.product( - head_dims, - logits_hooks, - pos_encoding_modes, - allow_fp16_qk_reduction_options, - mask_modes, - idtypes, - ): - for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list( - itertools.product(prefill_dtypes, fp8_dtypes) - ): - fname = f"batch_paged_prefill_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" - files_prefill.append(prefix + "/" + fname) - content = generate_batch_paged_prefill_inst.get_cu_file_str( - head_dim, - logits_hook, - pos_encoding_mode, - allow_fp16_qk_reduction, - mask_mode, - dtype_q, # dtype_q - dtype_kv, # dtype_kv - dtype_q, # dtype_out - idtype, - ) - write_if_different(root / prefix / fname, content) - - # batch ragged prefill files - for ( - head_dim, - logits_hook, - pos_encoding_mode, - allow_fp16_qk_reduction, - mask_mode, - idtype, - ) in itertools.product( - head_dims, - logits_hooks, - pos_encoding_modes, - allow_fp16_qk_reduction_options, - mask_modes, - idtypes, - ): - for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)): - fname = f"batch_ragged_prefill_head_{head_dim}_logitshook_{logits_hook}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" - files_prefill.append(prefix + "/" + fname) - content = generate_batch_ragged_prefill_inst.get_cu_file_str( - head_dim, - logits_hook, - pos_encoding_mode, - allow_fp16_qk_reduction, - mask_mode, - dtype_q, # dtype_q - dtype_kv, # dtype_kv - dtype_q, # dtype_out - idtype, - ) - write_if_different(root / prefix / fname, content) - - return files_prefill, files_decode - def get_version(): version = os.getenv("FLASHINFER_BUILD_VERSION") @@ -258,121 +31,21 @@ def get_version(): return version -def get_cuda_version() -> Tuple[int, int]: - if torch_cpp_ext.CUDA_HOME is None: - nvcc = "nvcc" - else: - nvcc = os.path.join(torch_cpp_ext.CUDA_HOME, "bin/nvcc") - txt = subprocess.check_output([nvcc, "--version"], text=True) - major, minor = map(int, re.findall(r"release (\d+)\.(\d+),", txt)[0]) - return major, minor - - def generate_build_meta() -> None: - d = {} version = get_version() - d["cuda_major"], d["cuda_minor"] = get_cuda_version() - d["torch"] = torch.__version__ - d["python"] = platform.python_version() - d["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", None) with open(root / "flashinfer/_build_meta.py", "w") as f: f.write(f"__version__ = {version!r}\n") - f.write(f"build_meta = {d!r}") -def remove_unwanted_pytorch_nvcc_flags(): - REMOVE_NVCC_FLAGS = [ - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - ] - for flag in REMOVE_NVCC_FLAGS: - try: - torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) - except ValueError: - pass - - -class NinjaBuildExtension(torch_cpp_ext.BuildExtension): - def __init__(self, *args, **kwargs) -> None: - # do not override env MAX_JOBS if already exists - if not os.environ.get("MAX_JOBS"): - max_num_jobs_cores = max(1, os.cpu_count()) - os.environ["MAX_JOBS"] = str(max_num_jobs_cores) - - super().__init__(*args, **kwargs) +def clear_aot_config(): + # remove aot_config.py + aot_config_path = root / "flashinfer" / "jit" / "aot_config.py" + if os.path.exists(aot_config_path): + os.remove(aot_config_path) if __name__ == "__main__": - remove_unwanted_pytorch_nvcc_flags() generate_build_meta() - files_prefill, files_decode = get_instantiation_cu() - include_dirs = [ - str(root.resolve() / "include"), - str(root.resolve() / "3rdparty" / "cutlass" / "include"), # for group gemm - ] - extra_compile_args = { - "cxx": [ - "-O3", - "-Wno-switch-bool", - ], - "nvcc": [ - "-O3", - "-std=c++17", - "--threads", - "1", - "-Xfatbin", - "-compress-all", - "-use_fast_math", - ], - } - ext_modules = [] - ext_modules.append( - torch_cpp_ext.CUDAExtension( - name="flashinfer._kernels", - sources=[ - "csrc/cascade.cu", - "csrc/page.cu", - "csrc/flashinfer_ops.cu", - "csrc/sampling.cu", - "csrc/norm.cu", - "csrc/activation.cu", - "csrc/rope.cu", - "csrc/group_gemm.cu", - "csrc/quantization.cu", - "csrc/bmm_fp8.cu", - ], - include_dirs=include_dirs, - extra_compile_args=extra_compile_args, - ) - ) - ext_modules.append( - torch_cpp_ext.CUDAExtension( - name="flashinfer._decode", - sources=[ - "csrc/single_decode.cu", - "csrc/flashinfer_ops_decode.cu", - "csrc/batch_decode.cu", - ] - + files_decode, - include_dirs=include_dirs, - extra_compile_args=extra_compile_args, - ) - ) - ext_modules.append( - torch_cpp_ext.CUDAExtension( - name="flashinfer._prefill", - sources=[ - "csrc/single_prefill.cu", - "csrc/flashinfer_ops_prefill.cu", - "csrc/batch_prefill.cu", - ] - + files_prefill, - include_dirs=include_dirs, - extra_compile_args=extra_compile_args, - ) - ) setuptools.setup( name="flashinfer", version=get_version(), @@ -382,6 +55,4 @@ def __init__(self, *args, **kwargs) -> None: description="FlashInfer: Kernel Library for LLM Serving", url="https://github.com/flashinfer-ai/flashinfer", python_requires=">=3.8", - ext_modules=ext_modules, - cmdclass={"build_ext": NinjaBuildExtension}, ) diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index 8b219f3e4..d69d93cec 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -37,7 +37,6 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { size_t page_size = state.get_int64("page_size"); size_t num_qo_heads = state.get_int64("num_qo_heads"); size_t num_kv_heads = state.get_int64("num_kv_heads"); - bool cooperative = state.get_int64("cooperative"); // KV cache: auto pages_per_seq = (seqlen + page_size - 1) / page_size; @@ -56,11 +55,11 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { thrust::device_vector kv_indptr(kv_indptr_host); thrust::device_vector kv_indices(kv_indicies_host); thrust::device_vector kv_last_page_len(kv_last_page_len_host); - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data.data()), thrust::raw_pointer_cast(kv_indices.data()), - thrust::raw_pointer_cast(kv_indptr.data()), - thrust::raw_pointer_cast(kv_last_page_len.data())); + paged_kv_t paged_kv(num_kv_heads, page_size, head_dim, batch_size, kv_layout, + thrust::raw_pointer_cast(kv_data.data()), + thrust::raw_pointer_cast(kv_indices.data()), + thrust::raw_pointer_cast(kv_indptr.data()), + thrust::raw_pointer_cast(kv_last_page_len.data())); // Allocate input data: thrust::device_vector q(batch_size * num_qo_heads * head_dim); thrust::device_vector o(batch_size * num_qo_heads * head_dim); @@ -71,38 +70,24 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { state.add_global_memory_writes(vec_bytes(o), "Write"); BatchDecodeHandler handler; - if (cooperative) { - size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector float_buffer(float_workspace_size_in_bytes); - size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; - thrust::device_vector int_buffer(int_workspace_size_in_bytes); - // begin forward - BatchDecodeHandlerPlan( - &handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), - float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), - int_workspace_size_in_bytes, kv_indptr_host.data(), kv_last_page_len_host.data(), - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode); - state.exec([&](nvbench::launch&) { - cudaError_t status = - BatchDecodeWithPagedKVCacheWrapper( - &handler, thrust::raw_pointer_cast(q.data()), /*q_offset=*/nullptr, paged_kv, - thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); - if (status != cudaSuccess) { - state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); - } - }); - } else { - state.exec([&](nvbench::launch&) { - cudaError_t status = - BatchDecodeWithPagedKVCacheNoSplitKV( - thrust::raw_pointer_cast(q.data()), /*q_offset=*/nullptr, paged_kv, - kv_partition_info_t(), thrust::raw_pointer_cast(o.data()), - /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); - if (status != cudaSuccess) { - state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); - } - }); - } + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); + // begin forward + BatchDecodeHandlerPlan( + &handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size, num_qo_heads, num_kv_heads, + head_dim, page_size, pos_encoding_mode); + state.exec([&](nvbench::launch&) { + cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( + &handler, thrust::raw_pointer_cast(q.data()), /*q_offset=*/nullptr, paged_kv, + thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); + if (status != cudaSuccess) { + state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); + } + }); } template @@ -132,11 +117,11 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { thrust::device_vector kv_indptr(kv_indptr_host); thrust::device_vector kv_indices(kv_indicies_host); thrust::device_vector kv_last_page_len(kv_last_page_len_host); - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data.data()), thrust::raw_pointer_cast(kv_indices.data()), - thrust::raw_pointer_cast(kv_indptr.data()), - thrust::raw_pointer_cast(kv_last_page_len.data())); + paged_kv_t paged_kv(num_kv_heads, page_size, head_dim, batch_size, kv_layout, + thrust::raw_pointer_cast(kv_data.data()), + thrust::raw_pointer_cast(kv_indices.data()), + thrust::raw_pointer_cast(kv_indptr.data()), + thrust::raw_pointer_cast(kv_last_page_len.data())); // Allocate input data: thrust::device_vector q(batch_size * num_qo_heads * head_dim); @@ -164,13 +149,11 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { - cudaError_t status = - BatchPrefillWithPagedKVCacheWrapper( - &handler, thrust::raw_pointer_cast(q.data()), - thrust::raw_pointer_cast(qo_indptr_d.data()), - /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()), - /*lse=*/nullptr, num_qo_heads, - /*causal=*/false, pos_encoding_mode); + cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( + &handler, thrust::raw_pointer_cast(q.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), + /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()), + /*lse=*/nullptr, num_qo_heads, + /*causal=*/false, pos_encoding_mode); }); } @@ -188,8 +171,7 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { 160, 192, 224, 256, 320, 384, 448, 512, 640, 768, 896, 1024}) \ .add_int64_axis("page_size", {4, 8, 16, 32, 64}) \ .add_int64_axis("num_qo_heads", {32}) \ - .add_int64_axis("num_kv_heads", {32, 4}) \ - .add_int64_axis("cooperative", {0, 1}) + .add_int64_axis("num_kv_heads", {32, 4}) #define BENCH_FLASHINFER_BATCH_DECODE_WITH_PREFILL(dtype) \ auto bench_flashinfer_batch_decode_with_prefill_##dtype##_ = \ diff --git a/src/bench_batch_prefill.cu b/src/bench_batch_prefill.cu index 512637f57..802bbb1fa 100644 --- a/src/bench_batch_prefill.cu +++ b/src/bench_batch_prefill.cu @@ -21,7 +21,7 @@ #include #include -#include "flashinfer/attention/handler.cuh" +#include "flashinfer/attention/scheduler.cuh" #include "flashinfer/layout.cuh" #include "flashinfer/pos_enc.cuh" #include "flashinfer_ops.cuh" diff --git a/src/bench_cascade.cu b/src/bench_cascade.cu index 00f374c46..d4b604888 100644 --- a/src/bench_cascade.cu +++ b/src/bench_cascade.cu @@ -87,7 +87,6 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { thrust::device_vector kv_data_d(kv_data_h); thrust::device_vector q_d(q_h); - constexpr PageStorage page_storage = PageStorage::kIndices; state.add_global_memory_reads(kv_data_h.size() + q_h.size(), "Read"); state.add_global_memory_writes(q_h.size(), "Write"); @@ -101,7 +100,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { thrust::device_vector kv_indptr_unique_d(kv_indptr_unique_h), kv_indices_unique_d(kv_indices_unique_h), kv_last_page_len_unique_d(kv_last_page_len_unique_h); - paged_kv_t paged_kv_casacde_d( + paged_kv_t paged_kv_casacde_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, thrust::raw_pointer_cast(kv_data_d.data()), thrust::raw_pointer_cast(kv_indices_unique_d.data()), @@ -112,7 +111,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { thrust::device_vector float_buffer(float_workspace_size_in_bytes); size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; thrust::device_vector int_buffer(int_workspace_size_in_bytes); - BatchDecodeHandlerPlan( + BatchDecodeHandlerPlan( &cascade_handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), @@ -134,7 +133,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { std::string(cudaGetErrorString(status))); } - status = BatchDecodeWithPagedKVCacheWrapper( + status = BatchDecodeWithPagedKVCacheWrapper( &cascade_handler, thrust::raw_pointer_cast(q_d.data()), /*q_offset=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), /*lse=*/thrust::raw_pointer_cast(lse_cascade_1_d.data()), num_qo_heads, @@ -162,7 +161,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { thrust::device_vector kv_indptr_combined_d(kv_indptr_combined_h), kv_indices_combined_d(kv_indices_combined_h), kv_last_page_len_combined_d(kv_last_page_len_combined_h); - paged_kv_t paged_kv_baseline_d( + paged_kv_t paged_kv_baseline_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, thrust::raw_pointer_cast(kv_data_d.data()), thrust::raw_pointer_cast(kv_indices_combined_d.data()), @@ -173,7 +172,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { thrust::device_vector float_buffer(float_workspace_size_in_bytes); size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; thrust::device_vector int_buffer(int_workspace_size_in_bytes); - BatchDecodeHandlerPlan( + BatchDecodeHandlerPlan( &baseline_handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, kv_indptr_combined_h.data(), @@ -182,7 +181,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( + cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), /*q_offset=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), /*lse=*/nullptr, num_qo_heads, PosEncodingMode::kNone); @@ -228,7 +227,6 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { thrust::device_vector kv_data_d(kv_data_h); thrust::device_vector q_d(q_h); thrust::device_vector qo_indptr_d(qo_indptr_h); - constexpr PageStorage page_storage = PageStorage::kIndices; state.add_global_memory_reads(kv_data_h.size() + q_h.size(), "Read"); state.add_global_memory_writes(q_h.size(), "Write"); @@ -242,7 +240,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { thrust::device_vector kv_indptr_unique_d(kv_indptr_unique_h), kv_indices_unique_d(kv_indices_unique_h), kv_last_page_len_unique_d(kv_last_page_len_unique_h); - paged_kv_t paged_kv_casacde_d( + paged_kv_t paged_kv_casacde_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, thrust::raw_pointer_cast(kv_data_d.data()), thrust::raw_pointer_cast(kv_indices_unique_d.data()), @@ -275,7 +273,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { std::string(cudaGetErrorString(status))); } - status = BatchPrefillWithPagedKVCacheWrapper( + status = BatchPrefillWithPagedKVCacheWrapper( &cascade_handler, thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), /*q_offset=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), @@ -303,7 +301,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { thrust::device_vector kv_indptr_combined_d(kv_indptr_combined_h), kv_indices_combined_d(kv_indices_combined_h), kv_last_page_len_combined_d(kv_last_page_len_combined_h); - paged_kv_t paged_kv_baseline_d( + paged_kv_t paged_kv_baseline_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, thrust::raw_pointer_cast(kv_data_d.data()), thrust::raw_pointer_cast(kv_indices_combined_d.data()), @@ -321,7 +319,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { head_dim, page_size); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); - cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( + cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), /*q_offset=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), diff --git a/src/cpu_reference.h b/src/cpu_reference.h index a5c8fb5df..73174da73 100644 --- a/src/cpu_reference.h +++ b/src/cpu_reference.h @@ -156,8 +156,7 @@ std::vector single_mha(const std::vector& q, const std::vect } template -void append_paged_kv_cache(paged_kv_t page_cpu, - const std::vector>& keys, +void append_paged_kv_cache(paged_kv_t page_cpu, const std::vector>& keys, const std::vector>& values, const std::vector& append_indptr) { size_t batch_size = page_cpu.batch_size; diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 84081658e..d6fa6051c 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -13,20 +13,242 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include +#include +#include +#include +#include #include -#include "flashinfer/attention/logits_post_hook.cuh" +#include "flashinfer/allocator.h" #include "flashinfer/attention/mask.cuh" +#include "flashinfer/attention/scheduler.cuh" +#include "flashinfer/attention/warp_layout.cuh" +#include "flashinfer/layout.cuh" #include "utils.h" namespace flashinfer { -template +template +cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +class BatchDecodeHandler { + public: + template + cudaError_t PlanDispatched(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, size_t int_workspace_size_in_bytes, IdType* indptr_h, + IdType* last_page_len_h, uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t page_size) { + int_buffer_ = int_buffer; + float_buffer_ = float_buffer; + using ParamsT = BatchDecodeParams; + using AttentionVariant = + ComposedAttention; + return DecodePlan( + float_buffer, float_workspace_size_in_bytes, int_buffer, page_locked_buffer_, + int_workspace_size_in_bytes, plan_info_, indptr_h, batch_size, num_qo_heads, num_kv_heads, + page_size, cuda_graph_enabled_, stream_); + } + + void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) { + cudaFreeHost(page_locked_buffer_); + cudaMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes); + } + + cudaStream_t GetCUDAStream() const { return stream_; } + + void SetCUDAStream(cudaStream_t stream) { stream_ = stream; } + + /*! + * \brief Constructor of BatchDecodeHandler + * \param enable_cuda_graph A boolean indicates whether to enable CUDA graph + * \param batch_size If enable_cuda_graph is true, we must specify a fixed batch_size + */ + BatchDecodeHandler(bool enable_cuda_graph = false, uint32_t batch_size = 0) + : cuda_graph_enabled_(enable_cuda_graph), stream_(nullptr) { + cudaMallocHost(&page_locked_buffer_, 8 * 1024 * 1024); + } + ~BatchDecodeHandler() { cudaFreeHost(page_locked_buffer_); } + + bool IsCUDAGraphEnabled() const { return cuda_graph_enabled_; } + + DecodePlanInfo GetPlanInfo() const { return plan_info_; } + + template + IdType* GetRequestIndices() { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.request_indices_offset); + } + + template + IdType* GetKVTileIndices() { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.kv_tile_indices_offset); + } + + template + IdType* GetOIndptr() { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.o_indptr_offset); + } + + template + IdType* GetKVChunkSizePtr() { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.kv_chunk_size_ptr_offset); + } + + template + DTypeO* GetTmpV() { + if (plan_info_.split_kv) { + return GetPtrFromBaseOffset(float_buffer_, plan_info_.v_offset); + } + return nullptr; + } + + float* GetTmpS() { + if (plan_info_.split_kv) { + return GetPtrFromBaseOffset(float_buffer_, plan_info_.s_offset); + } + return nullptr; + } + + bool* GetBlockValidMask() { + if (plan_info_.split_kv && plan_info_.enable_cuda_graph) { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.block_valid_mask_offset); + } + return nullptr; + } + + protected: + void* page_locked_buffer_; + void* int_buffer_; + void* float_buffer_; + DecodePlanInfo plan_info_; + bool cuda_graph_enabled_; + cudaStream_t stream_; +}; + +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp_v, + float* tmp_s, cudaStream_t stream); + +class BatchPrefillHandler { + public: + void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) { + cudaFreeHost(page_locked_buffer_); + cudaMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes); + } + + template + cudaError_t Plan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, + size_t int_workspace_size_in_bytes, IdType* qo_indptr_h, IdType* kv_indptr_h, + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, uint32_t page_size) { + int_buffer_ = int_buffer; + float_buffer_ = float_buffer; + return PrefillPlan(float_buffer, float_workspace_size_in_bytes, int_buffer, + page_locked_buffer_, int_workspace_size_in_bytes, plan_info_, + qo_indptr_h, kv_indptr_h, batch_size, num_qo_heads, num_kv_heads, + head_dim, page_size, enable_cuda_graph_, sizeof(DTypeO), stream_); + } + + cudaStream_t GetCUDAStream() const { return stream_; } + + void SetCUDAStream(cudaStream_t stream) { stream_ = stream; } + + bool IsCUDAGraphEnabled() const { return enable_cuda_graph_; } + + BatchPrefillHandler(bool enable_cuda_graph = false) + : enable_cuda_graph_(enable_cuda_graph), stream_(nullptr) { + cudaMallocHost(&page_locked_buffer_, 8 * 1024 * 1024); + } + ~BatchPrefillHandler() { cudaFreeHost(page_locked_buffer_); } + + PrefillPlanInfo GetPlanInfo() const { return plan_info_; } + + template + IdType* GetRequestIndices() { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.request_indices_offset); + } + + template + IdType* GetQOTileIndices() { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.qo_tile_indices_offset); + } + + template + IdType* GetKVTileIndices() { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.kv_tile_indices_offset); + } + + template + IdType* GetOIndptr() { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.o_indptr_offset); + } + + template + IdType* GetKVChunkSizePtr() { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.kv_chunk_size_ptr_offset); + } + + template + IdType* GetMergeIndptr() { + if (plan_info_.split_kv) { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.merge_indptr_offset); + } + return nullptr; + } + + template + DTypeO* GetTmpV() { + if (plan_info_.split_kv) { + return GetPtrFromBaseOffset(float_buffer_, plan_info_.v_offset); + } + return nullptr; + } + + float* GetTmpS() { + if (plan_info_.split_kv) { + return GetPtrFromBaseOffset(float_buffer_, plan_info_.s_offset); + } + return nullptr; + } + + bool* GetBlockValidMask() { + if (plan_info_.split_kv && plan_info_.enable_cuda_graph) { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.block_valid_mask_offset); + } + return nullptr; + } + + protected: + void* page_locked_buffer_; + void* int_buffer_; + void* float_buffer_; + PrefillPlanInfo plan_info_; + bool enable_cuda_graph_; + cudaStream_t stream_; +}; + +template +cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp, + cudaStream_t stream); + +template cudaError_t SinglePrefillWithKVCacheCustomMask( - DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp, - float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, + DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeO* o, DTypeO* tmp, float* lse, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, @@ -38,13 +260,20 @@ cudaError_t SinglePrefillWithKVCacheCustomMask( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {DISPATCH_head_dim( head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return SinglePrefillWithKVCacheDispatched( - q, k, v, custom_mask, o, tmp, lse, num_qo_heads, num_kv_heads, qo_len, kv_len, - qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, - /*window_left=*/-1, - /*logits_soft_cap*/ 0.f, sm_scale, rope_scale, rope_theta, stream); + using ParamsT = SinglePrefillParams; + using AttentionVariant = + ComposedAttention; + ParamsT params(q, k, v, custom_mask, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, + qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, head_dim, + /*window_left=*/-1, + /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta); + return SinglePrefillWithKVCacheDispatched(params, tmp, stream); })})}); return cudaSuccess; } @@ -52,7 +281,7 @@ cudaError_t SinglePrefillWithKVCacheCustomMask( /*! * \brief FlashAttention prefill CUDA function for a single request. * \tparam DTypeIn The data type of input - * \tparam DTypeOut The data type of output + * \tparam DTypeO The data type of output * \param q The query tensor. * \param k The key tensor. * \param v The value tensor. @@ -73,8 +302,8 @@ cudaError_t SinglePrefillWithKVCacheCustomMask( * \param stream The cuda stream to execute the kernel on. * \return status Indicates whether CUDA calls are successful */ -template -cudaError_t SinglePrefillWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, +template +cudaError_t SinglePrefillWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeO* o, DTypeO* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, @@ -94,21 +323,29 @@ cudaError_t SinglePrefillWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut {DISPATCH_head_dim( head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return SinglePrefillWithKVCacheDispatched( - q, k, v, /*custom_mask=*/nullptr, o, tmp, lse, num_qo_heads, num_kv_heads, - qo_len, kv_len, qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, - /*window_left=*/-1, - /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); + using ParamsT = SinglePrefillParams; + using AttentionVariant = + ComposedAttention; + ParamsT params(q, k, v, /*custom_mask=*/nullptr, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, + qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, head_dim, + /*window_left=*/-1, + /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta); + return SinglePrefillWithKVCacheDispatched(params, tmp, stream); })})})}); return cudaSuccess; } -template +template cudaError_t BatchPrefillWithRaggedKVCacheWrapper( BatchPrefillHandler* handler, DTypeQ* q, IdType* qo_indptr, DTypeKV* k, DTypeKV* v, - IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, + IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeO* o, float* lse, const uint32_t batch_size, const uint32_t num_qo_heads, const uint32_t num_kv_heads, const uint32_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, @@ -118,38 +355,60 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = get_qkv_strides(kv_layout, 0, num_qo_heads, num_kv_heads, head_dim); + auto plan_info = handler->GetPlanInfo(); + auto warp_layout = WarpLayout(plan_info.warp_layout_code); DISPATCH_head_dim( head_dim, HEAD_DIM, {DISPATCH_mask_mode( mask_mode, MASK_MODE, {DISPATCH_pos_encoding_mode( - pos_encoding_mode, pos_encoding_mode, + pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_allow_fp16_qk_reduction(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { - return BatchPrefillWithRaggedKVCacheWrapperDispatched< - HEAD_DIM, LogitsPostHook::kNone, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, - MASK_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( - handler, q, qo_indptr, k, v, kv_indptr, /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, num_qo_heads, - num_kv_heads, qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, - /*window_left=*/-1, - /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); + using ParamsT = BatchPrefillRaggedParams; + using AttentionVariant = + ComposedAttention; + ParamsT params(q, k, v, /*custom_mask=*/nullptr, qo_indptr, kv_indptr, + /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, qo_stride_n, + qo_stride_h, kv_stride_n, kv_stride_h, /*window_left=*/-1, + /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta); + params.request_indices = handler->GetRequestIndices(); + params.qo_tile_indices = handler->GetQOTileIndices(); + params.kv_tile_indices = handler->GetKVTileIndices(); + params.o_indptr = handler->GetOIndptr(); + params.kv_chunk_size_ptr = handler->GetKVChunkSizePtr(); + params.merge_indptr = handler->GetMergeIndptr(); + params.block_valid_mask = handler->GetBlockValidMask(); + params.total_num_rows = plan_info.total_num_rows; + params.padded_batch_size = plan_info.padded_batch_size; + + DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { + BatchPrefillWithRaggedKVCacheDispatched( + params, handler->GetTmpV(), handler->GetTmpS(), stream); + }); })})})}); return cudaSuccess; } -template +template cudaError_t BatchPrefillWithPagedKVCacheWrapper( BatchPrefillHandler* handler, DTypeQ* q, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, - uint32_t num_qo_heads, bool causal = true, - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + paged_kv_t paged_kv, DTypeO* o, float* lse, uint32_t num_qo_heads, + bool causal = true, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim))); const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + auto plan_info = handler->GetPlanInfo(); + auto warp_layout = WarpLayout(plan_info.warp_layout_code); DISPATCH_head_dim( head_dim, HEAD_DIM, {DISPATCH_mask_mode( @@ -157,19 +416,44 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( {DISPATCH_pos_encoding_mode( pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_allow_fp16_qk_reduction(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { - return BatchPrefillWithPagedKVCacheWrapperDispatched< - PAGE_STORAGE, HEAD_DIM, LogitsPostHook::kNone, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( - handler, q, qo_indptr, q_offset, paged_kv, - /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, o, lse, num_qo_heads, /*window_left=*/-1, - /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); + using ParamsT = BatchPrefillPagedParams; + using AttentionVariant = + ComposedAttention; + ParamsT params(q, paged_kv, /*custom_mask=*/nullptr, qo_indptr, + /*qk_indptr=*/nullptr, q_offset, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, + /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, + rope_theta); + params.request_indices = handler->GetRequestIndices(); + params.qo_tile_indices = handler->GetQOTileIndices(); + params.kv_tile_indices = handler->GetKVTileIndices(); + params.o_indptr = handler->GetOIndptr(); + params.kv_chunk_size_ptr = handler->GetKVChunkSizePtr(); + params.merge_indptr = handler->GetMergeIndptr(); + params.block_valid_mask = handler->GetBlockValidMask(); + params.total_num_rows = plan_info.total_num_rows; + params.padded_batch_size = plan_info.padded_batch_size; + DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { + return BatchPrefillWithPagedKVCacheDispatched< + WARP_LAYOUT, HEAD_DIM, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, + AttentionVariant>(params, handler->GetTmpV(), handler->GetTmpS(), + stream); + }) })})})}); return cudaSuccess; } -template -cudaError_t SingleDecodeWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, +template +cudaError_t SingleDecodeWithKVCacheDispatched(typename AttentionVariant::ParamsT params, + typename AttentionVariant::DTypeO* tmp, + cudaStream_t stream); + +template +cudaError_t SingleDecodeWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeO* o, DTypeO* tmp, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t seq_len, uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, @@ -186,54 +470,27 @@ cudaError_t SingleDecodeWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* DISPATCH_head_dim( head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - SingleDecodeWithKVCacheDispatched( - q, k, v, o, tmp, num_qo_heads, num_kv_heads, seq_len, kv_layout, - /*window_left=*/-1, - /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); - })}); - return cudaSuccess; -} + using ParamsT = SingleDecodeParams; + using AttentionVariant = + ComposedAttention; + ParamsT params(q, k, v, o, /*alibi_slopes=*/nullptr, seq_len, num_qo_heads, num_kv_heads, + kv_layout, head_dim, /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, + rope_scale, rope_theta); -template -cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( - DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, - kv_partition_info_t kv_partition_info, DTypeOut* o, float* lse, uint32_t num_qo_heads, - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, - std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, - float rope_theta = 1e4, cudaStream_t stream = nullptr) { - const uint32_t num_kv_heads = paged_kv.num_heads; - const uint32_t head_dim = paged_kv.head_dim; - const uint32_t batch_size = paged_kv.batch_size; - const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); - if (num_qo_heads % num_kv_heads != 0) { - std::ostringstream err_msg; - err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of num_kv_heads " - << num_kv_heads; - throw std::invalid_argument(err_msg.str()); - } - - DISPATCH_head_dim( - head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return BatchDecodeWithPagedKVCacheDispatched( - q, q_offset, paged_kv, kv_partition_info, o, /*tmp_v=*/nullptr, /*tmp_s=*/nullptr, lse, - /*block_valid_mask=*/nullptr, /*padded_batch_size=*/paged_kv.batch_size, num_qo_heads, - /*window_left=*/-1, - /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); + SingleDecodeWithKVCacheDispatched( + params, tmp, stream); })}); - return cudaSuccess; } /*! * \brief Wrapper of BatchDecodeWithPagedKVCache function, and caches the temporary buffer * for cooperative kernels. - * \tparam page_storage Whether to store indices or pointers of each active page * \tparam DTypeQ The data type of query tensor. * \tparam DTypeKV The data type of key-value tensor. - * \tparam DTypeOut The data type of output tensor. + * \tparam DTypeO The data type of output tensor. * \tparam IdType The data type of index tensor. * \param handler The handler for the batch decode forward request. * \param q The input tensor. @@ -246,12 +503,11 @@ cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( * \param rope_theta The theta of rope. * \param stream The CUDA stream. */ -template +template cudaError_t BatchDecodeWithPagedKVCacheWrapper( - BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, - uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, + DTypeO* o, float* lse, uint32_t num_qo_heads, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim))); @@ -263,20 +519,31 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( throw std::invalid_argument(err_msg.str()); } - DISPATCH_head_dim(paged_kv.head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return BatchDecodeWithPagedKVCacheWrapperDispatched< - PAGE_STORAGE, HEAD_DIM, LogitsPostHook::kNone, POS_ENCODING_MODE, DTypeQ, - DTypeKV, DTypeOut, IdType>( - handler, q, q_offset, paged_kv, o, lse, num_qo_heads, - /*window_left=*/-1, - /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); - })}); + DISPATCH_head_dim( + paged_kv.head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + using ParamsT = BatchDecodeParams; + using AttentionVariant = + ComposedAttention; + ParamsT params(q, q_offset, paged_kv, o, lse, /*alibi_slopes=*/nullptr, num_qo_heads, + /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, + rope_theta); + params.request_indices = handler->GetRequestIndices(); + params.kv_tile_indices = handler->GetKVTileIndices(); + params.o_indptr = handler->GetOIndptr(); + params.kv_chunk_size_ptr = handler->GetKVChunkSizePtr(); + params.block_valid_mask = handler->GetBlockValidMask(); + params.padded_batch_size = handler->GetPlanInfo().padded_batch_size; + + return BatchDecodeWithPagedKVCacheDispatched( + params, handler->GetTmpV(), handler->GetTmpS(), stream); + })}); return cudaSuccess; } -template +template cudaError_t BatchDecodeHandlerPlan(BatchDecodeHandler* handler, void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, size_t int_workspace_size_in_bytes, IdType* indptr_h, @@ -291,8 +558,7 @@ cudaError_t BatchDecodeHandlerPlan(BatchDecodeHandler* handler, void* float_buff } DISPATCH_head_dim(head_dim, HEAD_DIM, { DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return handler->PlanDispatched( + return handler->PlanDispatched( float_buffer, float_workspace_size_in_bytes, int_buffer, int_workspace_size_in_bytes, indptr_h, last_page_len_h, batch_size, num_qo_heads, num_kv_heads, page_size); }); diff --git a/src/test_attn_all_reduce.cu b/src/test_attn_all_reduce.cu index cb69f814f..45c5e110a 100644 --- a/src/test_attn_all_reduce.cu +++ b/src/test_attn_all_reduce.cu @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include #include #include diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index 8098661fe..7862a03ea 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -28,8 +28,7 @@ constexpr QKVLayout kv_layout = QKVLayout::kNHD; template void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, - flashinfer::PosEncodingMode pos_encoding_mode, - bool cooperative) { + flashinfer::PosEncodingMode pos_encoding_mode) { std::vector seq_lens(batch_size); utils::vec_randint_(seq_lens, 1, 1024); std::vector append_indptr{0}; @@ -77,7 +76,7 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si assert(q.size() == batch_size * num_qo_heads * head_dim); assert(o_ref.size() == batch_size * num_qo_heads * head_dim); - flashinfer::paged_kv_t paged_kv_cpu( + flashinfer::paged_kv_t paged_kv_cpu( num_kv_heads, page_size, head_dim, batch_size, kv_layout, kv_data.data(), kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); cpu_reference::append_paged_kv_cache(paged_kv_cpu, keys, values, append_indptr); @@ -91,7 +90,7 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si thrust::device_vector o_device(o_ref.size()); // create paged_kv object - flashinfer::paged_kv_t paged_kv( + flashinfer::paged_kv_t paged_kv( num_kv_heads, page_size, head_dim, batch_size, kv_layout, thrust::raw_pointer_cast(kv_data_device.data()), thrust::raw_pointer_cast(kv_indices_device.data()), @@ -102,30 +101,18 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si thrust::device_vector float_buffer(float_workspace_size_in_bytes); size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; thrust::device_vector int_buffer(int_workspace_size_in_bytes); - BatchDecodeHandlerPlan( + BatchDecodeHandlerPlan( &handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, kv_indptr.data(), kv_last_page_len.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode); - if (!cooperative) { - // use non-cooperative kernel - cudaError_t status = - flashinfer::BatchDecodeWithPagedKVCacheNoSplitKV( - thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv, - kv_partition_info_t(), thrust::raw_pointer_cast(o_device.data()), - /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); - EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); - } else { - cudaError_t status = - flashinfer::BatchDecodeWithPagedKVCacheWrapper( - &handler, thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv, - thrust::raw_pointer_cast(o_device.data()), /*lse=*/nullptr, num_qo_heads, - pos_encoding_mode); - EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); - } + cudaError_t status = + flashinfer::BatchDecodeWithPagedKVCacheWrapper( + &handler, thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv, + thrust::raw_pointer_cast(o_device.data()), /*lse=*/nullptr, num_qo_heads, + pos_encoding_mode); + EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); // compare result thrust::host_vector o_host = o_device; size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; @@ -150,25 +137,6 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si template void TestBatchDecodeKernelCorrectness() { - for (size_t page_size : {1, 3, 7, 16}) { - for (size_t batch_size : {1, 7, 37, 61}) { - for (size_t num_qo_heads : {32}) { - for (size_t num_kv_heads : {32, 8, 4}) { - for (size_t head_dim : {64, 128, 256}) { - for (size_t pos_encoding_mode : {0U, 1U}) { - _TestBatchDecodingKernelCorrectness( - page_size, batch_size, num_qo_heads, num_kv_heads, head_dim, - flashinfer::PosEncodingMode(pos_encoding_mode), false); - } - } - } - } - } - } -} - -template -void TestCooperativeBatchDecodeKernelCorrectness() { for (size_t page_size : {1, 3, 7, 16}) { for (size_t batch_size : {1, 2, 4, 8}) { for (size_t num_qo_heads : {32}) { @@ -177,7 +145,7 @@ void TestCooperativeBatchDecodeKernelCorrectness() { for (size_t pos_encoding_mode : {0U, 1U}) { _TestBatchDecodingKernelCorrectness( page_size, batch_size, num_qo_heads, num_kv_heads, head_dim, - flashinfer::PosEncodingMode(pos_encoding_mode), true); + flashinfer::PosEncodingMode(pos_encoding_mode)); } } } @@ -205,7 +173,3 @@ TEST(FlashInferCorrectnessTest, TestBatchDecodeKernelCorrectnessE5M2) { TestBatchDecodeKernelCorrectness(); } #endif - -TEST(FlashInferCorrectnessTest, TestCooperativeBatchDecodeKernelCorrectnessTestFP16) { - TestCooperativeBatchDecodeKernelCorrectness(); -} diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index b4dfe3d71..c1cfdf355 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -62,7 +62,7 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n } kv_data.resize(page_counter * 2 * num_kv_heads * page_size * head_dim); - flashinfer::paged_kv_t paged_kv_cpu( + flashinfer::paged_kv_t paged_kv_cpu( num_kv_heads, page_size, head_dim, batch_size, kv_layout, kv_data.data(), kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); cpu_reference::append_paged_kv_cache(paged_kv_cpu, key, value, append_indptr); @@ -74,7 +74,7 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n thrust::device_vector kv_last_page_len_device(kv_last_page_len); // create paged_kv object - flashinfer::paged_kv_t paged_kv = paged_kv_cpu; + flashinfer::paged_kv_t paged_kv = paged_kv_cpu; paged_kv.k_data = thrust::raw_pointer_cast(kv_data_device.data()); paged_kv.v_data = paged_kv.k_data + paged_kv_cpu.kv_ptr_delta(); paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data()); @@ -112,12 +112,12 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); for (uint32_t num_runs = 0; num_runs < 10; ++num_runs) { - auto status = flashinfer::BatchPrefillWithPagedKVCacheWrapper( - &handler, thrust::raw_pointer_cast(q_device.data()), - thrust::raw_pointer_cast(q_indptr_device.data()), /*q_offset=*/nullptr, paged_kv, - thrust::raw_pointer_cast(o_device.data()), - /*lse=*/nullptr, num_qo_heads, causal, pos_encoding_mode, allow_fp16_qk_reduction); + auto status = + flashinfer::BatchPrefillWithPagedKVCacheWrapper( + &handler, thrust::raw_pointer_cast(q_device.data()), + thrust::raw_pointer_cast(q_indptr_device.data()), /*q_offset=*/nullptr, paged_kv, + thrust::raw_pointer_cast(o_device.data()), + /*lse=*/nullptr, num_qo_heads, causal, pos_encoding_mode, allow_fp16_qk_reduction); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); } @@ -278,7 +278,7 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si } kv_data.resize(page_counter * 2 * num_kv_heads * page_size * head_dim); - flashinfer::paged_kv_t paged_kv_cpu( + flashinfer::paged_kv_t paged_kv_cpu( num_kv_heads, page_size, head_dim, batch_size, kv_layout, kv_data.data(), kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); cpu_reference::append_paged_kv_cache(paged_kv_cpu, key, value, append_indptr); @@ -290,7 +290,7 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si thrust::device_vector kv_last_page_len_device(kv_last_page_len); // create paged_kv object - flashinfer::paged_kv_t paged_kv = paged_kv_cpu; + flashinfer::paged_kv_t paged_kv = paged_kv_cpu; paged_kv.k_data = thrust::raw_pointer_cast(kv_data_device.data()); paged_kv.v_data = paged_kv.k_data + paged_kv_cpu.kv_ptr_delta(); paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data()); @@ -334,8 +334,7 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si int_workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); - auto status = BatchPrefillWithPagedKVCacheWrapper( + auto status = BatchPrefillWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q_device.data()), thrust::raw_pointer_cast(q_indptr_device.data()), /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), @@ -404,7 +403,7 @@ void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( } kv_data.resize(page_counter * 2 * num_kv_heads * page_size * head_dim); - flashinfer::paged_kv_t paged_kv_cpu( + flashinfer::paged_kv_t paged_kv_cpu( num_kv_heads, page_size, head_dim, batch_size, kv_layout, kv_data.data(), kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); cpu_reference::append_paged_kv_cache(paged_kv_cpu, key, value, append_indptr); @@ -416,7 +415,7 @@ void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( thrust::device_vector kv_last_page_len_device(kv_last_page_len); // create paged_kv object - flashinfer::paged_kv_t paged_kv = paged_kv_cpu; + flashinfer::paged_kv_t paged_kv = paged_kv_cpu; paged_kv.k_data = thrust::raw_pointer_cast(kv_data_device.data()); paged_kv.v_data = paged_kv.k_data + paged_kv_cpu.kv_ptr_delta(); paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data()); @@ -461,8 +460,7 @@ void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( int_workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); - auto status = BatchPrefillWithPagedKVCacheWrapper( + auto status = BatchPrefillWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q_device.data()), thrust::raw_pointer_cast(q_indptr_device.data()), /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), @@ -518,7 +516,7 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz } kv_data.resize(page_counter * 1 * 2 * num_kv_heads * page_size * head_dim); - flashinfer::paged_kv_t paged_kv_cpu( + flashinfer::paged_kv_t paged_kv_cpu( num_kv_heads, page_size, head_dim, 1, kv_layout, kv_data.data(), kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); cpu_reference::append_paged_kv_cache(paged_kv_cpu, {k}, {v}, append_indptr); @@ -530,7 +528,7 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz thrust::device_vector kv_last_page_len_device(kv_last_page_len); // create paged_kv object - flashinfer::paged_kv_t paged_kv = paged_kv_cpu; + flashinfer::paged_kv_t paged_kv = paged_kv_cpu; paged_kv.k_data = thrust::raw_pointer_cast(kv_data_device.data()); paged_kv.v_data = paged_kv.k_data + paged_kv_cpu.kv_ptr_delta(); paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data()); @@ -561,8 +559,7 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz append_indptr.data(), kv_indptr.data(), /*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim, page_size); - auto status = BatchPrefillWithPagedKVCacheWrapper( + auto status = BatchPrefillWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q_device.data()), thrust::raw_pointer_cast(q_indptr_device.data()), /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), diff --git a/src/test_cascade.cu b/src/test_cascade.cu index a74439fbf..22b73bc25 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -266,16 +266,14 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, kv_last_page_len_combined_d(kv_last_page_len_combined_h), kv_last_page_len_unique_d(kv_last_page_len_unique_h); - constexpr PageStorage page_storage = PageStorage::kIndices; - - paged_kv_t paged_kv_baseline_d( + paged_kv_t paged_kv_baseline_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, thrust::raw_pointer_cast(kv_data_d.data()), thrust::raw_pointer_cast(kv_indices_combined_d.data()), thrust::raw_pointer_cast(kv_indptr_combined_d.data()), thrust::raw_pointer_cast(kv_last_page_len_combined_d.data())); - paged_kv_t paged_kv_casacde_d( + paged_kv_t paged_kv_casacde_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, thrust::raw_pointer_cast(kv_data_d.data()), thrust::raw_pointer_cast(kv_indices_unique_d.data()), @@ -289,20 +287,20 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; thrust::device_vector int_buffer(int_workspace_size_in_bytes); - BatchDecodeHandlerPlan( + BatchDecodeHandlerPlan( &baseline_handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); - BatchDecodeHandlerPlan( + BatchDecodeHandlerPlan( &cascade_handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); // Compute result using baseline implementation - cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( + cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), /*q_offset=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), /*lse=*/nullptr, num_qo_heads, PosEncodingMode::kNone); @@ -322,7 +320,7 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, EXPECT_EQ(status, cudaSuccess) << "Cascade implementation prefill failed with error: " << cudaGetErrorString(status); - status = BatchDecodeWithPagedKVCacheWrapper( + status = BatchDecodeWithPagedKVCacheWrapper( &cascade_handler, thrust::raw_pointer_cast(q_d.data()), /*q_offset=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), /*lse=*/thrust::raw_pointer_cast(lse_cascade_1_d.data()), num_qo_heads, @@ -394,16 +392,14 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, kv_last_page_len_combined_d(kv_last_page_len_combined_h), kv_last_page_len_unique_d(kv_last_page_len_unique_h); - constexpr PageStorage page_storage = PageStorage::kIndices; - - paged_kv_t paged_kv_baseline_d( + paged_kv_t paged_kv_baseline_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, thrust::raw_pointer_cast(kv_data_d.data()), thrust::raw_pointer_cast(kv_indices_combined_d.data()), thrust::raw_pointer_cast(kv_indptr_combined_d.data()), thrust::raw_pointer_cast(kv_last_page_len_combined_d.data())); - paged_kv_t paged_kv_casacde_d( + paged_kv_t paged_kv_casacde_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, thrust::raw_pointer_cast(kv_data_d.data()), thrust::raw_pointer_cast(kv_indices_unique_d.data()), @@ -427,7 +423,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, qo_indptr_h.data(), kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); - cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( + cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), /*q_offset=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), @@ -450,7 +446,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, << "Cascade implementation shared prefix prefill failed with error: " << cudaGetErrorString(status); - status = BatchPrefillWithPagedKVCacheWrapper( + status = BatchPrefillWithPagedKVCacheWrapper( &cascade_handler, thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), /*r_rope_position=*/nullptr, paged_kv_casacde_d, diff --git a/src/test_page.cu b/src/test_page.cu index ce5f7c7b5..f7b5bacdf 100644 --- a/src/test_page.cu +++ b/src/test_page.cu @@ -79,19 +79,19 @@ void _TestAppendPagedKVKernelCorrectness(size_t page_size, size_t batch_size, si } indptr_cpu.push_back(indptr_cpu.back() + page_indices[i].size()); } - paged_kv_t paged_kv_cpu( - num_heads, page_size, head_dim, batch_size, kv_layout, kv_data_cpu.data(), - indices_cpu.data(), indptr_cpu.data(), last_page_len.data()); + paged_kv_t paged_kv_cpu(num_heads, page_size, head_dim, batch_size, kv_layout, + kv_data_cpu.data(), indices_cpu.data(), indptr_cpu.data(), + last_page_len.data()); cpu_reference::append_paged_kv_cache(paged_kv_cpu, keys, values, append_indptr); thrust::device_vector indptr_gpu(indptr_cpu); thrust::device_vector indices_gpu(indices_cpu); thrust::device_vector last_page_len_gpu(last_page_len); - paged_kv_t paged_kv_gpu( - num_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data_gpu.data()), thrust::raw_pointer_cast(indices_gpu.data()), - thrust::raw_pointer_cast(indptr_gpu.data()), - thrust::raw_pointer_cast(last_page_len_gpu.data())); + paged_kv_t paged_kv_gpu(num_heads, page_size, head_dim, batch_size, kv_layout, + thrust::raw_pointer_cast(kv_data_gpu.data()), + thrust::raw_pointer_cast(indices_gpu.data()), + thrust::raw_pointer_cast(indptr_gpu.data()), + thrust::raw_pointer_cast(last_page_len_gpu.data())); thrust::device_vector append_indptr_gpu(append_indptr); thrust::device_vector keys_gpu(append_indptr.back() * num_heads * head_dim); diff --git a/src/test_single_prefill.cu b/src/test_single_prefill.cu index 08afb71be..a766c76b6 100644 --- a/src/test_single_prefill.cu +++ b/src/test_single_prefill.cu @@ -23,7 +23,7 @@ using namespace flashinfer; -template +template void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, bool causal, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, @@ -32,7 +32,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu std::vector q(qo_len * num_qo_heads * head_dim); std::vector k(kv_len * num_kv_heads * head_dim); std::vector v(kv_len * num_kv_heads * head_dim); - std::vector o(qo_len * num_qo_heads * head_dim); + std::vector o(qo_len * num_qo_heads * head_dim); utils::vec_normal_(q); utils::vec_normal_(k); @@ -42,10 +42,10 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu thrust::device_vector q_d(q); thrust::device_vector k_d(k); thrust::device_vector v_d(v); - thrust::device_vector o_d(o); - thrust::device_vector tmp_d(16 * 1024 * 1024); + thrust::device_vector o_d(o); + thrust::device_vector tmp_d(16 * 1024 * 1024); - cudaError_t status = flashinfer::SinglePrefillWithKVCache( + cudaError_t status = flashinfer::SinglePrefillWithKVCache( thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(k_d.data()), thrust::raw_pointer_cast(v_d.data()), thrust::raw_pointer_cast(o_d.data()), thrust::raw_pointer_cast(tmp_d.data()), @@ -55,8 +55,8 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu EXPECT_EQ(status, cudaSuccess) << "SinglePrefillWithKVCache kernel launch failed, error message: " << cudaGetErrorString(status); - thrust::host_vector o_h(o_d); - std::vector o_ref = cpu_reference::single_mha( + thrust::host_vector o_h(o_d); + std::vector o_ref = cpu_reference::single_mha( q, k, v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, kv_layout, pos_encoding_mode); size_t num_results_error_atol = 0; @@ -83,7 +83,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu EXPECT_FALSE(nan_detected) << "Nan detected in the result."; } -template +template void TestSinglePrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduction) { for (size_t qo_len : {1, 31, 63, 127}) { for (size_t kv_len : {31717}) { @@ -92,7 +92,7 @@ void TestSinglePrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduction) for (bool causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { - _TestSinglePrefillKernelCorrectness( + _TestSinglePrefillKernelCorrectness( qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); } @@ -125,10 +125,10 @@ void TestSinglePrefillFP8KernelLongContextCorrectness(bool allow_fp16_qk_reducti } } -template +template void TestSinglePrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduction) { - float rtol = std::is_same::value ? 1e-2 : 1e-3; - float atol = std::is_same::value ? 1e-2 : 1e-3; + float rtol = std::is_same::value ? 1e-2 : 1e-3; + float atol = std::is_same::value ? 1e-2 : 1e-3; for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { for (size_t num_qo_heads : {32}) { for (size_t num_kv_heads : {4, 8, 32}) { @@ -136,7 +136,7 @@ void TestSinglePrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduction for (bool causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { - _TestSinglePrefillKernelCorrectness( + _TestSinglePrefillKernelCorrectness( qkv_len, qkv_len, num_qo_heads, num_kv_heads, head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction, rtol, atol); @@ -173,7 +173,7 @@ void TestSinglePrefillFP8KernelShortContextCorrectness(bool allow_fp16_qk_reduct } } -template +template void TestSinglePrefillKernelCorrectness(bool allow_fp16_qk_reduction) { for (size_t qo_len : {399, 400, 401}) { for (size_t kv_len : {533, 534, 535}) { @@ -182,7 +182,7 @@ void TestSinglePrefillKernelCorrectness(bool allow_fp16_qk_reduction) { for (bool causal : {false, true}) { for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { - _TestSinglePrefillKernelCorrectness( + _TestSinglePrefillKernelCorrectness( qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); } diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index fdfa6f28a..d9793fa3b 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -235,7 +235,6 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q CHECK_EQ(k_rope_pos_offset->ndim, 1); CHECK_EQ(k_rope_pos_offset->shape[0], num_total_seqs); - constexpr PageStorage page_storage = PageStorage::kIndices; constexpr QKVLayout kv_layout = QKVLayout::kHND; const float sm_scale = attn_score_scaling_factor / std::sqrt(static_cast(nfeat)); @@ -243,7 +242,7 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q pages->dtype, dtype_in, {DISPATCH_TVM_CUDA_DTYPE( output->dtype, dtype_out, {DISPATCH_TVM_CUDA_IDTYPE(page_table_values->dtype, dtype_idx, { - paged_kv_t cache( + paged_kv_t cache( nhead_kv, page_size, nfeat, num_total_seqs, kv_layout, static_cast(pages->data), static_cast(page_table_values->data) + @@ -254,17 +253,19 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q last_page_len->byte_offset / sizeof(dtype_idx), static_cast(k_rope_pos_offset->data) + k_rope_pos_offset->byte_offset / sizeof(dtype_idx)); - cudaError_t status = BatchPrefillWithPagedKVCacheWrapper< - page_storage, dtype_in, dtype_in, dtype_out, dtype_idx>( - &batch_prefill_paged_kv_handlers[handler_id], static_cast(q_data->data), - static_cast(qo_indptr->data) + - qo_indptr->byte_offset / sizeof(dtype_idx), - static_cast(q_offset->data) + q_offset->byte_offset / sizeof(dtype_idx), - cache, static_cast(output->data), - /*lse=*/static_cast(lse->data), nhead_qo, - /*causal=*/causal, PosEncodingMode(pos_encoding_mode), - /*allow_fp16_qk_reduction=*/false, sm_scale, rope_scale, rope_theta, - /*stream=*/0); + cudaError_t status = + BatchPrefillWithPagedKVCacheWrapper( + &batch_prefill_paged_kv_handlers[handler_id], + static_cast(q_data->data), + static_cast(qo_indptr->data) + + qo_indptr->byte_offset / sizeof(dtype_idx), + static_cast(q_offset->data) + + q_offset->byte_offset / sizeof(dtype_idx), + cache, static_cast(output->data), + /*lse=*/static_cast(lse->data), nhead_qo, + /*causal=*/causal, PosEncodingMode(pos_encoding_mode), + /*allow_fp16_qk_reduction=*/false, sm_scale, rope_scale, rope_theta, + /*stream=*/0); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); } @@ -381,7 +382,6 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ CHECK_EQ(k_rope_pos_offset->ndim, 1); CHECK_EQ(k_rope_pos_offset->shape[0], num_total_seqs); - constexpr PageStorage page_storage = PageStorage::kIndices; constexpr QKVLayout kv_layout = QKVLayout::kHND; const float sm_scale = attn_score_scaling_factor / std::sqrt(static_cast(nfeat)); @@ -389,7 +389,7 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ pages->dtype, dtype_in, {DISPATCH_TVM_CUDA_DTYPE( output->dtype, dtype_out, {DISPATCH_TVM_CUDA_IDTYPE(page_table_values->dtype, dtype_idx, { - paged_kv_t cache( + paged_kv_t cache( nhead_kv, page_size, nfeat, num_total_seqs, kv_layout, static_cast(pages->data), static_cast(page_table_values->data) + @@ -400,14 +400,15 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ last_page_len->byte_offset / sizeof(dtype_idx), static_cast(k_rope_pos_offset->data) + k_rope_pos_offset->byte_offset / sizeof(dtype_idx)); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( - &batch_decode_handlers[handler_id], static_cast(q_data->data), - static_cast(q_offset->data) + q_offset->byte_offset / sizeof(dtype_idx), - cache, static_cast(output->data), - /*lse=*/static_cast(lse->data), nhead_qo, - PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, - /*stream=*/0); + cudaError_t status = + BatchDecodeWithPagedKVCacheWrapper( + &batch_decode_handlers[handler_id], static_cast(q_data->data), + static_cast(q_offset->data) + + q_offset->byte_offset / sizeof(dtype_idx), + cache, static_cast(output->data), + /*lse=*/static_cast(lse->data), nhead_qo, + PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, + /*stream=*/0); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); } @@ -427,7 +428,6 @@ void _FlashInferAttentionDecodeWithPagedKVCachePlan( int_workspace_buffer->shape[0] * int_workspace_buffer->dtype.bits / 8; CHECK_LT(handler_idx, max_num_handlers) << "The handler id must be less than " << max_num_handlers; - constexpr PageStorage page_storage = PageStorage::kIndices; // NOTE(Zihao): here we presume the input data type is half, in the future we should // leave a parameter for the input data type. using dtype_in = half; @@ -435,17 +435,16 @@ void _FlashInferAttentionDecodeWithPagedKVCachePlan( cudaStream_t original_stream = batch_decode_handlers[handler_idx].GetCUDAStream(); batch_decode_handlers[handler_idx].SetCUDAStream(static_cast(copy_stream)); DISPATCH_TVM_CUDA_IDTYPE(page_table_indptr->dtype, dtype_idx, { - cudaError_t status = - BatchDecodeHandlerPlan( - batch_decode_handlers + handler_idx, static_cast(float_workspace_buffer->data), - float_workspace_size_in_bytes, static_cast(int_workspace_buffer->data), - int_workspace_size_in_bytes, - static_cast(page_table_indptr->data) + - page_table_indptr->byte_offset / sizeof(dtype_idx), - static_cast(last_page_len->data) + - last_page_len->byte_offset / sizeof(dtype_idx), - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, - PosEncodingMode(pos_encoding_mode)); + cudaError_t status = BatchDecodeHandlerPlan( + batch_decode_handlers + handler_idx, static_cast(float_workspace_buffer->data), + float_workspace_size_in_bytes, static_cast(int_workspace_buffer->data), + int_workspace_size_in_bytes, + static_cast(page_table_indptr->data) + + page_table_indptr->byte_offset / sizeof(dtype_idx), + static_cast(last_page_len->data) + + last_page_len->byte_offset / sizeof(dtype_idx), + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, + PosEncodingMode(pos_encoding_mode)); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer decode Plan error " << cudaGetErrorString(status); } diff --git a/python/tests/alibi_reference.py b/tests/alibi_reference.py similarity index 100% rename from python/tests/alibi_reference.py rename to tests/alibi_reference.py diff --git a/python/tests/rope_reference.py b/tests/rope_reference.py similarity index 100% rename from python/tests/rope_reference.py rename to tests/rope_reference.py diff --git a/python/tests/test_activation.py b/tests/test_activation.py similarity index 91% rename from python/tests/test_activation.py rename to tests/test_activation.py index cbbecfd2f..13301f2e0 100644 --- a/python/tests/test_activation.py +++ b/tests/test_activation.py @@ -27,7 +27,7 @@ def test_fused_silu_mul(dim, batch_size, seq_len): y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim]) y = flashinfer.activation.silu_and_mul(x) torch.testing.assert_close( - y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3 + y_ref, y, rtol=1e-3, atol=1e-3 ) @@ -39,7 +39,7 @@ def test_fused_gelu_tanh_mul(dim, batch_size, seq_len): y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh") y = flashinfer.activation.gelu_tanh_and_mul(x) torch.testing.assert_close( - y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3 + y_ref, y, rtol=1e-3, atol=1e-3 ) @@ -51,5 +51,5 @@ def test_fused_gelu_mul(dim, batch_size, seq_len): y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none") y = flashinfer.activation.gelu_and_mul(x) torch.testing.assert_close( - y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3 + y_ref, y, rtol=1e-3, atol=1e-3 ) diff --git a/python/tests/test_alibi.py b/tests/test_alibi.py similarity index 86% rename from python/tests/test_alibi.py rename to tests/test_alibi.py index a35dcb119..b9605f46a 100644 --- a/python/tests/test_alibi.py +++ b/tests/test_alibi.py @@ -14,14 +14,16 @@ limitations under the License. """ -import flashinfer +import numpy import pytest import torch +import flashinfer + from alibi_reference import alibi_attention -@pytest.mark.parametrize("seq_len", [1, 9, 81, 729, 33001]) +@pytest.mark.parametrize("seq_len", [1, 9, 81, 729]) @pytest.mark.parametrize("num_heads", [4, 8, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) def test_single_decode_alibi( @@ -37,12 +39,12 @@ def test_single_decode_alibi( mask = torch.ones(1, seq_len, dtype=torch.bool).to(0) o_ref = alibi_attention(q.unsqueeze(0), k, v, mask).squeeze(0) torch.testing.assert_close( - o.cpu().numpy(), o_ref.cpu().numpy(), rtol=1e-3, atol=1e-3 + o, o_ref, rtol=1e-3, atol=1e-3 ) @pytest.mark.parametrize("q_len", [1, 17, 81, 987]) -@pytest.mark.parametrize("kv_len", [1, 17, 81, 987, 31111]) +@pytest.mark.parametrize("kv_len", [1, 17, 81, 987]) @pytest.mark.parametrize("num_heads", [4, 8, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("causal", [False, True]) @@ -67,10 +69,10 @@ def test_single_prefill_alibi( mask = torch.tril(mask, diagonal=kv_len - q_len) o_ref = alibi_attention(q, k, v, mask) torch.testing.assert_close( - o.cpu().numpy(), o_ref.cpu().numpy(), rtol=1e-2, atol=1e-2 + o, o_ref, rtol=1e-2, atol=1e-2 ) if __name__ == "__main__": - test_single_decode_alibi(9, 32, 128) - test_single_prefill_alibi(1, 64, 1, 128, False) + test_single_decode_alibi(4096, 32, 128) + test_single_prefill_alibi(128, 128, 8, 128, False) diff --git a/python/tests/test_batch_decode_kernels.py b/tests/test_batch_decode_kernels.py similarity index 96% rename from python/tests/test_batch_decode_kernels.py rename to tests/test_batch_decode_kernels.py index 6808483e4..aba07289f 100644 --- a/python/tests/test_batch_decode_kernels.py +++ b/tests/test_batch_decode_kernels.py @@ -122,9 +122,7 @@ def test_batch_decode_with_paged_kv_cache( pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) - o_i_np = o[i].cpu().numpy() - o_ref_i_np = o_ref_i.cpu().numpy() - torch.testing.assert_close(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17]) @@ -234,9 +232,7 @@ def test_batch_decode_with_tuple_paged_kv_cache( pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) - o_i_np = o[i].cpu().numpy() - o_ref_i_np = o_ref_i.cpu().numpy() - torch.testing.assert_close(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17]) @@ -398,9 +394,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( o_ref_i = flashinfer.single_decode_with_kv_cache( qi, ki, vi, pos_encoding_mode=pos_encoding_mode ) - o_i_np = o[i].cpu().numpy() - o_ref_i_np = o_ref_i.cpu().numpy() - torch.testing.assert_close(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) if __name__ == "__main__": diff --git a/python/tests/test_batch_prefill_kernels.py b/tests/test_batch_prefill_kernels.py similarity index 95% rename from python/tests/test_batch_prefill_kernels.py rename to tests/test_batch_prefill_kernels.py index 878593bc0..eb8978341 100644 --- a/python/tests/test_batch_prefill_kernels.py +++ b/tests/test_batch_prefill_kernels.py @@ -29,7 +29,7 @@ @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) -@pytest.mark.parametrize("use_cuda_graph", [False, True]) +@pytest.mark.parametrize("use_cuda_graph", [True]) @pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) @pytest.mark.parametrize("return_lse", [True, False]) def test_batch_prefill_with_paged_kv_cache( @@ -110,6 +110,7 @@ def test_batch_prefill_with_paged_kv_cache( kv_last_page_len_warmup = torch.full( (batch_size,), page_size, dtype=torch.int32 ) + wrapper.plan( q_indptr_warmup, kv_indptr_warmup, @@ -140,7 +141,7 @@ def test_batch_prefill_with_paged_kv_cache( if return_lse: o, _ = wrapper.run_return_lse(q, kv_data) else: - o = wrapper.run(q, kv_data) + o = wrapper.run(q, kv_data) wrapper.plan( q_indptr_cpu, @@ -204,9 +205,8 @@ def test_batch_prefill_with_paged_kv_cache( pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) - o_i_np = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]].cpu().numpy() - o_ref_i_np = o_ref_i.cpu().numpy() - torch.testing.assert_close(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + o_i = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] + torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17]) @@ -394,9 +394,8 @@ def test_batch_prefill_with_tuple_paged_kv_cache( pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) - o_i_np = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]].cpu().numpy() - o_ref_i_np = o_ref_i.cpu().numpy() - torch.testing.assert_close(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + o_i = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] + torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17]) @@ -491,7 +490,7 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask( else: o_causal = wrapper.run(q, kv_data) torch.testing.assert_close( - o_custom.cpu().numpy(), o_causal.cpu().numpy(), rtol=1e-3, atol=1e-3 + o_custom, o_causal, rtol=1e-3, atol=1e-3 ) @@ -553,9 +552,8 @@ def test_batch_prefill_with_ragged_kv_cache( pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) - o_i_np = o[q_indptr[i] : q_indptr[i + 1]].cpu().numpy() - o_ref_i_np = o_ref_i.cpu().numpy() - torch.testing.assert_close(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + o_i = o[q_indptr[i] : q_indptr[i + 1]] + torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17]) @@ -632,7 +630,7 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( else: o_causal = wrapper.run(q, k, v) torch.testing.assert_close( - o_custom.cpu().numpy(), o_causal.cpu().numpy(), rtol=1e-3, atol=1e-3 + o_custom, o_causal, rtol=1e-3, atol=1e-3 ) @@ -647,11 +645,11 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( 12, 54, 37, 1, 8, 8, 128, True, "HND", "NONE", False, 0.0, False ) test_batch_prefill_with_paged_kv_cache_custom_mask( - 12, 137, 137, 1, 8, 8, 128, "HND", "NONE", 0.0, False + 1, 137, 137, 1, 8, 8, 128, "HND", "NONE", 0.0, False ) test_batch_prefill_with_ragged_kv_cache( 12, 54, 37, 8, 8, 128, True, "NONE", 0.0, False ) test_batch_prefill_with_ragged_kv_cache_custom_mask( - 12, 137, 137, 8, 8, 128, "NONE", 0.0, False + 1, 137, 137, 8, 8, 128, "NONE", 0.0, False ) diff --git a/python/tests/test_block_sparse.py b/tests/test_block_sparse.py similarity index 97% rename from python/tests/test_block_sparse.py rename to tests/test_block_sparse.py index 96213c81e..f9ca24dee 100644 --- a/python/tests/test_block_sparse.py +++ b/tests/test_block_sparse.py @@ -86,7 +86,7 @@ def test_block_sparse_attention( ) o = sparse_attention_wrapper.run(q, k, v) - torch.testing.assert_close(o_ref.cpu(), o.cpu(), atol=1e-2, rtol=1e-3) + torch.testing.assert_close(o_ref, o, atol=1e-2, rtol=1e-3) if __name__ == "__main__": diff --git a/python/tests/test_bmm_fp8.py b/tests/test_bmm_fp8.py similarity index 100% rename from python/tests/test_bmm_fp8.py rename to tests/test_bmm_fp8.py diff --git a/python/tests/test_decode_fp8_calibration_scale.py b/tests/test_decode_fp8_calibration_scale.py similarity index 97% rename from python/tests/test_decode_fp8_calibration_scale.py rename to tests/test_decode_fp8_calibration_scale.py index d31a240e5..aca709c0f 100644 --- a/python/tests/test_decode_fp8_calibration_scale.py +++ b/tests/test_decode_fp8_calibration_scale.py @@ -68,7 +68,7 @@ def test_single_decode_fp8_calibration_scale( ) torch.testing.assert_close( - o_fp16.cpu().numpy(), o_fp8.cpu().numpy(), atol=1e-2, rtol=2e-2 + o_fp16, o_fp8, atol=1e-2, rtol=2e-2 ) @@ -152,7 +152,7 @@ def test_batch_decode_with_paged_kv_cache_fp8_calibration_scale( o_fp8 = wrapper.run(q, kv_data_fp8.to(dtype), k_scale=k_scale, v_scale=v_scale) torch.testing.assert_close( - o_fp16.cpu().numpy(), o_fp8.cpu().numpy(), atol=1e-2, rtol=2e-1 + o_fp16, o_fp8, atol=1e-2, rtol=2e-1 ) diff --git a/python/tests/test_decode_prefill_lse.py b/tests/test_decode_prefill_lse.py similarity index 93% rename from python/tests/test_decode_prefill_lse.py rename to tests/test_decode_prefill_lse.py index e7afdcb6a..2d5e745dc 100644 --- a/python/tests/test_decode_prefill_lse.py +++ b/tests/test_decode_prefill_lse.py @@ -65,11 +65,14 @@ def test_mlc_failed_case(): ) o_1_tc, lse_1_tc = wrapper_tensor_cores.run_return_lse(q, kv_data) + print(lse_1, lse_1_tc) + print(o_1, o_1_tc) + torch.testing.assert_close( - lse_1.cpu().numpy(), lse_1_tc.cpu().numpy(), rtol=1e-3, atol=1e-3 + lse_1, lse_1_tc, rtol=1e-3, atol=1e-3 ) torch.testing.assert_close( - o_1.cpu().numpy(), o_1_tc.cpu().numpy(), rtol=1e-3, atol=1e-3 + o_1, o_1_tc, rtol=1e-3, atol=1e-3 ) diff --git a/python/tests/test_fp8_prefill.py b/tests/test_fp8_prefill.py similarity index 97% rename from python/tests/test_fp8_prefill.py rename to tests/test_fp8_prefill.py index 9a5a9559d..e6475f3f9 100644 --- a/python/tests/test_fp8_prefill.py +++ b/tests/test_fp8_prefill.py @@ -106,7 +106,7 @@ def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( ) torch.testing.assert_close( - o_fp16.cpu().numpy(), o_fp8.cpu().numpy(), atol=1e-2, rtol=2e-1 + o_fp16, o_fp8, atol=1e-2, rtol=2e-1 ) @@ -184,7 +184,7 @@ def test_batch_decode_with_prefill_with_paged_kv_cache( o_decode_fp8 = decode_wrapper.run(q, kv_data) torch.testing.assert_close( - o_decode_fp8.cpu().numpy(), o_fp8.cpu().numpy(), atol=1e-2, rtol=1e-2 + o_decode_fp8, o_fp8, atol=1e-2, rtol=1e-2 ) diff --git a/python/tests/test_group_gemm.py b/tests/test_group_gemm.py similarity index 96% rename from python/tests/test_group_gemm.py rename to tests/test_group_gemm.py index 40b58cc00..96c48fb88 100644 --- a/python/tests/test_group_gemm.py +++ b/tests/test_group_gemm.py @@ -46,19 +46,13 @@ def test_segment_gemm( torch.manual_seed(42) workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(device) segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer) - x = torch.randn(batch_size * num_rows_per_batch, d_in, dtype=dtype).to( - device - ) + x = torch.randn(batch_size * num_rows_per_batch, d_in, dtype=dtype).to(device) if use_weight_indices: num_weights = 1024 if column_major: - weight = torch.randn(num_weights, d_out, d_in, dtype=dtype).to( - device - ) + weight = torch.randn(num_weights, d_out, d_in, dtype=dtype).to(device) else: - weight = torch.randn(num_weights, d_in, d_out, dtype=dtype).to( - device - ) + weight = torch.randn(num_weights, d_in, d_out, dtype=dtype).to(device) else: if column_major: weight = torch.randn(batch_size, d_out, d_in, dtype=dtype).to(device) diff --git a/python/tests/test_logits_cap.py b/tests/test_logits_cap.py similarity index 92% rename from python/tests/test_logits_cap.py rename to tests/test_logits_cap.py index 2d3bb8abd..93b1a6bb4 100644 --- a/python/tests/test_logits_cap.py +++ b/tests/test_logits_cap.py @@ -48,7 +48,7 @@ def test_single_decode_logits_soft_cap( o = flashinfer.single_decode_with_kv_cache(q, k, v, logits_soft_cap=soft_cap) o_ref = attention_logits_soft_cap_torch(q.unsqueeze(0), k, v, soft_cap).squeeze(0) torch.testing.assert_close( - o.cpu().numpy(), o_ref.cpu().numpy(), rtol=1e-3, atol=1e-3 + o, o_ref, rtol=1e-3, atol=1e-3 ) @@ -71,10 +71,10 @@ def test_single_prefill_logits_soft_cap( o = flashinfer.single_prefill_with_kv_cache(q, k, v, logits_soft_cap=soft_cap) o_ref = attention_logits_soft_cap_torch(q, k, v, soft_cap) torch.testing.assert_close( - o.cpu().numpy(), o_ref.cpu().numpy(), rtol=1e-2, atol=1e-2 + o, o_ref, rtol=1e-2, atol=1e-2 ) if __name__ == "__main__": test_single_decode_logits_soft_cap(9, 32, 128, 30.0) - test_single_prefill_logits_soft_cap(1, 64, 1, 128, 30.0) + test_single_prefill_logits_soft_cap(64, 64, 1, 128, 30.0) diff --git a/python/tests/test_non_contiguous_prefill.py b/tests/test_non_contiguous_prefill.py similarity index 95% rename from python/tests/test_non_contiguous_prefill.py rename to tests/test_non_contiguous_prefill.py index d098385c6..53f8d688d 100644 --- a/python/tests/test_non_contiguous_prefill.py +++ b/tests/test_non_contiguous_prefill.py @@ -50,7 +50,7 @@ def test_single_prefill_packed_input( q.contiguous(), k.contiguous(), v.contiguous(), causal=causal ) - torch.testing.assert_close(o_packed.cpu(), o_contiguous.cpu(), rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [1, 19, 99]) @@ -93,7 +93,7 @@ def test_batch_ragged_prefill_packed_input( o_packed = wrapper.run(q, k, v) o_contiguous = wrapper.run(q.contiguous(), k.contiguous(), v.contiguous()) - torch.testing.assert_close(o_packed.cpu(), o_contiguous.cpu(), rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3) if __name__ == "__main__": diff --git a/python/tests/test_norm.py b/tests/test_norm.py similarity index 97% rename from python/tests/test_norm.py rename to tests/test_norm.py index 923f34f6d..9bb4f6305 100644 --- a/python/tests/test_norm.py +++ b/tests/test_norm.py @@ -77,7 +77,7 @@ def test_norm(batch_size, hidden_size, dtype, specify_out): y = flashinfer.norm.rmsnorm(x, w) torch.testing.assert_close( - y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3 + y_ref, y, rtol=1e-3, atol=1e-3 ) @@ -119,7 +119,7 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out): y = flashinfer.norm.gemma_rmsnorm(x, w) torch.testing.assert_close( - y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3 + y_ref, y, rtol=1e-3, atol=1e-3 ) diff --git a/python/tests/test_quantization.py b/tests/test_quantization.py similarity index 100% rename from python/tests/test_quantization.py rename to tests/test_quantization.py diff --git a/python/tests/test_rope.py b/tests/test_rope.py similarity index 93% rename from python/tests/test_rope.py rename to tests/test_rope.py index 304cc144a..1750ed34a 100644 --- a/python/tests/test_rope.py +++ b/tests/test_rope.py @@ -70,10 +70,10 @@ def test_llama_rope_inplace( # compare torch.testing.assert_close( - q_rope_ref.cpu().numpy(), q.cpu().numpy(), rtol=1e-3, atol=1e-3 + q_rope_ref, q, rtol=1e-3, atol=1e-3 ) torch.testing.assert_close( - k_rope_ref.cpu().numpy(), k.cpu().numpy(), rtol=1e-3, atol=1e-3 + k_rope_ref, k, rtol=1e-3, atol=1e-3 ) @@ -126,10 +126,10 @@ def test_llama_rope( # compare torch.testing.assert_close( - q_rope_ref.cpu().numpy(), q_rope.cpu().numpy(), rtol=1e-3, atol=1e-3 + q_rope_ref, q_rope, rtol=1e-3, atol=1e-3 ) torch.testing.assert_close( - k_rope_ref.cpu().numpy(), k_rope.cpu().numpy(), rtol=1e-3, atol=1e-3 + k_rope_ref, k_rope, rtol=1e-3, atol=1e-3 ) @@ -182,10 +182,10 @@ def test_llama31_rope_inplace( # compare torch.testing.assert_close( - q_rope_ref.cpu().numpy(), q.cpu().numpy(), rtol=1e-3, atol=1e-3 + q_rope_ref, q, rtol=1e-3, atol=1e-3 ) torch.testing.assert_close( - k_rope_ref.cpu().numpy(), k.cpu().numpy(), rtol=1e-3, atol=1e-3 + k_rope_ref, k, rtol=1e-3, atol=1e-3 ) @@ -238,10 +238,10 @@ def test_llama31_rope( # compare torch.testing.assert_close( - q_rope_ref.cpu().numpy(), q_rope.cpu().numpy(), rtol=1e-3, atol=1e-3 + q_rope_ref, q_rope, rtol=1e-3, atol=1e-3 ) torch.testing.assert_close( - k_rope_ref.cpu().numpy(), k_rope.cpu().numpy(), rtol=1e-3, atol=1e-3 + k_rope_ref, k_rope, rtol=1e-3, atol=1e-3 ) diff --git a/python/tests/test_sampling.py b/tests/test_sampling.py similarity index 98% rename from python/tests/test_sampling.py rename to tests/test_sampling.py index ea781451b..9c5e312ab 100644 --- a/python/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -246,8 +246,8 @@ def test_top_p_renorm_probs(batch_size, vocab_size, p): renorm_prob = flashinfer.sampling.top_p_renorm_probs(normalized_prob, p) torch.testing.assert_close( - renorm_prob_ground_truth.cpu().numpy(), - renorm_prob.cpu().numpy(), + renorm_prob_ground_truth, + renorm_prob, rtol=1e-3, atol=1e-3, ) @@ -273,8 +273,8 @@ def test_top_k_renorm_probs(batch_size, vocab_size, k): renorm_prob = flashinfer.sampling.top_k_renorm_probs(normalized_prob, k) torch.testing.assert_close( - renorm_prob_ground_truth.cpu().numpy(), - renorm_prob.cpu().numpy(), + renorm_prob_ground_truth, + renorm_prob, rtol=1e-3, atol=1e-3, ) @@ -294,8 +294,8 @@ def test_top_k_mask_logits(batch_size, vocab_size, k): renormed_probs_ref = flashinfer.sampling.top_k_renorm_prob(probs, k) torch.testing.assert_close( - renormed_probs.cpu().numpy(), - renormed_probs_ref.cpu().numpy(), + renormed_probs, + renormed_probs_ref, rtol=1e-3, atol=1e-3, ) diff --git a/python/tests/test_shared_prefix_kernels.py b/tests/test_shared_prefix_kernels.py similarity index 90% rename from python/tests/test_shared_prefix_kernels.py rename to tests/test_shared_prefix_kernels.py index 06a3b7167..34518ca52 100644 --- a/python/tests/test_shared_prefix_kernels.py +++ b/tests/test_shared_prefix_kernels.py @@ -128,7 +128,7 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( qo_indptr_top = torch.tensor([0, q.shape[0]], dtype=torch.int32).to(0) if stage == "decode": - qo_indptr_bottom = torch.arange(0, batch_size + 1).to(0) + qo_indptr_bottom = torch.arange(0, batch_size + 1, dtype=torch.int32).to(0) multi_level_wrapper.plan( [qo_indptr_top, qo_indptr_bottom], [shared_kv_indptr, unique_kv_indptr], @@ -141,7 +141,9 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( ) o_multi_level = multi_level_wrapper.run(q, kv_data) else: - qo_indptr_bottom = torch.arange(0, batch_size + 1).to(0) * unique_kv_len + qo_indptr_bottom = ( + torch.arange(0, batch_size + 1, dtype=torch.int32).to(0) * unique_kv_len + ) multi_level_wrapper.plan( [qo_indptr_top, qo_indptr_bottom], [shared_kv_indptr, unique_kv_indptr], @@ -184,7 +186,7 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( ) torch.testing.assert_close( - o_multi_level.cpu().numpy(), o_two_level.cpu().numpy(), rtol=1e-3, atol=1e-3 + o_multi_level, o_two_level, rtol=1e-3, atol=1e-3 ) @@ -216,10 +218,10 @@ def test_merge_state_in_place_with_mask(seed, num_tries): va_merged = va sa_merged = sa torch.testing.assert_close( - va_merged.cpu().numpy(), va_merged_ref.cpu().numpy(), rtol=1e-3, atol=1e-3 + va_merged, va_merged_ref, rtol=1e-3, atol=1e-3 ) torch.testing.assert_close( - sa_merged.cpu().numpy(), sa_merged_ref.cpu().numpy(), rtol=1e-3, atol=1e-3 + sa_merged, sa_merged_ref, rtol=1e-3, atol=1e-3 ) # Mask with all zeros. Input and output should be identical. @@ -230,10 +232,10 @@ def test_merge_state_in_place_with_mask(seed, num_tries): va_merged = va sa_merged = sa torch.testing.assert_close( - va_merged.cpu().numpy(), va_orginal.cpu().numpy(), rtol=1e-3, atol=1e-3 + va_merged, va_orginal, rtol=1e-3, atol=1e-3 ) torch.testing.assert_close( - sa_merged.cpu().numpy(), sa_original.cpu().numpy(), rtol=1e-3, atol=1e-3 + sa_merged, sa_original, rtol=1e-3, atol=1e-3 ) # Test some random masks. @@ -253,26 +255,26 @@ def test_merge_state_in_place_with_mask(seed, num_tries): sa_merged = sa torch.testing.assert_close( - va_merged[false_indices].cpu().numpy(), - va_orginal[false_indices].cpu().numpy(), + va_merged[false_indices], + va_orginal[false_indices], rtol=1e-3, atol=1e-3, ) torch.testing.assert_close( - sa_merged[false_indices].cpu().numpy(), - sa_original[false_indices].cpu().numpy(), + sa_merged[false_indices], + sa_original[false_indices], rtol=1e-3, atol=1e-3, ) torch.testing.assert_close( - va_merged[true_indices].cpu().numpy(), - va_merged_ref[true_indices].cpu().numpy(), + va_merged[true_indices], + va_merged_ref[true_indices], rtol=1e-3, atol=1e-3, ) torch.testing.assert_close( - sa_merged[true_indices].cpu().numpy(), - sa_merged_ref[true_indices].cpu().numpy(), + sa_merged[true_indices], + sa_merged_ref[true_indices], rtol=1e-3, atol=1e-3, ) diff --git a/python/tests/test_sliding_window.py b/tests/test_sliding_window.py similarity index 95% rename from python/tests/test_sliding_window.py rename to tests/test_sliding_window.py index 1b5896fde..5e2d5577c 100644 --- a/python/tests/test_sliding_window.py +++ b/tests/test_sliding_window.py @@ -121,9 +121,7 @@ def test_batch_decode_sliding_window( vi, window_left=window_left, ) - o_i_np = o[i].cpu().numpy() - o_ref_i_np = o_ref_i.cpu().numpy() - torch.testing.assert_close(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("seq_len", [1, 3, 19, 99, 199, 1999]) @@ -275,9 +273,8 @@ def test_batch_paged_prefill_sliding_window( window_left=window_left, causal=True, ) - o_i_np = o[q_indptr[i] : q_indptr[i + 1]].cpu().numpy() - o_ref_i_np = o_ref_i.cpu().numpy() - torch.testing.assert_close(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + o_i = o[q_indptr[i] : q_indptr[i + 1]] + torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17]) @@ -336,12 +333,12 @@ def test_batch_ragged_prefill_sliding_window( window_left=window_left, causal=True, ) - o_i_np = o[q_indptr[i] : q_indptr[i + 1]].cpu().numpy() - o_ref_i_np = o_ref_i.cpu().numpy() - torch.testing.assert_close(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + o_i = o[q_indptr[i] : q_indptr[i + 1]] + torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) if __name__ == "__main__": + test_single_decode_sliding_window(13, 20, 1, 4, 128) test_single_prefill_sliding_window(13, 20, 1, 4, 128) test_batch_paged_prefill_sliding_window(12, 54, 37, 13, 1, 4, 128, 1) test_batch_ragged_prefill_sliding_window(12, 54, 37, 13, 1, 4, 128) diff --git a/python/tests/test_tensor_cores_decode.py b/tests/test_tensor_cores_decode.py similarity index 97% rename from python/tests/test_tensor_cores_decode.py rename to tests/test_tensor_cores_decode.py index b4495c609..a7b420a7a 100644 --- a/python/tests/test_tensor_cores_decode.py +++ b/tests/test_tensor_cores_decode.py @@ -54,7 +54,7 @@ def test_single_decode_tensor_cores( ) torch.testing.assert_close( - o.cpu().numpy(), o_tensor_cores.cpu().numpy(), rtol=1e-3, atol=1e-3 + o, o_tensor_cores, rtol=1e-3, atol=1e-3 ) @@ -126,7 +126,7 @@ def test_batch_decode_tensor_cores( o_tensor_cores = wrapper_tensor_cores.run(q, kv_data) torch.testing.assert_close( - o.cpu().numpy(), o_tensor_cores.cpu().numpy(), rtol=1e-3, atol=1e-3 + o, o_tensor_cores, rtol=1e-3, atol=1e-3 ) @@ -242,5 +242,5 @@ def test_batch_decode_tensor_cores_cuda_graph( g.replay() torch.testing.assert_close( - o.cpu().numpy(), o_tensor_cores.cpu().numpy(), rtol=1e-3, atol=1e-3 + o, o_tensor_cores, rtol=1e-3, atol=1e-3 ) diff --git a/python/tests/test_triton_cascade.py b/tests/test_triton_cascade.py similarity index 100% rename from python/tests/test_triton_cascade.py rename to tests/test_triton_cascade.py