From 0cf1ceed49f643b08d87737d2313f4603eade26f Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sun, 26 Jan 2025 13:39:03 -0500 Subject: [PATCH] [Bugfix/CI] Fix broken kernels/test_mha.py (#12450) --- tests/kernels/test_mha_attn.py | 4 ++-- vllm/attention/layer.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_mha_attn.py b/tests/kernels/test_mha_attn.py index 22d434f5e40ef..eab874e9e02bb 100644 --- a/tests/kernels/test_mha_attn.py +++ b/tests/kernels/test_mha_attn.py @@ -26,7 +26,7 @@ def clear_cache(): @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) def test_mha_attn_platform(device: str): """ - Test that the attention selector between different platform and device. + Test the attention selector between different platform and device. """ torch.set_default_dtype(torch.float16) @@ -41,7 +41,7 @@ def test_mha_attn_platform(device: str): else: with patch("vllm.attention.selector.current_platform", CudaPlatform()): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.FLASH_ATTN + assert attn.attn_backend == _Backend.XFORMERS with patch("vllm.attention.selector.current_platform", CudaPlatform()): attn = MultiHeadAttention(16, 72, scale=1) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index da663d894aeb3..962c45a65ae23 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -210,6 +210,9 @@ def __init__( self.scale = scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + dtype = torch.get_default_dtype() attn_backend = get_attn_backend(head_size, dtype, @@ -240,6 +243,11 @@ def forward( key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) + if (num_repeat := self.num_queries_per_kv) > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_repeat, dim=2) + value = torch.repeat_interleave(value, num_repeat, dim=2) + if self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops