diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 3e3c0668198ad..124d5d297a574 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -31,9 +31,9 @@ NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing -# FlashAttention forward only supports head dimension at most 128 -# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64, 80, 120, 256] +# This should be sync with get_supported_head_sizes() in +# vllm.attention.ops.paged_attn.PagedAttention +HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] diff --git a/vllm/config.py b/vllm/config.py index e64883368a751..2fe674b857e16 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -733,9 +733,12 @@ def get_head_size(self) -> int: if hasattr(self.hf_text_config, "model_type") and (self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3')): - # FlashAttention supports only head_size 32, 64, 128, 256, - # we need to pad head_size 192 to 256 - return 256 + qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", + 0) + qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", + 0) + if qk_rope_head_dim and qk_nope_head_dim: + return qk_rope_head_dim + qk_nope_head_dim if self.is_attention_free: return 0 diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d83cafaf998ab..af6810a140b43 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -262,14 +262,8 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - # self.attn = Attention(self.num_heads, - # self.qk_head_dim, - # self.scaling, - # num_kv_heads=self.num_heads) - - # TODO, support head_size 192 self.attn = Attention(self.num_local_heads, - 256, + self.qk_head_dim, self.scaling, num_kv_heads=self.num_local_heads, cache_config=cache_config, @@ -319,18 +313,14 @@ def forward( k = torch.empty_like(q) k[..., :self.qk_nope_head_dim] = k_nope k[..., self.qk_nope_head_dim:] = k_pe - q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], - value=0).view(-1, - self.num_local_heads * 256) - k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], - value=0).view(-1, - self.num_local_heads * 256) - v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], - value=0).view(-1, - self.num_local_heads * 256) + # padding value to qk_head_dim for alignment + v = torch.nn.functional.pad( + v, [0, self.qk_head_dim - self.v_head_dim], + value=0).view(-1, self.num_local_heads * self.qk_head_dim) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = attn_output.view( - -1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads, + self.qk_head_dim)[..., :self.v_head_dim].reshape( -1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/model_executor/models/deepseek_v3.py b/vllm/model_executor/models/deepseek_v3.py index ca79b14c55fea..0b44f0d062c40 100644 --- a/vllm/model_executor/models/deepseek_v3.py +++ b/vllm/model_executor/models/deepseek_v3.py @@ -269,14 +269,8 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - # self.attn = Attention(self.num_heads, - # self.qk_head_dim, - # self.scaling, - # num_kv_heads=self.num_heads) - - # TODO, support head_size 192 self.attn = Attention(self.num_local_heads, - 256, + self.qk_head_dim, self.scaling, num_kv_heads=self.num_local_heads, cache_config=cache_config, @@ -326,18 +320,14 @@ def forward( k = torch.empty_like(q) k[..., :self.qk_nope_head_dim] = k_nope k[..., self.qk_nope_head_dim:] = k_pe - q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], - value=0).view(-1, - self.num_local_heads * 256) - k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], - value=0).view(-1, - self.num_local_heads * 256) - v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], - value=0).view(-1, - self.num_local_heads * 256) + # padding value to qk_head_dim for alignment + v = torch.nn.functional.pad( + v, [0, self.qk_head_dim - self.v_head_dim], + value=0).view(-1, self.num_local_heads * self.qk_head_dim) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = attn_output.view( - -1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads, + self.qk_head_dim)[..., :self.v_head_dim].reshape( -1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output