From ab1e2ad89f27319f5b4874c5e8b526c1cae43598 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 14 Jun 2024 02:13:13 -0700 Subject: [PATCH] feat: initial support of logits hook (#298) Implement the #257 feature. --- CMakeLists.txt | 252 ++++++++-------- cmake/config.cmake | 1 + include/flashinfer/attention/decode.cuh | 88 +++--- include/flashinfer/attention/handler.cuh | 62 ++-- .../flashinfer/attention/logits_post_hook.cuh | 66 +++++ include/flashinfer/attention/prefill.cuh | 127 +++++--- include/flashinfer/decode_attention_decl.cuh | 34 ++- include/flashinfer/math.cuh | 31 ++ include/flashinfer/prefill_attention_decl.cuh | 49 ++-- python/csrc/batch_decode.cu | 276 ++++++++++-------- python/csrc/batch_prefill.cu | 238 ++++++++------- python/csrc/flashinfer_ops.h | 42 +-- python/csrc/pytorch_extension_utils.h | 19 +- python/csrc/single_decode.cu | 69 +++-- python/csrc/single_prefill.cu | 105 ++++--- python/flashinfer/decode.py | 68 ++++- python/flashinfer/prefill.py | 68 ++++- python/generate_batch_padded_decode_inst.py | 7 +- python/generate_batch_paged_decode_inst.py | 16 +- python/generate_batch_paged_prefill_inst.py | 7 +- python/generate_batch_ragged_prefill_inst.py | 7 +- python/generate_dispatch_inc.py | 22 ++ python/generate_single_decode_inst.py | 21 +- python/generate_single_prefill_inst.py | 7 +- python/literal_map.py | 5 + python/setup.py | 32 +- python/tests/test_alibi.py | 2 +- python/tests/test_batch_decode_kernels.py | 18 +- python/tests/test_logits_cap.py | 76 +++++ python/tests/test_shared_prefix_kernels.py | 2 + src/bench_batch_decode.cu | 4 +- src/bench_cascade.cu | 6 +- src/flashinfer_ops.cuh | 99 ++++--- src/test_batch_decode.cu | 22 +- src/test_cascade.cu | 11 +- src/tvm_wrapper.cu | 24 +- 36 files changed, 1239 insertions(+), 744 deletions(-) create mode 100644 include/flashinfer/attention/logits_post_hook.cuh create mode 100644 python/tests/test_logits_cap.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 5d28ab317..87bbce877 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,6 +41,7 @@ flashinfer_option(FLASHINFER_GEN_GROUP_SIZES "Group sizes to enable" 1 4 5 6 7 8 flashinfer_option(FLASHINFER_GEN_PAGE_SIZES "Prefill page sizes to enable" 1 16 32) flashinfer_option(FLASHINFER_GEN_HEAD_DIMS "Head dims to enable" 64 128 256) flashinfer_option(FLASHINFER_GEN_KV_LAYOUTS "KV layouts to enable" 0 1) +flashinfer_option(FLASHINFER_GEN_LOGITS_POST_HOOKS "Logits post hooks" 0 1) flashinfer_option(FLASHINFER_GEN_POS_ENCODING_MODES "Pos encodings to enable" 0 1 2) flashinfer_option(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "QK reductions to enable" "false" "true") flashinfer_option(FLASHINFER_GEN_CASUALS "Casual modes to enable" "false" "true") @@ -83,6 +84,7 @@ endif(FLASHINFER_ENABLE_BF16) set (GROUP_SIZES ${FLASHINFER_GEN_GROUP_SIZES}) set (PAGE_SIZES ${FLASHINFER_GEN_PAGE_SIZES}) set (HEAD_DIMS ${FLASHINFER_GEN_HEAD_DIMS}) +set (LOGITS_POST_HOOKS ${FLASHINFER_GEN_LOGITS_POST_HOOKS}) set (KV_LAYOUTS ${FLASHINFER_GEN_KV_LAYOUTS}) set (POS_ENCODING_MODES ${FLASHINFER_GEN_POS_ENCODING_MODES}) set (ALLOW_FP16_QK_REDUCTIONS ${FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS}) @@ -116,7 +118,7 @@ file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated) set(dispatch_inc_file ${PROJECT_SOURCE_DIR}/src/dispatch.inc) add_custom_command( OUTPUT ${dispatch_inc_file} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --group_sizes ${GROUP_SIZES} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --group_sizes ${GROUP_SIZES} --logits_post_hooks ${LOGITS_POST_HOOKS} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py COMMENT "Generating additional source file ${generated_dispatch_inc}" VERBATIM @@ -126,97 +128,101 @@ add_custom_target(dispatch_inc DEPENDS ${dispatch_inc_file}) # single decode kernel inst generation foreach(group_size IN LISTS GROUP_SIZES) foreach(head_dim IN LISTS HEAD_DIMS) - foreach(kv_layout IN LISTS KV_LAYOUTS) - foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) - foreach(dtype IN LISTS DECODE_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND single_decode_kernels_src ${generated_kernel_src}) - endforeach(dtype) - - # fp8 in, fp16 out - foreach(dtype IN LISTS DECODE_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypein_${dtype}_dtypeout_f16.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND single_decode_kernels_src ${generated_kernel_src}) - endforeach(dtype) - endforeach(pos_encoding_mode) - endforeach(kv_layout) + foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) + foreach(kv_layout IN LISTS KV_LAYOUTS) + foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + foreach(dtype IN LISTS DECODE_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND single_decode_kernels_src ${generated_kernel_src}) + endforeach(dtype) + + # fp8 in, fp16 out + foreach(dtype IN LISTS DECODE_FP8_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND single_decode_kernels_src ${generated_kernel_src}) + endforeach(dtype) + endforeach(pos_encoding_mode) + endforeach(kv_layout) + endforeach(logits_post_hook) endforeach(head_dim) endforeach(group_size) # batch decode kernel inst generation foreach(group_size IN LISTS GROUP_SIZES) foreach(head_dim IN LISTS HEAD_DIMS) - foreach(kv_layout IN LISTS KV_LAYOUTS) - foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) - # paged kv-cache - foreach(idtype IN LISTS IDTYPES) + foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) + foreach(kv_layout IN LISTS KV_LAYOUTS) + foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + # paged kv-cache + foreach(idtype IN LISTS IDTYPES) + foreach(dtype IN LISTS DECODE_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND batch_decode_kernels_src ${generated_kernel_src}) + endforeach(dtype) + + # fp8 in, fp16 out + foreach(dtype IN LISTS DECODE_FP8_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16_idtype_${idtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND batch_decode_kernels_src ${generated_kernel_src}) + endforeach() + endforeach(idtype) + + # padded kv-cache foreach(dtype IN LISTS DECODE_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) list(APPEND batch_decode_kernels_src ${generated_kernel_src}) endforeach(dtype) - # fp8 in, fp16 out + # padded kv-cache, fp8 in, fp16 out foreach(dtype IN LISTS DECODE_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypein_${dtype}_dtypeout_f16_idtype_${idtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) list(APPEND batch_decode_kernels_src ${generated_kernel_src}) endforeach() - endforeach(idtype) - - # padded kv-cache - foreach(dtype IN LISTS DECODE_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND batch_decode_kernels_src ${generated_kernel_src}) - endforeach(dtype) - - # padded kv-cache, fp8 in, fp16 out - foreach(dtype IN LISTS DECODE_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypein_${dtype}_dtypeout_f16.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND batch_decode_kernels_src ${generated_kernel_src}) - endforeach() - endforeach(pos_encoding_mode) - endforeach(kv_layout) + endforeach(pos_encoding_mode) + endforeach(kv_layout) + endforeach(logits_post_hook) endforeach(head_dim) endforeach(group_size) @@ -227,25 +233,27 @@ target_compile_options(decode_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options # single prefill kernel inst generation foreach(group_size IN LISTS GROUP_SIZES) foreach(head_dim IN LISTS HEAD_DIMS) - foreach(kv_layout IN LISTS KV_LAYOUTS) - foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) - foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) - foreach(mask_mode IN LISTS MASK_MODES) - foreach(dtype IN LISTS PREFILL_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND single_prefill_kernels_src ${generated_kernel_src}) - endforeach(dtype) - endforeach(mask_mode) - endforeach(allow_fp16_qk_reduction) - endforeach(pos_encoding_mode) - endforeach(kv_layout) + foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) + foreach(kv_layout IN LISTS KV_LAYOUTS) + foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) + foreach(mask_mode IN LISTS MASK_MODES) + foreach(dtype IN LISTS PREFILL_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND single_prefill_kernels_src ${generated_kernel_src}) + endforeach(dtype) + endforeach(mask_mode) + endforeach(allow_fp16_qk_reduction) + endforeach(pos_encoding_mode) + endforeach(kv_layout) + endforeach(logits_post_hook) endforeach(head_dim) endforeach(group_size) @@ -253,55 +261,59 @@ endforeach(group_size) foreach(group_size IN LISTS GROUP_SIZES) foreach(page_size IN LISTS PAGE_SIZES) foreach(head_dim IN LISTS HEAD_DIMS) + foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) + foreach(kv_layout IN LISTS KV_LAYOUTS) + foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) + foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) + foreach(mask_mode IN LISTS MASK_MODES) + foreach(dtype IN LISTS PREFILL_DTYPES) + foreach(idtype IN LISTS IDTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_group_${group_size}_page_${page_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + add_custom_command( + OUTPUT ${generated_kernel_src} + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py + COMMENT "Generating additional source file ${generated_kernel_src}" + VERBATIM + ) + list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src}) + endforeach(idtype) + endforeach(dtype) + endforeach(mask_mode) + endforeach(allow_fp16_qk_reduction) + endforeach(pos_encoding_mode) + endforeach(kv_layout) + endforeach(logits_post_hook) + endforeach(head_dim) + endforeach(page_size) +endforeach(group_size) + +# batch ragged prefill kernel inst generation +foreach(group_size IN LISTS GROUP_SIZES) + foreach(head_dim IN LISTS HEAD_DIMS) + foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS) foreach(kv_layout IN LISTS KV_LAYOUTS) foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) foreach(mask_mode IN LISTS MASK_MODES) foreach(dtype IN LISTS PREFILL_DTYPES) foreach(idtype IN LISTS IDTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_group_${group_size}_page_${page_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_group_${group_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py + COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} + DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py COMMENT "Generating additional source file ${generated_kernel_src}" VERBATIM ) - list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src}) + list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src}) endforeach(idtype) endforeach(dtype) endforeach(mask_mode) endforeach(allow_fp16_qk_reduction) endforeach(pos_encoding_mode) endforeach(kv_layout) - endforeach(head_dim) - endforeach(page_size) -endforeach(group_size) - -# batch ragged prefill kernel inst generation -foreach(group_size IN LISTS GROUP_SIZES) - foreach(head_dim IN LISTS HEAD_DIMS) - foreach(kv_layout IN LISTS KV_LAYOUTS) - foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES) - foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS) - foreach(mask_mode IN LISTS MASK_MODES) - foreach(dtype IN LISTS PREFILL_DTYPES) - foreach(idtype IN LISTS IDTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src}) - endforeach(idtype) - endforeach(dtype) - endforeach(mask_mode) - endforeach(allow_fp16_qk_reduction) - endforeach(pos_encoding_mode) - endforeach(kv_layout) + endforeach(logits_post_hook) endforeach(head_dim) endforeach(group_size) diff --git a/cmake/config.cmake b/cmake/config.cmake index c6c8fa894..c3b860f1d 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -23,6 +23,7 @@ set(FLASHINFER_DISTRIBUTED ON) # The following configurations can impact the binary # size of the generated library set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8) +set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0) set(FLASHINFER_GEN_PAGE_SIZES 1 16 32) set(FLASHINFER_GEN_HEAD_DIMS 64 128 256) set(FLASHINFER_GEN_KV_LAYOUTS 0 1) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 67a00bed2..013f8486f 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -36,6 +36,7 @@ #include "../utils.cuh" #include "../vec_dtypes.cuh" #include "cascade.cuh" +#include "logits_post_hook.cuh" #include "state.cuh" namespace flashinfer { @@ -48,6 +49,7 @@ namespace { /*! * \brief Load k tile from smem and compute qk + * \tparam logits_post_hook The logits post hook used in the kernel * \tparam pos_encoding_mode The positional encoding mode used in the kernel * \tparam head_dim A template integer indicates the head dimension * \tparam vec_size A template integer indicates the vector size @@ -65,8 +67,8 @@ namespace { * \param s A float indicates the thread-local result of qk * \param st The self-attention state to be updated */ -template +template __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage_idx, const vec_t& q_vec, const vec_t& freq, uint32_t kv_idx_base, @@ -96,6 +98,7 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage s[j] += math::shfl_xor_sync(s[j], offset); } s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4; + s[j] = apply_logits_post_hook(s[j]); if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { s[j] += alibi_slope * float(int(kv_idx_base + tz * tile_size + j) - q_offset); } @@ -178,6 +181,7 @@ __device__ __forceinline__ void sync_state(state_t& st, float* smem, f /*! * \brief FlashAttention decoding cuda kernel with kv-cache for a single request + * \tparam logits_post_hook The logits post hook used in the kernel * \tparam kv_layout The layout of k/v matrices (NHD or HND) * \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not * \tparam pos_encoding_mode The positional encoding mode @@ -202,9 +206,10 @@ __device__ __forceinline__ void sync_state(state_t& st, float* smem, f * of "theta" used in RoPE (Rotary Positional Embeddings) * \param kv_chunk_size A integer indicates the kv-chunk size */ -template +template __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, DTypeKV* __restrict__ v, DTypeOut* __restrict__ o, DTypeOut* __restrict__ tmp, @@ -213,7 +218,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ float rope_rcp_theta, uint32_t kv_chunk_size) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); - sm_scale *= math::log2e; + sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f); constexpr uint32_t head_dim = bdx * vec_size; uint32_t kv_head_idx = blockIdx.y; @@ -297,7 +302,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ // compute qk cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); - compute_qk( + compute_qk( k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec, freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, seq_len - 1, alibi_slope, s, st_local); @@ -356,16 +361,16 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ } } -template +template __global__ void BatchDecodeWithPaddedKVCacheKernel( DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, DTypeKV* __restrict__ v, DTypeOut* __restrict__ o, float* __restrict__ lse, tensor_info_t info, float sm_scale, float rope_rcp_scale, float rope_rcp_theta) { auto block = cg::this_thread_block(); - sm_scale *= math::log2e; + sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f); constexpr uint32_t head_dim = bdx * vec_size; uint32_t kv_head_idx = blockIdx.y; @@ -438,7 +443,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( // compute qk cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); - compute_qk( + compute_qk( k_smem + (stage_idx * bdz + tz) * bdy * head_dim, stage_idx, q_vec, freq, consumer_kv_idx_base, iter * bdy * bdz, seq_len, seq_len - 1, alibi_slope, s, st_local); block.sync(); @@ -489,6 +494,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( /*! * \brief FlashAttention decoding cuda kernel with paged kv-cache for multiple requests + * \tparam logits_post_hook The logits post hook used in the kernel * \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not * \tparam pos_encoding_mode The positional encoding mode * \tparam vec_size A template integer indicates the vector size @@ -512,10 +518,10 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( * \param rope_rcp_theta A floating number indicate the reciprocal * of "theta" used in RoPE (Rotary Positional Embeddings) */ -template +template __global__ void BatchDecodeWithPagedKVCacheKernel( DTypeQ* __restrict__ q, IdType* __restrict__ q_offset, paged_kv_t paged_kv, @@ -524,7 +530,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( bool* __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale, float rope_rcp_theta) { auto block = cg::this_thread_block(); - sm_scale *= math::log2e; + sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f); constexpr uint32_t head_dim = bdx * vec_size; const uint32_t batch_idx = blockIdx.x; @@ -649,7 +655,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( // compute qk cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); - compute_qk( + compute_qk( k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec, freq, (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) + @@ -760,8 +766,8 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo * \param stream The cuda stream to launch the kernel * \return status Indicates whether CUDA calls are successful */ -template cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, uint32_t num_kv_heads, @@ -786,9 +792,9 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, if (seq_len <= 256 || tmp == nullptr) { // no need to use partition-kv kernel auto kernel = - SingleDecodeWithKVCacheKernel; + SingleDecodeWithKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -807,9 +813,10 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { // use partition-kv kernel - auto kernel = SingleDecodeWithKVCacheKernel; + auto kernel = + SingleDecodeWithKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -848,8 +855,9 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, return cudaSuccess; } -template +template cudaError_t BatchDecodeWithPagedKVCacheDispatched( DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, @@ -877,9 +885,10 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( dim3 nblks(padded_batch_size, num_kv_heads); dim3 nthrs(bdx, bdy, bdz); auto kernel = - BatchDecodeWithPagedKVCacheKernel; + BatchDecodeWithPagedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, @@ -898,9 +907,10 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( } else { // use partition-kv kernel auto partition_kv_kernel = - BatchDecodeWithPagedKVCacheKernel; + BatchDecodeWithPagedKVCacheKernel; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, @@ -946,8 +956,9 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( * \param stream The cuda stream to launch the kernel * \return status Indicates whether CUDA calls are successful */ -template +template cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, @@ -970,8 +981,9 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeK dim3 nblks(batch_size, num_kv_heads); dim3 nthrs(bdx, bdy, bdz); - auto kernel = BatchDecodeWithPaddedKVCacheKernel; + auto kernel = BatchDecodeWithPaddedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); tensor_info_t info(1, padded_kv_len, num_kv_heads); diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 6d8bf3e05..35b568006 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -26,13 +26,14 @@ #include "../page.cuh" #include "../pos_enc.cuh" #include "../utils.cuh" +#include "logits_post_hook.cuh" namespace flashinfer { -template +template __global__ void BatchDecodeWithPagedKVCacheKernel( DTypeQ* __restrict__ q, IdType* __restrict__ q_offset, paged_kv_t paged_kv, @@ -99,8 +100,9 @@ std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatc * \param stream The cuda stream to launch the kernel * \return status Indicates whether CUDA calls are successful */ -template +template cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads, @@ -118,11 +120,11 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); - // Note that the dtype of Q should not impact the cudaOccupancyMaxActiveBlocksPerMultiprocessor - // return, which is why we just use DTypeKV as it simplifies the API. - auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel< - /*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx, vec_size, bdx, - bdy, bdz, page_storage, kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>; + auto partition_kv_kernel = + BatchDecodeWithPagedKVCacheKernel; int num_blocks_per_sm = 0; int num_sm = 0; int dev_id = 0; @@ -295,18 +297,18 @@ class BatchDecodeHandler { bool* GetBlockValidMask() const { return block_valid_mask_; } - template + template cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr, IdType* last_page_len, uint32_t batch_size, uint32_t num_qo_heads, uint32_t page_size) { batch_size_before_partition_ = batch_size; uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size; - auto work_estimation_func = - BatchDecodeWithPagedKVCacheWorkEstimationDispatched; + auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched< + GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE, DTypeQ, + DTypeKV, DTypeOut, IdType>; FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr, num_qo_heads, page_size, @@ -326,26 +328,21 @@ class BatchDecodeHandler { AlignedAllocator allocator(buffer, workspace_size_in_bytes); tmp_v_ = allocator.aligned_alloc( num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16); - tmp_s_ = allocator.aligned_alloc( - num_qo_heads * padded_batch_size * 2 * sizeof(float), 16); - new_indptr_ = allocator.aligned_alloc( - (padded_batch_size + 1) * sizeof(IdType), 16); + tmp_s_ = + allocator.aligned_alloc(num_qo_heads * padded_batch_size * 2 * sizeof(float), 16); + new_indptr_ = allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), 16); void* new_indptr_h_ = page_locked_buffer_; - new_last_page_len_ = - allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); + new_last_page_len_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); void* new_last_page_len_h_ = (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); - chunk_indptr_ = allocator.aligned_alloc( - (padded_batch_size + 1) * sizeof(IdType), 16); + chunk_indptr_ = allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), 16); void* chunk_indptr_h_ = (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); - batch_idx_map_ = - allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); + batch_idx_map_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); void* batch_idx_map_h_ = (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); - chunk_start_pos_ = - allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); + chunk_start_pos_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16); void* chunk_start_pos_h_ = (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); seq_lengths_before_partition_ = @@ -353,16 +350,15 @@ class BatchDecodeHandler { void* seq_lengths_before_partition_h_ = (char*)page_locked_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); - block_valid_mask_ = - allocator.aligned_alloc(padded_batch_size * sizeof(bool), 16); + block_valid_mask_ = allocator.aligned_alloc(padded_batch_size * sizeof(bool), 16); bool* block_valid_mask_h_ = (bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_); std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0); size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( - max_num_pages_per_batch, batch_size, padded_batch_size, page_size, - indptr, last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, + max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr, + last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, (IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_, /*device_buffer=*/new_indptr_, diff --git a/include/flashinfer/attention/logits_post_hook.cuh b/include/flashinfer/attention/logits_post_hook.cuh new file mode 100644 index 000000000..fbd7ff81c --- /dev/null +++ b/include/flashinfer/attention/logits_post_hook.cuh @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_LOGITS_POST_HOOK_CUH_ +#define FLASHINFER_ATTENTION_LOGITS_POST_HOOK_CUH_ + +#include "../math.cuh" + +namespace flashinfer { + +enum class LogitsPostHook { + kNone = 0U, + kCap30 = 1U, +}; + +/*! + * \brief Grok's logits cap function + * \ref + * https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L864-L865 + */ +__forceinline__ __device__ float logits_cap_30(float x) { + return (30 * math::log2e) * math::tanh(x); +} + +__forceinline__ __device__ half2 logits_cap_30(half2 x) { + return __hmul2(__float2half2_rn(30 * math::log2e), math::tanh(x)); +} + +template +__forceinline__ __device__ T apply_logits_post_hook(T x); + +template <> +__forceinline__ __device__ float apply_logits_post_hook(float x) { + return x; +} + +template <> +__forceinline__ __device__ float apply_logits_post_hook(float x) { + return logits_cap_30(x); +} + +template <> +__forceinline__ __device__ half2 apply_logits_post_hook(half2 x) { + return x; +} + +template <> +__forceinline__ __device__ half2 apply_logits_post_hook(half2 x) { + return logits_cap_30(x); +} + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_LOGITS_POST_HOOK_CUH_ diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index cce0773d6..982a98fb8 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -18,6 +18,8 @@ #include #include #include + +#include #ifdef FLASHINFER_ENABLE_FP8 #include #endif @@ -35,6 +37,7 @@ #include "../pos_enc.cuh" #include "../utils.cuh" #include "cascade.cuh" +#include "logits_post_hook.cuh" #include "mask.cuh" namespace flashinfer { @@ -476,8 +479,8 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id } } -template +template __device__ __forceinline__ void compute_qk(smem_t* q_smem, uint32_t* q_smem_offset_r, smem_t* k_smem, uint32_t* k_smem_offset_r, DTypeQKAccum (*s_frag)[num_frags_z][8]) { @@ -525,6 +528,32 @@ __device__ __forceinline__ void compute_qk(smem_t* q_smem, uint32_t* q_smem_offs } *q_smem_offset_r -= num_frags_y * 2; *k_smem_offset_r -= num_frags_y * 2; + + if constexpr (std::is_same::value) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + s_frag[fx][fz][reg_id] = apply_logits_post_hook(s_frag[fx][fz][reg_id]); + } + } + } + } else { + static_assert(std::is_same::value); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { + *(half2*)(&s_frag[fx][fz][reg_id * 2]) = + apply_logits_post_hook(*(half2*)(&s_frag[fx][fz][reg_id * 2])); + } + } + } + } } template @@ -897,10 +926,10 @@ __device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8] * \param log2_rope_rcp_theta log2(1/(rope_theta)), where rope_theta is the theta * used in RoPE. */ -template +template __global__ void SinglePrefillWithKVCacheKernel( DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, float* __restrict__ custom_mask, DTypeOut* __restrict__ o, void* __restrict__ tmp, @@ -908,7 +937,7 @@ __global__ void SinglePrefillWithKVCacheKernel( float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); - sm_scale *= math::log2e; + sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f); const uint32_t qo_len = qkv_info.qo_len; const uint32_t kv_len = qkv_info.kv_len; const uint32_t tx = threadIdx.x, ty = threadIdx.y; @@ -1033,8 +1062,8 @@ __global__ void SinglePrefillWithKVCacheKernel( } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, - &k_smem_offset_r, s_frag); + compute_qk( + &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { apply_alibi_bias( @@ -1111,10 +1140,10 @@ __global__ void SinglePrefillWithKVCacheKernel( } } -template +template __global__ void BatchPrefillWithRaggedKVCacheKernel( DTypeIn* __restrict__ q, IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices, IdType* __restrict__ qo_indptr, DTypeIn* __restrict__ k, @@ -1125,7 +1154,7 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); - sm_scale *= math::log2e; + sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f); constexpr uint32_t head_dim = num_frags_y * 16; auto block = cg::this_thread_block(); @@ -1259,8 +1288,8 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, - &k_smem_offset_r, s_frag); + compute_qk( + &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { // TODO(Zihao): handle the case that q_offset is specified @@ -1328,10 +1357,11 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( } } -template +template __global__ void BatchPrefillWithPagedKVCacheKernel( IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices, DTypeIn* __restrict__ q, paged_kv_t paged_kv, @@ -1342,7 +1372,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( constexpr uint32_t aligned_group_size = 16 / rows_per_warp; static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); - sm_scale *= math::log2e; + sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f); auto block = cg::this_thread_block(); const uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y, kv_head_idx = blockIdx.z; @@ -1471,8 +1501,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, - &k_smem_offset_r, s_frag); + compute_qk( + &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { // TODO(Zihao): handle the case that q_offset is specified @@ -1687,9 +1717,9 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( return cudaSuccess; } -template +template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, float* custom_mask, DTypeOut* o, float* tmp, float* lse, uint32_t num_kv_heads, uint32_t qo_len, @@ -1745,9 +1775,10 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* constexpr uint32_t num_threads = num_warps * warp_size; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; auto partition_kv_kernel = - SinglePrefillWithKVCacheKernel; + SinglePrefillWithKVCacheKernel; tensor_info_t qkv_info(qo_len, kv_len, num_kv_heads); uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); @@ -1772,9 +1803,11 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* if (num_chunks <= 1 || tmp == nullptr) { // Enough parallelism, do not split-kv - auto kernel = SinglePrefillWithKVCacheKernel< - /*partition_kv=*/false, GROUP_SIZE, MASK_MODE, KV_LAYOUT, pos_encoding_mode, - num_frags_x, num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut>; + auto kernel = + SinglePrefillWithKVCacheKernel; void* args[] = {(void*)&q, (void*)&k, (void*)&v, @@ -1821,9 +1854,10 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* return cudaSuccess; } -template +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, @@ -1869,10 +1903,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( " and report the issue to the developers."; throw std::invalid_argument(err_msg.str()); } else { - auto kernel = - BatchPrefillWithRaggedKVCacheKernel; + auto kernel = BatchPrefillWithRaggedKVCacheKernel< + LOGITS_POST_HOOK, GROUP_SIZE, MASK_MODE, KV_LAYOUT, pos_encoding_mode, num_frags_x, + num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); FLASHINFER_CUDA_CALL( @@ -1901,10 +1934,10 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( return cudaSuccess; } -template +template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, paged_kv_t paged_kv, float* custom_mask, @@ -1952,9 +1985,11 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( " and report the issue to the developers."; throw std::invalid_argument(err_msg.str()); } else { - auto kernel = BatchPrefillWithPagedKVCacheKernel< - GROUP_SIZE, PAGE_SIZE, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, - num_frags_z, num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; + auto kernel = + BatchPrefillWithPagedKVCacheKernel; uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); FLASHINFER_CUDA_CALL( diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh index 409984181..f9d51bd42 100644 --- a/include/flashinfer/decode_attention_decl.cuh +++ b/include/flashinfer/decode_attention_decl.cuh @@ -18,9 +18,8 @@ #include -#include - #include "attention/handler.cuh" +#include "attention/logits_post_hook.cuh" #include "layout.cuh" #include "page.cuh" #include "pos_enc.cuh" @@ -28,36 +27,40 @@ namespace flashinfer { -template +template cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, uint32_t num_kv_heads, uint32_t seq_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template +template cudaError_t BatchDecodeWithPagedKVCacheDispatched( - DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, + DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template +template cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template +template cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, + paged_kv_t paged_kv, DTypeOut* o, float* lse, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { - paged_kv_t new_paged_kv = paged_kv; + paged_kv_t new_paged_kv = paged_kv; kv_partition_info_t kv_partition_info; DTypeOut* tmp_v = handler->GetTempV(); float* tmp_s = handler->GetTempS(); @@ -81,8 +84,9 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( throw std::runtime_error(err_msg.str()); } - return BatchDecodeWithPagedKVCacheDispatched( + return BatchDecodeWithPagedKVCacheDispatched( q, q_offset, new_paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse, handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), sm_scale, rope_scale, rope_theta, stream); diff --git a/include/flashinfer/math.cuh b/include/flashinfer/math.cuh index 9ecbc2d8b..c2401c7e1 100644 --- a/include/flashinfer/math.cuh +++ b/include/flashinfer/math.cuh @@ -114,6 +114,37 @@ __forceinline__ __device__ float rsqrt(float x) { return y; } +/*! + * \brief Wrapper of PTX tanh.approx.f32 instruction, which computes tanh(x) + * \param x input + */ +__forceinline__ __device__ float tanh(float x) { + float y; + asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +/*! + * \brief Wrapper of PTX tanh.approx.f16x2 instruction, which computes tanh(x) + * \param x input + */ +__forceinline__ __device__ half2 tanh(half2 x) { + uint32_t y_u32; + uint32_t x_u32 = half2_as_uint32(x); + asm volatile("tanh.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32)); + return uint32_as_half2(y_u32); +} + +/*! + * \brief Wrapper of PTX tanh.approx.f16 instruction, which computes tanh(x) + * \param x input + */ +__forceinline__ __device__ half tanh(half x) { + ushort y_u16; + asm volatile("tanh.approx.f16 %0, %1;" : "=h"(y_u16) : "h"(__half_as_ushort(x))); + return __ushort_as_half(y_u16); +} + } // namespace math } // namespace flashinfer #endif // FLASHINFER_MATH_CUH_ diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh index a59777c0e..7fea6ac74 100644 --- a/include/flashinfer/prefill_attention_decl.cuh +++ b/include/flashinfer/prefill_attention_decl.cuh @@ -21,6 +21,7 @@ #include #include "attention/handler.cuh" +#include "attention/logits_post_hook.cuh" #include "attention/mask.cuh" #include "layout.cuh" #include "page.cuh" @@ -29,18 +30,19 @@ namespace flashinfer { -template +template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, float* custom_mask, DTypeOut* o, float* tmp, float* lse, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, @@ -48,22 +50,23 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( uint32_t num_qo_tiles, uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream = nullptr); -template +template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, float* custom_mask, + paged_kv_t paged_kv, float* custom_mask, IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template +template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, float* custom_mask, + paged_kv_t paged_kv, float* custom_mask, IdType* qk_indptr, DTypeOut* o, float* lse, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { float* tmp = nullptr; @@ -85,17 +88,17 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, { return BatchPrefillWithPagedKVCacheDispatched< - page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, pos_encoding_mode, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( + PAGE_STORAGE, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, + POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, custom_mask, qk_indptr, o, tmp, lse, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream); }); return cudaSuccess; } -template +template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, float* custom_mask, IdType* qk_indptr, IdType* q_offset, @@ -119,9 +122,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( } DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, { - return BatchPrefillWithRaggedKVCacheDispatched( + return BatchPrefillWithRaggedKVCacheDispatched< + NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, custom_mask, qk_indptr, q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 1af97f01f..98962f964 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -22,8 +22,8 @@ using namespace flashinfer; std::vector batch_decode_with_padded_kv_cache( torch::Tensor q, torch::Tensor k_padded, torch::Tensor v_padded, unsigned int layout, - unsigned int pos_encoding_mode, float sm_scale, float rope_scale, float rope_theta, - bool return_lse) { + unsigned int pos_encoding_mode, bool logits_cap, float sm_scale, float rope_scale, + float rope_theta, bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(k_padded); CHECK_INPUT(v_padded); @@ -56,6 +56,9 @@ std::vector batch_decode_with_padded_kv_cache( lse = torch::empty({batch_size, num_qo_heads}, q.options()).to(torch::kFloat32); } + const LogitsPostHook logits_post_hook = + logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; + if (is_float8_tensor(q)) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k_padded.scalar_type(), kv_type, [&] { @@ -63,22 +66,24 @@ std::vector batch_decode_with_padded_kv_cache( return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_pos_encoding_mode( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - nv_half* tmp = nullptr; - cudaError_t status = - BatchDecodeWithPaddedKVCacheDispatched( - static_cast(q.data_ptr()), - static_cast(k_padded.data_ptr()), - static_cast(v_padded.data_ptr()), - static_cast(o.data_ptr()), - /*tmp=*/tmp, - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPaddedKVCache failed with error code ", status); - return true; + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + nv_half* tmp = nullptr; + cudaError_t status = BatchDecodeWithPaddedKVCacheDispatched< + GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + q_type, kv_type, nv_half>( + static_cast(q.data_ptr()), + static_cast(k_padded.data_ptr()), + static_cast(v_padded.data_ptr()), + static_cast(o.data_ptr()), + /*tmp=*/tmp, + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPaddedKVCache failed with error code ", status); + return true; + }); }); }); }); @@ -93,18 +98,23 @@ std::vector batch_decode_with_padded_kv_cache( return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_pos_encoding_mode( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - cudaError_t status = BatchDecodeWithPaddedKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, q_type, kv_type, q_type>( - static_cast(q.data_ptr()), static_cast(k_padded.data_ptr()), - static_cast(v_padded.data_ptr()), static_cast(o.data_ptr()), - /*tmp=*/tmp, - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPaddedKVCache failed with error code ", status); - return true; + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + cudaError_t status = BatchDecodeWithPaddedKVCacheDispatched< + GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + q_type, kv_type, q_type>( + static_cast(q.data_ptr()), + static_cast(k_padded.data_ptr()), + static_cast(v_padded.data_ptr()), + static_cast(o.data_ptr()), + /*tmp=*/tmp, + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPaddedKVCache failed with error code ", status); + return true; + }); }); }); }); @@ -123,7 +133,7 @@ std::vector batch_decode_with_padded_kv_cache( void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, - unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode, + unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode, bool logits_cap, torch::Tensor empty_q_data, torch::Tensor empty_kv_data) { // NOTE(zihao): not necessary to be CUDA tensor CHECK_CONTIGUOUS(indptr); @@ -139,57 +149,64 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_->SetCUDAStream(torch_current_stream); + const LogitsPostHook logits_post_hook = + logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; + if (is_float8_tensor(empty_q_data)) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_q_data.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(empty_kv_data.scalar_type(), kv_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - handler_->BeginForwardDispatched( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, - page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8( + empty_kv_data.scalar_type(), kv_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = handler_->BeginForwardDispatched< + GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK, + KV_LAYOUT, POS_ENCODING_MODE, q_type, kv_type, nv_half, int32_t>( + static_cast(workspace_buffer.data_ptr()), + workspace_size_in_bytes, static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, + num_qo_heads, page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); }); + }); + }); }); }); - }); - }); }); } else { DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_q_data.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(empty_kv_data.scalar_type(), kv_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - handler_->BeginForwardDispatched( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, - page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8( + empty_kv_data.scalar_type(), kv_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = handler_->BeginForwardDispatched< + GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK, + KV_LAYOUT, POS_ENCODING_MODE, q_type, kv_type, q_type, int32_t>( + static_cast(workspace_buffer.data_ptr()), + workspace_size_in_bytes, static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, + num_qo_heads, page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); }); + }); + }); }); }); - }); - }); }); } } @@ -204,8 +221,8 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, - unsigned int pos_encoding_mode, float sm_scale, float rope_scale, float rope_theta, - bool return_lse) { + unsigned int pos_encoding_mode, bool logits_cap, float sm_scale, float rope_scale, + float rope_theta, bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(paged_kv_data); CHECK_INPUT(paged_kv_indptr); @@ -247,69 +264,78 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( lse = torch::empty({batch_size, num_qo_heads}, q.options()).to(torch::kFloat32); } + const LogitsPostHook logits_post_hook = + logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; + if (is_float8_tensor(q)) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(paged_kv_data.scalar_type(), kv_type, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, - static_cast(paged_kv_data.data_ptr()), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, - q_type, kv_type, nv_half, int32_t>( - handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, - paged_kv, static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8( + paged_kv_data.scalar_type(), kv_type, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, + KV_LAYOUT, POS_ENCODING_MODE, q_type, kv_type, nv_half, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); }); + }); + }); }); }); - }); - }); }); } else { DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(paged_kv_data.scalar_type(), kv_type, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, - static_cast(paged_kv_data.data_ptr()), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, - q_type, kv_type, q_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, - paged_kv, static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8( + paged_kv_data.scalar_type(), kv_type, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, + KV_LAYOUT, POS_ENCODING_MODE, q_type, kv_type, q_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); }); + }); + }); }); }); - }); - }); }); } diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 13ab21dfe..bc0ea0bd0 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -54,8 +54,8 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, bool causal, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, - bool return_lse) { + bool logits_cap, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, + float rope_theta, bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(qo_indptr); CHECK_INPUT(paged_kv_data); @@ -101,42 +101,46 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); } MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + const LogitsPostHook logits_post_hook = + logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, - static_cast(paged_kv_data.data_ptr()), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, KV_LAYOUT, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, - POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, - int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - /*q_offset=*/nullptr, paged_kv, - /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithPagedKVCache failed with error code ", - cudaGetErrorString(status)); - return true; + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { + cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, + LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + /*q_offset=*/nullptr, paged_kv, + /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCache failed with error code ", + cudaGetErrorString(status)); + return true; + }); }); - }); - }); + }); + }); }); }); }); @@ -154,8 +158,8 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, torch::Tensor custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, - float rope_theta, bool return_lse) { + unsigned int pos_encoding_mode, bool logits_cap, bool allow_fp16_qk_reduction, float sm_scale, + float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(qo_indptr); CHECK_INPUT(paged_kv_data); @@ -208,42 +212,46 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32); } constexpr MaskMode MASK_MODE = MaskMode::kCustom; + const LogitsPostHook logits_post_hook = + logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, - static_cast(paged_kv_data.data_ptr()), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, KV_LAYOUT, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, - POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, - int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - /*q_offset=*/nullptr, paged_kv, - static_cast(custom_mask.data_ptr()), - static_cast(qk_indptr.data_ptr()), - static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithPagedKVCache failed with error code ", - cudaGetErrorString(status)); - return true; + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_page_size(page_size, PAGE_SIZE, [&] { + cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, + LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + /*q_offset=*/nullptr, paged_kv, + static_cast(custom_mask.data_ptr()), + static_cast(qk_indptr.data_ptr()), + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCache failed with error code ", + cudaGetErrorString(status)); + return true; + }); }); - }); - }); + }); + }); }); }); }); @@ -288,7 +296,7 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, - torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode, + torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode, bool logits_cap, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); @@ -325,6 +333,8 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( } MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + const LogitsPostHook logits_post_hook = + logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { @@ -334,24 +344,27 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { return DISPATCH_pos_encoding_mode( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< - GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - static_cast(k.data_ptr()), static_cast(v.data_ptr()), - static_cast(kv_indptr.data_ptr()), - /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, - /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, - static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithRaggedKVCache failed with error ", - cudaGetErrorString(status)); - return true; + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< + GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + static_cast(kv_indptr.data_ptr()), + /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, + /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithRaggedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); }); }); }); @@ -370,8 +383,8 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, - float rope_theta, bool return_lse) { + unsigned int pos_encoding_mode, bool logits_cap, bool allow_fp16_qk_reduction, float sm_scale, + float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(qo_indptr); CHECK_INPUT(k); @@ -413,6 +426,9 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC } constexpr MaskMode MASK_MODE = MaskMode::kCustom; + const LogitsPostHook logits_post_hook = + logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { @@ -420,25 +436,27 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { return DISPATCH_pos_encoding_mode( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< - GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), - static_cast(k.data_ptr()), static_cast(v.data_ptr()), - static_cast(kv_indptr.data_ptr()), - static_cast(custom_mask.data_ptr()), - static_cast(qk_indptr.data_ptr()), - /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, - static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchPrefillWithRaggedKVCache failed with error ", - cudaGetErrorString(status)); - return true; + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< + GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE, c_type, c_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + static_cast(qo_indptr.data_ptr()), + static_cast(k.data_ptr()), static_cast(v.data_ptr()), + static_cast(kv_indptr.data_ptr()), + static_cast(custom_mask.data_ptr()), + static_cast(qk_indptr.data_ptr()), + /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, + static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithRaggedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); }); }); }); diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index e7ab8a07d..bbfebd213 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -23,18 +23,20 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, unsigned int pos_encoding_mode, - unsigned int layout, float sm_scale, float rope_scale, - float rope_theta); + bool logits_cap, unsigned int layout, float sm_scale, + float rope_scale, float rope_theta); std::vector single_prefill_with_kv_cache( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, - unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float sm_scale, float rope_scale, float rope_theta, bool return_lse); + unsigned int layout, unsigned int pos_encoding_mode, bool logits_cap, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, + bool return_lse); std::vector single_prefill_with_kv_cache_custom_mask( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor custom_mask, torch::Tensor tmp, - unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float sm_scale, float rope_scale, float rope_theta, bool return_lse); + unsigned int layout, unsigned int pos_encoding_mode, bool logits_cap, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, + bool return_lse); void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, torch::Tensor append_indptr, torch::Tensor kv_data, @@ -51,8 +53,8 @@ std::vector merge_states(torch::Tensor v, torch::Tensor s); std::vector batch_decode_with_padded_kv_cache( torch::Tensor q, torch::Tensor k_padded, torch::Tensor v_padded, unsigned int layout, - unsigned int pos_encoding_mode, float sm_scale, float rope_scale, float rope_theta, - bool return_lse); + unsigned int pos_encoding_mode, bool logits_cap, float sm_scale, float rope_scale, + float rope_theta, bool return_lse); torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples); @@ -77,15 +79,17 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, - unsigned int pos_encoding_mode, torch::Tensor empty_q_data, torch::Tensor empty_kv_data); + unsigned int pos_encoding_mode, bool logits_cap, torch::Tensor empty_q_data, + torch::Tensor empty_kv_data); void EndForward(); void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } std::vector Forward(torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, - unsigned int pos_encoding_mode, float sm_scale, - float rope_scale, float rope_theta, bool return_lse); + unsigned int pos_encoding_mode, bool logits_cap, + float sm_scale, float rope_scale, float rope_theta, + bool return_lse); BatchDecodeWithPagedKVCachePyTorchWrapper( std::shared_ptr handler_ptr, flashinfer::QKVLayout kv_layout) : handler_(handler_ptr), kv_layout_(kv_layout) {} @@ -112,14 +116,14 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, bool causal, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float sm_scale, float rope_scale, float rope_theta, - bool return_lse); + unsigned int pos_encoding_mode, bool logits_cap, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, + float rope_theta, bool return_lse); std::vector ForwardCustomMask( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, torch::Tensor custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, + unsigned int pos_encoding_mode, bool logits_cap, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph) : kv_layout_(flashinfer::QKVLayout(layout)), @@ -140,14 +144,14 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, bool causal, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float sm_scale, float rope_scale, float rope_theta, - bool return_lse); + unsigned int pos_encoding_mode, bool logits_cap, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, + float rope_theta, bool return_lse); std::vector ForwardCustomMask(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, + unsigned int pos_encoding_mode, bool logits_cap, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph) diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index f67fa3695..bdaa14c39 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -31,7 +31,6 @@ using namespace flashinfer; - #ifdef FLASHINFER_ENABLE_BF16 #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ @@ -97,8 +96,8 @@ using namespace flashinfer; }() #endif -#if defined (FLASHINFER_ENABLE_BF16) && defined (FLASHINFER_ENABLE_FP8) -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ +#if defined(FLASHINFER_ENABLE_BF16) && defined(FLASHINFER_ENABLE_FP8) +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Half: { \ @@ -124,8 +123,8 @@ using namespace flashinfer; return false; \ } \ }() -#elif defined (FLASHINFER_ENABLE_BF16) -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ +#elif defined(FLASHINFER_ENABLE_BF16) +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Half: { \ @@ -143,8 +142,8 @@ using namespace flashinfer; return false; \ } \ }() -#elif defined (FLASHINFER_ENABLE_FP8) -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ +#elif defined(FLASHINFER_ENABLE_FP8) +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Float8_e4m3fn: { \ @@ -163,7 +162,7 @@ using namespace flashinfer; } \ }() #else -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Half: { \ @@ -206,6 +205,10 @@ using namespace flashinfer; #define DISPATCH_head_dim(expr, const_expr, ...) \ _DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__)) +#define DISPATCH_logits_post_hook(expr, const_expr, ...) \ + _DISPATCH_SWITCH("logits post hook", expr, \ + _DISPATCH_CASES_logits_post_hook(const_expr, __VA_ARGS__)) + #define DISPATCH_kv_layout(expr, const_expr, ...) \ _DISPATCH_SWITCH("kv layout", expr, _DISPATCH_CASES_kv_layout(const_expr, __VA_ARGS__)) diff --git a/python/csrc/single_decode.cu b/python/csrc/single_decode.cu index 6f591bad3..d0c220b4f 100644 --- a/python/csrc/single_decode.cu +++ b/python/csrc/single_decode.cu @@ -22,8 +22,8 @@ using namespace flashinfer; torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, unsigned int pos_encoding_mode, - unsigned int layout, float sm_scale, float rope_scale, - float rope_theta) { + bool logits_cap, unsigned int layout, float sm_scale, + float rope_scale, float rope_theta) { CHECK_INPUT(q); CHECK_INPUT(k); CHECK_INPUT(v); @@ -49,26 +49,30 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc auto o = torch::empty_like( q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type())); + const LogitsPostHook logits_post_hook = + logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; + if (is_float8_tensor(q)) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k.scalar_type(), kv_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - SingleDecodeWithKVCacheDispatched( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, - rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SingleDecodeWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = SingleDecodeWithKVCacheDispatched< + GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE>( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, + rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SingleDecodeWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); }); }); }); @@ -79,21 +83,22 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k.scalar_type(), kv_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - SingleDecodeWithKVCacheDispatched( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, - rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SingleDecodeWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = SingleDecodeWithKVCacheDispatched< + GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE>( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, + rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SingleDecodeWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); }); }); }); diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index 9882b2d6c..4f292e61a 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -22,8 +22,9 @@ using namespace flashinfer; std::vector single_prefill_with_kv_cache( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, - unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + unsigned int layout, unsigned int pos_encoding_mode, bool logits_cap, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, + bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(k); CHECK_INPUT(v); @@ -57,34 +58,37 @@ std::vector single_prefill_with_kv_cache( } const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + const LogitsPostHook logits_post_hook = + logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - SinglePrefillWithKVCacheDispatched( - static_cast(q.data_ptr()), - static_cast(k.data_ptr()), - static_cast(v.data_ptr()), - /*custom_mask=*/nullptr, static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SinglePrefillWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); - }); + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = SinglePrefillWithKVCacheDispatched< + GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE>( + static_cast(q.data_ptr()), + static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + /*custom_mask=*/nullptr, static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SinglePrefillWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); + }); }); }); }); @@ -100,8 +104,9 @@ std::vector single_prefill_with_kv_cache( std::vector single_prefill_with_kv_cache_custom_mask( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor custom_mask, torch::Tensor tmp, - unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + unsigned int layout, unsigned int pos_encoding_mode, bool logits_cap, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, + bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(k); CHECK_INPUT(v); @@ -137,31 +142,35 @@ std::vector single_prefill_with_kv_cache_custom_mask( } constexpr MaskMode MASK_MODE = MaskMode::kCustom; + const LogitsPostHook logits_post_hook = + logits_cap ? LogitsPostHook::kCap30 : LogitsPostHook::kNone; bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - return DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = SinglePrefillWithKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE>( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), - static_cast(custom_mask.data_ptr()), - static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SinglePrefillWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); - }); + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = SinglePrefillWithKVCacheDispatched< + GROUP_SIZE, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE, + ALLOW_FP16_QK_REDUCTION, MASK_MODE>( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + static_cast(custom_mask.data_ptr()), + static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SinglePrefillWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); + }); }); }); }); diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 3e2913352..4e691fb81 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -57,6 +57,7 @@ def single_decode_with_kv_cache( v: torch.Tensor, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", + logits_cap: bool = False, q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, @@ -81,8 +82,14 @@ def single_decode_with_kv_cache( kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be + The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Defaults to ``NONE``. + logits_cap : bool + Whether to apply logits cap to attention scores. + If ``True``, the attention scores will be capped according to formula (proposed in + Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. + Defaults to ``False``. q_scale : Optional[float] The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. k_scale : Optional[float] @@ -143,6 +150,7 @@ def single_decode_with_kv_cache( v, tmp, PosEncodingMode[pos_encoding_mode].value, + logits_cap, TensorLayout[kv_layout].value, sm_scale, rope_scale, @@ -159,6 +167,7 @@ def batch_decode_with_padded_kv_cache( v_padded: torch.Tensor, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", + logits_cap: bool = False, q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, @@ -186,8 +195,14 @@ def batch_decode_with_padded_kv_cache( kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be + The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Defaults to ``NONE``. + logits_cap : bool + Whether to apply logits cap to attention scores. + If ``True``, the attention scores will be capped according to formula (proposed in + Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. + Defaults to ``False``. q_scale : Optional[float] The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. k_scale : Optional[float] @@ -247,6 +262,7 @@ def batch_decode_with_padded_kv_cache( v_padded, TensorLayout[kv_layout].value, PosEncodingMode[pos_encoding_mode].value, + logits_cap, sm_scale, rope_scale, rope_theta, @@ -263,6 +279,7 @@ def batch_decode_with_padded_kv_cache_return_lse( v_padded: torch.Tensor, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", + logits_cap: bool = False, q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, @@ -291,8 +308,14 @@ def batch_decode_with_padded_kv_cache_return_lse( kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be + The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Defaults to ``NONE``. + logits_cap : bool + Whether to apply logits cap to attention scores. + If ``True``, the attention scores will be capped according to formula (proposed in + Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. + Defaults to ``False``. q_scale : Optional[float] The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. k_scale : Optional[float] @@ -359,6 +382,7 @@ def batch_decode_with_padded_kv_cache_return_lse( v_padded, TensorLayout[kv_layout].value, PosEncodingMode[pos_encoding_mode].value, + logits_cap, sm_scale, rope_scale, rope_theta, @@ -539,6 +563,7 @@ def begin_forward( head_dim: int, page_size: int, pos_encoding_mode: str = "NONE", + logits_cap: bool = False, data_type: Union[str, torch.dtype] = "float16", q_data_type: Optional[Union[str, torch.dtype]] = None, ): @@ -563,13 +588,19 @@ def begin_forward( page_size : int The page size of the paged kv cache pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be + The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Defaults to ``NONE``. + logits_cap: bool + Whether to apply logits cap to attention scores. + If ``True``, the attention scores will be capped according to formula (proposed in + Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. + Defaults to ``False``. data_type : Union[str, torch.dtype] - The data type of the paged kv cache + The data type of the paged kv cache. Defaults to ``float16``. q_data_type : Optional[Union[str, torch.dtype]] The data type of the query tensor. If None, will be set to - ``data_type``. + ``data_type``. Defaults to ``None``. Note ---- @@ -609,7 +640,9 @@ def begin_forward( empty_q_data = torch.empty( 0, dtype=( - getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type + getattr(torch, q_data_type) + if isinstance(q_data_type, str) + else q_data_type ), ) empty_kv_data = torch.empty( @@ -628,6 +661,7 @@ def begin_forward( head_dim, page_size, PosEncodingMode[pos_encoding_mode].value, + logits_cap, empty_q_data, empty_kv_data, ) @@ -645,6 +679,7 @@ def forward( q: torch.Tensor, paged_kv_data: torch.Tensor, pos_encoding_mode: str = "NONE", + logits_cap: bool = False, q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, @@ -665,8 +700,14 @@ def forward( ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be + The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Defaults to ``NONE``. + logits_cap: bool + Whether to apply logits cap to attention scores. + If ``True``, the attention scores will be capped according to formula (proposed in + Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. + Defaults to ``False``. q_scale : Optional[float] The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. k_scale : Optional[float] @@ -707,6 +748,7 @@ def forward( self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, PosEncodingMode[pos_encoding_mode].value, + logits_cap, sm_scale, rope_scale, rope_theta, @@ -721,6 +763,7 @@ def forward_return_lse( q: torch.Tensor, paged_kv_data: torch.Tensor, pos_encoding_mode: str = "NONE", + logits_cap: bool = False, q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, @@ -742,8 +785,14 @@ def forward_return_lse( ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be + The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Defaults to ``NONE``. + logits_cap: bool + Whether to apply logits cap to attention scores. + If ``True``, the attention scores will be capped according to formula (proposed in + Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. + Defaults to ``False``. q_scale : Optional[float] The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. k_scale : Optional[float] @@ -790,6 +839,7 @@ def forward_return_lse( self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, PosEncodingMode[pos_encoding_mode].value, + logits_cap, sm_scale, rope_scale, rope_theta, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 042967b84..1c05795fa 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -61,6 +61,7 @@ def single_prefill_with_kv_cache( causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", + logits_cap: bool = False, allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -91,8 +92,14 @@ def single_prefill_with_kv_cache( kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be + The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Default is ``NONE``. + logits_cap : bool + Whether to apply logits cap to attention scores. + If ``True``, the attention scores will be capped according to formula (proposed in + Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. + Defaults to ``False``. allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). @@ -165,6 +172,7 @@ def single_prefill_with_kv_cache( tmp, TensorLayout[kv_layout].value, PosEncodingMode[pos_encoding_mode].value, + logits_cap, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -180,6 +188,7 @@ def single_prefill_with_kv_cache( causal, TensorLayout[kv_layout].value, PosEncodingMode[pos_encoding_mode].value, + logits_cap, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -196,6 +205,7 @@ def single_prefill_with_kv_cache_return_lse( causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", + logits_cap: bool = False, allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -226,8 +236,14 @@ def single_prefill_with_kv_cache_return_lse( kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be + The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Default is ``NONE``. + logits_cap : bool + Whether to apply logits cap to attention scores. + If ``True``, the attention scores will be capped according to formula (proposed in + Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. + Defaults to ``False``. allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). @@ -318,6 +334,7 @@ def single_prefill_with_kv_cache_return_lse( tmp, TensorLayout[kv_layout].value, PosEncodingMode[pos_encoding_mode].value, + logits_cap, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -333,6 +350,7 @@ def single_prefill_with_kv_cache_return_lse( causal, TensorLayout[kv_layout].value, PosEncodingMode[pos_encoding_mode].value, + logits_cap, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -725,6 +743,7 @@ def forward( paged_kv_data: torch.Tensor, causal: bool = True, pos_encoding_mode: str = "NONE", + logits_cap: bool = False, allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -747,8 +766,14 @@ def forward( This is only effective when :attr:`custom_mask` is not provided in :meth:`begin_forward`. pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be + The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Default is ``NONE``. + logits_cap : bool + Whether to apply logits cap to attention scores. + If ``True``, the attention scores will be capped according to formula (proposed in + Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. + Defaults to ``False``. allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). @@ -792,6 +817,7 @@ def forward( self._paged_kv_last_page_len_buf, causal, PosEncodingMode[pos_encoding_mode].value, + logits_cap, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -809,6 +835,7 @@ def forward( self._custom_mask_buf, self._qk_indptr_buf, PosEncodingMode[pos_encoding_mode].value, + logits_cap, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -822,6 +849,7 @@ def forward_return_lse( paged_kv_data: torch.Tensor, causal: bool = True, pos_encoding_mode: str = "NONE", + logits_cap: bool = False, allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -842,8 +870,14 @@ def forward_return_lse( causal : bool Whether to apply causal mask to the attention matrix. pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be + The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Default is ``NONE``. + logits_cap : bool + Whether to apply logits cap to attention scores. + If ``True``, the attention scores will be capped according to formula (proposed in + Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. + Defaults to ``False``. allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). @@ -880,7 +914,7 @@ def forward_return_lse( paged_kv_data = paged_kv_data.to(torch.float16) paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) - if self._custom_mask is None: + if self._custom_mask_buf is None: return self._wrapper.forward( q, self._qo_indptr_buf, @@ -890,6 +924,7 @@ def forward_return_lse( self._paged_kv_last_page_len_buf, causal, PosEncodingMode[pos_encoding_mode].value, + logits_cap, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -907,6 +942,7 @@ def forward_return_lse( self._custom_mask_buf, self._qk_indptr_buf, PosEncodingMode[pos_encoding_mode].value, + logits_cap, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -1215,6 +1251,7 @@ def forward( v: torch.Tensor, causal: bool = True, pos_encoding_mode: str = "NONE", + logits_cap: bool = False, allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -1235,8 +1272,14 @@ def forward( Whether to apply causal mask to the attention matrix. This argument is ignored if ``mask`` is provided in :meth:`begin_forward`. pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be + The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Default is ``NONE``. + logits_cap : bool + Whether to apply logits cap to attention scores. + If ``True``, the attention scores will be capped according to formula (proposed in + Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. + Defaults to ``False``. allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). @@ -1278,6 +1321,7 @@ def forward( self._kv_indptr_buf, causal, PosEncodingMode[pos_encoding_mode].value, + logits_cap, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -1294,6 +1338,7 @@ def forward( self._custom_mask_buf, self._qk_indptr_buf, PosEncodingMode[pos_encoding_mode].value, + logits_cap, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -1308,6 +1353,7 @@ def forward_return_lse( v: torch.Tensor, causal: bool = True, pos_encoding_mode: str = "NONE", + logits_cap: bool = False, allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -1328,8 +1374,14 @@ def forward_return_lse( Whether to apply causal mask to the attention matrix. This argument is ignored if ``mask`` is provided in :meth:`begin_forward`. pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be + The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Default is ``NONE``. + logits_cap : bool + Whether to apply logits cap to attention scores. + If ``True``, the attention scores will be capped according to formula (proposed in + Grok-1): :math:`30 \times \mathrm{tanh}(x / 30)`, where :math:`x` is the input logits. + Defaults to ``False``. allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). @@ -1373,6 +1425,7 @@ def forward_return_lse( self._kv_indptr_buf, causal, PosEncodingMode[pos_encoding_mode].value, + logits_cap, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -1389,6 +1442,7 @@ def forward_return_lse( self._custom_mask_buf, self._qk_indptr_buf, PosEncodingMode[pos_encoding_mode].value, + logits_cap, allow_fp16_qk_reduction, sm_scale, rope_scale, diff --git a/python/generate_batch_padded_decode_inst.py b/python/generate_batch_padded_decode_inst.py index fa5fb973d..63b6df17f 100644 --- a/python/generate_batch_padded_decode_inst.py +++ b/python/generate_batch_padded_decode_inst.py @@ -20,6 +20,7 @@ kv_layout_literal, pos_encoding_mode_literal, dtype_literal, + logits_hook_literal, ) from pathlib import Path @@ -27,6 +28,7 @@ def get_cu_file_str( group_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, dtype_q, @@ -37,7 +39,7 @@ def get_cu_file_str( namespace flashinfer {{ -template cudaError_t BatchDecodeWithPaddedKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( +template cudaError_t BatchDecodeWithPaddedKVCacheDispatched<{group_size}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( {dtype_q}* q, {dtype_kv}* k, {dtype_kv}* v, {dtype_out}* o, {dtype_out}* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, @@ -46,6 +48,7 @@ def get_cu_file_str( }} """.format( + logits_hook=logits_hook_literal[int(logits_hook)], kv_layout=kv_layout_literal[int(kv_layout)], group_size=group_size, head_dim=head_dim, @@ -59,7 +62,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"batch_padded_decode_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" + r"batch_padded_decode_group_([0-9]+)_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) diff --git a/python/generate_batch_paged_decode_inst.py b/python/generate_batch_paged_decode_inst.py index bd7d26524..5f4293f39 100644 --- a/python/generate_batch_paged_decode_inst.py +++ b/python/generate_batch_paged_decode_inst.py @@ -21,12 +21,21 @@ pos_encoding_mode_literal, dtype_literal, idtype_literal, + logits_hook_literal, ) from pathlib import Path def get_cu_file_str( - group_size, head_dim, kv_layout, pos_encoding_mode, dtype_q, dtype_kv, dtype_out, idtype + group_size, + head_dim, + logits_hook, + kv_layout, + pos_encoding_mode, + dtype_q, + dtype_kv, + dtype_out, + idtype, ): content = """#include @@ -34,7 +43,7 @@ def get_cu_file_str( constexpr PageStorage page_storage = PageStorage::kIndices; -template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{group_size}, {head_dim}, page_storage, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>( +template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{group_size}, {head_dim}, page_storage, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>( {dtype_q}* q, {idtype}* q_offset, paged_kv_t paged_kv, kv_partition_info_t<{idtype}> kv_partition_info, @@ -45,6 +54,7 @@ def get_cu_file_str( }} """.format( + logits_hook=logits_hook_literal[int(logits_hook)], kv_layout=kv_layout_literal[int(kv_layout)], group_size=group_size, head_dim=head_dim, @@ -59,7 +69,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"batch_paged_decode_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" + r"batch_paged_decode_group_([0-9]+)_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) diff --git a/python/generate_batch_paged_prefill_inst.py b/python/generate_batch_paged_prefill_inst.py index 5d301dd20..491af6bc4 100644 --- a/python/generate_batch_paged_prefill_inst.py +++ b/python/generate_batch_paged_prefill_inst.py @@ -23,6 +23,7 @@ pos_encoding_mode_literal, dtype_literal, idtype_literal, + logits_hook_literal, ) from pathlib import Path @@ -31,6 +32,7 @@ def get_cu_file_str( group_size, page_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, @@ -42,7 +44,7 @@ def get_cu_file_str( num_frags_x_choices = [1, 2] insts = "\n".join( [ - """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( + """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, {idtype}* qo_indptr, {idtype}* q_offset, paged_kv_t paged_kv, @@ -52,6 +54,7 @@ def get_cu_file_str( float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); """.format( + logits_hook=logits_hook_literal[int(logits_hook)], kv_layout=kv_layout_literal[int(kv_layout)], num_frags_x=num_frags_x, page_size=page_size, @@ -82,7 +85,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"batch_paged_prefill_group_([0-9]+)_page_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" + r"batch_paged_prefill_group_([0-9]+)_page_([0-9]+)_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/generate_batch_ragged_prefill_inst.py b/python/generate_batch_ragged_prefill_inst.py index 7eeab91ee..a09a8a7a5 100644 --- a/python/generate_batch_ragged_prefill_inst.py +++ b/python/generate_batch_ragged_prefill_inst.py @@ -22,6 +22,7 @@ pos_encoding_mode_literal, dtype_literal, idtype_literal, + logits_hook_literal, ) from pathlib import Path @@ -29,6 +30,7 @@ def get_cu_file_str( group_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, @@ -40,7 +42,7 @@ def get_cu_file_str( num_frags_x_choices = [1, 2] insts = "\n".join( [ - """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{num_frags_x}, {group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>( + """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{num_frags_x}, {group_size}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>( {dtype_in}* q, {idtype}* request_indices, {idtype}* tile_indices, {idtype}* qo_indptr, {dtype_in}* k, {dtype_in}* v, {idtype}* kv_indptr, float* custom_mask, {idtype}* qk_indptr, @@ -51,6 +53,7 @@ def get_cu_file_str( float rope_theta, cudaStream_t stream); """.format( num_frags_x=num_frags_x, + logits_hook=logits_hook_literal[int(logits_hook)], kv_layout=kv_layout_literal[int(kv_layout)], group_size=group_size, head_dim=head_dim, @@ -78,7 +81,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"batch_ragged_prefill_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" + r"batch_ragged_prefill_group_([0-9]+)_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/generate_dispatch_inc.py b/python/generate_dispatch_inc.py index 03ec819da..5e74d0d75 100644 --- a/python/generate_dispatch_inc.py +++ b/python/generate_dispatch_inc.py @@ -21,6 +21,7 @@ pos_encoding_mode_literal, bool_literal, mask_mode_literal, + logits_hook_literal, ) @@ -57,6 +58,19 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: dispatch_page_sizes_str = f"""#define _DISPATCH_CASES_page_size(case_var, ...) \\ {dispatch_page_sizes_entries} // EOL +""" + # logits post hooks + dispatch_logits_post_hooks_entries = "\n".join( + [ + " _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format( + logits_hook_literal[_] + ) + for _ in args.logits_post_hooks + ] + ) + dispatch_logits_post_hooks_str = f"""#define _DISPATCH_CASES_logits_post_hook(case_var, ...) \\ +{dispatch_logits_post_hooks_entries} +// EOL """ # kv layouts dispatch_kv_layouts_entries = "\n".join( @@ -114,6 +128,7 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: dispatch_head_dims_str, dispatch_group_sizes_str, dispatch_page_sizes_str, + dispatch_logits_post_hooks_str, dispatch_kv_layouts_str, dispatch_pos_encoding_modes_str, dispatch_allow_fp16_qk_reductions_str, @@ -140,6 +155,13 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str: parser.add_argument( "--group_sizes", type=int, required=True, nargs="+", help="Group sizes" ) + parser.add_argument( + "--logits_post_hooks", + type=int, + required=True, + nargs="+", + help="Logit post hooks", + ) parser.add_argument( "--kv_layouts", type=int, required=True, nargs="+", help="KV layouts" ) diff --git a/python/generate_single_decode_inst.py b/python/generate_single_decode_inst.py index 8fc36218c..bda23e6fe 100644 --- a/python/generate_single_decode_inst.py +++ b/python/generate_single_decode_inst.py @@ -16,18 +16,30 @@ import sys import re -from literal_map import kv_layout_literal, pos_encoding_mode_literal, dtype_literal +from literal_map import ( + kv_layout_literal, + pos_encoding_mode_literal, + dtype_literal, + logits_hook_literal, +) from pathlib import Path def get_cu_file_str( - group_size, head_dim, kv_layout, pos_encoding_mode, dtype_q, dtype_kv, dtype_out + group_size, + head_dim, + logits_hook, + kv_layout, + pos_encoding_mode, + dtype_q, + dtype_kv, + dtype_out, ): content = """#include namespace flashinfer {{ -template cudaError_t SingleDecodeWithKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( +template cudaError_t SingleDecodeWithKVCacheDispatched<{group_size}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( {dtype_q}* q, {dtype_kv}* k, {dtype_kv}* v, {dtype_out}* o, {dtype_out}* tmp, uint32_t num_kv_heads, uint32_t seq_len, float sm_scale, float rope_scale, @@ -35,6 +47,7 @@ def get_cu_file_str( }} """.format( + logits_hook=logits_hook_literal[int(logits_hook)], kv_layout=kv_layout_literal[int(kv_layout)], group_size=group_size, head_dim=head_dim, @@ -48,7 +61,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"single_decode_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" + r"single_decode_group_([0-9]+)_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) diff --git a/python/generate_single_prefill_inst.py b/python/generate_single_prefill_inst.py index 7ffad989e..f55e15b02 100644 --- a/python/generate_single_prefill_inst.py +++ b/python/generate_single_prefill_inst.py @@ -21,6 +21,7 @@ pos_encoding_mode_literal, dtype_literal, mask_mode_literal, + logits_hook_literal, ) from pathlib import Path @@ -28,6 +29,7 @@ def get_cu_file_str( group_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, @@ -40,7 +42,7 @@ def get_cu_file_str( namespace flashinfer {{ -template cudaError_t SinglePrefillWithKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}>( +template cudaError_t SinglePrefillWithKVCacheDispatched<{group_size}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}>( {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, float* custom_mask, {dtype_out}* o, float* tmp, float* lse, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, float rope_scale, @@ -48,6 +50,7 @@ def get_cu_file_str( }} """.format( + logits_hook=logits_hook_literal[int(logits_hook)], kv_layout=kv_layout_literal[int(kv_layout)], group_size=group_size, head_dim=head_dim, @@ -62,7 +65,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( - r"single_prefill_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" + r"single_prefill_group_([0-9]+)_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) diff --git a/python/literal_map.py b/python/literal_map.py index bf4ac679d..7a8c51ac6 100644 --- a/python/literal_map.py +++ b/python/literal_map.py @@ -20,6 +20,11 @@ 2: "MaskMode::kCustom", } +logits_hook_literal = { + 0: "LogitsPostHook::kNone", + 1: "LogitsPostHook::kCap30", +} + kv_layout_literal = { 0: "QKVLayout::kNHD", 1: "QKVLayout::kHND", diff --git a/python/setup.py b/python/setup.py index c758f394d..139e8dedf 100644 --- a/python/setup.py +++ b/python/setup.py @@ -63,7 +63,8 @@ def get_instantiation_cu() -> List[str]: prefix = "csrc/generated" (root / prefix).mkdir(parents=True, exist_ok=True) - group_sizes = os.environ.get("FLASHINFER_GROUP_SIZES", "1,4,6,8").split(",") + group_sizes = os.environ.get("FLASHINFER_GROUP_SIZES", "1,4,8").split(",") + logits_hooks = os.environ.get("FLASHINFER_LOGITS_POST_HOOKS", "0,1").split(",") page_sizes = os.environ.get("FLASHINFER_PAGE_SIZES", "1,16,32").split(",") head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",") kv_layouts = os.environ.get("FLASHINFER_KV_LAYOUTS", "0,1").split(",") @@ -83,6 +84,7 @@ def get_instantiation_cu() -> List[str]: group_sizes=map(int, group_sizes), page_sizes=map(int, page_sizes), head_dims=map(int, head_dims), + logits_post_hooks=map(int, logits_hooks), kv_layouts=map(int, kv_layouts), pos_encoding_modes=map(int, pos_encoding_modes), allow_fp16_qk_reductions=map(int, allow_fp16_qk_reduction_options), @@ -106,21 +108,24 @@ def get_instantiation_cu() -> List[str]: for ( group_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, ) in itertools.product( group_sizes, head_dims, + logits_hooks, kv_layouts, pos_encoding_modes, ): for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" - fname = f"single_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" + fname = f"single_decode_group_{group_size}_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" files.append(prefix + "/" + fname) content = generate_single_decode_inst.get_cu_file_str( group_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, dtype_q, @@ -133,22 +138,25 @@ def get_instantiation_cu() -> List[str]: for ( group_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, ) in itertools.product( group_sizes, head_dims, + logits_hooks, kv_layouts, pos_encoding_modes, ): for idtype in idtypes: for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" - fname = f"batch_paged_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu" + fname = f"batch_paged_decode_group_{group_size}_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_paged_decode_inst.get_cu_file_str( group_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, dtype_q, @@ -160,11 +168,12 @@ def get_instantiation_cu() -> List[str]: for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" - fname = f"batch_padded_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" + fname = f"batch_padded_decode_group_{group_size}_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" files.append(prefix + "/" + fname) content = generate_batch_padded_decode_inst.get_cu_file_str( group_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, dtype_q, @@ -177,6 +186,7 @@ def get_instantiation_cu() -> List[str]: for ( group_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, @@ -184,17 +194,19 @@ def get_instantiation_cu() -> List[str]: ) in itertools.product( group_sizes, head_dims, + logits_hooks, kv_layouts, pos_encoding_modes, allow_fp16_qk_reduction_options, mask_modes, ): for dtype in prefill_dtypes: - fname = f"single_prefill_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}.cu" + fname = f"single_prefill_group_{group_size}_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}.cu" files.append(prefix + "/" + fname) content = generate_single_prefill_inst.get_cu_file_str( group_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, @@ -209,6 +221,7 @@ def get_instantiation_cu() -> List[str]: group_size, page_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, @@ -218,6 +231,7 @@ def get_instantiation_cu() -> List[str]: group_sizes, page_sizes, head_dims, + logits_hooks, kv_layouts, pos_encoding_modes, allow_fp16_qk_reduction_options, @@ -225,12 +239,13 @@ def get_instantiation_cu() -> List[str]: idtypes, ): for dtype in prefill_dtypes: - fname = f"batch_paged_prefill_group_{group_size}_page_{page_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" + fname = f"batch_paged_prefill_group_{group_size}_page_{page_size}_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_paged_prefill_inst.get_cu_file_str( group_size, page_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, @@ -245,6 +260,7 @@ def get_instantiation_cu() -> List[str]: for ( group_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, @@ -253,6 +269,7 @@ def get_instantiation_cu() -> List[str]: ) in itertools.product( group_sizes, head_dims, + logits_hooks, kv_layouts, pos_encoding_modes, allow_fp16_qk_reduction_options, @@ -260,11 +277,12 @@ def get_instantiation_cu() -> List[str]: idtypes, ): for dtype in prefill_dtypes: - fname = f"batch_ragged_prefill_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" + fname = f"batch_ragged_prefill_group_{group_size}_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_ragged_prefill_inst.get_cu_file_str( group_size, head_dim, + logits_hook, kv_layout, pos_encoding_mode, allow_fp16_qk_reduction, diff --git a/python/tests/test_alibi.py b/python/tests/test_alibi.py index d96d21268..06387df13 100644 --- a/python/tests/test_alibi.py +++ b/python/tests/test_alibi.py @@ -62,7 +62,7 @@ def test_single_prefill_alibi( v = torch.randn(kv_len, num_heads, head_dim).to(0).half() o = flashinfer.single_prefill_with_kv_cache( - q, k, v, causal, pos_encoding_mode="ALIBI" + q, k, v, causal=causal, pos_encoding_mode="ALIBI" ) mask = torch.ones(q_len, kv_len, dtype=torch.bool).to(0) if causal: diff --git a/python/tests/test_batch_decode_kernels.py b/python/tests/test_batch_decode_kernels.py index d7dc92a0b..c020e6571 100644 --- a/python/tests/test_batch_decode_kernels.py +++ b/python/tests/test_batch_decode_kernels.py @@ -72,8 +72,8 @@ def test_batch_decode_with_paged_kv_cache( head_dim, page_size, "NONE", - kv_dtype, - q_dtype, + data_type=kv_dtype, + q_data_type=q_dtype, ) o = wrapper.forward(q, kv_data.to(kv_dtype), pos_encoding_mode=pos_encoding_mode) @@ -182,8 +182,8 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( head_dim, page_size, "NONE", - kv_dtype, - q_dtype, + data_type=kv_dtype, + q_data_type=q_dtype, ) # warmup s = torch.cuda.Stream() @@ -214,8 +214,8 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( head_dim, page_size, "NONE", - kv_dtype, - q_dtype, + data_type=kv_dtype, + q_data_type=q_dtype, ) g.replay() @@ -235,8 +235,8 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( head_dim, page_size, "NONE", - kv_dtype, - q_dtype, + data_type=kv_dtype, + q_data_type=q_dtype, ) g.replay() @@ -307,4 +307,4 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( ) test_cuda_graph_batch_decode_with_paged_kv_cache( 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float8_e5m2, torch.float16 - ) \ No newline at end of file + ) diff --git a/python/tests/test_logits_cap.py b/python/tests/test_logits_cap.py new file mode 100644 index 000000000..6ae18a849 --- /dev/null +++ b/python/tests/test_logits_cap.py @@ -0,0 +1,76 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import math +import torch +import numpy +import pytest +import flashinfer + + +def attention_logits_cap_torch(q, k, v): + q_len, num_heads, head_dim = q.shape + kv_len = k.shape[0] + scores = torch.einsum("qhd,khd->qkh", q.float(), k.float()) + scores *= 1.0 / math.sqrt(head_dim) + scores = 30 * torch.tanh(scores / 30) + attn = torch.softmax(scores, dim=1) + return torch.einsum("ovh,vhd->ohd", attn, v.float()).to(q) + + +@pytest.mark.parametrize("seq_len", [1, 9, 81, 729, 33001]) +@pytest.mark.parametrize("num_heads", [4, 8, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +def test_single_decode_logits_cap( + seq_len, + num_heads, + head_dim, +): + q = torch.randn(num_heads, head_dim).to(0).half() + k = torch.randn(seq_len, num_heads, head_dim).to(0).half() + v = torch.randn(seq_len, num_heads, head_dim).to(0).half() + + o = flashinfer.single_decode_with_kv_cache(q, k, v, logits_cap=True) + o_ref = attention_logits_cap_torch(q.unsqueeze(0), k, v).squeeze(0) + numpy.testing.assert_allclose( + o.cpu().numpy(), o_ref.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + + +@pytest.mark.parametrize("q_len", [1, 17, 81, 987]) +@pytest.mark.parametrize("kv_len", [1, 17, 81, 987, 31111]) +@pytest.mark.parametrize("num_heads", [4, 8, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +def test_single_prefill_logits_cap( + q_len, + kv_len, + num_heads, + head_dim, +): + q = torch.randn(q_len, num_heads, head_dim).to(0).half() + k = torch.randn(kv_len, num_heads, head_dim).to(0).half() + v = torch.randn(kv_len, num_heads, head_dim).to(0).half() + + o = flashinfer.single_prefill_with_kv_cache(q, k, v, logits_cap=True) + o_ref = attention_logits_cap_torch(q, k, v) + numpy.testing.assert_allclose( + o.cpu().numpy(), o_ref.cpu().numpy(), rtol=1e-2, atol=1e-2 + ) + + +if __name__ == "__main__": + test_single_decode_logits_cap(9, 32, 128) + test_single_prefill_logits_cap(1, 64, 1, 128) diff --git a/python/tests/test_shared_prefix_kernels.py b/python/tests/test_shared_prefix_kernels.py index b0d149fea..ba4e752fb 100644 --- a/python/tests/test_shared_prefix_kernels.py +++ b/python/tests/test_shared_prefix_kernels.py @@ -227,6 +227,7 @@ def test_batch_prefill_with_shared_prefix_paged_kv_cache( num_heads, num_heads, head_dim, + page_size, ) o_baseline = baseline_wrapper.forward(q, kv_data, causal=causal) @@ -242,6 +243,7 @@ def test_batch_prefill_with_shared_prefix_paged_kv_cache( num_heads, num_heads, head_dim, + page_size, ) o_cascade = cascade_wrapper.forward(q, k_shared, v_shared, kv_data, causal=causal) diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index 9aa1b9199..4c2f746dc 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -79,7 +79,7 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { head_dim, page_size, pos_encoding_mode); state.exec([&](nvbench::launch&) { cudaError_t status = - BatchDecodeWithPagedKVCacheWrapper( + BatchDecodeWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q.data()), /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); if (status != cudaSuccess) { @@ -89,7 +89,7 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { } else { state.exec([&](nvbench::launch&) { cudaError_t status = - BatchDecodeWithPagedKVCacheNoSplitKV( + BatchDecodeWithPagedKVCacheNoSplitKV( thrust::raw_pointer_cast(q.data()), /*q_offset=*/nullptr, paged_kv, kv_partition_info_t(), thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); diff --git a/src/bench_cascade.cu b/src/bench_cascade.cu index ec09cdb53..b32b4beed 100644 --- a/src/bench_cascade.cu +++ b/src/bench_cascade.cu @@ -130,7 +130,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { std::string(cudaGetErrorString(status))); } - status = BatchDecodeWithPagedKVCacheWrapper( + status = BatchDecodeWithPagedKVCacheWrapper( &cascade_handler, thrust::raw_pointer_cast(q_d.data()), /*q_offset=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), /*lse=*/thrust::raw_pointer_cast(lse_cascade_1_d.data()), num_qo_heads, @@ -166,7 +166,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { BatchDecodeHandler baseline_handler; size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - BatchDecodeHandlerBeginForward( + BatchDecodeHandlerBeginForward( &baseline_handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); @@ -174,7 +174,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = - BatchDecodeWithPagedKVCacheWrapper( + BatchDecodeWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), /*q_offset=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 84294e003..51b2b8025 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -17,6 +17,7 @@ #include #include +#include "flashinfer/attention/logits_post_hook.cuh" #include "utils.h" namespace flashinfer { @@ -69,8 +70,8 @@ cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOu pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { return SinglePrefillWithKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE>( + GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, + POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE>( q, k, v, /*custom_mask=*/nullptr, o, tmp, lse, num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, stream); @@ -102,8 +103,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( {DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { return BatchPrefillWithRaggedKVCacheWrapperDispatched< - GROUP_SIZE, HEAD_DIM, KV_LAYOUT, pos_encoding_mode, - ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( + GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, + pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, + DTypeOut, IdType>( handler, q, qo_indptr, k, v, kv_indptr, /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); @@ -111,11 +113,11 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( return cudaSuccess; } -template cudaError_t BatchPrefillWithPagedKVCacheWrapper( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, + paged_kv_t paged_kv, DTypeOut* o, float* lse, uint32_t num_qo_heads, bool causal = true, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, @@ -128,20 +130,21 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( num_qo_heads / num_kv_heads, GROUP_SIZE, {DISPATCH_head_dim( head_dim, HEAD_DIM, - {DISPATCH_mask_mode(mask_mode, MASK_MODE, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, pos_encoding_mode, - {DISPATCH_allow_fp16_qk_reduction( - allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, - {DISPATCH_page_size(paged_kv.page_size, PAGE_SIZE, { - return BatchPrefillWithPagedKVCacheWrapperDispatched< - page_storage, kv_layout, PAGE_SIZE, GROUP_SIZE, - HEAD_DIM, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, - MASK_MODE, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, q_offset, paged_kv, - /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, o, lse, - sm_scale, rope_scale, rope_theta, stream); - })})})})})}); + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, POS_ENCODING_MODE, + {DISPATCH_allow_fp16_qk_reduction( + allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, + {DISPATCH_page_size(paged_kv.page_size, PAGE_SIZE, { + return BatchPrefillWithPagedKVCacheWrapperDispatched< + PAGE_STORAGE, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, + KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MASK_MODE, + DTypeIn, DTypeOut, IdType>(handler, q, qo_indptr, q_offset, paged_kv, + /*custom_mask=*/nullptr, + /*qk_indptr=*/nullptr, o, lse, sm_scale, + rope_scale, rope_theta, stream); + })})})})})}); return cudaSuccess; } @@ -167,10 +170,10 @@ cudaError_t SingleDecodeWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode( pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { - SingleDecodeWithKVCacheDispatched(q, k, v, o, tmp, num_kv_heads, - seq_len, sm_scale, rope_scale, - rope_theta, stream); + SingleDecodeWithKVCacheDispatched( + q, k, v, o, tmp, num_kv_heads, seq_len, sm_scale, rope_scale, rope_theta, + stream); })})})}); return cudaSuccess; } @@ -199,18 +202,19 @@ cudaError_t BatchDecodeWithPaddedKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTyp head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode( pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { - return BatchDecodeWithPaddedKVCacheDispatched( - q, k, v, o, tmp, lse, batch_size, padded_kv_len, num_qo_heads, sm_scale, - rope_scale, rope_theta, stream); + return BatchDecodeWithPaddedKVCacheDispatched< + GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, POS_ENCODING_MODE, + DTypeQ, DTypeKV, DTypeOut>(q, k, v, o, tmp, lse, batch_size, padded_kv_len, + num_qo_heads, sm_scale, rope_scale, rope_theta, + stream); })})})}); return cudaSuccess; } -template cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( - DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, + DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, float* lse, uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, @@ -230,9 +234,9 @@ cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( num_qo_heads / num_kv_heads, GROUP_SIZE, {DISPATCH_head_dim( head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return BatchDecodeWithPagedKVCacheDispatched( + return BatchDecodeWithPagedKVCacheDispatched< + GROUP_SIZE, HEAD_DIM, PAGE_STORAGE, LogitsPostHook::kNone, KV_LAYOUT, + POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( q, q_offset, paged_kv, kv_partition_info, o, /*tmp_v=*/nullptr, /*tmp_s=*/nullptr, lse, /*block_valid_mask=*/nullptr, /*padded_batch_size=*/paged_kv.batch_size, sm_scale, @@ -264,11 +268,11 @@ cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( * \note This wrapper function should be only called after we call BeginForward function in the * BatchDecodeHandler. */ -template cudaError_t BatchDecodeWithPagedKVCacheWrapper( BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, + paged_kv_t paged_kv, DTypeOut* o, float* lse, uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { @@ -283,19 +287,19 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( DISPATCH_group_size( num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_head_dim( - paged_kv.head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return BatchDecodeWithPagedKVCacheWrapperDispatched( - handler, q, q_offset, paged_kv, o, lse, sm_scale, rope_scale, rope_theta, stream); - })})}); + {DISPATCH_head_dim(paged_kv.head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + return BatchDecodeWithPagedKVCacheWrapperDispatched< + PAGE_STORAGE, GROUP_SIZE, HEAD_DIM, LogitsPostHook::kNone, KV_LAYOUT, + POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( + handler, q, q_offset, paged_kv, o, lse, sm_scale, rope_scale, + rope_theta, stream); + })})}); return cudaSuccess; } -template +template cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* buffer, size_t workspace_size_in_bytes, IdType* indptr, IdType* last_page_len, uint32_t batch_size, @@ -311,8 +315,9 @@ cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* bu DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, { DISPATCH_head_dim(head_dim, HEAD_DIM, { DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return handler->BeginForwardDispatched( + return handler->BeginForwardDispatched( buffer, workspace_size_in_bytes, indptr, last_page_len, batch_size, num_qo_heads, page_size); }); diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index b8f77a971..00bfe90a8 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -107,18 +107,20 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si if (!cooperative) { // use non-cooperative kernel - cudaError_t status = flashinfer::BatchDecodeWithPagedKVCacheNoSplitKV( - thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv, - kv_partition_info_t(), thrust::raw_pointer_cast(o_device.data()), - /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); + cudaError_t status = + flashinfer::BatchDecodeWithPagedKVCacheNoSplitKV( + thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv, + kv_partition_info_t(), thrust::raw_pointer_cast(o_device.data()), + /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); } else { - cudaError_t status = flashinfer::BatchDecodeWithPagedKVCacheWrapper( - &handler, thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv, - thrust::raw_pointer_cast(o_device.data()), /*lse=*/nullptr, num_qo_heads, - pos_encoding_mode); + cudaError_t status = + flashinfer::BatchDecodeWithPagedKVCacheWrapper( + &handler, thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv, + thrust::raw_pointer_cast(o_device.data()), /*lse=*/nullptr, num_qo_heads, + pos_encoding_mode); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); } // compare result diff --git a/src/test_cascade.cu b/src/test_cascade.cu index 0b1e6d181..a9909152f 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -294,10 +294,11 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); // Compute result using baseline implementation - cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( - &baseline_handler, thrust::raw_pointer_cast(q_d.data()), - /*q_offset=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), - /*lse=*/nullptr, num_qo_heads, PosEncodingMode::kNone); + cudaError_t status = + BatchDecodeWithPagedKVCacheWrapper( + &baseline_handler, thrust::raw_pointer_cast(q_d.data()), + /*q_offset=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), + /*lse=*/nullptr, num_qo_heads, PosEncodingMode::kNone); EXPECT_EQ(status, cudaSuccess) << "Baseline implementation failed with error: " << cudaGetErrorString(status); @@ -314,7 +315,7 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, EXPECT_EQ(status, cudaSuccess) << "Cascade implementation prefill failed with error: " << cudaGetErrorString(status); - status = BatchDecodeWithPagedKVCacheWrapper( + status = BatchDecodeWithPagedKVCacheWrapper( &cascade_handler, thrust::raw_pointer_cast(q_d.data()), /*q_offset=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), /*lse=*/thrust::raw_pointer_cast(lse_cascade_1_d.data()), num_qo_heads, diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index b7682972e..d7e3cbe89 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -392,8 +392,8 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ last_page_len->byte_offset / sizeof(dtype_idx), static_cast(k_rope_pos_offset->data) + k_rope_pos_offset->byte_offset / sizeof(dtype_idx)); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( + cudaError_t status = BatchDecodeWithPagedKVCacheWrapper< + page_storage, kv_layout, dtype_in, dtype_in, dtype_out, dtype_idx>( &batch_decode_handlers[handler_id], static_cast(q_data->data), static_cast(q_offset->data) + q_offset->byte_offset / sizeof(dtype_idx), cache, static_cast(output->data), @@ -423,16 +423,16 @@ void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward( cudaStream_t original_stream = batch_decode_handlers[handler_idx].GetCUDAStream(); batch_decode_handlers[handler_idx].SetCUDAStream(static_cast(copy_stream)); DISPATCH_TVM_CUDA_IDTYPE(page_table_indptr->dtype, dtype_idx, { - cudaError_t status = - BatchDecodeHandlerBeginForward( - batch_decode_handlers + handler_idx, static_cast(workspace_buffer->data), - workspace_size_in_bytes, - static_cast(page_table_indptr->data) + - page_table_indptr->byte_offset / sizeof(dtype_idx), - static_cast(last_page_len->data) + - last_page_len->byte_offset / sizeof(dtype_idx), - batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, - PosEncodingMode(pos_encoding_mode)); + cudaError_t status = BatchDecodeHandlerBeginForward( + batch_decode_handlers + handler_idx, static_cast(workspace_buffer->data), + workspace_size_in_bytes, + static_cast(page_table_indptr->data) + + page_table_indptr->byte_offset / sizeof(dtype_idx), + static_cast(last_page_len->data) + + last_page_len->byte_offset / sizeof(dtype_idx), + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, + PosEncodingMode(pos_encoding_mode)); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer decode BeginForward error " << cudaGetErrorString(status); }