From 3138328207bbe0b519c33a2f59be8ef2cf44d5b7 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 26 Aug 2024 21:20:05 -0400 Subject: [PATCH] [Runtime] Support KV cache with RoPE extension factor array (#17294) This PR enhances the KV cache with the RoPE extensio factor support. With this PR, the KV cache can support models like Phi3.5 which comes with the extension factor. --- src/runtime/relax_vm/kv_state.h | 1 + src/runtime/relax_vm/paged_kv_cache.cc | 63 +++++++++++-------- ...tin_paged_attention_kv_cache_flashinfer.py | 3 + ...me_builtin_paged_attention_kv_cache_tir.py | 1 + 4 files changed, 43 insertions(+), 25 deletions(-) diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index f4d6036b9638..6d30ce998add 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -167,6 +167,7 @@ class AttentionKVCacheObj : public KVStateObj { * `(total_length, num_qo_heads + 2 * num_kv_heads, head_dim)`. * \param mask The input mask data, in layout `(total_sqr_length)`. * \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`. + * \param attn_score_scaling_factor The additional attention scaling factor. * \sa AttentionKVCache::Attention */ virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 6bf3dc7ce609..591187ab5fe7 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -848,6 +848,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const double rotary_scale_; /*! \brief The RoPE theta. */ const double rotary_theta_; + /*! \brief The optional RoPE extension factors for RoPE scaling. */ + const Optional rope_ext_factors_; /*! \brief We fix int32 to be the index dtype of auxiliary data. */ const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1)); @@ -988,7 +990,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset, // int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window, - RoPEMode rope_mode, double rotary_scale, double rotary_theta, DLDataType dtype, Device device, + RoPEMode rope_mode, double rotary_scale, double rotary_theta, + Optional rope_ext_factors, DLDataType dtype, Device device, PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc f_attention_prefill, PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged, @@ -1013,6 +1016,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { : rope_mode), rotary_scale_(rotary_scale), rotary_theta_(rotary_theta), + rope_ext_factors_(std::move(rope_ext_factors)), f_transpose_append_(std::move(f_transpose_append)), f_compact_copy_(std::move(f_compact_copy)), f_attention_prefill_(std::move(f_attention_prefill)), @@ -1132,6 +1136,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device, preferred_host_device, copy_stream_); } + + // Right now only the "normal" RoPE mode supports the RoPE extention factors. + if (rope_ext_factors_.defined()) { + CHECK(rope_mode_ == RoPEMode::kNormal) + << "The RoPE mode must be normal to support RoPE extension factors."; + } } ~PagedAttentionKVCacheObj() { @@ -1726,8 +1736,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, head_dim_}, qkv_data->dtype); // Part 2. Split fused qkv and apply rotary embedding to q/k data. - f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, - static_cast(rope_mode_ == RoPEMode::kNormal)); + if (!rope_ext_factors_.defined()) { + f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + static_cast(rope_mode_ == RoPEMode::kNormal)); + } else { + f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + rope_ext_factors_.value()); + } // Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set. if (append_before_attn_) { @@ -2462,7 +2477,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27) + CHECK(args.size() == 27 || args.size() == 28) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; @@ -2499,14 +2514,12 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") PackedFunc f_split_rotary = args[22]; PackedFunc f_copy_single_page = args[23]; Optional f_debug_get_kv = args[24]; - PackedFunc f_compact_copy{nullptr}; - PackedFunc f_attention_prefill_with_tree_mask{nullptr}; + PackedFunc f_compact_copy = args[25]; + PackedFunc f_attention_prefill_with_tree_mask = args[26]; + Optional rope_ext_factors = NullOpt; - if (args.size() >= 26) { - f_compact_copy = args[25].AsObjectRef(); - } - if (args.size() >= 27) { - f_attention_prefill_with_tree_mask = args[26].AsObjectRef(); + if (args.size() >= 28 && args[27].IsObjectRef()) { + rope_ext_factors = args[27].AsObjectRef(); } CHECK_EQ(cache_config.size(), 5); @@ -2523,9 +2536,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") ObjectPtr n = make_object( page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, - RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device, - std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill), - std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), + RoPEMode(rope_mode), rotary_scale, rotary_theta, std::move(rope_ext_factors), // + init->dtype, init->device, std::move(f_transpose_append), std::move(f_compact_copy), + std::move(f_attention_prefill), std::move(f_attention_decode), + std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), std::move(f_attention_prefill_ragged_begin_forward), @@ -2539,7 +2553,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21) + CHECK(args.size() == 21 || args.size() == 22) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; @@ -2570,14 +2584,12 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") PackedFunc f_split_rotary = args[16]; PackedFunc f_copy_single_page = args[17]; Optional f_debug_get_kv = args[18]; - PackedFunc f_compact_copy{nullptr}; - PackedFunc f_attention_prefill_with_tree_mask{nullptr}; + PackedFunc f_compact_copy = args[19]; + PackedFunc f_attention_prefill_with_tree_mask = args[20]; + Optional rope_ext_factors = NullOpt; - if (args.size() >= 20) { - f_compact_copy = args[19].AsObjectRef(); - } - if (args.size() >= 21) { - f_attention_prefill_with_tree_mask = args[20].AsObjectRef(); + if (args.size() >= 22 && args[21].IsObjectRef()) { + rope_ext_factors = args[21].AsObjectRef(); } CHECK_EQ(cache_config.size(), 5); @@ -2594,9 +2606,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") ObjectPtr n = make_object( page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, - RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device, - std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill), - std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), + RoPEMode(rope_mode), rotary_scale, rotary_theta, std::move(rope_ext_factors), // + init->dtype, init->device, std::move(f_transpose_append), std::move(f_compact_copy), + std::move(f_attention_prefill), std::move(f_attention_decode), + std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), // NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index cab10f84cddf..2252cb8d9c09 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -379,6 +379,9 @@ def create_kv_cache(rope_mode): fsplit_rotary, fcopy_single_page, fcopy_cache, + None, + None, + None, ) return cache diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 96a2438505b2..ff655e141b96 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -180,6 +180,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): fcopy_cache, fcompact_copy, fattn_prefill_with_tree_mask, + None, ) return cache