Skip to content

Commit

Permalink
simple optimization for moonlight moe decoding forward (#12891)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Feb 25, 2025
1 parent ae9f532 commit 5faba06
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,9 +2031,11 @@ def _optimize_post(model):
from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.deepseek import deepseek_model_forward
from ipex_llm.transformers.models.deepseek import deepseek_attention_forward
from ipex_llm.transformers.models.deepseek import deepseek_moe_forward
convert_forward(model, module.DeepseekV3RMSNorm, rms_norm_forward)
convert_forward(model, module.DeepseekV3Model, deepseek_model_forward)
convert_forward(model, module.DeepseekV3Attention, deepseek_attention_forward)
convert_forward(model, module.DeepseekV3MoE, deepseek_moe_forward)

return model

Expand Down
32 changes: 32 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,35 @@ def deepseek_attention_forward(
attn_weights = None

return attn_output, attn_weights, past_key_value


def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
idxs = topk_ids.flatten().tolist()
outputs = []
for i in idxs:
expert = self.experts[i]
expert_out = expert(x)
outputs.append(expert_out)
outs = torch.cat(outputs, dim=0)
reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
return final_out


def deepseek_moe_forward(self, hidden_states: torch.Tensor):
identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if not self.training:
# IPEX-LLM OPT start : add special moe_infer implementation for decoding
if topk_idx.size(0) == 1:
y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
else:
y = self.moe_infer(hidden_states, topk_idx, topk_weight)
y = y.view(*orig_shape)
# IPEX-LLM OPT end
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
return y

0 comments on commit 5faba06

Please sign in to comment.