Skip to content

Commit

Permalink
[KVCache] TIR attention kernel support for MLA (#17618)
Browse files Browse the repository at this point in the history
This PR introduces the MLA attention kernels written in TIR.
It also implements the KV cache MLA computation logic.

A new unit test file is added to ensure the correctness of the
TIR kernels.

This PR also fixes a few TIR prefill kernel tile size initialization.
  • Loading branch information
MasterJH5574 authored Feb 5, 2025
1 parent 9898039 commit 3eb5ad6
Show file tree
Hide file tree
Showing 8 changed files with 2,024 additions and 869 deletions.
1,895 changes: 1,317 additions & 578 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py

Large diffs are not rendered by default.

24 changes: 22 additions & 2 deletions python/tvm/relax/frontend/nn/llm/tree_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,17 @@ def tree_attn(

bdx = 32
num_warps = 4
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16
tile_x, tile_y, tile_z = (
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
d,
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
)
original_tile_y = tile_y
original_tile_z = tile_z
while (tile_x * tile_z) % (bdx * num_warps) != 0:
tile_z += original_tile_z
while (tile_x * tile_y) % (bdx * num_warps) != 0:
tile_y += original_tile_y

# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
Expand Down Expand Up @@ -881,7 +891,17 @@ def tree_attn_with_paged_kv_cache(

bdx = 32
num_warps = 4
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16
tile_x, tile_y, tile_z = (
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
d,
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
)
original_tile_y = tile_y
original_tile_z = tile_z
while (tile_x * tile_z) % (bdx * num_warps) != 0:
tile_z += original_tile_z
while (tile_x * tile_y) % (bdx * num_warps) != 0:
tile_y += original_tile_y

# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/relax_vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,21 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetQueryPositions);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKV);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv_mla")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKVMLA);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv")
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
double attn_score_scaling_factor, NDArray qkv_data, NDArray o_data) {
kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, std::move(o_data),
attn_score_scaling_factor);
});
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_absorbed")
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
double attn_score_scaling_factor, NDArray q_data, NDArray compressed_kv_data,
NDArray k_pe_data, NDArray o_data) {
kv_cache->MLAAbsorbed(layer_id, std::move(q_data), std::move(compressed_kv_data),
std::move(k_pe_data), std::move(o_data), attn_score_scaling_factor);
});

// RNN State methods
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method<RNNState>(&RNNStateObj::Get);
Expand Down
24 changes: 10 additions & 14 deletions src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,20 +181,6 @@ class AttentionKVCacheObj : public KVStateObj {
virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional<NDArray> mask,
NDArray o_data, double attn_score_scaling_factor) = 0;

/*!
* \brief Compute attention with Q/K/V data.
* \param layer_id The model layer where the attention compute happens.
* \param q_data The input Q data, in layout `(total_length, num_qo_heads, head_dim)`
* \param k_data The input K data, in layout `(total_length, num_kv_heads, head_dim)`
* \param v_data The input V data, in layout `(total_length, num_kv_heads, head_dim)`
* \param mask The input mask data, in layout `(total_sqr_length)`.
* \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`.
* \param attn_score_scaling_factor The additional attention scaling factor.
*/
virtual void AttentionWithSeparateQKV(int64_t layer_id, NDArray q_data, NDArray k_data,
NDArray v_data, Optional<NDArray> mask, NDArray o_data,
double attn_score_scaling_factor) = 0;

/*!
* \brief Compute multi-head latent attention after applying weight absorption.
* \param layer_id The model layer where the attention compute happens.
Expand Down Expand Up @@ -275,6 +261,16 @@ class AttentionKVCacheObj : public KVStateObj {
virtual void DebugGetKV(int64_t seq_id, //
int64_t start_pos, int64_t end_pos, NDArray k_data, NDArray v_data) = 0;

/*!
* \brief Fetch the compact K/V data of the given sequence for MLA cache.
* \param seq_id The sequence whose K/V data is to be fetched.
* \param start_pos The start position (inclusive) of the K/V data to fetch.
* \param end_pos The end position (exclusive) of the K/V data to fetch.
* \param kv_data The output KV data of the given sequence in layout elaborated above.
*/
virtual void DebugGetKVMLA(int64_t seq_id, int64_t start_pos, int64_t end_pos,
NDArray kv_data) = 0;

/*!
* \brief Set the K/V data of the given sequence from input K/V data.
* `start_pos` (inclusive) controls starting position of K/V data
Expand Down
Loading

0 comments on commit 3eb5ad6

Please sign in to comment.