Skip to content

Commit

Permalink
Fix DeepSeek-V2 expert-parallelism crash due to indexing error (#1765)
Browse files Browse the repository at this point in the history
  • Loading branch information
skavulya authored Feb 18, 2025
1 parent 2dbbb46 commit d80283e
Showing 1 changed file with 16 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from transformers import PretrainedConfig
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.generation import GenerationMixin
from transformers.integrations.deepspeed import is_deepspeed_available
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
Expand All @@ -58,6 +59,7 @@
logging,
)

from ....distributed.tensorparallel import _all_reduce
from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask
from .configuration_deepseek_v2 import DeepseekV2Config

Expand Down Expand Up @@ -626,8 +628,8 @@ def __init__(self, config):
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size)

self.expert_slice = math.ceil(config.n_routed_experts / SLICE_MAX_EXPERT)
self.expert_chunk = self.config.n_routed_experts // self.expert_slice
self.expert_slice = math.ceil(self.experts_per_rank / SLICE_MAX_EXPERT)
self.expert_chunk = math.ceil(self.experts_per_rank / self.expert_slice)

def forward(self, hidden_states):
identity = hidden_states
Expand Down Expand Up @@ -667,23 +669,12 @@ def forward(self, hidden_states):
(batch * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
for idx in range(self.expert_slice):
experts_range = range(self.expert_chunk)
gate_proj_list = [
self.experts[idx * self.expert_chunk + i].gate_proj.weight.squeeze()
for i in experts_range
if self.experts[idx * self.expert_chunk + i] is not None
]
down_proj_list = [
self.experts[idx * self.expert_chunk + i].down_proj.weight.squeeze()
for i in experts_range
if self.experts[idx * self.expert_chunk + i] is not None
]
up_proj_list = [
self.experts[idx * self.expert_chunk + i].up_proj.weight.squeeze()
for i in experts_range
if self.experts[idx * self.expert_chunk + i] is not None
]

experts_min = (self.ep_rank * self.experts_per_rank) + (self.expert_chunk * idx)
experts_max = min((experts_min + self.expert_chunk), (self.ep_rank + 1) * self.experts_per_rank)
experts_range = range(experts_min, experts_max)
gate_proj_list = [self.experts[i].gate_proj.weight.squeeze() for i in experts_range]
down_proj_list = [self.experts[i].down_proj.weight.squeeze() for i in experts_range]
up_proj_list = [self.experts[i].up_proj.weight.squeeze() for i in experts_range]
hidden_states_slice = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=topk_idx,
Expand All @@ -693,13 +684,15 @@ def forward(self, hidden_states):
w3=down_proj_list,
permuted_weights=True,
activation="silu",
experts_min=(self.expert_chunk * idx),
experts_max=(self.expert_chunk * (idx + 1) - 1),
experts_min=experts_min,
experts_max=experts_max - 1,
)
final_hidden_states = final_hidden_states + hidden_states_slice
htcore.mark_step()

if is_deepspeed_available():
if self.ep_size > 1:
final_hidden_states = _all_reduce(final_hidden_states)
elif is_deepspeed_available():
from deepspeed import comm as dist

if dist.is_initialized():
Expand Down Expand Up @@ -1774,7 +1767,7 @@ def forward(
)


class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down

0 comments on commit d80283e

Please sign in to comment.