Skip to content

Commit

Permalink
[Runtime][KVCache] Initial interface setup for MLA (#17616)
Browse files Browse the repository at this point in the history
This PR introduces the initial KV cache interface setup for multi-head
latent attention in DeepSeek models.

Some interface implementations are marked todo for implementation
in the soon future.
  • Loading branch information
MasterJH5574 authored Jan 31, 2025
1 parent cf9a3e1 commit 8b4df72
Show file tree
Hide file tree
Showing 2 changed files with 330 additions and 46 deletions.
63 changes: 63 additions & 0 deletions src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,69 @@ class AttentionKVCacheObj : public KVStateObj {
virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional<NDArray> mask,
NDArray o_data, double attn_score_scaling_factor) = 0;

/*!
* \brief Compute attention with Q/K/V data.
* \param layer_id The model layer where the attention compute happens.
* \param q_data The input Q data, in layout `(total_length, num_qo_heads, head_dim)`
* \param k_data The input K data, in layout `(total_length, num_kv_heads, head_dim)`
* \param v_data The input V data, in layout `(total_length, num_kv_heads, head_dim)`
* \param mask The input mask data, in layout `(total_sqr_length)`.
* \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`.
* \param attn_score_scaling_factor The additional attention scaling factor.
*/
virtual void AttentionWithSeparateQKV(int64_t layer_id, NDArray q_data, NDArray k_data,
NDArray v_data, Optional<NDArray> mask, NDArray o_data,
double attn_score_scaling_factor) = 0;

/*!
* \brief Compute multi-head latent attention after applying weight absorption.
* \param layer_id The model layer where the attention compute happens.
* \param q_data The input Q data, in layout `(total_length, num_qo_heads, qk_head_dim)`
* \param compressed_kv_data The compressed latent KV data, in layout
* `(total_length, num_kv_heads, kv_lora_rank)`
* \param k_pe_data The positional embedding part of K data, in layout
* `(total_length, num_kv_heads, qk_rope_head_dim)`, where `kv_lora_rank + qk_rope_head_dim`
* equals qk_head_dim
* \param o_data The output O data, in layout `(total_length, num_qo_heads, v_head_dim)`.
* \param attn_score_scaling_factor The additional attention scaling factor.
*/
virtual void MLAAbsorbed(int64_t layer_id, NDArray q_data, NDArray compressed_kv_data,
NDArray k_pe_data, NDArray o_data, double attn_score_scaling_factor) = 0;

/*!
* \brief Compute multi-head latent attention in normal style.
* \param layer_id The model layer where the attention compute happens.
* \param q_data The input Q data, in layout
* `(total_length, num_qo_heads, qk_nope_head_dim + qk_rope_head_dim)`
* \param k_data The input K data, in layout
* `(total_length, num_qo_heads, qk_nope_head_dim + qk_rope_head_dim)`
* \param v_data The input V data, in layout
* `(total_length, num_qo_heads, v_head_dim)`
* \param compressed_kv_data The compressed latent KV data, in layout
* `(total_length, num_kv_heads, kv_lora_rank)`
* \param k_pe_data The positional embedding part of K data, in layout
* `(total_length, num_kv_heads, qk_rope_head_dim)`, where `kv_lora_rank + qk_rope_head_dim`
* equals qk_head_dim
* \param o_data The output O data, in layout `(total_length, num_qo_heads, v_head_dim)`.
* \param attn_score_scaling_factor The additional attention scaling factor.
*/
virtual void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,
NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data,
double attn_score_scaling_factor) = 0;

/*!
* \brief Compute linear attention with Q/K/V data.
* \param layer_id The model layer where the attention compute happens.
* \param q_data The input Q data, in layout `(total_length, num_qo_heads, head_dim)`.
* \param k_data The input K data, in layout `(total_length, num_kv_heads, head_dim)`.
* \param v_data The input V data, in layout `(total_length, num_kv_heads, head_dim)`.
* \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`.
* \param attn_score_scaling_factor The additional attention scaling factor.
* \sa AttentionKVCache::Attention
*/
virtual void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,
double attn_score_scaling_factor) = 0;

/************** Positions **************/

/*!
Expand Down
Loading

0 comments on commit 8b4df72

Please sign in to comment.