Skip to content

Commit

Permalink
support deepseek mla
Browse files Browse the repository at this point in the history
  • Loading branch information
yangw1234 committed Feb 2, 2025
1 parent fe27a80 commit 323d562
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 24 deletions.
10 changes: 7 additions & 3 deletions scripts/run_example_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import argparse
import os

#model_path = "/software/data/DeepSeek-R1/"
model_path = "deepseek-ai/DeepSeek-V2-Lite"

# Parse the command-line arguments.
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="/software/data/DeepSeek-R1/", help="The model path.")
parser.add_argument("--model", type=str, default=model_path, help="The model path.")
#parser.add_argument("--model", type=str, default="/data/models/DeepSeek-R1/", help="The model path.")
parser.add_argument("--tokenizer", type=str, default="deepseek-ai/DeepSeek-R1", help="The model path.")
parser.add_argument("--tokenizer", type=str, default=model_path, help="The model path.")
#parser.add_argument("--model", type=str, default="/data/models/DeepSeek-R1-bf16-small/", help="The model path.")
#parser.add_argument("--tokenizer", type=str, default="opensourcerelease/DeepSeek-R1-bf16", help="The model path.")
parser.add_argument("--tp_size", type=int, default=8, help="The number of threads.")
parser.add_argument("--tp_size", type=int, default=1, help="The number of threads.")
args = parser.parse_args()

os.environ["VLLM_SKIP_WARMUP"] = "true"
Expand All @@ -37,6 +40,7 @@
tokenizer=args.tokenizer,
trust_remote_code=True,
dtype="bfloat16",
max_model_len=1024,
)
else:
llm = LLM(
Expand Down
220 changes: 218 additions & 2 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
from vllm.attention.backends.mla.utils import MLACommonImpl
import vllm_hpu_extension.kernels as kernels
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.flags import enabled_flags
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
VLLMKVCache)

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType, AttentionLayer)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
HPUPagedAttentionMetadata)
Expand Down Expand Up @@ -67,6 +67,26 @@ def copy_blocks(
HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)


class HPUMLAAttentionBackend(HPUAttentionBackend):
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
kv_lora_rank: int,
) -> Tuple[int, ...]:
k_pe_size = kv_lora_rank // 8
return (num_blocks, block_size, kv_lora_rank + k_pe_size), True

@staticmethod
def get_impl_cls() -> Type["HPUAttentionImpl"]:
return HPUMLAImpl

@staticmethod
def get_name() -> str:
return "HPU_MLA"


@dataclass
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
"""Metadata for HPUAttentionbackend."""
Expand All @@ -76,6 +96,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]
input_positions: torch.Tensor
seq_lens: Optional[List[int]] = None
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
Expand All @@ -88,6 +109,201 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
cross_block_scales: Optional[torch.Tensor] = None
cross_block_usage: Optional[torch.Tensor] = None
cross_attn_bias: Optional[torch.Tensor] = None


class HPUMLAImpl(MLACommonImpl[HPUAttentionMetadata]):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**kwargs) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**kwargs)

self.matmul_qk = Matmul()
self.softmax = Softmax()
self.matmul_av = Matmul()
self.batch2block_matmul = Matmul()
self.block2batch_matmul = Matmul()
self.latent_cache = VLLMKVCache()
HPUFusedSDPA = kernels.fsdpa()
self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
else ModuleFusedSDPA(HPUFusedSDPA)

unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")

def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")

batch_size = hidden_states_or_q_c.shape[0]

is_prefill = attn_metadata.is_prompt

k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)

# Restore head dim (for rotary embedding)
# k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions"), f"attn meta: {attn_metadata}"

if not is_prefill:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
input_positions = attn_metadata.input_positions.view(-1)
print("q_pe", q_pe.shape)
print("k_pe", k_pe.shape)
print("input_positions", attn_metadata.input_positions.shape)
q_pe, k_pe = \
self.rotary_emb(input_positions, q_pe, k_pe)
else:
q = self.q_proj(hidden_states_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)

q_pe = q[..., self.qk_nope_head_dim:]

