Skip to content

Commit

Permalink
optimize attention part of moonlight-14B-A3B (#12886)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Feb 25, 2025
1 parent dd30d12 commit ab3fc66
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 4 deletions.
13 changes: 12 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,9 @@ def _optimize_pre(model, qtype=None):
model.apply(pre_register_inv_freq)
elif model.config.model_type == "multi_modality":
_optimize_pre(model.language_model)

elif model.config.model_type == "deepseek_v3" and model.config.hidden_size == 2048:
from ipex_llm.transformers.models.deepseek import padding_mla_v_hd
model.apply(padding_mla_v_hd)
return model


Expand Down Expand Up @@ -2023,6 +2025,15 @@ def _optimize_post(model):

# llm
_optimize_post(model.language_model)
elif model.config.model_type == "deepseek_v3" and model.config.hidden_size == 2048:
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.deepseek import deepseek_model_forward
from ipex_llm.transformers.models.deepseek import deepseek_attention_forward
convert_forward(model, module.DeepseekV3RMSNorm, rms_norm_forward)
convert_forward(model, module.DeepseekV3Model, deepseek_model_forward)
convert_forward(model, module.DeepseekV3Attention, deepseek_attention_forward)

return model

Expand Down
27 changes: 27 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,33 @@ def padding_attention_hd_base(module: torch.nn.Module, attention_class,
module.old_head_dim = old_head_dim


def padding_mla_v_hd_base(module: torch.nn.Module, attention_class):
if (
isinstance(attention_class, str) and module.__class__.__name__ == attention_class
or not isinstance(attention_class, str) and isinstance(module, attention_class)
):
k_head_dim = module.q_head_dim
v_head_dim = module.v_head_dim
if v_head_dim < k_head_dim:
kv_b_proj = module.kv_b_proj
w = kv_b_proj.weight.data.view(module.num_heads,
module.qk_nope_head_dim + module.v_head_dim,
module.kv_lora_rank)
k_w, v_w = w.split([module.qk_nope_head_dim, module.v_head_dim], dim=1)
new_v_w = torch.zeros([module.num_heads, k_head_dim, module.kv_lora_rank],
dtype=v_w.dtype, device=v_w.device)
new_v_w[:, :v_head_dim, :] = v_w
new_w = torch.cat([k_w, new_v_w], dim=1).view(-1, module.kv_lora_rank)

new_kv_b_proj = torch.nn.Linear(0, 0, bias=False,
dtype=new_w.dtype, device=new_w.device)
new_kv_b_proj.in_features = new_w.size(1)
new_kv_b_proj.out_features = new_w.size(0)
new_kv_b_proj.weight = torch.nn.Parameter(new_w, False)

module.kv_b_proj = new_kv_b_proj


def padding_states_hd(states: torch.Tensor, old_head_dim: int, new_head_dim: int):
bsz, num_heads, seq_len, head_dim = states.size()
if head_dim == old_head_dim and old_head_dim < new_head_dim:
Expand Down
271 changes: 271 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
# which is licensed under Apache License 2.0:
#
# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
#

import torch
import warnings

from typing import Optional, Tuple, List, Union
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.kv import DynamicNormalCache
from ipex_llm.transformers.models.common import padding_mla_v_hd_base
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import rotate_half


def padding_mla_v_hd(module: torch.nn.Module):
padding_mla_v_hd_base(module, "DeepseekV3Attention")


def deepseek_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)

use_cache = use_cache if use_cache is not None else self.config.use_cache

return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

# retrieve input_ids and inputs_embeds
invalidInputError((input_ids is None) ^ (inputs_embeds is None),
"You cannot specify both input_ids and inputs_embeds at the same time, "
"and must specify either one")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

batch_size, seq_length = inputs_embeds.shape[:2]

# IPEX-LLM OPT start: kv cache
past_key_values_length = 0
use_cache = True if inputs_embeds.device.type == "xpu" else use_cache
if use_cache:
if not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
# IPEX-LLM OPT end: kv cache

if position_ids is None:
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=inputs_embeds.device,
)
position_ids = position_ids.unsqueeze(0)

# IPEX-LLM OPT start: fuse rope
if inputs_embeds.device.type == "xpu" and position_ids is not None:
cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds,
seq_length + past_key_values_length)
cos = cos[position_ids[0]].contiguous()
sin = sin[position_ids[0]].contiguous()
position_embeddings = (cos, sin)
else:
position_embeddings = None
# IPEX-LLM OPT end: fuse rope

# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)

# embed positions
hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)

layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=position_embeddings,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = next_decoder_cache
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)

b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def deepseek_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
"Please make sure use `attention_mask` instead.`"
)

bsz, q_len, _ = hidden_states.size()

if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)

compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.q_head_dim)
.transpose(1, 2)
)

k_nope, value_states = torch.split(
kv, [self.qk_nope_head_dim, self.q_head_dim], dim=-1
)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

position_embeddings = kwargs.get("position_embeddings", None)
if position_embeddings is not None:
query_states = q
key_states = torch.cat(
[k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],
dim=-1
)
import xe_addons
cos, sin = position_embeddings
xe_addons.rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim:],
key_states[:, :, :, self.qk_nope_head_dim:],
cos, sin, True)
else:
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim:] = q_pe

key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe

if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)

attn_weights = None
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
attention_mask, q_len == kv_seq_len, self.softmax_scale
)
attn_output = attn_output[:, :, :, :self.v_head_dim]

attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
28 changes: 25 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/minicpm3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://hf-mirror.com/openbmb/MiniCPM3-4B/blob/main/modeling_minicpm.py
# which is licensed under Apache License 2.0:
#
# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
#

import torch
import warnings

Expand Down Expand Up @@ -122,9 +144,6 @@ def minicpm3_attention_forward(

q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)

compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
Expand Down Expand Up @@ -169,6 +188,9 @@ def minicpm3_attention_forward(
else:
invalidInputError(f"unknown rope method: {self.rotary_emb.__class__.__name__}")
else:
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

Expand Down

0 comments on commit ab3fc66

Please sign in to comment.