From 3eb5ad6711bc15a336fa2ffe676abaa2b0544a93 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 5 Feb 2025 13:48:00 -0500 Subject: [PATCH] [KVCache] TIR attention kernel support for MLA (#17618) This PR introduces the MLA attention kernels written in TIR. It also implements the KV cache MLA computation logic. A new unit test file is added to ensure the correctness of the TIR kernels. This PR also fixes a few TIR prefill kernel tile size initialization. --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 1895 ++++++++++++----- python/tvm/relax/frontend/nn/llm/tree_attn.py | 24 +- src/runtime/relax_vm/kv_state.cc | 9 + src/runtime/relax_vm/kv_state.h | 24 +- src/runtime/relax_vm/paged_kv_cache.cc | 226 +- ...tin_paged_attention_kv_cache_flashinfer.py | 229 +- ...uiltin_paged_attention_kv_cache_mla_tir.py | 456 ++++ ...me_builtin_paged_attention_kv_cache_tir.py | 30 +- 8 files changed, 2024 insertions(+), 869 deletions(-) create mode 100644 tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 844b237381a0..f5ff0105d0f2 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -77,6 +77,16 @@ def check_thread_limits(target: Target, bdx: int, bdy: int, bdz: int, gdz: int): assert gdz == 1, f"webgpu's blockIdx.z should be 1, but got gdz={gdz}" +class AttnKind(enum.IntEnum): + """The attention kind class. + MHA denotes multi-head attention, multi-query attention or grouped query attention. + MLA denotes multi-head latent attention. + """ + + MHA = 0 + MLA = 1 + + class RopeMode(enum.IntEnum): """The RoPE mode of the Paged KV cache. If it is none, the KV cache will not apply RoPE to q and k. @@ -129,6 +139,47 @@ def attention_with_fused_qkv( ) ).reshape(b, s, num_qo_heads, d) + def mla_absorbed( + self, + layer_id: int, + q: Tensor, + compressed_kv: Tensor, + k_pe: Tensor, + attn_score_scaling_factor: float = 1.0, + ) -> Tensor: + """Compute multi-head latent attention with the given data + on the specified layer with the weight absorption optimization. + + - For prefill, the input q/kv and output tensor have shape + (1, total_seq_len) for the first two dimensions. + - For decode, the input q/kv and output tensor have shape + (batch_size, 1) for the first two dimensions. + """ + # pylint: disable=protected-access + b, s, h_qo, d_qk = q._expr.struct_info.shape + kv_lora_rank = compressed_kv._expr.struct_info.shape[3] + qk_rope_head_dim = k_pe._expr.struct_info.shape[3] + q = q.reshape(b * s, h_qo, d_qk) + compressed_kv = compressed_kv.reshape(b * s, kv_lora_rank) + k_pe = k_pe.reshape(b * s, qk_rope_head_dim) + + return Tensor( + _expr=rx.BlockBuilder.current().emit( + rx.call_dps_packed( + "vm.builtin.attention_kv_cache_mla_absorbed", + [ + self._expr, + rx.PrimValue(layer_id), # type: ignore[arg-type] + rx.PrimValue(attn_score_scaling_factor), + q._expr, + compressed_kv._expr, + k_pe._expr, + ], + out_sinfo=rx.TensorStructInfo((b * s, h_qo, kv_lora_rank), q.dtype), + ) + ) + ).reshape(b, s, h_qo, kv_lora_rank) + def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor: """Get the in-sequence positions of each slot in the query, which are needed for applying positional embeddings in some models. @@ -385,224 +436,51 @@ def __init__( # pylint: disable=too-many-locals ] if str(target.kind) == "llvm": + # pylint: disable=line-too-long + # fmt: off args.extend( [ - bb.add_func( - _attention_prefill_cpu( - num_key_value_heads, - num_attention_heads, - head_dim, - dtype, - False, - rope_scaling, - ), - "tir_attention_prefill_cpu", - ), - bb.add_func( - _attention_decode_cpu( - num_key_value_heads, - num_attention_heads, - head_dim, - dtype, - False, - rope_scaling, - ), - "tir_attention_decode_cpu", - ), - bb.add_func( - _attention_prefill_cpu( - num_key_value_heads, - num_attention_heads, - head_dim, - dtype, - True, - rope_scaling, - ), - "tir_attention_prefill_cpu_sliding_window", - ), - bb.add_func( - _attention_decode_cpu( - num_key_value_heads, - num_attention_heads, - head_dim, - dtype, - True, - rope_scaling, - ), - "tir_attention_decode_cpu_sliding_window", - ), - bb.add_func( - _attention_prefill_ragged_cpu( - num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling - ), - "tir_attention_prefill_ragged_cpu", - ), - bb.add_func( - _merge_state_inplace_cpu(dtype), - "tir_attention_merge_state_cpu", - ), - bb.add_func( - llama_rope_with_position_map( - rope_theta, - rope_scale, - head_dim, - num_attention_heads, - num_key_value_heads, - dtype, - rope_scaling, - rotary_dim, - ), - "tir_split_rotary", - ), - bb.add_func( - _copy_single_page_cpu(num_key_value_heads, page_size, head_dim, dtype), - "kv_cache_copy_single_page_cpu", - ), - bb.add_func( - _kv_cache_debug_get_kv( - num_hidden_layers, num_key_value_heads, head_dim, dtype - ), - "kv_cache_debug_get_kv", - ), - bb.add_func( - _compact_kv_copy_cpu(num_key_value_heads, head_dim, dtype), - "kv_cache_compact_kv_copy_cpu", - ), - bb.add_func( - tree_attn_cpu( - num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling - ), - "tir_attention_prefill_with_tree_mask_cpu", - ), - bb.add_func( - tree_attn_with_paged_kv_cache_cpu( - num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling - ), - "tir_attention_prefill_with_tree_mask_with_paged_kv_cache_cpu", - ), + bb.add_func(_attention_prefill_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling), "tir_attention_prefill_cpu"), + bb.add_func(_attention_decode_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling), "tir_attention_decode_cpu"), + bb.add_func(_attention_prefill_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling), "tir_attention_prefill_cpu_sliding_window"), + bb.add_func(_attention_decode_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling), "tir_attention_decode_cpu_sliding_window"), + bb.add_func(_attention_prefill_ragged_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling), "tir_attention_prefill_ragged_cpu"), + bb.add_func(_merge_state_inplace_cpu(dtype), "tir_attention_merge_state_cpu"), + bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), + bb.add_func(_copy_single_page_cpu(num_key_value_heads, page_size, head_dim, dtype), "kv_cache_copy_single_page_cpu"), + bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), + bb.add_func(_compact_kv_copy_cpu(num_key_value_heads, head_dim, dtype), "kv_cache_compact_kv_copy_cpu"), + bb.add_func(tree_attn_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling), "tir_attention_prefill_with_tree_mask_cpu"), + bb.add_func(tree_attn_with_paged_kv_cache_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache_cpu"), rope_ext_factors, rx.PrimValue(enable_disaggregation), ] ) + # fmt: on + # pylint: enable=line-too-long else: + # pylint: disable=line-too-long + # fmt: off args.extend( [ - bb.add_func( - _attention_prefill( - num_key_value_heads, - num_attention_heads, - head_dim, - dtype, - False, - rope_scaling, - target, - ), - "tir_attention_prefill", - ), - bb.add_func( - _attention_decode( - num_key_value_heads, - num_attention_heads, - head_dim, - dtype, - False, - rope_scaling, - target, - ), - "tir_attention_decode", - ), - bb.add_func( - _attention_prefill( - num_key_value_heads, - num_attention_heads, - head_dim, - dtype, - True, - rope_scaling, - target, - ), - "tir_attention_prefill_sliding_window", - ), - bb.add_func( - _attention_decode( - num_key_value_heads, - num_attention_heads, - head_dim, - dtype, - True, - rope_scaling, - target, - ), - "tir_attention_decode_sliding_window", - ), - bb.add_func( - _attention_prefill_ragged( - num_key_value_heads, - num_attention_heads, - head_dim, - dtype, - rope_scaling, - target, - ), - "tir_attention_prefill_ragged", - ), - bb.add_func( - _merge_state_inplace(num_attention_heads, head_dim, dtype, target), - "tir_attention_merge_state", - ), - bb.add_func( - llama_rope_with_position_map( - rope_theta, - rope_scale, - head_dim, - num_attention_heads, - num_key_value_heads, - dtype, - rope_scaling, - rotary_dim, - ), - "tir_split_rotary", - ), - bb.add_func( - _copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), - "kv_cache_copy_single_page", - ), - bb.add_func( - _kv_cache_debug_get_kv( - num_hidden_layers, num_key_value_heads, head_dim, dtype - ), - "kv_cache_debug_get_kv", - ), - bb.add_func( - _compact_kv_copy(num_key_value_heads, head_dim, dtype, target), - "kv_cache_compact_kv_copy", - ), - bb.add_func( - tree_attn( - num_key_value_heads, - num_attention_heads, - head_dim, - dtype, - rope_scaling, - target, - ), - "tir_attention_prefill_with_tree_mask", - ), - bb.add_func( - tree_attn_with_paged_kv_cache( - num_key_value_heads, - num_attention_heads, - head_dim, - dtype, - rope_scaling, - target, - ), - "tir_attention_prefill_with_tree_mask_with_paged_kv_cache", - ), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling, target), "tir_attention_decode"), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window"), + bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged"), + bb.add_func(_merge_state_inplace(num_attention_heads, head_dim, dtype, target), "tir_attention_merge_state"), + bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), + bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), + bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), + bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"), + bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"), rope_ext_factors, rx.PrimValue(enable_disaggregation), ] ) + # fmt: on + # pylint: enable=line-too-long super().__init__( _expr=rx.call_pure_packed( @@ -613,6 +491,129 @@ def __init__( # pylint: disable=too-many-locals _name=name, ) + @staticmethod + def create_mla_kv_cache( # pylint: disable=too-many-locals + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + layer_partition: rx.ShapeExpr, + num_hidden_layers: int, + num_attention_heads: int, + num_key_value_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + kv_lora_rank: int, + enable_disaggregation: bool, + dtype: str, + target: Target, + name: str = "paged_kv_cache", + ) -> PagedKVCache: + """Create a paged KV cache object with TIR kernels with multi-head latent attention. + + Parameters + ---------- + max_batch_size : tir.Var + The maximum allowed batch size of the KV cache. + It is a symbolic variable whose concrete value is specified + at runtime. + max_total_seq_len : tir.Var + The maximum allowed total sequence length of the KV cache. + It is a symbolic variable whose concrete value is specified + at runtime. + prefill_chunk_size : tir.Var + The maximum total sequence length in a prefill. + It is a symbolic variable whose concrete value is specified + at runtime. + page_size : tir.Var + The size (a.k.a. number of tokens) of each page. + It is a symbolic variable whose concrete value is specified + at runtime. + support_sliding_window : tir.Var + 0 or 1, denoting whether the KV cache supports sliding window. + It is a symbolic variable whose concrete value is specified + at runtime. + layer_partition : rx.ShapeExpr + The KV cache layer partition for pipeline stages. + It is an indptr array, denoting the starting layer of each pipeline stage. + qk_nope_head_dim : int + The head dim size (RoPE excluded) for queries and keys in MLA. + qk_rope_head_dim : int + The head dim size (RoPE included) for queries and keys in MLA. + v_head_dim : int + The head dim size for values in MLA. + kv_lora_rank : int + The LoRA rank for keys and values in MLA. + enable_disaggregation : bool + Whether to enable disaggregation in the KV cache. + target : Target + The target to build the model to. + """ + + bb = rx.BlockBuilder.current() + args = [ + rx.ShapeExpr( + [ + max_batch_size, + max_total_seq_len, + prefill_chunk_size, + page_size, + support_sliding_window, + ] + ), + layer_partition, + rx.PrimValue(num_attention_heads), + rx.PrimValue(1), + rx.PrimValue(kv_lora_rank + qk_rope_head_dim), + rx.PrimValue(kv_lora_rank), + rx.PrimValue(qk_rope_head_dim), + rx.ShapeExpr([int(AttnKind.MLA) for _ in range(num_hidden_layers)]), + rx.PrimValue(RopeMode.NONE), + rx.PrimValue(1), + rx.PrimValue(10000), + rx.op.zeros((), dtype), + # pylint: disable=line-too-long + # fmt: off + bb.add_func(_kv_cache_transpose_append(num_key_value_heads, v_head_dim, dtype), "kv_cache_transpose_append"), + bb.add_func(_kv_cache_transpose_append_mla(kv_lora_rank, qk_rope_head_dim, dtype), "kv_cache_transpose_append_mla"), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, False, {}, target), "tir_attention_prefill"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, False, {}, target), "tir_attention_decode"), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, True, {}, target), "tir_attention_prefill_sliding_window"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, True, {}, target), "tir_attention_decode_sliding_window"), + bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, {}, target), "tir_attention_prefill_ragged"), + rx.PrimValue(0), + rx.PrimValue(0), + rx.PrimValue(0), + rx.PrimValue(0), + rx.PrimValue(0), + rx.PrimValue(0), + bb.add_func(_attention_prefill_mla(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target), "tir_attention_prefill_mla"), + bb.add_func(_attention_decode_mla(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target), "tir_attention_decode_mla"), + bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, v_head_dim, dtype, {}, target), "tir_attention_prefill_ragged_mla_normal"), + bb.add_func(_attention_prefill_ragged_mla_absorbed(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, target), "tir_attention_prefill_ragged_mla_absorbed"), + bb.add_func(_merge_state_inplace(num_attention_heads, kv_lora_rank, dtype, target), "tir_attention_merge_state"), + bb.add_func(llama_rope_with_position_map(10000, 1, qk_rope_head_dim, num_attention_heads, num_key_value_heads, dtype, {}, None), "tir_split_rotary"), + bb.add_func(_copy_single_page_mla(page_size, kv_lora_rank + qk_rope_head_dim, dtype, target), "kv_cache_copy_single_page_mla"), + bb.add_func(_kv_cache_debug_get_kv_mla(num_hidden_layers, kv_lora_rank + qk_rope_head_dim, dtype), "kv_cache_debug_get_kv_mla"), + bb.add_func(_compact_kv_copy(num_key_value_heads, qk_nope_head_dim, dtype, target), "kv_cache_compact_kv_copy"), + bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, {}, target), "tir_attention_prefill_with_tree_mask"), + bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, {}, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"), + rx.PrimValue(0), + rx.PrimValue(enable_disaggregation), + # fmt: on + # pylint: enable=line-too-long + ] + return PagedKVCache( + _expr=rx.call_pure_packed( + "vm.builtin.paged_attention_kv_cache_create_reduced_mla", + *args, + sinfo_args=rx.ObjectStructInfo(), + ), + _name=name, + ) + # mypy: disable-error-code="attr-defined,valid-type,no-redef" # pylint: disable=too-many-locals @@ -661,6 +662,43 @@ def tir_kv_cache_transpose_append( return tir_kv_cache_transpose_append +def _kv_cache_transpose_append_mla(kv_lora_rank: int, qk_rope_head_dim: int, dtype): + """Return the TIR function that appends new compressed KV data to PagedKVCache for MLA.""" + + # pylint: disable=line-too-long + # fmt: off + @T.prim_func + def tir_kv_cache_transpose_append_mla( + var_pages: T.handle, + var_compressed_kv_data: T.handle, + var_k_pe_data: T.handle, + var_position_map: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + ntoken = T.SizeVar("num_tokens_excluding_cache", "int64") + num_pages = T.int64() + pages_elem_offset = T.int64() + position_map_elem_offset = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 16, kv_lora_rank + qk_rope_head_dim), dtype, elem_offset=pages_elem_offset) + compressed_kv_data = T.match_buffer(var_compressed_kv_data, (ntoken, kv_lora_rank), dtype) + k_pe_data = T.match_buffer(var_k_pe_data, (ntoken, qk_rope_head_dim), dtype) + position_map = T.match_buffer( + var_position_map, (ntoken,), "int32", elem_offset=position_map_elem_offset + ) + for global_pos, f in T.grid(ntoken, kv_lora_rank + qk_rope_head_dim): + if position_map[global_pos] != T.int32(-1): + with T.block("k_transpose_append"): + vgpos, vf = T.axis.remap("SS", [global_pos, f]) + T.reads(position_map[vgpos], compressed_kv_data[vgpos, vf], k_pe_data[vgpos, vf - kv_lora_rank]) + T.writes(pages[position_map[vgpos] // 16, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] # type: ignore + pages[T.floordiv(position, 16), T.floormod(position, 16), vf] = T.if_then_else(vf < kv_lora_rank, compressed_kv_data[vgpos, vf], k_pe_data[vgpos, vf - kv_lora_rank]) + # fmt: on + # pylint: enable=line-too-long + + return tir_kv_cache_transpose_append_mla + + def _kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype): """Return the TIR function that fetches the k/v data on given positions and layer.""" @@ -700,6 +738,42 @@ def tir_kv_cache_debug_get_kv( return tir_kv_cache_debug_get_kv +def _kv_cache_debug_get_kv_mla(num_hidden_layers, d_qk, dtype): + """Return the TIR function that fetches the k/v data on given positions and layer.""" + + # pylint: disable=line-too-long + # fmt: off + @T.prim_func + def tir_kv_cache_debug_get_kv_mla( + var_pages: T.handle, + var_position_map: T.handle, + var_compressed_kv_with_k_pe_data: T.handle, + layer_id: T.int64, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + seqlen = T.SizeVar("num_tokens_including_cache", "int64") + page_size = T.SizeVar("page_size", "int64") + num_pages = T.int64() + pages_elem_offset = T.int64() + position_map_elem_offset = T.int64() + pages = T.match_buffer(var_pages, (num_pages, page_size, d_qk), dtype, elem_offset=pages_elem_offset) + position_map = T.match_buffer( + var_position_map, (seqlen,), "int32", elem_offset=position_map_elem_offset + ) + compressed_kv_with_k_pe_data = T.match_buffer(var_compressed_kv_with_k_pe_data, (num_hidden_layers, seqlen, d_qk), dtype) + for p, d in T.grid(seqlen, d_qk): + with T.block("copy0"): + vp, vd = T.axis.remap("SS", [p, d]) + T.reads(position_map[vp], pages[position_map[vp] // page_size, position_map[vp] % page_size, vd]) + T.writes(compressed_kv_with_k_pe_data[layer_id, vp, vd]) + position: T.int32 = position_map[vp] # type: ignore[name-defined] + compressed_kv_with_k_pe_data[layer_id, vp, vd] = pages[T.floordiv(position, page_size), T.floormod(position, page_size), vd] + # fmt: on + # pylint: enable=line-too-long + + return tir_kv_cache_debug_get_kv_mla + + def _rope( buffer: T.Buffer, offset: tir.Var, @@ -924,9 +998,7 @@ def batch_prefill_paged_kv_cpu( return batch_prefill_paged_kv_cpu -def _attention_prefill( - h_kv, h_q, d, dtype, sliding_window: bool, rope_scaling: Dict[str, Any], target: Target -): +def _get_prefill_kernel_config(h_kv, h_q, d, dtype, target: Target): NUM_BLKS = 16 LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes group_size = h_q // h_kv @@ -934,7 +1006,17 @@ def _attention_prefill( bdx = 32 num_warps = 4 - tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + tile_x, tile_y, tile_z = ( + 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), + d, + 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), + ) + original_tile_y = tile_y + original_tile_z = tile_z + while (tile_x * tile_z) % (bdx * num_warps) != 0: + tile_z += original_tile_z + while (tile_x * tile_y) % (bdx * num_warps) != 0: + tile_y += original_tile_y # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( @@ -943,51 +1025,182 @@ def _attention_prefill( ): tile_z = 8 num_warps = 2 + if target.kind.name == "opencl" and ( + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) + ): + LOAD_VEC = 16 // ((DataType(dtype).bits + 7) // 8) # 16 bytes + NUM_BLKS = group_size * 8 + check_thread_limits(target, bdx=bdx, bdy=num_warps, bdz=1, gdz=1) - global_symbol = "batch_prefill_paged_kv" - if sliding_window: - global_symbol += "_sliding_window" + return NUM_BLKS, LOAD_VEC, group_size, sm_scale, bdx, num_warps, tile_x, tile_y, tile_z - # pylint: disable=line-too-long,too-many-branches - # fmt: off - @T.prim_func - def batch_prefill_paged_kv( - _0: T.int32, # pylint: disable=unused-argument - var_q: T.handle, # [total_len, h_q, d] - var_q_indptr: T.handle, # [batch_size + 1] - var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] - var_page_indptr: T.handle, # [batch_size + 1] - var_page_values: T.handle, # [nnz_pages] - var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] - var_k_rope_pos_offset: T.handle, # [b] - var_q_rope_position: T.handle, # [total_len] - var_output: T.handle, # [total_len, h_q, d] - var_lse: T.handle, # [total_len, h_q] - causal: T.int32, - rotary_mode: T.int32, - rope_scale: T.float32, - rope_theta: T.float32, - attn_score_scaling_factor: T.float32, - ): - T.func_attr({"global_symbol": global_symbol}) - batch_size = T.int32(is_size_var=True) - total_len = T.int32(is_size_var=True) - nnz_pages = T.int32(is_size_var=True) - max_num_pages = T.int32(is_size_var=True) - pages_elem_offset = T.int64(is_size_var=True) - q_indptr_elem_offset = T.int32(is_size_var=True) - page_indptr_elem_offset = T.int32(is_size_var=True) - page_values_elem_offset = T.int32(is_size_var=True) - k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) - q_rope_position_elem_offset = T.int32(is_size_var=True) - length_info_elem_offset = T.int32(is_size_var=True) - q = T.match_buffer(var_q, (total_len, h_q, d), dtype) - q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) - pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype, elem_offset=pages_elem_offset) - page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) - page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) +def _schedule_prefill_kernel( + sch: tir.Schedule, + load_vec, + bdx, + num_warps, + tile_x, + tile_y, + tile_z, + transform_k_load: bool, + merged_qk_load: bool, +) -> tir.Schedule: + get_extent = lambda *lps: [int(sch.get(lp).extent) for lp in lps] + + def get_vecsize(extent): + return min(load_vec, (extent & ~(extent - 1))) + + def getxy_vecsize(x, y, t): + assert (x * y) % t == 0 + return min(get_vecsize(y), get_vecsize(x * y // t)) + + def get_tile_size(x, y, t): + cnt = (x * y) // t + assert (x * y) % t == 0 + tile_y = (int)(math.ceil(math.sqrt(cnt))) + while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt: + tile_y += 1 + assert tile_y <= cnt + tile_x = cnt // tile_y + return tile_x, tile_y + + def apply_to_qkv_load(sch: tir.Schedule, block): + loop_x, loop_y = sch.get_loops(block)[-2:] + x_extent, y_extent = get_extent(loop_x, loop_y) + vec_size = getxy_vecsize(x_extent, y_extent, bdx * num_warps) + yo, yv = sch.split(loop_y, [None, vec_size]) + yo_extent = y_extent // vec_size + tile_x, tile_y = get_tile_size(x_extent, yo_extent, (bdx * num_warps)) + xo, xi = sch.split(loop_x, [tile_x, None]) + yo, yi = sch.split(yo, [tile_y, None]) + sch.reorder(xi, yi, xo, yo) + t = sch.fuse(xi, yi) + ty, tx = sch.split(t, [num_warps, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(yv) + + def apply_to_so_ewise(sch: tir.Schedule, block, tile): + loop_x, loop_y = sch.get_loops(block)[-2:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + yiv_extent = get_vecsize(tile[1]) + yio, yiv = sch.split(yi, [None, yiv_extent]) + sch.unroll(yio) + sch.vectorize(yiv) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def apply_to_gemm(sch: tir.Schedule, block, tile, r_len=16, k_major=False): + loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + ko, ki = sch.split(loop_z, factors=[None, r_len]) + if k_major: + sch.reorder(ko, xi, yi, ki) + else: + sch.reorder(ko, ki, xi, yi) + yiv_extent = get_vecsize(tile[1]) + yio, yiv = sch.split(yi, [None, yiv_extent]) + sch.unroll(yio) + sch.vectorize(yiv) + sch.unroll(xi) + sch.decompose_reduction(block, ty) + + def apply_to_md(sch, block): + loop = sch.get_loops(block)[-1] + _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + if transform_k_load and not merged_qk_load: + sch.transform_layout("K_load", ("write", 0), lambda i, j: (j, i)) + tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) + tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) + apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, k_major=True) + apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, k_major=False) + apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_block("Q_load")) + if not merged_qk_load: + apply_to_qkv_load(sch, sch.get_block("K_load")) + apply_to_qkv_load(sch, sch.get_block("V_load")) + else: + apply_to_qkv_load(sch, sch.get_block("KV_load")) + apply_to_md(sch, sch.get_block("lse_store")) + return sch + + +def _attention_prefill( + h_kv, h_q, d, dtype, sliding_window: bool, rope_scaling: Dict[str, Any], target: Target +): + ( + NUM_BLKS, + LOAD_VEC, + group_size, + sm_scale, + bdx, + num_warps, + tile_x, + tile_y, + tile_z, + ) = _get_prefill_kernel_config(h_kv, h_q, d, dtype, target) + + global_symbol = "batch_prefill_paged_kv" + if sliding_window: + global_symbol += "_sliding_window" + + # pylint: disable=line-too-long,too-many-branches + # fmt: off + @T.prim_func + def batch_prefill_paged_kv( + _0: T.int32, # pylint: disable=unused-argument + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] + var_page_indptr: T.handle, # [batch_size + 1] + var_page_values: T.handle, # [nnz_pages] + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] + var_k_rope_pos_offset: T.handle, # [b] + var_q_rope_position: T.handle, # [total_len] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + causal: T.int32, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + ): + T.func_attr({"global_symbol": global_symbol}) + batch_size = T.int32(is_size_var=True) + total_len = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + pages_elem_offset = T.int64(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (total_len, h_q, d), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) + pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype, elem_offset=pages_elem_offset) + page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) + page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", elem_offset=q_rope_position_elem_offset) output = T.match_buffer(var_output, (total_len, h_q, d), dtype) @@ -1221,73 +1434,9 @@ def batch_prefill_paged_kv( # fmt: on # pylint: enable=line-too-long,too-many-branches sch = tir.Schedule(batch_prefill_paged_kv) - - def get_tile_size(x, y, t): - cnt = (x * y) // t - assert (x * y) % t == 0 - tile_y = (int)(math.ceil(math.sqrt(cnt))) - while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt: - tile_y += 1 - assert tile_y <= cnt - tile_x = cnt // tile_y - return tile_x, tile_y - - def apply_to_qkv_load(sch: tir.Schedule, block): - loop_x, loop_y = sch.get_loops(block)[-2:] - loop = sch.fuse(loop_x, loop_y) - _, ty, tx, vec = sch.split( - loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True - ) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - sch.vectorize(vec) - - def apply_to_so_ewise(sch: tir.Schedule, block, tile): - loop_x, loop_y = sch.get_loops(block)[-2:] - xo, xi = sch.split(loop_x, factors=[None, tile[0]]) - yo, yi = sch.split(loop_y, factors=[None, tile[1]]) - sch.reorder(xo, yo, xi, yi) - t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[None, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - def apply_to_gemm( # pylint: disable=unused-argument - sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False - ): - loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] - xo, xi = sch.split(loop_x, factors=[None, tile[0]]) - yo, yi = sch.split(loop_y, factors=[None, tile[1]]) - sch.reorder(xo, yo, xi, yi) - t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[None, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - ko, ki = sch.split(loop_z, factors=[None, r_len]) - if k_major: - sch.reorder(ko, xi, yi, ki) - else: - sch.reorder(ko, ki, xi, yi) - sch.decompose_reduction(block, ty) - - def apply_to_md(sch, block): - loop = sch.get_loops(block)[-1] - _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) - tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) - apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) - apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) - apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) - apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) - apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) - apply_to_qkv_load(sch, sch.get_block("Q_load")) - apply_to_qkv_load(sch, sch.get_block("K_load")) - apply_to_qkv_load(sch, sch.get_block("V_load")) - apply_to_md(sch, sch.get_block("lse_store")) + sch = _schedule_prefill_kernel( + sch, LOAD_VEC, bdx, num_warps, tile_x, tile_y, tile_z, False, False + ) return sch.mod["main"].with_attr("tir.is_scheduled", 1) @@ -1850,21 +1999,17 @@ def merge_state_inplace( def _attention_sequence_prefill( h_kv, h_q, d, dtype, target: Target, causal=0, attn_score_scaling_factor=1.0 ): # pylint: disable=line-too-long - LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes - group_size = h_q // h_kv - sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) - - bdx = 32 - num_warps = 4 - tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 - - # Otherwise we would exceed maxComputeWorkgroupStorageSize - if ( - str(target.kind) == "webgpu" - and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 - ): - tile_z = 8 - num_warps = 2 + ( + _, + LOAD_VEC, + group_size, + sm_scale, + bdx, + num_warps, + tile_x, + tile_y, + tile_z, + ) = _get_prefill_kernel_config(h_kv, h_q, d, dtype, target) # fmt: off @T.prim_func @@ -2099,99 +2244,32 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches # fmt: on # pylint: enable=line-too-long,too-many-branches sch = tir.Schedule(batch_sequence_prefill_kv) + sch = _schedule_prefill_kernel( + sch, LOAD_VEC, bdx, num_warps, tile_x, tile_y, tile_z, False, False + ) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) - def get_tile_size(x, y, t): - cnt = (x * y) // t - assert (x * y) % t == 0 - tile_y = (int)(math.ceil(math.sqrt(cnt))) - while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt: - tile_y += 1 - assert tile_y <= cnt - tile_x = cnt // tile_y - return tile_x, tile_y - - def apply_to_qkv_load(sch: tir.Schedule, block): - loop_x, loop_y = sch.get_loops(block)[-2:] - loop = sch.fuse(loop_x, loop_y) - _, ty, tx, vec = sch.split( - loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True - ) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - sch.vectorize(vec) - def apply_to_so_ewise(sch: tir.Schedule, block, tile): - loop_x, loop_y = sch.get_loops(block)[-2:] - xo, xi = sch.split(loop_x, factors=[None, tile[0]]) - yo, yi = sch.split(loop_y, factors=[None, tile[1]]) - sch.reorder(xo, yo, xi, yi) - t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[None, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") +def _attention_prefill_ragged_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]): + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) - def apply_to_gemm( # pylint: disable=unused-argument - sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False - ): - loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] - xo, xi = sch.split(loop_x, factors=[None, tile[0]]) - yo, yi = sch.split(loop_y, factors=[None, tile[1]]) - sch.reorder(xo, yo, xi, yi) - t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[None, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - ko, ki = sch.split(loop_z, factors=[None, r_len]) - if k_major: - sch.reorder(ko, xi, yi, ki) - else: - sch.reorder(ko, ki, xi, yi) - sch.decompose_reduction(block, ty) - - def apply_to_md(sch, block): - loop = sch.get_loops(block)[-1] - _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - def apply_schedule(sch): - tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) - tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) - apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) - apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) - apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) - apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) - apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) - apply_to_qkv_load(sch, sch.get_block("Q_load")) - apply_to_qkv_load(sch, sch.get_block("K_load")) - apply_to_qkv_load(sch, sch.get_block("V_load")) - - apply_schedule(sch) - apply_to_md(sch, sch.get_block("lse_store")) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -def _attention_prefill_ragged_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]): - group_size = h_q // h_kv - sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) - - @T.prim_func - def batch_prefill_ragged_kv( # pylint: disable=too-many-branches - var_q: T.handle, # [total_len, h_q, d] - var_q_indptr: T.handle, # [batch_size + 1] - var_k: T.handle, # [total_len, h_kv, d] - var_v: T.handle, # [total_len, h_kv, d] - var_kv_indptr: T.handle, # [batch_size + 1] - var_q_rope_position: T.handle, # [total_q_len] - var_k_rope_pos_offset: T.handle, # [b] - var_output: T.handle, # [total_len, h_q, d] - var_lse: T.handle, # [total_len, h_q] - causal: T.int32, - rotary_mode: T.int32, - rope_scale: T.float32, - rope_theta: T.float32, - attn_score_scaling_factor: T.float32, + @T.prim_func + def batch_prefill_ragged_kv( # pylint: disable=too-many-branches + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_k: T.handle, # [total_len, h_kv, d] + var_v: T.handle, # [total_len, h_kv, d] + var_kv_indptr: T.handle, # [batch_size + 1] + var_q_rope_position: T.handle, # [total_q_len] + var_k_rope_pos_offset: T.handle, # [b] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + causal: T.int32, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, ): batch_size = T.int32(is_size_var=True) qo_len = T.int32(is_size_var=True) @@ -2343,38 +2421,17 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target): # pylint: disable=line-too-long - NUM_BLKS = 16 - LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes - group_size = h_q // h_kv - sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) - - bdx = 32 - num_warps = 4 - tile_x, tile_y, tile_z = ( - 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), - d, - 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), - ) - original_tile_y = tile_y - original_tile_z = tile_z - while (tile_x * tile_z) % (bdx * num_warps) != 0: - tile_z += original_tile_z - while (tile_x * tile_y) % (bdx * num_warps) != 0: - tile_y += original_tile_y - - # Otherwise we would exceed maxComputeWorkgroupStorageSize - if ( - str(target.kind) == "webgpu" - and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 - ): - tile_z = 8 - num_warps = 2 - - if target.kind.name == "opencl" and ( - ("android" in str(target.host)) or ("adreno" in str(target.attrs)) - ): - LOAD_VEC = 16 // ((DataType(dtype).bits + 7) // 8) # 16 bytes - NUM_BLKS = group_size * 8 + ( + NUM_BLKS, + LOAD_VEC, + group_size, + sm_scale, + bdx, + num_warps, + tile_x, + tile_y, + tile_z, + ) = _get_prefill_kernel_config(h_kv, h_q, d, dtype, target) # fmt: off @T.prim_func @@ -2618,117 +2675,779 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches # fmt: on # pylint: enable=line-too-long,too-many-branches sch = tir.Schedule(batch_prefill_ragged_kv) - get_extent = lambda *lps: [int(sch.get(lp).extent) for lp in lps] + sch = _schedule_prefill_kernel( + sch, LOAD_VEC, bdx, num_warps, tile_x, tile_y, tile_z, True, False + ) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) - def get_vecsize(extent): - return min(LOAD_VEC, (extent & ~(extent - 1))) - def getxy_vecsize(x, y, t): - assert (x * y) % t == 0 - return min(get_vecsize(y), get_vecsize(x * y // t)) +def _attention_prefill_mla( + h_q, + d_latent, + d_rope, + dtype, + sliding_window: bool, + target: Target, +): + d_qk = d_latent + d_rope + ( + NUM_BLKS, + LOAD_VEC, + group_size, + _, + bdx, + num_warps, + tile_x, + tile_y, + tile_z, + ) = _get_prefill_kernel_config(1, h_q, d_qk, dtype, target) + + global_symbol = "batch_prefill_paged_kv_mla" + if sliding_window: + global_symbol += "_sliding_window" - def get_tile_size(x, y, t): - cnt = (x * y) // t - assert (x * y) % t == 0 - tile_y = (int)(math.ceil(math.sqrt(cnt))) - while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt: - tile_y += 1 - assert tile_y <= cnt - tile_x = cnt // tile_y - return tile_x, tile_y + # pylint: disable=line-too-long,too-many-branches + # fmt: off + @T.prim_func + def batch_prefill_paged_kv_mla( + _0: T.int32, # pylint: disable=unused-argument + var_q: T.handle, # [total_len, h_q, d_qk] + var_q_indptr: T.handle, # [batch_size + 1] + var_pages: T.handle, # [max_num_pages, page_size, d_qk] + var_page_indptr: T.handle, # [batch_size + 1] + var_page_values: T.handle, # [nnz_pages] + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] + var_output: T.handle, # [total_len, h_q, d_latent] + var_lse: T.handle, # [total_len, h_q] + causal: T.int32, + attn_score_scaling_factor: T.float32, + ): + T.func_attr({"global_symbol": global_symbol}) + batch_size = T.int32(is_size_var=True) + total_len = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + pages_elem_offset = T.int64(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) - def apply_to_qkv_load(sch: tir.Schedule, block): - loop_x, loop_y = sch.get_loops(block)[-2:] - x_extent, y_extent = get_extent(loop_x, loop_y) - vec_size = getxy_vecsize(x_extent, y_extent, bdx * num_warps) - yo, yv = sch.split(loop_y, [None, vec_size]) - yo_extent = y_extent // vec_size - tile_x, tile_y = get_tile_size(x_extent, yo_extent, (bdx * num_warps)) - xo, xi = sch.split(loop_x, [tile_x, None]) - yo, yi = sch.split(yo, [tile_y, None]) - sch.reorder(xi, yi, xo, yo) - t = sch.fuse(xi, yi) - ty, tx = sch.split(t, [num_warps, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - sch.vectorize(yv) + q = T.match_buffer(var_q, (total_len, h_q, d_qk), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) + pages = T.match_buffer(var_pages, (max_num_pages, 16, d_qk), dtype, elem_offset=pages_elem_offset) + page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) + page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) + output = T.match_buffer(var_output, (total_len, h_q, d_latent), dtype) + lse = T.match_buffer(var_lse, (total_len, h_q), "float32") # pylint: disable=unused-variable + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info(var_length_info, batch_size, sliding_window, length_info_elem_offset) - def apply_to_so_ewise(sch: tir.Schedule, block, tile): - loop_x, loop_y = sch.get_loops(block)[-2:] - xo, xi = sch.split(loop_x, factors=[None, tile[0]]) - yo, yi = sch.split(loop_y, factors=[None, tile[1]]) - sch.reorder(xo, yo, xi, yi) - yiv_extent = get_vecsize(tile[1]) - yio, yiv = sch.split(yi, [None, yiv_extent]) - sch.unroll(yio) - sch.vectorize(yiv) - t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[None, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") + # kernel code + for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): + for lty in T.thread_binding(num_warps, thread="threadIdx.y"): + for ltx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("attn"): + bx, ty, tx = T.axis.remap("SSS", [lbx, lty, ltx]) + T.reads() + T.writes() + tile_id = _var("int32") + batch_idx = _var("int32") + batch_tiles = _var("int32") + batch_rows = _var("int32") + iterator = _var("int32") + kv_chunk_len = _var("int32") + + Q_smem = T.alloc_buffer((tile_x, d_qk), dtype, scope="shared") + KV_smem = T.alloc_buffer((tile_z, d_qk), dtype, scope="shared") + S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") + + S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d_latent), "float32", scope="local") + + m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + + m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + + ## get tile_no, batch_idx, batch_tiles, batch_rows + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + # advance to next tile + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] -= batch_tiles[0] + batch_idx[0] += 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) - def apply_to_gemm( # pylint: disable=unused-argument - sch: tir.Schedule, block, tile, read_0, read_1, r_len=16, k_major=False - ): - loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] - xo, xi = sch.split(loop_x, factors=[None, tile[0]]) - yo, yi = sch.split(loop_y, factors=[None, tile[1]]) - sch.reorder(xo, yo, xi, yi) - t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[None, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + LH_start: T.int32 = tile_id[0] * tile_x + q_indptr_val: T.int32 = q_indptr[b_idx] - ko, ki = sch.split(loop_z, factors=[None, r_len]) - if k_major: - sch.reorder(ko, xi, yi, ki) - else: - sch.reorder(ko, ki, xi, yi) - yiv_extent = get_vecsize(tile[1]) - yio, yiv = sch.split(yi, [None, yiv_extent]) - sch.unroll(yio) - sch.vectorize(yiv) - sch.unroll(xi) - sch.decompose_reduction(block, ty) + cur_page_indptr_begin: T.int32 = page_indptr[b_idx] + cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, b_idx, length_info, sliding_window), + 0 + ) + T.tvm_storage_sync("shared") - def apply_to_md(sch, block): - loop = sch.get_loops(block)[-1] - _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") + # init states + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + m_smem[row] = -5e4 + d_smem[row] = 1.0 - sch.transform_layout("K_load", ("write", 0), lambda i, j: (j, i)) - tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) - tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) - apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) - apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) - apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) - apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) - apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) - apply_to_qkv_load(sch, sch.get_block("Q_load")) - apply_to_qkv_load(sch, sch.get_block("K_load")) - apply_to_qkv_load(sch, sch.get_block("V_load")) + for li, lj in T.grid(tile_x, d_latent): + with T.block("O_init"): + i, j = T.axis.remap("SS", [li, lj]) + O_local[i, j] = 0.0 + T.tvm_storage_sync("shared") - apply_to_md(sch, sch.get_block("lse_store")) + # Load Q from gmem to smem + for li, lj in T.grid(tile_x, tile_y): + with T.block("Q_load"): + i, j = T.axis.remap("SS", [li, lj]) + T.reads() + T.writes() + cur_L = q_indptr_val + (LH_start + i) // group_size + cur_H_qo = (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = q[cur_L, cur_H_qo, j] + else: + Q_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): + L_kv_start: T.int32 = iterator * tile_z + for lz, ly in T.grid(tile_z, tile_y): + with T.block("KV_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + KV_smem[i, j] = pages[page_no, page_offset, j] + else: + KV_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Compute S + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_z, tile_y): + with T.block("S_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + S_local[i, j] = 0.0 + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(KV_smem[j, k], "float32") * attn_score_scaling_factor * math.log2(math.exp(1)) + T.tvm_storage_sync("shared") + for li, lj in T.grid(tile_x, tile_z): + with T.block("S_store"): + i, j = T.axis.remap("SS", [li, lj]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + + # Update S, m, d + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update1"): + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size + for j in T.serial(tile_z): + if _causal_mask(causal, + row=row_, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + with T.block("update"): + for j in T.serial(tile_z): + # this is to avoid sync inside condition branch + if row < tile_x: + row_: T.int32 = (LH_start + row) // group_size + if _causal_mask(causal, + row=row_, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update"): + for j in T.serial(tile_z): + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + + # Update O + with T.block(): + for li, lj, lk in T.grid(tile_x, d_latent, tile_z): + with T.block("O_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) + O_local[i, j] += S_smem[i, k] * T.cast(KV_smem[k, j], "float32") + + # Store O from smem to gmem + for li, lj in T.grid(tile_x, d_latent): + with T.block("O_store"): + i, j = T.axis.remap("SS", [li, lj]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] + + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) + + # move to next tile + tile_id[0] += NUM_BLKS + # fmt: on + # pylint: enable=line-too-long,too-many-branches + sch = tir.Schedule(batch_prefill_paged_kv_mla) + sch = _schedule_prefill_kernel( + sch, LOAD_VEC, bdx, num_warps, tile_x, d_latent, tile_z, False, True + ) return sch.mod["main"].with_attr("tir.is_scheduled", 1) -def _copy_single_page_cpu(num_heads, page_size, head_dim, dtype): - tx = 1 +def _attention_prefill_ragged_mla_absorbed(h_q, d_latent, d_rope, dtype, target: Target): + d_qk = d_latent + d_rope + ( + NUM_BLKS, + LOAD_VEC, + group_size, + _, + bdx, + num_warps, + tile_x, + tile_y, + tile_z, + ) = _get_prefill_kernel_config(1, h_q, d_qk, dtype, target) + # pylint: disable=line-too-long,too-many-branches + # fmt: off @T.prim_func - def copy_single_page_cpu( - var_pages: T.handle, - src_page_id: T.int64, - tgt_page_id: T.int64, - copy_length: T.int64, - ): - T.func_attr({"tir.is_scheduled": 1}) + def batch_prefill_ragged_kv_mla_absorbed( # pylint: disable=too-many-branches + var_q: T.handle, # [total_len, h_q, d_qk] + var_q_indptr: T.handle, # [batch_size + 1] + var_compressed_kv: T.handle, # [total_len, d_latent] + var_k_pe: T.handle, # [total_len, d_rope] + var_kv_indptr: T.handle, # [batch_size + 1] + var_output: T.handle, # [total_len, h_q, d_latent] + var_lse: T.handle, # [total_len, h_q] + causal: T.int32, + attn_score_scaling_factor: T.float32 + ): + batch_size_plus_1 = T.int32(is_size_var=True) + qo_len = T.int32(is_size_var=True) + kv_len = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + kv_indptr_elem_offset = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (qo_len, h_q, d_qk), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size_plus_1,), "int32", elem_offset=q_indptr_elem_offset) + compressed_kv = T.match_buffer(var_compressed_kv, (kv_len, d_latent), dtype) + k_pe = T.match_buffer(var_k_pe, (kv_len, d_rope), dtype) + kv_indptr = T.match_buffer(var_kv_indptr, (batch_size_plus_1,), "int32", elem_offset=kv_indptr_elem_offset) + output = T.match_buffer(var_output, (qo_len, h_q, d_latent), dtype) + lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable + + # kernel code + for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): + for lty in T.thread_binding(num_warps, thread="threadIdx.y"): + for ltx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("attn"): + bx, ty, tx = T.axis.remap("SSS", [lbx, lty, ltx]) + T.reads() + T.writes() + tile_id = _var("int32") + batch_idx = _var("int32") + batch_tiles = _var("int32") + batch_rows = _var("int32") + iterator = _var("int32") + kv_chunk_len = _var("int32") + + Q_smem = T.alloc_buffer((tile_x, d_qk), dtype, scope="shared") + KV_smem = T.alloc_buffer((tile_z, d_qk), dtype, scope="shared") + S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") + + S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d_latent), "float32", scope="local") + + m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + + m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + + ## get tile_no, batch_idx, batch_tiles, batch_rows + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + while T.tvm_thread_invariant(batch_idx[0] < batch_size_plus_1 - 1): + # advance to next tile + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size_plus_1 - 1: + tile_id[0] -= batch_tiles[0] + batch_idx[0] += 1 + if batch_idx[0] < batch_size_plus_1 - 1: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + + if T.tvm_thread_invariant(batch_idx[0] < batch_size_plus_1 - 1): + b_idx: T.int32 = batch_idx[0] + q_indptr_val: T.int32 = q_indptr[b_idx] + LH_start: T.int32 = tile_id[0] * tile_x + + kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] + T.tvm_storage_sync("shared") + + # init states + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + m_smem[row] = -5e4 + d_smem[row] = 1.0 + + for li, lj in T.grid(tile_x, d_latent): + with T.block("O_init"): + i, j = T.axis.remap("SS", [li, lj]) + O_local[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Load Q from gmem to smem + for li, lj in T.grid(tile_x, tile_y): + with T.block("Q_load"): + i, j = T.axis.remap("SS", [li, lj]) + T.reads() + T.writes() + cur_L = q_indptr_val + (LH_start + i) // group_size + cur_H_qo = (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = q[cur_L, cur_H_qo, j] + else: + Q_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): + L_kv_start: T.int32 = iterator * tile_z + L_kv_base: T.int32 = kv_indptr[b_idx] + for lz, ly in T.grid(tile_z, d_latent): + with T.block("V_load"): + i, j = T.axis.remap("SS", [lz, ly]) + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + KV_smem[i, j] = compressed_kv[L_kv_base + cur_L, j] + else: + KV_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + for lz, ly in T.grid(tile_z, d_rope): + with T.block("K_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_start + i + if cur_L < kv_chunk_len[0]: + KV_smem[i, d_latent + j] = k_pe[L_kv_base + cur_L, j] + else: + KV_smem[i, d_latent + j] = 0.0 + T.tvm_storage_sync("shared") + + # Compute S + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_z, tile_y): + with T.block("S_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + S_local[i, j] = 0.0 + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(KV_smem[j, k], "float32") * attn_score_scaling_factor * math.log2(math.exp(1)) + T.tvm_storage_sync("shared") + for li, lj in T.grid(tile_x, tile_z): + with T.block("S_store"): + i, j = T.axis.remap("SS", [li, lj]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + + # Update S, m, d + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update1"): + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size + for j in T.serial(tile_z): + if _causal_mask(causal, + row=row_, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + with T.block("update"): + for j in T.serial(tile_z): + # this is to avoid sync inside condition branch + if row < tile_x: + row_: T.int32 = (LH_start + row) // group_size + if _causal_mask(causal, + row=row_, + col=L_kv_start + j, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update"): + for j in T.serial(tile_z): + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + + # Update O + with T.block(): + for li, lj, lk in T.grid(tile_x, d_latent, tile_z): + with T.block("O_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) + O_local[i, j] += S_smem[i, k] * T.cast(KV_smem[k, j], "float32") + + # Store O from smem to gmem + for li, lj in T.grid(tile_x, d_latent): + with T.block("O_store"): + i, j = T.axis.remap("SS", [li, lj]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] + + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) + + # move to next tile + tile_id[0] += NUM_BLKS + # fmt: on + # pylint: enable=line-too-long,too-many-branches + sch = tir.Schedule(batch_prefill_ragged_kv_mla_absorbed) + sch = _schedule_prefill_kernel( + sch, LOAD_VEC, bdx, num_warps, tile_x, tile_y, tile_z, False, False + ) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +def _attention_decode_mla(h_q, d_latent, d_rope, qkv_dtype, sliding_window: bool, target: Target): + d_qk = d_latent + d_rope + qkv_dtype_bytes = 2 + + THREAD_LIMIT = 512 + TILE_SIZE_PER_BDX = 2 + if target.kind.name == "opencl" and ( + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) + ): + # Keeping lower thread limit for this kernel on adreno target + # to avoid register spill + THREAD_LIMIT = 256 + TILE_SIZE_PER_BDX = 1 + max_num_threads_per_block = get_max_num_threads_per_block(target) + thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) + + GROUP_SIZE = h_q + VEC_SIZE = min(max(8 // qkv_dtype_bytes, d_qk // 32), 4) + bdx = d_qk // VEC_SIZE + bdy = GROUP_SIZE + while bdx * bdy > thread_limit and bdy > 1: + bdy //= 2 + gdy = GROUP_SIZE // bdy + threads_per_CTA = max(thread_limit, bdx * bdy) + bdz = threads_per_CTA // (bdx * bdy) + tile_size_per_bdx = TILE_SIZE_PER_BDX if GROUP_SIZE == 1 else 1 + check_thread_limits(target, bdx=bdx, bdy=bdy, bdz=bdz, gdz=1) + + global_symbol = "batch_decode_paged_kv_mla" + if sliding_window: + global_symbol += "_sliding_window" + + # pylint: disable=line-too-long,too-many-branches + # fmt: off + @T.prim_func + def batch_decode_paged_kv_mla( + _0: T.int32, # pylint: disable=unused-argument + Q_handle: T.handle, + pages_handle: T.handle, + page_table_indptr_handle: T.handle, + page_table_values_handle: T.handle, + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] + output_handle: T.handle, + lse_handle: T.handle, + attn_score_scaling_factor: T.float32, + ): + T.func_attr({"tir.is_scheduled": 1, "global_symbol": global_symbol}) + B = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + pages_elem_offset = T.int64(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) + + Q = T.match_buffer(Q_handle, (B, h_q, d_qk), qkv_dtype) + pages = T.match_buffer( + pages_handle, (max_num_pages, 16, d_qk), qkv_dtype, elem_offset=pages_elem_offset + ) + page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", elem_offset=page_indptr_elem_offset) + page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) + output = T.match_buffer(output_handle, (B, h_q, d_latent), qkv_dtype) + lse = T.match_buffer(lse_handle, (B, h_q), "float32") # pylint: disable=unused-variable + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info(var_length_info, B, sliding_window, length_info_elem_offset) + + for bx in T.thread_binding(B, thread="blockIdx.x"): + for by in T.thread_binding(gdy, thread="blockIdx.y"): + for ty in T.thread_binding(bdy, thread="threadIdx.y"): + for tx in T.thread_binding(bdx, thread="threadIdx.x"): + for tz in T.thread_binding(bdz, thread="threadIdx.z"): + with T.block("attn"): + Q_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") + kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") + KV_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, d_qk), qkv_dtype, scope="shared") + O_allreduce = T.alloc_buffer((bdz, bdy, d_qk), "float32", scope="shared") + md_allreduce = T.alloc_buffer((bdz, bdy, 2), "float32", scope="shared") + S_reduce_local = T.alloc_buffer((1,), "float32", scope="local") + t0 = T.alloc_buffer((1,), "float32", scope="local") + + S_local = T.alloc_buffer((bdy * tile_size_per_bdx), "float32", scope="local") + QK_local = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") + V_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") + m_prev = T.alloc_buffer((1,), "float32", scope="local") + d_prev = T.alloc_buffer((1,), "float32", scope="local") + other_m = T.alloc_buffer((1,), "float32", scope="local") + other_d = T.alloc_buffer((1,), "float32", scope="local") + exp_mprev = T.alloc_buffer((1,), "float32", scope="local") + exp_otherm = T.alloc_buffer((1,), "float32", scope="local") + other_o = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") + st_m = T.alloc_buffer((1,), "float32", scope="local") + st_d = T.alloc_buffer((1,), "float32", scope="local") + O_local = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") + + batch_idx: T.int32 = bx + cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] + cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, batch_idx, length_info, sliding_window), + 0 + ) + + # init states + st_m[0] = -5e4 + st_d[0] = 1.0 + for vec in T.vectorized(VEC_SIZE): + O_local[vec] = 0.0 + + # load q + for vec in T.vectorized(VEC_SIZE): + Q_local[vec] = Q[bx, by * bdy + ty, tx * VEC_SIZE + vec] + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_size_per_bdx * bdy * bdz)): + tile_start_s: T.int32(is_size_var=True) = (tz * bdy + ty) * tile_size_per_bdx # type: ignore + tile_start_g: T.int32(is_size_var=True) = ((iterator * bdz + tz) * bdy + ty) * tile_size_per_bdx # type: ignore + # load KV from global memory to shared memory + for j in T.serial(tile_size_per_bdx): + with T.block("KV_load"): + T.reads() + T.writes() + row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore + if row_g < kv_chunk_len[0]: + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_g, batch_idx, length_info, sliding_window) # type: ignore + page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + for vec in T.vectorized(VEC_SIZE): + KV_smem[tile_start_s + j, tx * VEC_SIZE + vec] = pages[page_no, page_offset, tx * VEC_SIZE + vec] + else: + for vec in T.vectorized(VEC_SIZE): + KV_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 + T.tvm_storage_sync("shared") + # compute QK + m_prev[0] = st_m[0] + for j in T.serial(bdy * tile_size_per_bdx): + # compute S = Q * K * sm_scale + for vec in T.vectorized(VEC_SIZE): + QK_local[vec] = T.cast(Q_local[vec], "float32") * T.cast(KV_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec], "float32") * attn_score_scaling_factor * math.log2(math.exp(1)) + S_reduce_local[0] = 0 + for vec in T.unroll(VEC_SIZE): + S_reduce_local[0] += QK_local[vec] + + with T.block("block_cross_thread"): + T.reads(S_reduce_local[0]) + T.writes(t0[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], True, t0[0], tx, dtype="handle") + + S_local[j] = -5e4 + if (iterator * bdz + tz) * bdy * tile_size_per_bdx + j < kv_chunk_len[0]: + S_local[j] = t0[0] + # update st_m + st_m[0] = T.max(st_m[0], S_local[j]) + + # update st_d, st_O + o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0]) + st_d[0] *= o_scale + for j in T.serial(bdy * tile_size_per_bdx): + S_local[j] = T.exp2(S_local[j] - st_m[0]) + st_d[0] += S_local[j] + for j in T.vectorized(VEC_SIZE): + O_local[j] *= o_scale + + # load V from shared memory to local memory + # compute O + for j in T.serial(bdy * tile_size_per_bdx): + if tx * VEC_SIZE < d_latent: + for vec in T.vectorized(VEC_SIZE): + V_local[vec] = KV_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec] + else: + for vec in T.vectorized(VEC_SIZE): + V_local[vec] = 0.0 + for vec in T.vectorized(VEC_SIZE): + O_local[vec] += T.cast(V_local[vec], "float32") * S_local[j] + + if bdz > 1: + # allreduce over bdz + for vec in T.vectorized(VEC_SIZE): + O_allreduce[tz, ty, tx * VEC_SIZE + vec] = O_local[vec] + md_allreduce[tz, ty, 0] = st_m[0] + md_allreduce[tz, ty, 1] = st_d[0] + T.tvm_storage_sync("shared") + + st_m[0] = -5e4 + st_d[0] = 1.0 + for vec in T.vectorized(VEC_SIZE): + O_local[vec] = 0.0 + + for j in T.serial(bdz): + m_prev[0] = st_m[0] + d_prev[0] = st_d[0] + other_m[0] = md_allreduce[j, ty, 0] + other_d[0] = md_allreduce[j, ty, 1] + for vec in T.vectorized(VEC_SIZE): + other_o[vec] = O_allreduce[j, ty, tx * VEC_SIZE + vec] + st_m[0] = T.max(st_m[0], other_m[0]) + st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) + exp_mprev[0] = T.exp2(m_prev[0] - st_m[0]) + exp_otherm[0] = T.exp2(other_m[0] - st_m[0]) + for vec in T.vectorized(VEC_SIZE): + O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0] + + # normalize O + for vec in T.vectorized(VEC_SIZE): + O_local[vec] /= st_d[0] + + # store O to global memory + if tx * VEC_SIZE < d_latent: + for vec in T.vectorized(VEC_SIZE): + output[batch_idx, by * bdy + ty, tx * VEC_SIZE + vec] = O_local[vec] + + # store lse to global memory + lse[batch_idx, by * bdy + ty] = st_m[0] + T.log2(st_d[0]) + # fmt: on + # pylint: enable=line-too-long,too-many-branches + return batch_decode_paged_kv_mla + + +def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target): + tx = get_max_num_threads_per_block(target) + + @T.prim_func + def copy_single_page( + var_pages: T.handle, + src_page_id: T.int64, + tgt_page_id: T.int64, + copy_length: T.int64, + ): + T.func_attr({"tir.is_scheduled": 1}) num_pages = T.int32() - pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) + pages_elem_offset = T.int64() + pages = T.match_buffer( + var_pages, + (num_pages, 2, num_heads, page_size, head_dim), + dtype, + elem_offset=pages_elem_offset, + ) - for b in T.serial((copy_length * num_heads * head_dim + tx - 1) // tx): - for t in T.serial(tx): + for b in T.thread_binding( + (copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" + ): + for t in T.thread_binding(tx, thread="threadIdx.x"): with T.block("copy"): T.where(b * tx + t < copy_length * num_heads * head_dim) vh = T.axis.spatial( @@ -2749,14 +3468,14 @@ def copy_single_page_cpu( pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd] pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd] - return copy_single_page_cpu + return copy_single_page -def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target): +def _copy_single_page_mla(page_size, head_dim, dtype, target: Target): tx = get_max_num_threads_per_block(target) @T.prim_func - def copy_single_page( + def copy_single_page_mla( var_pages: T.handle, src_page_id: T.int64, tgt_page_id: T.int64, @@ -2766,16 +3485,36 @@ def copy_single_page( num_pages = T.int32() pages_elem_offset = T.int64() pages = T.match_buffer( - var_pages, - (num_pages, 2, num_heads, page_size, head_dim), - dtype, - elem_offset=pages_elem_offset, + var_pages, (num_pages, page_size, head_dim), dtype, elem_offset=pages_elem_offset ) - for b in T.thread_binding( - (copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" - ): + for b in T.thread_binding((copy_length * head_dim + tx - 1) // tx, thread="blockIdx.x"): for t in T.thread_binding(tx, thread="threadIdx.x"): + with T.block("copy"): + T.where(b * tx + t < copy_length * head_dim) + vp = T.axis.spatial(copy_length, (b * tx + t) // head_dim) + vd = T.axis.spatial(head_dim, T.Cast("int32", (b * tx + t) % head_dim)) + pages[tgt_page_id, vp, vd] = pages[src_page_id, vp, vd] + + return copy_single_page_mla + + +def _copy_single_page_cpu(num_heads, page_size, head_dim, dtype): + tx = 1 + + @T.prim_func + def copy_single_page_cpu( + var_pages: T.handle, + src_page_id: T.int64, + tgt_page_id: T.int64, + copy_length: T.int64, + ): + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) + + for b in T.serial((copy_length * num_heads * head_dim + tx - 1) // tx): + for t in T.serial(tx): with T.block("copy"): T.where(b * tx + t < copy_length * num_heads * head_dim) vh = T.axis.spatial( @@ -2796,14 +3535,14 @@ def copy_single_page( pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd] pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd] - return copy_single_page + return copy_single_page_cpu -def _compact_kv_copy_cpu(num_heads, head_dim, dtype): - tx = 8 +def _compact_kv_copy(num_heads, head_dim, dtype, target: Target): + tx = get_max_num_threads_per_block(target) @T.prim_func - def compact_kv_copy_cpu( + def compact_kv_copy( var_pages: T.handle, var_copy_length_indptr: T.handle, var_copy_src_dst_pos: T.handle, @@ -2814,7 +3553,10 @@ def compact_kv_copy_cpu( total_copy_length = T.int32() copy_length_indptr_elem_offset = T.int32() copy_src_dst_pos_elem_offset = T.int32() - pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype) + pages_elem_offset = T.int64() + pages = T.match_buffer( + var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset + ) copy_length_indptr = T.match_buffer( var_copy_length_indptr, (batch_size + 1,), @@ -2829,8 +3571,10 @@ def compact_kv_copy_cpu( ) with T.block("root"): - for bhd_o in T.serial((batch_size * num_heads * head_dim + tx - 1) // tx): - for bhd_i in T.serial(tx): + for bhd_o in T.thread_binding( + (batch_size * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" + ): + for bhd_i in T.thread_binding(tx, thread="threadIdx.x"): b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim) h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads d: T.int32 = (bhd_o * tx + bhd_i) % head_dim @@ -2845,14 +3589,14 @@ def compact_kv_copy_cpu( src_pos // 16, 1, h, src_pos % 16, d ] - return compact_kv_copy_cpu + return compact_kv_copy -def _compact_kv_copy(num_heads, head_dim, dtype, target: Target): - tx = get_max_num_threads_per_block(target) +def _compact_kv_copy_cpu(num_heads, head_dim, dtype): + tx = 8 @T.prim_func - def compact_kv_copy( + def compact_kv_copy_cpu( var_pages: T.handle, var_copy_length_indptr: T.handle, var_copy_src_dst_pos: T.handle, @@ -2863,10 +3607,7 @@ def compact_kv_copy( total_copy_length = T.int32() copy_length_indptr_elem_offset = T.int32() copy_src_dst_pos_elem_offset = T.int32() - pages_elem_offset = T.int64() - pages = T.match_buffer( - var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset - ) + pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype) copy_length_indptr = T.match_buffer( var_copy_length_indptr, (batch_size + 1,), @@ -2881,10 +3622,8 @@ def compact_kv_copy( ) with T.block("root"): - for bhd_o in T.thread_binding( - (batch_size * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" - ): - for bhd_i in T.thread_binding(tx, thread="threadIdx.x"): + for bhd_o in T.serial((batch_size * num_heads * head_dim + tx - 1) // tx): + for bhd_i in T.serial(tx): b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim) h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads d: T.int32 = (bhd_o * tx + bhd_i) % head_dim @@ -2899,4 +3638,4 @@ def compact_kv_copy( src_pos // 16, 1, h, src_pos % 16, d ] - return compact_kv_copy + return compact_kv_copy_cpu diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index fa0146afb618..33614633fc77 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -320,7 +320,17 @@ def tree_attn( bdx = 32 num_warps = 4 - tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + tile_x, tile_y, tile_z = ( + 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), + d, + 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), + ) + original_tile_y = tile_y + original_tile_z = tile_z + while (tile_x * tile_z) % (bdx * num_warps) != 0: + tile_z += original_tile_z + while (tile_x * tile_y) % (bdx * num_warps) != 0: + tile_y += original_tile_y # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( @@ -881,7 +891,17 @@ def tree_attn_with_paged_kv_cache( bdx = 32 num_warps = 4 - tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + tile_x, tile_y, tile_z = ( + 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), + d, + 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), + ) + original_tile_y = tile_y + original_tile_z = tile_z + while (tile_x * tile_z) % (bdx * num_warps) != 0: + tile_z += original_tile_z + while (tile_x * tile_y) % (bdx * num_warps) != 0: + tile_y += original_tile_y # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc index 67afda3bfda1..c78ada58e6d6 100644 --- a/src/runtime/relax_vm/kv_state.cc +++ b/src/runtime/relax_vm/kv_state.cc @@ -74,12 +74,21 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions") .set_body_method(&AttentionKVCacheObj::GetQueryPositions); TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv") .set_body_method(&AttentionKVCacheObj::DebugGetKV); +TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv_mla") + .set_body_method(&AttentionKVCacheObj::DebugGetKVMLA); TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv") .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, double attn_score_scaling_factor, NDArray qkv_data, NDArray o_data) { kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, std::move(o_data), attn_score_scaling_factor); }); +TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_absorbed") + .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, + double attn_score_scaling_factor, NDArray q_data, NDArray compressed_kv_data, + NDArray k_pe_data, NDArray o_data) { + kv_cache->MLAAbsorbed(layer_id, std::move(q_data), std::move(compressed_kv_data), + std::move(k_pe_data), std::move(o_data), attn_score_scaling_factor); + }); // RNN State methods TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method(&RNNStateObj::Get); diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index 77c17d1c555f..300d22b85909 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -181,20 +181,6 @@ class AttentionKVCacheObj : public KVStateObj { virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, NDArray o_data, double attn_score_scaling_factor) = 0; - /*! - * \brief Compute attention with Q/K/V data. - * \param layer_id The model layer where the attention compute happens. - * \param q_data The input Q data, in layout `(total_length, num_qo_heads, head_dim)` - * \param k_data The input K data, in layout `(total_length, num_kv_heads, head_dim)` - * \param v_data The input V data, in layout `(total_length, num_kv_heads, head_dim)` - * \param mask The input mask data, in layout `(total_sqr_length)`. - * \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`. - * \param attn_score_scaling_factor The additional attention scaling factor. - */ - virtual void AttentionWithSeparateQKV(int64_t layer_id, NDArray q_data, NDArray k_data, - NDArray v_data, Optional mask, NDArray o_data, - double attn_score_scaling_factor) = 0; - /*! * \brief Compute multi-head latent attention after applying weight absorption. * \param layer_id The model layer where the attention compute happens. @@ -275,6 +261,16 @@ class AttentionKVCacheObj : public KVStateObj { virtual void DebugGetKV(int64_t seq_id, // int64_t start_pos, int64_t end_pos, NDArray k_data, NDArray v_data) = 0; + /*! + * \brief Fetch the compact K/V data of the given sequence for MLA cache. + * \param seq_id The sequence whose K/V data is to be fetched. + * \param start_pos The start position (inclusive) of the K/V data to fetch. + * \param end_pos The end position (exclusive) of the K/V data to fetch. + * \param kv_data The output KV data of the given sequence in layout elaborated above. + */ + virtual void DebugGetKVMLA(int64_t seq_id, int64_t start_pos, int64_t end_pos, + NDArray kv_data) = 0; + /*! * \brief Set the K/V data of the given sequence from input K/V data. * `start_pos` (inclusive) controls starting position of K/V data diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 0b83bb426d18..075ff0b94471 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -83,7 +83,7 @@ ShapeTuple GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int num_ // Ignore v_head_dim since multi-head attention requires K/V to have the same head dim. return {num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}; } else if (attn_kind == AttnKind::kMLA) { - return {num_total_pages, num_kv_heads, page_size, qk_head_dim + qk_rope_head_dim}; + return {num_total_pages, page_size, qk_head_dim}; } else if (attn_kind == AttnKind::kLinearAttn) { return {num_sequence, num_kv_heads, qk_head_dim, v_head_dim}; } @@ -1148,7 +1148,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { PackedFunc f_mla_prefill_ragged_absorbed_; PackedFunc f_merge_inplace_; PackedFunc f_split_rotary_; - PackedFunc f_separate_rotary_; PackedFunc f_copy_single_page_; Optional f_debug_get_kv_; @@ -1183,8 +1182,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { Optional f_attention_decode_end_forward, PackedFunc f_mla_prefill, PackedFunc f_mla_decode, PackedFunc f_mla_prefill_ragged_normal, PackedFunc f_mla_prefill_ragged_absorbed, PackedFunc f_merge_inplace, - PackedFunc f_split_rotary, PackedFunc f_separate_rotary, PackedFunc f_copy_single_page, - Optional f_debug_get_kv) + PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional f_debug_get_kv) : page_size_(page_size), num_layers_(num_layers), layer_id_begin_offset_(layer_id_begin_offset), @@ -1226,10 +1224,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_mla_prefill_ragged_absorbed_(std::move(f_mla_prefill_ragged_absorbed)), f_merge_inplace_(std::move(f_merge_inplace)), f_split_rotary_(std::move(f_split_rotary)), - f_separate_rotary_(std::move(f_separate_rotary)), f_copy_single_page_(std::move(f_copy_single_page)), f_debug_get_kv_(std::move(f_debug_get_kv)), device_(device) { + // Note: For MLA, sliding window and disaggregation are disabled for now. + if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMLA) != attn_kinds_.end()) { + CHECK(!support_sliding_window_) << "Sliding window not supported yet for MLA"; + CHECK(!enable_kv_transfer) << "KV transfer not supported yet for MLA"; + } + pages_.reserve(num_layers); if (enable_kv_transfer) { // For now, KV transfer only supports MHA. @@ -1337,14 +1340,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray::Empty({kFloatAttnWorkspaceByte / 4}, DataType::Float(32), device); } - temp_attn_q_device_ = - NDArray::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim}, dtype, device); - temp_attn_k_device_ = - NDArray::Empty({prefill_chunk_size_, num_kv_heads, qk_head_dim}, dtype, device); - temp_attn_v_device_ = - NDArray::Empty({prefill_chunk_size_, num_kv_heads, v_head_dim}, dtype, device); + if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHA) != attn_kinds_.end()) { + temp_attn_q_device_ = + NDArray::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim}, dtype, device); + temp_attn_k_device_ = + NDArray::Empty({prefill_chunk_size_, num_kv_heads, qk_head_dim}, dtype, device); + temp_attn_v_device_ = + NDArray::Empty({prefill_chunk_size_, num_kv_heads, v_head_dim}, dtype, device); + } temp_attn_output_device_ = - NDArray::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim}, dtype, device); + NDArray::Empty({prefill_chunk_size_, num_qo_heads, v_head_dim}, dtype, device); temp_attn_scores_device_ = NDArray::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device); merged_attn_scores_device_ = @@ -1714,6 +1719,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths, const Optional& opt_token_tree_parent_ptr) final { + // Note: MLA does not supported tree attention for now. + if (attn_kinds_[0] == AttnKind::kMLA) { + CHECK(!opt_token_tree_parent_ptr.defined()) << "Tree attention is not supported yet for MLA"; + } + CHECK_EQ(seq_ids.size(), append_lengths.size()) << "The seq_ids size (" << seq_ids.size() << ") and append_lengths size (" << append_lengths.size() << ") mismatch."; @@ -2083,8 +2093,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray pages = pages_[local_layer_id]; CHECK(qkv_data.DataType() == pages.DataType()); CHECK(o_data.DataType() == pages.DataType()); - CHECK(attn_kinds_[layer_id] == AttnKind::kMHA); + // qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, qk_head_dim) // o_data: (num_total_length, num_qo_heads, qk_head_dim) @@ -2171,15 +2181,61 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - void AttentionWithSeparateQKV(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, - Optional mask, NDArray o_data, - double attn_score_scaling_factor) final { - // Todo(ruihang): implement it - } - void MLAAbsorbed(int64_t layer_id, NDArray q_data, NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data, double attn_score_scaling_factor) { - // Todo(ruihang): implement it + // Part 1. Shape and dtype check. + int64_t local_layer_id = layer_id - layer_id_begin_offset_; + CHECK_GE(local_layer_id, 0); + CHECK_LT(local_layer_id, num_layers_); + NDArray pages = pages_[local_layer_id]; + CHECK(q_data.DataType() == pages.DataType()); + CHECK(compressed_kv_data.DataType() == pages.DataType()); + CHECK(k_pe_data.DataType() == pages.DataType()); + CHECK(o_data.DataType() == pages.DataType()); + CHECK(attn_kinds_[layer_id] == AttnKind::kMLA); + + // q_data: (num_total_length, num_qo_heads, qk_head_dim) + // compressed_kv_data: (num_total_length, qk_head_dim - qk_rope_head_dim) + // k_pe_data: (num_total_length, qk_rope_head_dim) + // o_data: (num_total_length, num_qo_heads, v_head_dim) + CHECK_EQ(q_data->ndim, 3); + CHECK_EQ(compressed_kv_data->ndim, 2); + CHECK_EQ(k_pe_data->ndim, 2); + CHECK_EQ(o_data->ndim, 3); + + int64_t total_seq_length = 0; + for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) { + total_seq_length += cur_append_lengths_[seq_id]; + } + CHECK_LE(q_data->shape[0], total_seq_length); + CHECK_LE(compressed_kv_data->shape[0], total_seq_length); + CHECK_LE(k_pe_data->shape[0], total_seq_length); + CHECK_LE(o_data->shape[0], total_seq_length); + CHECK_EQ(q_data->shape[1], num_qo_heads_); + CHECK_EQ(o_data->shape[1], num_qo_heads_); + CHECK_EQ(q_data->shape[2], qk_head_dim_); + CHECK_EQ(compressed_kv_data->shape[1], qk_head_dim_ - qk_rope_head_dim_); + CHECK_EQ(k_pe_data->shape[1], qk_rope_head_dim_); + CHECK_EQ(o_data->shape[2], v_head_dim_); + + // Sync the copy stream and the compute stream. + ComputeStreamWaitForCopyStream(); + // The auxiliary data structure on device must have been synchronized. + ICHECK(!dirty_aux_data_device_); + + // Append k/v data to kv-cache if flag "append_before_attn" is set. + if (append_before_attn_) { + f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data, + append_position_map_view_); + } + // Perform MLA with weight absorption. + MLAAbsorbedInternal(layer_id, q_data, compressed_kv_data, k_pe_data, o_data, + attn_score_scaling_factor); + // Append k/v data to kv-cache if flag "append_before_attn" is not set. + if (!append_before_attn_) { + f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data, + append_position_map_view_); + } } void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, @@ -2344,6 +2400,49 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } + void DebugGetKVMLA(int64_t seq_id, int64_t start_pos, int64_t end_pos, NDArray kv_data) final { + CHECK(f_debug_get_kv_.defined()) + << "PageAttentionKVCache requires the `f_debug_get_kv` to be explicitly passed in when " + "initialization. Please construct the KV cache with `f_debug_get_kv`."; + + const Sequence& seq = seq_map_.at(seq_id); + CHECK_GE(start_pos, 0) << "DebugGetKV does not accept negative start_pos " << start_pos; + CHECK_LE(end_pos, seq.seq_length) << "DebugGetKV does not accept out-of-range end_pos"; + CHECK_LT(start_pos, end_pos) << "DebugGetKV does not accept \"start_pos >= end_pos\""; + + // kv_data: (num_layers, seq_length, qk_head_dim) + static constexpr const char* error_msg = + "DebugGetKV expects the kv_data in layout (num_layers, seq_length, qk_head_dim)."; + CHECK_EQ(kv_data->ndim, 3) << error_msg; + CHECK_EQ(kv_data->shape[0], num_layers_) << error_msg << " The number of layers mismatches."; + CHECK_EQ(kv_data->shape[1], end_pos - start_pos) + << error_msg << " The sequence length mismatches."; + CHECK_EQ(kv_data->shape[2], qk_head_dim_) + << error_msg << " The number of head features mismatches."; + + std::vector trace = seq.GetBlockTrace(global_block_pool_); + std::vector append_position_map; + append_position_map.reserve(seq.seq_length); + for (int32_t block_id : trace) { + const Block& block = global_block_pool_[block_id]; + for (int i = 0; i < block.seq_length; ++i) { + int32_t offset = + i < block.sink_length ? i : i - block.sink_length + block.sliding_window_offset; + int page_id = block.page_ids[offset / page_size_]; + int page_offset = offset % page_size_; + append_position_map.push_back(page_id * page_size_ + page_offset); + } + } + NDArray position_map_device = NDArray::Empty({end_pos - start_pos}, dtype_aux_, device_); + position_map_device.CopyFromBytes( + append_position_map.data() + start_pos, + (end_pos - start_pos) * ((dtype_aux_.bits * dtype_aux_.lanes + 7) / 8)); + for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) { + CHECK(attn_kinds_[layer_id] == AttnKind::kMLA) << "Only MHA is supported for DebugGetKVMLA"; + f_debug_get_kv_.value()(pages_[layer_id], position_map_device, kv_data, layer_id); + } + } + void DebugSetKV(int64_t seq_id, int64_t start_pos, NDArray k_data, NDArray v_data) final { ICHECK(false) << "DebugSetKV for PageAttentionKVCache not implemented yet."; } @@ -2853,6 +2952,63 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } + void MLAAbsorbedInternal(int64_t layer_id, NDArray q_data, NDArray compressed_kv_data, + NDArray k_pe_data, NDArray output, double attn_score_scaling_factor) { + int64_t local_layer_id = layer_id - layer_id_begin_offset_; + CHECK_GE(local_layer_id, 0); + CHECK_LT(local_layer_id, num_layers_); + PackedFunc f_prefill = f_mla_prefill_; + PackedFunc f_decode = f_mla_decode_; + CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; + + bool is_first_kernel = true; + if (!append_before_attn_) { + // The first part of attention, which only involves the q and the newly appended k/v. + is_first_kernel = false; + CHECK(is_chain_on_depths_[0]) << "Tree attn not able for MLA for now."; + // If the batch does not form a tree, use raggedness prefill kernel. + f_mla_prefill_ragged_absorbed_(q_data, cur_append_length_indptr_view_, compressed_kv_data, + k_pe_data, cur_append_length_indptr_view_, output, + merged_attn_scores_view_, + /*causal=*/1, attn_score_scaling_factor); + } + + for (int d = 0; d < num_depths_; ++d) { + if (page_indices_on_depths_view_[d]->shape[0] == 0) { + continue; + } + NDArray attn_output; + NDArray attn_scores; + if (is_first_kernel) { + attn_output = output; + attn_scores = merged_attn_scores_view_; + } else { + attn_output = temp_attn_output_view_; + attn_scores = temp_attn_scores_view_; + } + CHECK(is_chain_on_depths_[d]) << "Tree attn not able for MLA for now."; + if (use_decode_kernel_[d]) { + // Use decode kernel for depth d + f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], + page_indices_on_depths_view_[d], length_info_on_depths_view_[d], attn_output, + attn_scores, attn_score_scaling_factor); + } else { + // Use prefill kernel for depth d + f_prefill(/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], + page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], + length_info_on_depths_view_[d], attn_output, attn_scores, /*causal=*/0, + attn_score_scaling_factor); + } + + if (!is_first_kernel) { + f_merge_inplace_(output, merged_attn_scores_view_, temp_attn_output_view_, + temp_attn_scores_view_); + } else { + is_first_kernel = false; + } + } + } + /*! \brief Synchronize the copy stream and the compute stream. */ void ComputeStreamWaitForCopyStream() { if (!dirty_aux_data_device_) { @@ -2983,7 +3139,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } // 16. Create view for temporary arrays for attention computation. temp_attn_output_view_ = temp_attn_output_device_.CreateView( - {total_append_length, num_qo_heads_, qk_head_dim_}, temp_attn_output_device_->dtype); + {total_append_length, num_qo_heads_, v_head_dim_}, temp_attn_output_device_->dtype); temp_attn_scores_view_ = temp_attn_scores_device_.CreateView( {total_append_length, num_qo_heads_}, temp_attn_scores_device_->dtype); merged_attn_scores_view_ = merged_attn_scores_device_.CreateView( @@ -3088,8 +3244,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward), std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward), PackedFunc(), PackedFunc(), PackedFunc(), PackedFunc(), std::move(f_merge_inplace), - std::move(f_split_rotary), PackedFunc(), std::move(f_copy_single_page), - std::move(f_debug_get_kv)); + std::move(f_split_rotary), std::move(f_copy_single_page), std::move(f_debug_get_kv)); *rv = AttentionKVCache(std::move(n)); }); @@ -3170,14 +3325,13 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") std::move(f_attention_prefill_with_tree_mask_paged_kv), // NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // PackedFunc(), PackedFunc(), PackedFunc(), PackedFunc(), std::move(f_merge_inplace), - std::move(f_split_rotary), PackedFunc(), std::move(f_copy_single_page), - std::move(f_debug_get_kv)); + std::move(f_split_rotary), std::move(f_copy_single_page), std::move(f_debug_get_kv)); *rv = AttentionKVCache(std::move(n)); }); TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced_mla") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 39) << "Invalid number of KV cache constructor args."; + CHECK(args.size() == 38) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; int num_groups = 1; @@ -3219,19 +3373,18 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced_mla") PackedFunc f_mla_prefill_ragged_absorbed = args[28]; PackedFunc f_merge_inplace = args[29]; PackedFunc f_split_rotary = args[30]; - PackedFunc f_separate_rotary = args[31]; - PackedFunc f_copy_single_page = args[32]; - Optional f_debug_get_kv = args[33]; - PackedFunc f_compact_copy = args[34]; - PackedFunc f_attention_prefill_with_tree_mask = args[35]; - PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[36]; + PackedFunc f_copy_single_page = args[31]; + Optional f_debug_get_kv = args[32]; + PackedFunc f_compact_copy = args[33]; + PackedFunc f_attention_prefill_with_tree_mask = args[34]; + PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[35]; Optional rope_ext_factors = NullOpt; bool enable_kv_transfer = false; - if (args[37].IsObjectRef()) { - rope_ext_factors = args[37].AsObjectRef(); + if (args[36].IsObjectRef()) { + rope_ext_factors = args[36].AsObjectRef(); } - enable_kv_transfer = args[38]; + enable_kv_transfer = args[37]; auto f_convert_optional_packed_func = [&args](int arg_idx) -> Optional { if (args[arg_idx].IsObjectRef()) { @@ -3284,8 +3437,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced_mla") std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward), std::move(f_mla_prefill), std::move(f_mla_decode), std::move(f_mla_prefill_ragged_normal), std::move(f_mla_prefill_ragged_absorbed), std::move(f_merge_inplace), - std::move(f_split_rotary), std::move(f_separate_rotary), std::move(f_copy_single_page), - std::move(f_debug_get_kv)); + std::move(f_split_rotary), std::move(f_copy_single_page), std::move(f_debug_get_kv)); *rv = AttentionKVCache(std::move(n)); }); diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index 2422b16f5a04..fe4da50cd9bf 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import enum from typing import Dict, List, Tuple, Union import numpy as np @@ -24,9 +23,14 @@ import tvm import tvm.testing from tvm import dlight as dl -from tvm import tir +from tvm.relax.frontend.nn.llm.kv_cache import ( + RopeMode, + _copy_single_page, + _kv_cache_debug_get_kv, + _kv_cache_transpose_append, + llama_rope_with_position_map, +) from tvm.runtime import ShapeTuple -from tvm.script import tir as T reserved_nseq = 32 maximum_total_seq_length = 2048 @@ -70,207 +74,6 @@ fcopy_cache = None -@T.prim_func -def kv_cache_transpose_append( - var_pages: T.handle, - var_k_data: T.handle, - var_v_data: T.handle, - var_position_map: T.handle, -): - ntoken = T.SizeVar("ntoken", "int64") - page_size = T.SizeVar("page_size", "int64") - num_pages = T.int64() - position_map_elem_offset = T.int32() - pages = T.match_buffer(var_pages, (num_pages, 2, num_kv_heads, page_size, head_dim), dtype) - k_data = T.match_buffer(var_k_data, (ntoken, num_kv_heads, head_dim), dtype) - v_data = T.match_buffer(var_v_data, (ntoken, num_kv_heads, head_dim), dtype) - position_map = T.match_buffer( - var_position_map, (ntoken,), "int32", elem_offset=position_map_elem_offset - ) - - for global_pos, h, f in T.grid(ntoken, num_kv_heads, head_dim): - with T.block("k_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) - T.writes( - pages[position_map[vgpos] // page_size, 0, vh, position_map[vgpos] % page_size, vf] - ) - position: T.int64 = T.Cast("int64", position_map[vgpos]) - pages[ - T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf - ] = k_data[vgpos, vh, vf] - with T.block("v_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) - T.writes( - pages[position_map[vgpos] // page_size, 1, vh, position_map[vgpos] % page_size, vf] - ) - position: T.int64 = T.Cast("int64", position_map[vgpos]) - pages[ - T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vf - ] = v_data[vgpos, vh, vf] - - -def llama_rope_with_position_map( # pylint: disable=too-many-arguments - theta: float, - scale: float, - head_dim: int, - num_q_heads: int, - num_kv_heads: int, - dtype: float = "float16", - rotary_dim: int = None, -): - fused_heads = num_q_heads + num_kv_heads * 2 - if rotary_dim is None: - rotary_dim = head_dim - scale = tir.const(scale, dtype) - - def _rope_freq(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str): - freq = s / tir.power(theta, d * 2 % d_range / tir.const(d_range, "float32")) - cos_freq = tir.cos(freq).astype(dtype) - sin_freq = tir.sin(freq).astype(dtype) - return cos_freq, sin_freq - - def _rope( # pylint: disable=too-many-arguments - x: T.Buffer, - s: tir.Var, - h: tir.Var, - d: tir.Var, - pos: tir.Var, - ): - cos_freq, sin_freq = _rope_freq(pos * scale, d, rotary_dim, theta, dtype) - cos = cos_freq * x[s, h, d] - sin = sin_freq * tir.if_then_else( - d < rotary_dim // 2, - -x[s, h, d + rotary_dim // 2], - x[s, h, d - rotary_dim // 2], - ) - return cos + sin - - @T.prim_func(private=True) - def fused_rope( # pylint: disable=too-many-locals - var_qkv: T.handle, - var_position_map: T.handle, - var_q: T.handle, - var_k: T.handle, - var_v: T.handle, - apply_rope: T.int32, - ): - T.func_attr( - { - "op_pattern": 8, # 2 means injective, 8 means opaque - "tir.noalias": T.bool(True), - } - ) - seq_len = T.int64() - position_map_elem_offset = T.int64() - qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype) - q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype) - k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype) - v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype) - position_map = T.match_buffer( - var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset - ) - for iters in T.grid(seq_len, fused_heads, head_dim): - with T.block("llama_fused_rope"): - s, h, d = T.axis.remap("SSS", iters) - if h < num_q_heads: - q[s, h, d] = T.if_then_else( - apply_rope > 0 and d < rotary_dim, - _rope(qkv, s, h, d, position_map[s]), - qkv[s, h, d], - ) - elif h < num_q_heads + num_kv_heads: - k[s, h - num_q_heads, d] = T.if_then_else( - apply_rope > 0 and d < rotary_dim, - _rope(qkv, s, h, d, position_map[s]), - qkv[s, h, d], - ) - else: - v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] - - return fused_rope - - -@T.prim_func -def copy_cache( - var_pages: T.handle, - var_position_map: T.handle, - var_k_data: T.handle, - var_v_data: T.handle, - layer_id: T.int64, -): - num_kv_heads = T.int64() - head_dim = T.int64() - seqlen = T.SizeVar("seqlen", "int64") - page_size = T.int64() - num_pages = T.int64() - position_map_elem_offset = T.int64() - pages = T.match_buffer(var_pages, (num_pages, 2, num_kv_heads, page_size, head_dim), "float16") - position_map = T.match_buffer( - var_position_map, (seqlen,), "int32", elem_offset=position_map_elem_offset - ) - k_data = T.match_buffer(var_k_data, (num_layers, seqlen, num_kv_heads, head_dim), "float16") - v_data = T.match_buffer(var_v_data, (num_layers, seqlen, num_kv_heads, head_dim), "float16") - - for p, h, d in T.grid(seqlen, num_kv_heads, head_dim): - with T.block("copy0"): - vp, vh, vd = T.axis.remap("SSS", [p, h, d]) - T.reads( - position_map[vp], - pages[position_map[vp] // page_size, 0:2, vh, position_map[vp] % page_size, vd], - ) - T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) - position: T.int64 = T.Cast("int64", position_map[vp]) - k_data[layer_id, vp, vh, vd] = pages[ - T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd - ] - v_data[layer_id, vp, vh, vd] = pages[ - T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vd - ] - - -def _copy_single_page(num_heads, page_size, head_dim, dtype, target): - tx = 256 if str(target.kind) == "webgpu" else 1024 - - @T.prim_func - def copy_single_page( - pages: T.handle, - src_page_id: T.int64, - tgt_page_id: T.int64, - copy_length: T.int64, - ): - T.func_attr({"tir.is_scheduled": 1}) - num_pages = T.int32() - P = T.match_buffer(pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) - - for b in T.thread_binding( - (copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" - ): - for t in T.thread_binding(tx, thread="threadIdx.x"): - with T.block("copy"): - T.where(b * tx + t < copy_length * num_heads * head_dim) - vh = T.axis.spatial( - num_heads, - T.Cast("int32", (b * tx + t) // (copy_length * head_dim)), - ) - vp = T.axis.spatial( - copy_length, - (b * tx + t) % (copy_length * head_dim) // head_dim, - ) - vd = T.axis.spatial( - head_dim, - T.Cast( - "int32", - (b * tx + t) % head_dim, - ), - ) - P[tgt_page_id, 0, vh, vp, vd] = P[src_page_id, 0, vh, vp, vd] - P[tgt_page_id, 1, vh, vp, vd] = P[src_page_id, 1, vh, vp, vd] - - return copy_single_page - - def set_global_func(): global fclear, fcreate, fadd_sequence, fremove_sequence, ffork_sequence, fpopn global fbegin_forward, fend_forward, fattention, fattention_with_fuse_qkv, fdebug_get_kv @@ -327,12 +130,12 @@ def set_global_func(): target = tvm.target.Target.from_device(device) builts = [] for tir_func in [ - kv_cache_transpose_append, + _kv_cache_transpose_append(num_kv_heads, head_dim, dtype), llama_rope_with_position_map( - rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype + rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype, {} ), _copy_single_page(num_kv_heads, page_size, head_dim, dtype, target), - copy_cache, + _kv_cache_debug_get_kv(num_layers, num_kv_heads, head_dim, dtype), ]: mod = tvm.IRModule({"main": tir_func}) with target: @@ -388,18 +191,6 @@ def create_kv_cache(rope_mode): return cache -class RopeMode(enum.IntEnum): - """The RoPE mode of the Paged KV cache. - If it is none, the KV cache will not apply RoPE to q and k. - If it is normal, RoPE will be applied to k before adding k to cache. - Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly. - """ - - NONE = 0 - NORMAL = 1 - INLINE = 2 - - @pytest.fixture(params=[RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]) def kv_cache_and_rope_mode(request): set_global_func() diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py new file mode 100644 index 000000000000..72a45b8a4cf3 --- /dev/null +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py @@ -0,0 +1,456 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import itertools +from typing import Dict, List, Tuple, Union + +import numpy as np +import pytest +import scipy.special + +import tvm +import tvm.testing +from tvm import dlight as dl +from tvm.relax.frontend.nn.llm.kv_cache import ( + AttnKind, + RopeMode, + _attention_decode_mla, + _attention_prefill_mla, + _attention_prefill_ragged_mla_absorbed, + _copy_single_page_mla, + _kv_cache_debug_get_kv_mla, + _kv_cache_transpose_append_mla, + _merge_state_inplace, +) +from tvm.runtime import ShapeTuple + +reserved_nseq = 32 +maximum_total_seq_length = 2048 +prefill_chunk_size = 512 +page_size = 16 +num_layers = 4 +num_attention_heads = 128 +qk_nope_head_dim = 128 +qk_rope_head_dim = 64 +kv_lora_rank = 512 +dtype = None +device = tvm.cuda() + +fclear = None +fadd_sequence = None +fremove_sequence = None +ffork_sequence = None +fpopn = None +fbegin_forward = None +fend_forward = None +fmla_absorbed = None +fis_empty = None +fdebug_get_kv = None + +ftranspose_append = None +fcopy_cache = None +fattn_prefill = None +fattn_decode = None +fattn_prefill_ragged_absorbed = None +fmerge_state = None +fcopy_single_page = None + + +# Register a dumb function for testing purpose. +@tvm.register_func("test.dumb_function") +def _dumb_function(): + pass + + +def set_global_func(dtype): + global fclear, fadd_sequence, fremove_sequence, ffork_sequence + global fpopn, fbegin_forward, fend_forward + global fmla_absorbed, fis_empty, fdebug_get_kv + global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode + global fattn_prefill_ragged_absorbed + global fmerge_state, fcopy_single_page + + fclear = tvm.get_global_func("vm.builtin.kv_state_clear") + fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence") + fremove_sequence = tvm.get_global_func("vm.builtin.kv_state_remove_sequence") + ffork_sequence = tvm.get_global_func("vm.builtin.kv_state_fork_sequence") + fpopn = tvm.get_global_func("vm.builtin.kv_state_popn") + fbegin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward") + fend_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward") + fmla_absorbed = tvm.get_global_func("vm.builtin.attention_kv_cache_mla_absorbed") + fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty") + fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv_mla") + + target = tvm.target.Target.from_device(device) + builts = [] + for tir_func in [ + _kv_cache_transpose_append_mla(kv_lora_rank, qk_rope_head_dim, dtype), + _kv_cache_debug_get_kv_mla(num_layers, kv_lora_rank + qk_rope_head_dim, dtype), + _attention_prefill_mla( + num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target + ), + _attention_decode_mla( + num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target + ), + _attention_prefill_ragged_mla_absorbed( + num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, target + ), + _merge_state_inplace(num_attention_heads, kv_lora_rank, dtype, target), + _copy_single_page_mla(page_size, kv_lora_rank + qk_rope_head_dim, dtype, target), + ]: + mod = tvm.IRModule({"main": tir_func}) + with target: + mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) + f = tvm.build(mod["main"], target=target) + builts.append(f.entry_func) + + ( + ftranspose_append, + fcopy_cache, + fattn_prefill, + fattn_decode, + fattn_prefill_ragged_absorbed, + fmerge_state, + fcopy_single_page, + ) = builts + + +def create_kv_cache(dtype): + fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced_mla") + fdumb = tvm.get_global_func("test.dumb_function") + cache = fcreate( + tvm.runtime.ShapeTuple( + [ + reserved_nseq, + maximum_total_seq_length, + prefill_chunk_size, + page_size, + 0, + ] + ), + tvm.runtime.ShapeTuple([0, num_layers]), + num_attention_heads, + 1, + kv_lora_rank + qk_rope_head_dim, + kv_lora_rank, + qk_rope_head_dim, + tvm.runtime.ShapeTuple([int(AttnKind.MLA) for _ in range(num_layers)]), + RopeMode.NONE, + 1, + 10000, + tvm.nd.empty((), dtype, device=device), + fdumb, + ftranspose_append, + fdumb, + fdumb, + fdumb, + fdumb, + fdumb, + 0, + 0, + 0, + 0, + 0, + 0, + fattn_prefill, + fattn_decode, + fdumb, + fattn_prefill_ragged_absorbed, + fmerge_state, + fdumb, + fcopy_single_page, + fcopy_cache, + fdumb, + fdumb, + fdumb, + None, + False, + ) + return cache + + +@pytest.fixture(params=itertools.product(["float16"])) +def kv_cache_and_config(request): + global dtype + (dtype,) = request.param + set_global_func(dtype) + return (create_kv_cache(dtype),) + + +def verify_cached_kv(kv_cache, seq_ids, expected_kv): + for seq_id in seq_ids: + kv_expected = expected_kv[seq_id] + seq_length = expected_kv[seq_id].shape[1] + kv_actual = tvm.nd.empty(kv_expected.shape, dtype=dtype, device=device) + fdebug_get_kv(kv_cache, seq_id, 0, seq_length, kv_actual) + tvm.testing.assert_allclose(kv_actual.numpy(), kv_expected, rtol=1e-3, atol=1e-3) + + +def apply_attention( + kv_cache, + batch: List[Tuple[Union[int, Tuple[int, int, int]], int]], + cached_kv: Dict[int, np.ndarray], +) -> None: + seq_ids = [] + append_lengths = [] + for i, (seq_id, append_length) in enumerate(batch): + fork_parent_id = None + if isinstance(seq_id, tuple): + # Fork sequence + seq_id, fork_parent_id, fork_pos = seq_id + batch[i] = (seq_id, append_length) + seq_ids.append(seq_id) + append_lengths.append(append_length) + if fork_parent_id is not None: + assert fork_parent_id in cached_kv + assert seq_id not in cached_kv + ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos) + if fork_pos == -1: + cached_kv[seq_id] = cached_kv[fork_parent_id] + else: + cached_kv[seq_id] = cached_kv[fork_parent_id][::, :fork_pos] + elif seq_id not in cached_kv: + fadd_sequence(kv_cache, seq_id) + cached_kv[seq_id] = np.zeros((num_layers, 0, kv_lora_rank + qk_rope_head_dim), dtype) + + fbegin_forward(kv_cache, ShapeTuple(seq_ids), ShapeTuple(append_lengths), None) + + global_new_q = np.zeros( + (num_layers, 0, num_attention_heads, kv_lora_rank + qk_rope_head_dim), dtype + ) + global_new_kv = np.zeros((num_layers, 0, kv_lora_rank + qk_rope_head_dim), dtype) + + q_array = [] + for i, (seq_id, append_length) in enumerate(batch): + new_q = np.random.rand( + num_layers, append_length, num_attention_heads, kv_lora_rank + qk_rope_head_dim + ).astype(dtype) + new_kv = np.random.rand(num_layers, append_length, kv_lora_rank + qk_rope_head_dim).astype( + dtype + ) + q_array.append(new_q) + + cached_kv[seq_id] = np.concatenate([cached_kv[seq_id], new_kv], axis=1) + global_new_q = np.concatenate([global_new_q, new_q], axis=1) + global_new_kv = np.concatenate([global_new_kv, new_kv], axis=1) + + for layer_id in range(num_layers): + queries_np = global_new_q[layer_id] + queries = tvm.nd.array(queries_np, device) + compressed_kv = tvm.nd.array(global_new_kv[layer_id][:, :kv_lora_rank], device) + k_pe = tvm.nd.array(global_new_kv[layer_id][:, kv_lora_rank:], device) + outputs = tvm.nd.empty( + (queries_np.shape[0], queries_np.shape[1], kv_lora_rank), dtype, device=device + ) + fmla_absorbed(kv_cache, layer_id, 1.0, queries, compressed_kv, k_pe, outputs) + + # Compute attention expected results. + outputs = np.expand_dims(outputs.numpy(), axis=0) + sum_length = 0 + for i, (seq_id, append_length) in enumerate(batch): + assert cached_kv[seq_id].shape[1] >= append_length + + q_seq = q_array[i][layer_id].transpose(1, 0, 2) + k_seq = np.expand_dims(cached_kv[seq_id][layer_id], axis=1).transpose(1, 2, 0) + v_seq = np.expand_dims(cached_kv[seq_id][layer_id], axis=1).transpose(1, 0, 2)[ + :, :, :kv_lora_rank + ] + + k_seq = np.repeat(k_seq, num_attention_heads, axis=0) + v_seq = np.repeat(v_seq, num_attention_heads, axis=0) + softmax_input = q_seq.astype("float32") @ k_seq.astype("float32") + softmax_shape = softmax_input.shape + assert softmax_shape[-2] == append_length + length_diff = softmax_shape[-1] - softmax_shape[-2] + assert length_diff >= 0 + mask = np.tril( + np.full_like(softmax_input, np.finfo("float32").max), k=length_diff + ) + np.triu(np.full_like(softmax_input, np.finfo("float32").min), k=length_diff + 1) + + softmax_input = np.minimum(softmax_input, mask) + + results = np.expand_dims( + (scipy.special.softmax(softmax_input, axis=-1) @ v_seq.astype("float32")).transpose( + 1, 0, 2 + ), + axis=0, + ).astype(dtype) + + tvm.testing.assert_allclose( + outputs[:, sum_length : sum_length + append_length, ...], + results, + rtol=1e-3, + atol=1e-3, + ) + sum_length += append_length + fend_forward(kv_cache) + + # Verify + verify_cached_kv(kv_cache, seq_ids, cached_kv) + + +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config): + (kv_cache,) = kv_cache_and_config + fclear(kv_cache) + + # Prefill. + operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 19), (5, 20)]] + operation_seq += [[(6, 21), (7, 24)], [(2, 5), (4, 7), (8, 24)]] + operation_seq += [[(6, 13)], [(8, 19)], [(0, 1)], [(1, 3), (3, 8), (5, 12), (7, 11)]] + # Decode + operation_seq += [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]] + operation_seq += [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]] + operation_seq += [[(0, 1), (2, 1), (4, 1), (6, 1), (8, 1)]] + operation_seq += [[(4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]] + + cached_kv = {} + for batch in operation_seq: + apply_attention(kv_cache, batch, cached_kv) + + +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config): + (kv_cache,) = kv_cache_and_config + fclear(kv_cache) + + num_sequences = 5 + batch = [(seq_id, 1) for seq_id in range(num_sequences)] + cached_kv = {} + for seq_id_to_remove in range(num_sequences): + apply_attention(kv_cache, batch, cached_kv) + # Remove sequence. + fremove_sequence(kv_cache, seq_id_to_remove) + cached_kv.pop(seq_id_to_remove) + verify_cached_kv( + kv_cache, + seq_ids=[seq_id for seq_id in range(num_sequences) if seq_id != seq_id_to_remove], + expected_kv=cached_kv, + ) + + +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config): + (kv_cache,) = kv_cache_and_config + fclear(kv_cache) + + cached_kv = {} + batch = [(0, 60), (1, 88), (2, 17), (3, 4)] + apply_attention(kv_cache, batch, cached_kv) + # Fork existing sequences. + apply_attention(kv_cache, [((4, 3, -1), 35)], cached_kv) + apply_attention(kv_cache, [((5, 0, -1), 20)], cached_kv) + apply_attention(kv_cache, [((6, 5, -1), 102)], cached_kv) + apply_attention(kv_cache, [((7, 0, -1), 3)], cached_kv) + apply_attention(kv_cache, [((8, 5, -1), 71), ((9, 5, -1), 20)], cached_kv) + # 0 <- 5 <- 6,8,9 + # 0 <- 7 + # 3 <- 4 + # Mixture of decode and prefill. + operation_seq = [ + [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)], + [(7, 1), (6, 1), (8, 1), (9, 1)], + [(7, 1), (1, 1), (6, 1), (2, 1), (8, 1), (4, 1), (9, 1)], + [(7, 10), (6, 2), (8, 3), (9, 4)], + ] + for batch in operation_seq: + apply_attention(kv_cache, batch, cached_kv) + + apply_attention(kv_cache, [((10, 1, 33), 11)], cached_kv) + apply_attention(kv_cache, [((11, 0, 60), 45), ((12, 0, 15), 14)], cached_kv) + apply_attention(kv_cache, [((13, 0, 16), 19), ((14, 0, 17), 19)], cached_kv) + apply_attention(kv_cache, [((15, 5, 60), 8), ((16, 5, 80), 10)], cached_kv) + apply_attention( + kv_cache, + [((17, 5, 75), 11), ((18, 5, 76), 45), ((19, 5, 77), 14)], + cached_kv, + ) + + operation_seq = [ + [(6, 1), (11, 1), (13, 1), (9, 1)], + [(10, 1), (16, 1), (18, 1), (19, 1)], + [(8, 1), (15, 1), (17, 1), (12, 1), (14, 1)], + [(10, 10), (6, 2), (8, 3), (19, 4)], + ] + for batch in operation_seq: + apply_attention(kv_cache, batch, cached_kv) + + num_sequence = 20 + for i in range(num_sequence): + fremove_sequence(kv_cache, i) + cached_kv.pop(i) + verify_cached_kv( + kv_cache, + seq_ids=list(range(i + 1, num_sequence)), + expected_kv=cached_kv, + ) + + assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" + + # Test fork after page recycle + apply_attention(kv_cache, [(0, 7), (1, 24)], cached_kv) + apply_attention(kv_cache, [((2, 1, -1), 10)], cached_kv) + apply_attention(kv_cache, [((3, 0, -1), 20)], cached_kv) + apply_attention(kv_cache, [(2, 1), (3, 1)], cached_kv) + + apply_attention(kv_cache, [(10, 7), (11, 24)], cached_kv) + apply_attention(kv_cache, [((12, 11, -1), 200)], cached_kv) + apply_attention(kv_cache, [(10, 1), (12, 1)], cached_kv) + + +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_paged_attention_kv_cache_popn(kv_cache_and_config): + (kv_cache,) = kv_cache_and_config + fclear(kv_cache) + + cached_kv = {} + batch = [(0, 35), (1, 88), (2, 17), (3, 4)] + apply_attention(kv_cache, batch, cached_kv) + apply_attention(kv_cache, [((4, 3, -1), 35)], cached_kv) + + popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0), (4, 37)] + for seq_id, pop_length in popn_operations: + fpopn(kv_cache, seq_id, pop_length) + if pop_length != 0: + cached_kv[seq_id] = cached_kv[seq_id][:, :-pop_length, ...] + verify_cached_kv(kv_cache, seq_ids=list(range(4)), expected_kv=cached_kv) + + num_sequence = 5 + for seq_id in range(num_sequence): + fremove_sequence(kv_cache, seq_id) + verify_cached_kv( + kv_cache, + seq_ids=list(range(seq_id + 1, num_sequence)), + expected_kv=cached_kv, + ) + + assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" + + +if __name__ == "__main__": + DTYPES = ["float16"] + for (dtype,) in itertools.product(DTYPES): + set_global_func(dtype) + cache = create_kv_cache(dtype) + cache_and_config = (cache,) + test_paged_attention_kv_cache_prefill_and_decode(cache_and_config) + test_paged_attention_kv_cache_remove_sequence(cache_and_config) + test_paged_attention_kv_cache_fork_sequence(cache_and_config) + test_paged_attention_kv_cache_popn(cache_and_config) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 172eb20c26cf..e30debabfede 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import enum import itertools from typing import Dict, List, Optional, Tuple, Union @@ -26,6 +25,7 @@ import tvm.testing from tvm import dlight as dl from tvm.relax.frontend.nn.llm.kv_cache import ( + RopeMode, _attention_decode, _attention_prefill, _attention_prefill_ragged, @@ -198,18 +198,6 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): return cache -class RopeMode(enum.IntEnum): - """The RoPE mode of the Paged KV cache. - If it is none, the KV cache will not apply RoPE to q and k. - If it is normal, RoPE will be applied to k before adding k to cache. - Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly. - """ - - NONE = 0 - NORMAL = 1 - INLINE = 2 - - @pytest.fixture( params=itertools.chain( itertools.product( @@ -366,9 +354,11 @@ def apply_attention( rope_offset, rope_scale, rope_theta, - token_tree_node_depths_list[i][-append_length:] - if token_tree_node_depths_list[i] is not None - else None, + ( + token_tree_node_depths_list[i][-append_length:] + if token_tree_node_depths_list[i] is not None + else None + ), ) ) for l in range(num_layers) @@ -410,9 +400,11 @@ def apply_attention( rope_offset, rope_scale, rope_theta, - token_tree_node_depths_list[i][-append_length:] - if token_tree_node_depths_list[i] is not None - else None, + ( + token_tree_node_depths_list[i][-append_length:] + if token_tree_node_depths_list[i] is not None + else None + ), ) ).transpose(1, 0, 2) k_seq = (