# print("q_pe shape", q_pe.shape)
# print("k_pe shape", k_pe.shape)
# print("input_positions shape", attn_metadata.input_positions.shape)
input_positions = attn_metadata.input_positions.view(-1)
# TODO(lucas): there must be a nicer way to write this line
q[..., self.qk_nope_head_dim:], k_pe = \
self.rotary_emb(input_positions, q_pe, k_pe)

block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets

latent_vec = torch.concat(
(k_c_normed, k_pe.view(batch_size, -1, self.qk_rope_head_dim)), dim=-1)
# assert layer._k_scale == 0, f"got _k_scale={layer._k_scale}"
# print(f"layer._k_scale={layer._k_scale}")

# write the latent and rope to kv cache
if kv_cache is not None:
kv_cache = self.latent_cache(latent_vec, kv_cache, block_indices,
block_offsets)

if is_prefill:
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata, batch_size)
else:
return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata, batch_size)

def _forward_prefill(
self,
q: torch.Tensor,
k_c_normed: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
batch_size: int
) -> torch.Tensor:
kv_nope = self.kv_b_proj(k_c_normed)[0]\
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
q = q.view(batch_size, -1, self.num_heads, self.qk_head_dim)
k = k.view(batch_size, -1, self.num_heads, self.qk_head_dim)
v_padded = v_padded.view(batch_size, -1, self.num_heads, self.qk_head_dim)
out = ops.prompt_attention(
q,
k,
v_padded,
attn_bias=None,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
fsdpa_op=self.fused_scaled_dot_product_attention,
)
attn_output = out\
.view(batch_size, -1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
.reshape(batch_size, -1, self.num_heads * v.shape[-1])

return self.o_proj(attn_output)[0]

def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
batch_size: int
) -> torch.Tensor:
print(f"q_nope shape: {q_nope.shape}")
print(f"q_pe shape: {q_pe.shape}")

q = torch.cat([q_nope, q_pe], dim=-1)
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]

print(f"q shape: {q.shape}")
print(f"kv_c_and_k_pe_cache shape: {kv_c_and_k_pe_cache.shape}")
print(f"kv_c_cache shape: {kv_c_cache.shape}")
output = HPUPagedAttention.forward_decode(
query=q,
key_cache=kv_c_and_k_pe_cache,
value_cache=kv_c_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.latent_cache.fetch_from_cache,
values_fetch_func=self.latent_cache.fetch_from_cache)
output = output.view(batch_size, 1, -1)
print("output", output.shape)
result = self._v_up_proj_and_o_proj(output)
result = result.view(batch_size, 1, -1)
print("result", result.shape)
return result


class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.vllm_flash_attn import flash_attn_varlen_func
# from vllm.vllm_flash_attn import flash_attn_varlen_func


@dataclass
Expand Down
32 changes: 21 additions & 11 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,27 @@ def forward_hpu(
assert len(x.shape) == 2
import habana_frameworks.torch as htorch
htorch.core.mark_step()
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
if use_grouped_topk:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
else:
import torch.nn.functional as F
topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(topk_weights,
top_k,
dim=-1)
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(x.dtype)

# final_hidden_states = layer.hpu_fused_moe.MoeOp(
# hidden_states=x,
# expert_routing_table=topk_ids,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ def __init__(
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
num_tokens = hidden_states.shape[0]
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
Expand Down
5 changes: 4 additions & 1 deletion vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
block_size: int, use_v1: bool,
use_mla: bool) -> str:
logger.info("Using HPUAttention backend.")
return "vllm.attention.backends.hpu_attn.HPUAttentionBackend"
if use_mla:
return "vllm.attention.backends.hpu_attn.HPUMLAAttentionBackend"
else:
return "vllm.attention.backends.hpu_attn.HPUAttentionBackend"

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,7 @@ def _prepare_prompt(
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
input_positions=input_positions,
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
for t in multi_modal_kwargs:
Expand Down Expand Up @@ -1347,6 +1348,7 @@ def _prepare_decode(
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
input_positions=input_positions
)
return PrepareDecodeMetadata(input_tokens=input_tokens,
input_positions=input_positions,
Expand Down Expand Up @@ -1553,6 +1555,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
'block_offsets',
'block_scales',
'block_groups',
'input_positions',
])
return attention_metadata

Expand Down
Loading

0 comments on commit 323d562

Please sign in to comment.