Skip to content

Commit

Permalink
add fuse moe optimization for moonlight
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Feb 26, 2025
1 parent 5faba06 commit 3cf2cbe
Showing 1 changed file with 32 additions and 10 deletions.
42 changes: 32 additions & 10 deletions python/llm/src/ipex_llm/transformers/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,37 @@ def deepseek_attention_forward(


def moe_infer_decode(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor):
idxs = topk_ids.flatten().tolist()
outputs = []
for i in idxs:
expert = self.experts[i]
expert_out = expert(x)
outputs.append(expert_out)
outs = torch.cat(outputs, dim=0)
reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
if (
x.device.type == "xpu"
and x.dtype in [torch.float, torch.half]
and self.experts[0].down_proj.qtype == 2
):
if getattr(self, "gates", None) is None:
gate_addrs = [expert.gate_proj.weight.data_ptr() for expert in self.experts]
up_addrs = [expert.up_proj.weight.data_ptr() for expert in self.experts]
down_addrs = [expert.down_proj.weight.data_ptr() for expert in self.experts]
gates = torch.tensor(gate_addrs, dtype=torch.uint64, device=x.device)
ups = torch.tensor(up_addrs, dtype=torch.uint64, device=x.device)
downs = torch.tensor(down_addrs, dtype=torch.uint64, device=x.device)
self.register_buffer("gates", gates, persistent=False)
self.register_buffer("ups", ups, persistent=False)
self.register_buffer("downs", downs, persistent=False)

import xe_linear
final_out = xe_linear.moe_forward_vec(
x, topk_ids, topk_weight, self.gates, self.ups, self.downs,
x.size(-1), self.experts[0].intermediate_size, 2
)
else:
idxs = topk_ids.flatten().tolist()
outputs = []
for i in idxs:
expert = self.experts[i]
expert_out = expert(x)
outputs.append(expert_out)
outs = torch.cat(outputs, dim=0)
reshaped_topk_weight = topk_weight.squeeze(0).unsqueeze(-1).to(outs.dtype)
final_out = (outs * reshaped_topk_weight).sum(dim=0, keepdim=True)
return final_out


Expand All @@ -292,7 +314,7 @@ def deepseek_moe_forward(self, hidden_states: torch.Tensor):
flat_topk_idx = topk_idx.view(-1)
if not self.training:
# IPEX-LLM OPT start : add special moe_infer implementation for decoding
if topk_idx.size(0) == 1:
if topk_idx.size(0) == 1 and self.ep_size == 1:
y = moe_infer_decode(self, hidden_states, topk_idx, topk_weight)
else:
y = self.moe_infer(hidden_states, topk_idx, topk_weight)
Expand Down

0 comments on commit 3cf2cbe

Please sign in to comment.