From 38a4c7dc2ea6ab1c92d41d3e491a2c8d655e0b8f Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sun, 26 Jan 2025 17:38:27 +0000 Subject: [PATCH 1/3] [Bugfix/CI] Fix test_mha_attn Signed-off-by: Tyler Michael Smith --- tests/kernels/test_mha_attn.py | 4 ++-- vllm/attention/layer.py | 9 +++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_mha_attn.py b/tests/kernels/test_mha_attn.py index 22d434f5e40ef..eab874e9e02bb 100644 --- a/tests/kernels/test_mha_attn.py +++ b/tests/kernels/test_mha_attn.py @@ -26,7 +26,7 @@ def clear_cache(): @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) def test_mha_attn_platform(device: str): """ - Test that the attention selector between different platform and device. + Test the attention selector between different platform and device. """ torch.set_default_dtype(torch.float16) @@ -41,7 +41,7 @@ def test_mha_attn_platform(device: str): else: with patch("vllm.attention.selector.current_platform", CudaPlatform()): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.FLASH_ATTN + assert attn.attn_backend == _Backend.XFORMERS with patch("vllm.attention.selector.current_platform", CudaPlatform()): attn = MultiHeadAttention(16, 72, scale=1) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index da663d894aeb3..1e066edb3ac38 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -243,6 +243,15 @@ def forward( if self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops + # Expand key and value to match number of query heads + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_heads // + self.num_kv_heads, + dim=2) + value = value.repeat_interleave(self.num_heads // + self.num_kv_heads, + dim=2) + out = xops.memory_efficient_attention_forward(query, key, value, From 8c47e2fa29cef132f3f962e476ab6f997e538d73 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sun, 26 Jan 2025 17:42:16 +0000 Subject: [PATCH 2/3] assert Signed-off-by: Tyler Michael Smith --- vllm/attention/layer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 1e066edb3ac38..5908e3a15b5ca 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -245,6 +245,7 @@ def forward( # Expand key and value to match number of query heads if self.num_kv_heads != self.num_heads: + assert self.num_heads % self.num_kv_heads == 0 key = key.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2) From dabdae56083e82ccf6e5d530b35bd7d30b376097 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sun, 26 Jan 2025 17:50:40 +0000 Subject: [PATCH 3/3] improve Signed-off-by: Tyler Michael Smith --- vllm/attention/layer.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 5908e3a15b5ca..962c45a65ae23 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -210,6 +210,9 @@ def __init__( self.scale = scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + dtype = torch.get_default_dtype() attn_backend = get_attn_backend(head_size, dtype, @@ -240,19 +243,14 @@ def forward( key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) + if (num_repeat := self.num_queries_per_kv) > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_repeat, dim=2) + value = torch.repeat_interleave(value, num_repeat, dim=2) + if self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops - # Expand key and value to match number of query heads - if self.num_kv_heads != self.num_heads: - assert self.num_heads % self.num_kv_heads == 0 - key = key.repeat_interleave(self.num_heads // - self.num_kv_heads, - dim=2) - value = value.repeat_interleave(self.num_heads // - self.num_kv_heads, - dim=2) - out = xops.memory_efficient_attention_forward(query, key, value,