Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
sambhavnoobcoder committed Jan 22, 2025
1 parent 4d2eb40 commit df010bf
Showing 1 changed file with 7 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,16 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple:
"""
# Create a copy for applying jitter noise
routing_states = hidden_states.clone()

if self.training and self.jitter_noise > 0:
# Apply jitter noise only to the routing copy
routing_states *= torch.empty_like(routing_states).uniform_(
1.0 - self.jitter_noise,
1.0 + self.jitter_noise
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
)

# Use jittered states for routing decisions
router_logits = self.classifier(routing_states)

router_probs = nn.functional.softmax(router_logits, dim=-1)
expert_index = torch.argmax(router_probs, dim=-1)
expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
Expand Down Expand Up @@ -299,7 +298,7 @@ def forward(self, hidden_states):
"""
# Step 1: Get the router_mask from the router as wel as the probabilities
expert_hidden_states = hidden_states.clone()

# Get router outputs using potentially jittered states
router_mask, router_probs, router_logits = self.router(hidden_states)
expert_index = torch.argmax(router_mask, dim=-1)
Expand All @@ -308,12 +307,12 @@ def forward(self, hidden_states):
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.
# Use original hidden states for expert processing
next_states = expert_hidden_states.clone()

router_mask = router_mask.bool()
batch_size, seq_len, num_experts = router_mask.shape
idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0)
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[0].tolist()

for idx in idx_mask:
next_states[router_mask[:, :, idx]] = getattr(self.experts, f"expert_{idx}")(
expert_hidden_states[router_mask[:, :, idx]]
Expand Down

0 comments on commit df010bf

Please sign in to comment.