Skip to content

Commit

Permalink
[Runtime] Use preferred host memory (pinned memory) in KV cache (#17036)
Browse files Browse the repository at this point in the history
This PR updates the PagedKVCache with the pinned memory support,
which can reduce the copy overhead between CPU and GPU.

This PR also bumps FlashInfer version, which now supports
* specifying kernels to build via cmake,
* pinned memory as host memory.

We also update CMakeLists.txt and config.cmake to include the
FlashInfer compile options. Prior to this PR, the kernels being
built is hardcoded in FlashInfer header files.
  • Loading branch information
MasterJH5574 authored May 29, 2024
1 parent 8bdd54b commit 71f7af7
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 98 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/flashinfer
Submodule flashinfer updated 55 files
+12 −5 .github/workflows/release_wheel.yml
+4 −1 .gitignore
+1 −1 .release-please-manifest.json
+16 −0 CHANGELOG.md
+80 −39 CMakeLists.txt
+10 −0 cmake/config.cmake
+3 −0 docs/api/python/decode.rst
+13 −0 docs/api/python/norm.rst
+15 −0 docs/api/python/sampling.rst
+6 −6 docs/conf.py
+2 −1 docs/index.rst
+17 −2 docs/installation.rst
+0 −67 include/flashinfer/attention/decode.cuh
+183 −81 include/flashinfer/attention/handler.cuh
+46 −89 include/flashinfer/attention/prefill.cuh
+0 −146 include/flashinfer/decode_attention_decl.cuh
+0 −95 include/flashinfer/prefill_attention_decl.cuh
+109 −99 include/flashinfer/sampling.cuh
+12 −19 include/flashinfer/utils.cuh
+1 −0 python/MANIFEST.in
+63 −28 python/csrc/batch_decode.cu
+18 −8 python/csrc/batch_prefill.cu
+23 −3 python/csrc/flashinfer_ops.cu
+45 −19 python/csrc/flashinfer_ops.h
+43 −0 python/csrc/norm.cu
+4 −4 python/csrc/pytorch_extension_utils.h
+98 −0 python/csrc/sampling.cu
+7 −0 python/flashinfer/__init__.py
+284 −1 python/flashinfer/decode.py
+49 −0 python/flashinfer/norm.py
+4 −2 python/flashinfer/prefill.py
+190 −0 python/flashinfer/sampling.py
+3 −6 python/generate_batch_paged_prefill_inst.py
+40 −16 python/generate_dispatch_inc.py
+12 −3 python/setup.py
+143 −4 python/tests/test_batch_decode_kernels.py
+47 −0 python/tests/test_norm.py
+101 −0 python/tests/test_sampling.py
+3 −4 src/bench_batch_decode.cu
+5 −6 src/bench_cascade.cu
+4 −4 src/bench_sampling.cu
+2 −2 src/bench_single_decode.cu
+2 −1 src/bench_single_prefill.cu
+62 −63 src/cpu_reference.h
+314 −0 src/flashinfer_ops.cuh
+5 −5 src/test_batch_decode.cu
+9 −8 src/test_batch_prefill.cu
+9 −10 src/test_cascade.cu
+2 −2 src/test_page.cu
+1,707 −9 src/test_sampling.cu
+1 −1 src/test_single_decode.cu
+1 −2 src/test_single_prefill.cu
+23 −45 src/tvm_wrapper.cu
+43 −0 src/utils.h
+1 −1 version.txt
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -960,13 +960,13 @@ option(USE_FLASHINFER "Build TVM with FlashInfer" OFF)
if (USE_FLASHINFER STREQUAL "ON")
message(STATUS "Build with FlashInfer")
set(FLASHINFER_TVM_BINDING ON)
set(FLASHINFER_TVM_HOME ${PROJECT_SOURCE_DIR})
set(FLASHINFER_ENABLE_FP8 OFF)
set(FLASHINFER_ENABLE_BF16 OFF)
set(FLASHINFER_TVM_SOURCE_DIR ${PROJECT_SOURCE_DIR})
set(FLASHINFER_PREFILL OFF)
set(FLASHINFER_DECODE OFF)
set(FLASHINFER_PAGE OFF)
set(FLASHINFER_CASCADE OFF)
set(FLASHINFER_SAMPLING OFF)
set(FLASHINFER_NORM OFF)
add_subdirectory(3rdparty/flashinfer)
else ()
message(STATUS "Build without FlashInfer")
Expand Down
13 changes: 13 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,19 @@ set(USE_GTEST AUTO)
# Need to have USE_CUDA=ON
set(USE_CUTLASS OFF)

# Whether to enable FlashInfer or not.
set(USE_FLASHINFER OFF)
# Options for FlashInfer kernel compilation.
set(FLASHINFER_ENABLE_FP8 OFF)
set(FLASHINFER_ENABLE_BF16 OFF)
set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)
set(FLASHINFER_GEN_PAGE_SIZES 16)
set(FLASHINFER_GEN_HEAD_DIMS 128)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1)
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false")
set(FLASHINFER_GEN_CASUALS "false" "true")

# Enable to show a summary of TVM options
set(SUMMARIZE OFF)

Expand Down
17 changes: 17 additions & 0 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,23 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
return true;
}

/*!
* \brief Get the preferred host device from the input device.
* - For CUDA and ROCm, CUDAHost and ROCMHost will be returned for pinned memory,
* since pinned memory reduces copy overhead.
* - For other devices, CPU is returned as a fallback.
*/
inline Device GetPreferredHostDevice(Device device) {
if (device.device_type == DLDeviceType::kDLCUDA) {
return Device{DLDeviceType::kDLCUDAHost, 0};
} else if (device.device_type == DLDeviceType::kDLROCM) {
return Device{DLDeviceType::kDLROCMHost, 0};
} else {
// Fallback to CPU.
return Device{DLDeviceType::kDLCPU, 0};
}
}

} // namespace runtime
} // namespace tvm

Expand Down
Loading

0 comments on commit 71f7af7

Please sign in to comment.