diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index f978e02565d7..7e9cc7ff42ca 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit f978e02565d7157d57803eb4153369e046fc4106 +Subproject commit 7e9cc7ff42ca283c317061a877305d09a395fad2 diff --git a/CMakeLists.txt b/CMakeLists.txt index 683ce819dbdb..7575d6c2b4d6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -960,13 +960,13 @@ option(USE_FLASHINFER "Build TVM with FlashInfer" OFF) if (USE_FLASHINFER STREQUAL "ON") message(STATUS "Build with FlashInfer") set(FLASHINFER_TVM_BINDING ON) - set(FLASHINFER_TVM_HOME ${PROJECT_SOURCE_DIR}) - set(FLASHINFER_ENABLE_FP8 OFF) - set(FLASHINFER_ENABLE_BF16 OFF) + set(FLASHINFER_TVM_SOURCE_DIR ${PROJECT_SOURCE_DIR}) set(FLASHINFER_PREFILL OFF) set(FLASHINFER_DECODE OFF) set(FLASHINFER_PAGE OFF) set(FLASHINFER_CASCADE OFF) + set(FLASHINFER_SAMPLING OFF) + set(FLASHINFER_NORM OFF) add_subdirectory(3rdparty/flashinfer) else () message(STATUS "Build without FlashInfer") diff --git a/cmake/config.cmake b/cmake/config.cmake index ccb449fe2b23..5847acc298b1 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -444,6 +444,19 @@ set(USE_GTEST AUTO) # Need to have USE_CUDA=ON set(USE_CUTLASS OFF) +# Whether to enable FlashInfer or not. +set(USE_FLASHINFER OFF) +# Options for FlashInfer kernel compilation. +set(FLASHINFER_ENABLE_FP8 OFF) +set(FLASHINFER_ENABLE_BF16 OFF) +set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8) +set(FLASHINFER_GEN_PAGE_SIZES 16) +set(FLASHINFER_GEN_HEAD_DIMS 128) +set(FLASHINFER_GEN_KV_LAYOUTS 0 1) +set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1) +set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false") +set(FLASHINFER_GEN_CASUALS "false" "true") + # Enable to show a summary of TVM options set(SUMMARIZE OFF) diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 5bdc883649c9..3eb225fccffe 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -534,6 +534,23 @@ inline bool NDArray::Load(dmlc::Stream* strm) { return true; } +/*! + * \brief Get the preferred host device from the input device. + * - For CUDA and ROCm, CUDAHost and ROCMHost will be returned for pinned memory, + * since pinned memory reduces copy overhead. + * - For other devices, CPU is returned as a fallback. + */ +inline Device GetPreferredHostDevice(Device device) { + if (device.device_type == DLDeviceType::kDLCUDA) { + return Device{DLDeviceType::kDLCUDAHost, 0}; + } else if (device.device_type == DLDeviceType::kDLROCM) { + return Device{DLDeviceType::kDLROCMHost, 0}; + } else { + // Fallback to CPU. + return Device{DLDeviceType::kDLCPU, 0}; + } +} + } // namespace runtime } // namespace tvm diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index a5d2d9f41554..62750d6d7daa 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -194,6 +194,56 @@ enum class RoPEMode : int { kInline = 2, }; +/*! + * \brief The class of host memory int32 vector in "std::vector" interface. + * This vector allocates static memory on the specified host memory + * at the time of construction. + */ +class HostMemoryVector { + public: + HostMemoryVector() = default; + HostMemoryVector(const HostMemoryVector&) = delete; + HostMemoryVector(HostMemoryVector&& other) = default; + HostMemoryVector& operator=(const HostMemoryVector&) = delete; + HostMemoryVector& operator=(HostMemoryVector&& other) = default; + + explicit HostMemoryVector(int64_t reserved_size, DLDataType dtype, Device device) + : reserved_size_(reserved_size) { + ICHECK(DataType(dtype) == DataType::Int(32)); + data_ = NDArray::Empty({reserved_size}, dtype, device); + } + + void push_back(int32_t value) { + ICHECK_LT(current_size_, reserved_size_); + static_cast(data_->data)[current_size_++] = value; + } + + const int32_t& operator[](int64_t idx) const { + ICHECK_GE(idx, 0) << "Index " << idx << " is negative."; + ICHECK_LT(idx, current_size_) << "Index " << idx << " out of bounds " << current_size_; + return static_cast(data_->data)[idx]; + } + + int32_t back() const { + ICHECK_GT(current_size_, 0) << "Vector is empty"; + return static_cast(data_->data)[current_size_ - 1]; + } + + size_t size() const { return static_cast(current_size_); } + + int32_t* data() const { return static_cast(data_->data); } + + void clear() { current_size_ = 0; } + + /*! \brief Return the vector as an NDArray. */ + NDArray as_ndarray() { return data_.CreateView({current_size_}, data_->dtype); } + + private: + int64_t reserved_size_ = 0; + int64_t current_size_ = 0; + NDArray data_{nullptr}; +}; + /*! * \brief The paged attention auxiliary data manager class. * This class manages all the int32 auxiliary data on GPU device, such as @@ -213,8 +263,12 @@ enum class RoPEMode : int { */ class PagedKVCacheAuxDataManager { public: - PagedKVCacheAuxDataManager(DLDataType dtype_aux, Device device, TVMStreamHandle copy_stream) - : dtype_aux_(dtype_aux), device_(device), copy_stream_(copy_stream) { + PagedKVCacheAuxDataManager(DLDataType dtype_aux, Device device, Device preferred_host_device, + TVMStreamHandle copy_stream) + : dtype_aux_(dtype_aux), + device_(device), + preferred_host_device_(preferred_host_device), + copy_stream_(copy_stream) { ICHECK(DataType(dtype_aux) == DataType::Int(32)); } @@ -222,13 +276,13 @@ class PagedKVCacheAuxDataManager { /*! \brief Reset the status of copy manager. */ virtual void ResetCopy() = 0; /*! \brief Copy the indptr array of append lengths after coalescing. (see GetChunkedBlockIds) */ - virtual NDArray CopyQOIndptrOnDepthAsync(std::vector* data, int depth) = 0; + virtual NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Copy the indptr array of page table. */ - virtual NDArray CopyPageIndptrOnDepthAsync(std::vector* data, int depth) = 0; + virtual NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Copy the indices array of page table. */ - virtual NDArray CopyPageIndicesOnDepthAsync(std::vector* data, int depth) = 0; + virtual NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Copy the array of KV slot number used in the last page of the seq. */ - virtual NDArray CopyLastPageLenOnDepthAsync(std::vector* data, int depth) = 0; + virtual NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! * \brief Copy the length information of the sequences. * Each NDArray is in shape `(3, n)`. "n" is the number of sequences. @@ -239,27 +293,27 @@ class PagedKVCacheAuxDataManager { * \note When sliding window is not enabled, only the * "last_page_len" (a.k.a., the first "n" elements) will be effectively used. */ - virtual NDArray CopyLengthInfoOnDepthAsync(std::vector* last_page_len, - std::vector* sliding_window_offset, - std::vector* sink_size, int depth) = 0; + virtual NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, + HostMemoryVector* sliding_window_offset, + HostMemoryVector* sink_size, int depth) = 0; /*! \brief Copy the k position offset of applying RoPE for each sequence. */ - virtual NDArray CopyKRoPEPosOffsetOnDepthAsync(std::vector* data, int depth) = 0; + virtual NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! * \brief Copy the append length indptr array on device. * \note Since the Q/K/V data may have raggedness in terms of lengths, * we represent the append lengths in CSR format. */ - virtual NDArray CopyCurAppendLengthIndptrAsync(std::vector* data) = 0; + virtual NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) = 0; /*! \brief Copy the k position offset of applying RoPE for each sequence. */ - virtual NDArray CopyKRaggedRoPEPosOffsetAsync(std::vector* data) = 0; + virtual NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) = 0; /*! \brief Copy the q position mapping of applying RoPE for each sequence. */ - virtual NDArray CopyQRoPEPosMapAsync(std::vector* data) = 0; + virtual NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) = 0; /*! * \brief Copy the corresponding position in global KV cache (pages) * for each position along the length dimension of K/V data when * appending new K/V data. */ - virtual NDArray CopyAppendPositionMapAsync(std::vector* data) = 0; + virtual NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) = 0; /*! \brief Commit all the copy operations since the last commit. */ virtual void CommitCopy() = 0; @@ -268,6 +322,8 @@ class PagedKVCacheAuxDataManager { const DLDataType dtype_aux_; /*! \brief The device this PagedKVCache runs on. */ const Device device_; + /*! \brief The preferred host device. */ + const Device preferred_host_device_; /*! \brief The device stream for copying auxiliary data structure to GPU. */ const TVMStreamHandle copy_stream_; }; @@ -280,8 +336,9 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { public: explicit PlainPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, DLDataType dtype_aux, - DLDevice device, TVMStreamHandle copy_stream) - : PagedKVCacheAuxDataManager(dtype_aux, device, copy_stream) { + Device device, Device preferred_host_device, + TVMStreamHandle copy_stream) + : PagedKVCacheAuxDataManager(dtype_aux, device, preferred_host_device, copy_stream) { for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { qo_indptr_on_depths_device_.push_back( NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); @@ -302,64 +359,64 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { // The reset of the plain auxiliary data manager is no-op. void ResetCopy() final {} - NDArray CopyQOIndptrOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { NDArray view = qo_indptr_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyPageIndptrOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { NDArray view = page_indptr_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyPageIndicesOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final { NDArray view = page_indices_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyLastPageLenOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final { NDArray view = length_info_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyKRoPEPosOffsetOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final { NDArray view = k_rope_pos_offset_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyCurAppendLengthIndptrAsync(std::vector* data) final { + NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final { NDArray view = cur_append_length_indptr_device_.CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyKRaggedRoPEPosOffsetAsync(std::vector* data) final { + NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final { NDArray view = k_ragged_rope_pos_offset_device_.CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyQRoPEPosMapAsync(std::vector* data) final { + NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { NDArray view = q_rope_position_map_device_.CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyAppendPositionMapAsync(std::vector* data) final { + NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { NDArray view = append_position_map_device_.CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyLengthInfoOnDepthAsync(std::vector* last_page_len, - std::vector* sliding_window_offset, - std::vector* sink_size, int depth) final { + NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, + HostMemoryVector* sliding_window_offset, + HostMemoryVector* sink_size, int depth) final { int n_elem = last_page_len->size(); ICHECK_GT(n_elem, 0); NDArray view = length_info_on_depths_device_[depth].CreateView({3, n_elem}, dtype_aux_); @@ -412,7 +469,7 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { DLTensor copy_src; copy_src.data = vec_data; - copy_src.device = Device{kDLCPU, 0}; + copy_src.device = preferred_host_device_; copy_src.ndim = 1; copy_src.dtype = array->dtype; copy_src.shape = copy_dst.shape; @@ -443,15 +500,16 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { public: explicit CachedPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, DLDataType dtype_aux, - DLDevice device, TVMStreamHandle copy_stream) - : PagedKVCacheAuxDataManager(dtype_aux, device, copy_stream), + DLDevice device, Device preferred_host_device, + TVMStreamHandle copy_stream) + : PagedKVCacheAuxDataManager(dtype_aux, device, preferred_host_device, copy_stream), elem_byte_size_((dtype_aux.bits * dtype_aux.lanes + 7) / 8), offset_alignment_(cuda_byte_alignment_ / elem_byte_size_) { // - Calculate cache size of all the auxiliary arrays in // local cache and the large on-device array. int64_t cache_size = CalculateCacheSize(reserved_num_seqs, num_total_pages, prefill_chunk_size); // - Initialize the host auxiliary data buffer. - merged_aux_data_host_.resize(cache_size); + merged_aux_data_host_ = HostMemoryVector(cache_size, dtype_aux, preferred_host_device); // - Initialize the device auxiliary data buffer. memory::Allocator* allocator = memory::MemoryManager::GetOrCreateAllocator(device, memory::AllocatorType::kNaive); @@ -461,34 +519,32 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { } void ResetCopy() final { copy_offset_ = 0; } - NDArray CopyQOIndptrOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyVecToCache(data); } - NDArray CopyPageIndptrOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyVecToCache(data); } - NDArray CopyPageIndicesOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyVecToCache(data); } - NDArray CopyLastPageLenOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyVecToCache(data); } - NDArray CopyKRoPEPosOffsetOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyVecToCache(data); } - NDArray CopyCurAppendLengthIndptrAsync(std::vector* data) final { + NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final { return CopyVecToCache(data); } - NDArray CopyKRaggedRoPEPosOffsetAsync(std::vector* data) final { + NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final { return CopyVecToCache(data); } - NDArray CopyQRoPEPosMapAsync(std::vector* data) final { return CopyVecToCache(data); } - NDArray CopyAppendPositionMapAsync(std::vector* data) final { - return CopyVecToCache(data); - } - NDArray CopyLengthInfoOnDepthAsync(std::vector* last_page_len, - std::vector* sliding_window_offset, - std::vector* sink_size, int depth) final { + NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { return CopyVecToCache(data); } + NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { return CopyVecToCache(data); } + NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, + HostMemoryVector* sliding_window_offset, + HostMemoryVector* sink_size, int depth) final { int64_t n_elem = last_page_len->size(); std::memcpy(merged_aux_data_host_.data() + copy_offset_, last_page_len->data(), n_elem * elem_byte_size_); @@ -559,7 +615,7 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { * \brief Copy the input data to the cache at the given offset. * And return the NDArray view of the cache starting at the offset. */ - NDArray CopyVecToCache(std::vector* data) { + NDArray CopyVecToCache(HostMemoryVector* data) { int64_t n_elem = data->size(); std::memcpy(merged_aux_data_host_.data() + copy_offset_, data->data(), n_elem * elem_byte_size_); @@ -579,7 +635,7 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { const int64_t offset_alignment_; int64_t copy_offset_ = 0; - std::vector merged_aux_data_host_; + HostMemoryVector merged_aux_data_host_; memory::Storage merged_aux_data_device_; }; @@ -687,17 +743,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Below are the auxiliary data structure on CPU. // We make them class members to avoid repetitive allocation time in BeginForward. //------------------------------------------- - std::vector> qo_indptr_on_depths_host_; - std::vector> page_indptr_on_depths_host_; - std::vector> page_indices_on_depths_host_; - std::vector> last_page_len_on_depths_host_; - std::vector> sliding_window_offset_on_depths_host_; - std::vector> sink_size_on_depths_host_; - std::vector> k_rope_pos_offset_on_depths_host_; - std::vector k_ragged_rope_pos_offset_host_; - std::vector q_rope_position_map_host_; - std::vector append_position_map_host_; - std::vector cur_append_lengths_indptr_host_; + std::vector qo_indptr_on_depths_host_; + std::vector page_indptr_on_depths_host_; + std::vector page_indices_on_depths_host_; + std::vector last_page_len_on_depths_host_; + std::vector sliding_window_offset_on_depths_host_; + std::vector sink_size_on_depths_host_; + std::vector k_rope_pos_offset_on_depths_host_; + HostMemoryVector k_ragged_rope_pos_offset_host_; + HostMemoryVector q_rope_position_map_host_; + HostMemoryVector append_position_map_host_; + HostMemoryVector cur_append_lengths_indptr_host_; //------------------------------------------- // For efficient memory management, the actual sizes of the arrays @@ -804,6 +860,33 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { pages_.push_back( NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size, head_dim}, dtype, device)); } + // Allocate the host memory. + Device preferred_host_device = GetPreferredHostDevice(device); + for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { + qo_indptr_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device)); + page_indptr_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device)); + page_indices_on_depths_host_.push_back( + HostMemoryVector(num_total_pages, dtype_aux_, preferred_host_device)); + last_page_len_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); + sliding_window_offset_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); + sink_size_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); + k_rope_pos_offset_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); + } + k_ragged_rope_pos_offset_host_ = + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device); + q_rope_position_map_host_ = + HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device); + append_position_map_host_ = + HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device); + cur_append_lengths_indptr_host_ = + HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device); + for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { temp_attn_workspace_.push_back( NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); @@ -847,10 +930,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // operations may have issues on other platforms. if (device_.device_type == DLDeviceType::kDLCUDA) { aux_data_manager_ = std::make_unique( - reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device, copy_stream_); + reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device, + preferred_host_device, copy_stream_); } else { aux_data_manager_ = std::make_unique( - reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device, copy_stream_); + reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device, + preferred_host_device, copy_stream_); } } @@ -1124,7 +1209,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { is_decode_request_ = true; sequences.reserve(cur_batch_size_); last_block_length_before_append.reserve(cur_batch_size_); - k_ragged_rope_pos_offset_host_.resize(cur_batch_size_); + k_ragged_rope_pos_offset_host_.clear(); for (int i = 0; i < cur_batch_size_; ++i) { auto it = seq_map_.find(seq_ids[i]); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i] @@ -1132,7 +1217,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { sequences.push_back(&it->second); last_block_length_before_append.push_back( global_block_pool_[it->second.last_block_idx].seq_length); - k_ragged_rope_pos_offset_host_[i] = it->second.seq_length; + k_ragged_rope_pos_offset_host_.push_back(it->second.seq_length); it->second.seq_length += append_lengths[i]; if (append_lengths[i] != 1) { is_decode_request_ = false; @@ -1162,22 +1247,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - qo_indptr_on_depths_host_.resize(num_depths_); - page_indptr_on_depths_host_.resize(num_depths_); - page_indices_on_depths_host_.resize(num_depths_); - last_page_len_on_depths_host_.resize(num_depths_); - sliding_window_offset_on_depths_host_.resize(num_depths_); - sink_size_on_depths_host_.resize(num_depths_); - k_rope_pos_offset_on_depths_host_.resize(num_depths_); - for (int d = 0; d < num_depths_; ++d) { - std::vector& qo_indptr_h = qo_indptr_on_depths_host_[d]; - std::vector& page_indptr_h = page_indptr_on_depths_host_[d]; - std::vector& page_indices_h = page_indices_on_depths_host_[d]; - std::vector& last_page_len_h = last_page_len_on_depths_host_[d]; - std::vector& sliding_window_offset_h = sliding_window_offset_on_depths_host_[d]; - std::vector& sink_size_h = sink_size_on_depths_host_[d]; - std::vector& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d]; + HostMemoryVector& qo_indptr_h = qo_indptr_on_depths_host_[d]; + HostMemoryVector& page_indptr_h = page_indptr_on_depths_host_[d]; + HostMemoryVector& page_indices_h = page_indices_on_depths_host_[d]; + HostMemoryVector& last_page_len_h = last_page_len_on_depths_host_[d]; + HostMemoryVector& sliding_window_offset_h = sliding_window_offset_on_depths_host_[d]; + HostMemoryVector& sink_size_h = sink_size_on_depths_host_[d]; + HostMemoryVector& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d]; qo_indptr_h.clear(); page_indptr_h.clear(); page_indices_h.clear(); @@ -1198,7 +1275,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } else { const Block& block = global_block_pool_[block_id]; page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size()); - page_indices_h.insert(page_indices_h.end(), block.page_ids.begin(), block.page_ids.end()); + for (int32_t page_id : block.page_ids) { + page_indices_h.push_back(page_id); + } last_page_len_h.push_back(block.seq_length == 0 ? 0 : (block.seq_length - block.sink_length + block.sliding_window_offset - 1) % @@ -1620,14 +1699,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (append_before_attn_) { if (!support_sliding_window_) { f_attention_decode_begin_forward_.value()( - /*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_view_[0], - length_info_on_depths_view_[0], num_qo_heads_, num_kv_heads_, head_dim_, page_size_, + /*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_host_[0].as_ndarray(), + last_page_len_on_depths_host_[0].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_, + page_size_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); } } else { f_attention_prefill_ragged_begin_forward_.value()( - temp_attn_workspace_[0], cur_append_length_indptr_view_, cur_batch_size_, num_qo_heads_, - num_kv_heads_, head_dim_, copy_stream_); + temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, + num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_); if (support_sliding_window_) { return; } @@ -1637,12 +1717,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } if (use_decode_kernel_[d]) { f_attention_decode_begin_forward_.value()( - d, temp_attn_workspace_[d + 1], page_indptr_on_depths_view_[d], - length_info_on_depths_view_[d], num_qo_heads_, num_kv_heads_, head_dim_, page_size_, + d, temp_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(), + last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, + head_dim_, page_size_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); } else { f_attention_prefill_begin_forward_.value()( - /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_view_[d], + /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(), length_info_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_); } @@ -1732,17 +1813,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { */ void SyncAuxArrayToDevice() { ICHECK(dtype_aux_.bits == 32 && dtype_aux_.code == kDLInt); - ICHECK_EQ(qo_indptr_on_depths_host_.size(), num_depths_); - ICHECK_EQ(page_indptr_on_depths_host_.size(), num_depths_); - ICHECK_EQ(page_indices_on_depths_host_.size(), num_depths_); - ICHECK_EQ(last_page_len_on_depths_host_.size(), num_depths_); int64_t total_append_length = 0; int num_sequences = cur_append_lengths_.size(); - cur_append_lengths_indptr_host_.resize(num_sequences + 1); - cur_append_lengths_indptr_host_[0] = 0; + cur_append_lengths_indptr_host_.clear(); + cur_append_lengths_indptr_host_.push_back(0); for (int i = 0; i < num_sequences; ++i) { - cur_append_lengths_indptr_host_[i + 1] = - cur_append_lengths_indptr_host_[i] + cur_append_lengths_[i]; + cur_append_lengths_indptr_host_.push_back(cur_append_lengths_indptr_host_.back() + + cur_append_lengths_[i]); } total_append_length = cur_append_lengths_indptr_host_.back(); ICHECK_EQ(total_append_length, append_position_map_host_.size());