From 506c9324e69596dbd9d08ab565b6fe0ed9a216e8 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 30 Jan 2025 02:12:29 +0000 Subject: [PATCH] more cleanup Signed-off-by: Lucas Wilkinson --- ...ion.py => test_triton_decode_attention.py} | 62 +++++++++---------- vllm/attention/backends/mla/utils.py | 20 +++--- vllm/attention/backends/triton_mla.py | 4 +- vllm/attention/ops/triton_decode_attention.py | 28 ++++----- vllm/model_executor/models/deepseek_v2.py | 12 ++-- 5 files changed, 57 insertions(+), 69 deletions(-) rename tests/kernels/{test_decode_attention.py => test_triton_decode_attention.py} (51%) diff --git a/tests/kernels/test_decode_attention.py b/tests/kernels/test_triton_decode_attention.py similarity index 51% rename from tests/kernels/test_decode_attention.py rename to tests/kernels/test_triton_decode_attention.py index 55db6b1d1e8e5..14f5a3b770b69 100644 --- a/tests/kernels/test_decode_attention.py +++ b/tests/kernels/test_triton_decode_attention.py @@ -1,3 +1,4 @@ +import pytest import torch from vllm.attention.ops.triton_decode_attention import decode_attention_fwd @@ -7,37 +8,48 @@ def cdiv(a, b): return (a + b - 1) // b -def test_decode_attention(B, L, H_Q, H_KV, D, CACHE_SIZE, PAGE_SIZE): +@pytest.mark.parametrize("B", [3, 5]) +@pytest.mark.parametrize("L", [1027, 1025]) +@pytest.mark.parametrize("H_Q", [32]) +@pytest.mark.parametrize("H_KV", [32, 8]) +@pytest.mark.parametrize("D_QK", [128, 192, 576]) +@pytest.mark.parametrize("D_V", [128, 512]) +@pytest.mark.parametrize("CACHE_SIZE", [16384]) +@pytest.mark.parametrize("PAGE_SIZE", [1, 16]) +def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): assert CACHE_SIZE % PAGE_SIZE == 0 dtype = torch.bfloat16 seq_len = L # This represents the number of tokens already in the sequence - sm_scale = 1.0 / (D**0.5) + sm_scale = 1.0 / (D_QK**0.5) num_kv_splits = 8 num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) - req_to_page = torch.randint(0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda") + req_to_page = torch.randint(0, + CACHE_SIZE // PAGE_SIZE, + (B, num_pages_per_batch, 1), + device="cuda") req_to_token = req_to_page * PAGE_SIZE req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) - req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1) + req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view( + 1, 1, -1) req_to_token = req_to_token.view(B, -1) req_to_token = req_to_token[:, :seq_len].contiguous() # q represents the new token being generated, one per batch - q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") + q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda") # k_buffer and v_buffer represent all previous tokens # Page size is 1. - k_buffer = torch.randn(CACHE_SIZE, H_KV, D, dtype=dtype, device="cuda") - v_buffer = torch.randn(CACHE_SIZE, H_KV, D, dtype=dtype, device="cuda") + k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda") + v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda") # o will have the same shape as q - o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") - b_req_idx = torch.arange(B, device="cuda") - b_seq_len = torch.full((B,), seq_len, device="cuda") + b_seq_len = torch.full((B, ), seq_len, device="cuda") attn_logits = torch.empty( - (B, H_Q, num_kv_splits, D + 1), + (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda", ) @@ -56,24 +68,15 @@ def test_decode_attention(B, L, H_Q, H_KV, D, CACHE_SIZE, PAGE_SIZE): ) # Page size can be larger than 1. - k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D) - v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D) + k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK) + v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V) o1 = torch.zeros_like(o) - b_seq_len = torch.full((B,), seq_len, device="cuda") - - attn_logits = torch.empty( - (B, H_Q, num_kv_splits, D + 1), - dtype=torch.float32, - device="cuda", - ) - - # Trick: Flatten the KV cache so that we use page_size = 1 inside the kernel. decode_attention_fwd( q, - k_buffer.flatten(0, 1), - v_buffer.flatten(0, 1), + k_buffer, + v_buffer, o1, req_to_page, b_seq_len, @@ -82,14 +85,5 @@ def test_decode_attention(B, L, H_Q, H_KV, D, CACHE_SIZE, PAGE_SIZE): sm_scale, PAGE_SIZE, ) - print(torch.allclose(o, o1)) - assert torch.allclose(o, o1) - -if __name__ == "__main__": - # GQA - test_decode_attention(B=5, L=1027, H_Q=32, H_KV=8, D=128, CACHE_SIZE=16384, PAGE_SIZE=1) - test_decode_attention(B=5, L=1027, H_Q=32, H_KV=8, D=128, CACHE_SIZE=16384, PAGE_SIZE=16) - # MHA - test_decode_attention(B=3, L=1025, H_Q=32, H_KV=32, D=128, CACHE_SIZE=16384, PAGE_SIZE=1) - test_decode_attention(B=3, L=1025, H_Q=32, H_KV=32, D=128, CACHE_SIZE=16384, PAGE_SIZE=16) \ No newline at end of file + assert torch.allclose(o, o1) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index a82638b1b80db..e3203aca1880f 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -280,8 +280,8 @@ def _forward_decode( def forward( self, layer: AttentionLayer, - hidden_states_or_cq: torch.Tensor, # query in unified attn - ckv_normed: torch.Tensor, # key in unified attn + hidden_states_or_q_c: torch.Tensor, # query in unified attn + k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: MLAMetadataCommon, @@ -289,7 +289,7 @@ def forward( ) -> torch.Tensor: if output is not None: raise NotImplementedError( - "output is not yet supported for FlashInferMLAImpl") + "output is not yet supported for TritonMLAImpl") is_decode = attn_metadata.decode_metadata is not None is_prefill = attn_metadata.prefill_metadata is not None @@ -302,14 +302,14 @@ def forward( k_pe = k_pe.unsqueeze(1) if is_decode: - q_nope = self._q_proj_and_k_up_proj(hidden_states_or_cq) - q_pe = torch.matmul(hidden_states_or_cq, self.W_QR)\ + q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c) + q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) q_pe, k_pe = \ self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe) else: assert is_prefill - q = self.q_proj(hidden_states_or_cq)[0]\ + q = self.q_proj(hidden_states_or_q_c)[0]\ .view(-1, self.num_heads, self.qk_head_dim) # TODO(lucas): there must be a nicer way to write this line @@ -321,7 +321,7 @@ def forward( # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( - ckv_normed, + k_c_normed, k_pe.squeeze(1), kv_cache, attn_metadata.slot_mapping.flatten(), @@ -330,7 +330,7 @@ def forward( ) if attn_metadata.prefill_metadata is not None: - return self._forward_prefill(q, ckv_normed, k_pe, attn_metadata) + return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata) if attn_metadata.decode_metadata is not None: return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata) @@ -339,13 +339,13 @@ def forward( def _forward_prefill_flash( self, q: torch.Tensor, - ckv_normed: torch.Tensor, + k_c_normed: torch.Tensor, k_pe: torch.Tensor, seq_start_loc: torch.Tensor, max_prefill_seq_len: int, ) -> torch.Tensor: - kv_nope = self.kv_b_proj(ckv_normed)[0]\ + kv_nope = self.kv_b_proj(k_c_normed)[0]\ .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv_nope\ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 2b44d6e152e8e..f52edea9dd9d3 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -624,11 +624,11 @@ def __init__( def _forward_prefill( self, q: torch.Tensor, - ckv_normed: torch.Tensor, + kv_c_normed: torch.Tensor, k_pe: torch.Tensor, attn_metadata: TritonMLAMetadata, ) -> torch.Tensor: - return self._forward_prefill_flash(q, ckv_normed, k_pe, + return self._forward_prefill_flash(q, kv_c_normed, k_pe, attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py index 01a3ed6391bd3..675df109b6c0e 100644 --- a/vllm/attention/ops/triton_decode_attention.py +++ b/vllm/attention/ops/triton_decode_attention.py @@ -184,7 +184,7 @@ def _decode_att_m_fwd( batch, head_num = q.shape[0], q.shape[1] grid = (batch, head_num, NUM_KV_SPLITS) - kv_group_num = q.shape[1] // k_buffer.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[-2] num_warps = 4 if kv_group_num == 1 else 2 @@ -202,10 +202,10 @@ def _decode_att_m_fwd( Req_to_tokens.stride(0), q.stride(0), q.stride(1), - k_buffer.stride(0), - k_buffer.stride(1), - v_buffer.stride(0), - v_buffer.stride(1), + k_buffer.stride(-2), + k_buffer.stride(-1), + v_buffer.stride(-2), + v_buffer.stride(-1), att_out.stride(0), att_out.stride(1), att_out.stride(2), @@ -405,7 +405,7 @@ def _decode_grouped_att_m_fwd( BLOCK_DV = triton.next_power_of_2(Lv) batch, head_num = q.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k_buffer.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[-2] BLOCK_H = 16 NUM_KV_SPLITS = num_kv_splits @@ -436,10 +436,10 @@ def _decode_grouped_att_m_fwd( Req_to_tokens.stride(0), q.stride(0), q.stride(1), - k_buffer.stride(0), - k_buffer.stride(1), - v_buffer.stride(0), - v_buffer.stride(1), + k_buffer.stride(-2), + k_buffer.stride(-1), + v_buffer.stride(-2), + v_buffer.stride(-1), att_out.stride(0), att_out.stride(1), att_out.stride(2), @@ -633,13 +633,7 @@ def decode_attention_fwd( logit_cap=0.0, ): assert num_kv_splits == attn_logits.shape[2] - kv_group_num = q.shape[1] // v_buffer.shape[1] - - if page_size > 1: - # Make the buffers look like page_size 1 since the original kernel only - # supported page size 1 - k_buffer = k_buffer.flatten(0, 1) - v_buffer = v_buffer.flatten(0, 1) + kv_group_num = q.shape[1] // v_buffer.shape[-2] if kv_group_num == 1: # MHA diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index db67e82378318..538668927b72c 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -459,14 +459,14 @@ def forward( ) -> torch.Tensor: if self.q_lora_rank is not None: ckq = self.q_a_proj(hidden_states)[0] - hidden_states_or_ckq = self.q_a_layernorm(ckq) + hidden_states_or_q_c = self.q_a_layernorm(ckq) else: - hidden_states_or_ckq = hidden_states - ckv_nope, ck_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + hidden_states_or_q_c = hidden_states + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - ckv_nope_normed = self.kv_a_layernorm(ckv_nope.contiguous()) - return self.mla_attn(hidden_states_or_ckq, ckv_nope_normed, ck_pe, - kv_cache, attn_metadata) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache, + attn_metadata) class DeepseekV2DecoderLayer(nn.Module):