Skip to content

Commit

Permalink
fix replication on router and checkpointing
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Aug 23, 2024
1 parent 6ec816e commit 75bded5
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# Third Party
import torch
import torch.nn.functional as F


# this function ensures that the megablocks packaged is configured to use
Expand All @@ -17,8 +18,9 @@ def update_mlp_registry(
# Third Party
# pylint: disable=import-error,import-outside-toplevel
from megablocks.layers.dmlp_registry import _REGISTRY
from megablocks.layers.mlp import SparseMLP
from megablocks.layers.mlp import SparseMLP, resolve_dtensor
from megablocks.layers.moe import ParallelMLP
from megablocks.layers.router import LearnedRouter, _uniform_expert_assignment

# Local
from .sparse_mlp2 import SparseMLPv2
Expand Down Expand Up @@ -65,3 +67,37 @@ def forward(self, x, scores, expert_weights, top_experts):
# a hardcoded modification to the megablocks package more than a
# patch.
ParallelMLP.forward = forward

# for the router
# - need to resolve the dtensor since we had replicated the router
# weights
def forward_router(self, x):
if self.training and self.args.moe_jitter_eps is not None:
x = x * self.jitter(x)

_weight = resolve_dtensor(self.layer.weight)
_bias = None if self.layer.bias is None else resolve_dtensor(self.layer.bias)
# pylint: disable=not-callable
scores = F.linear(x.view(-1, x.shape[-1]), _weight, _bias).softmax(dim=-1)
expert_weights, expert_indices = self._top_k(scores)
if self.args.moe_normalize_expert_weights:
expert_weights = expert_weights / torch.norm(
expert_weights,
p=self.args.moe_normalize_expert_weights,
dim=-1,
keepdim=True,
)

expert_indices = (
_uniform_expert_assignment(
expert_indices,
self.args.moe_num_experts,
)
if self.args.uniform_expert_assignment
else expert_indices
)
return scores, expert_weights, expert_indices

# replace the forward function in the router
# - same as above
LearnedRouter.forward = forward_router
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,15 @@ def load_sharded_experts_onto_device(
# - concat on dim 0 and distribute
# - cast to the correct dtype for the module
param = torch.concat(data, dim=DIM_EXPERT).to(mod_dtype)
if KEY_DMOE_ROUTER not in weight_name:
param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, placements)
)
else:
# - do not shard the router but load onto device as well
param = torch.nn.Parameter(param.to(torch.cuda.current_device()))

_placements = placements
if KEY_DMOE_ROUTER in weight_name:
# - the router needs to be replicated
_placements = [Replicate() for _ in range(len(placements))]

param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, _placements)
)

# register the sharded parameter onto the megablocks.dmoe
mod.register_parameter(name, param)
Expand Down
2 changes: 1 addition & 1 deletion scripts/benchmarks/accelerate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fsdp_config:
# 3 is NO_SHARD, effectively disabling FSDP
# 4, 5 are HYBRID_ modes for multi-node training only.

fsdp_state_dict_type: FULL_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3)
fsdp_state_dict_type: SHARDED_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3)
# 2 is LOCAL_STATE_DICT where parameters are still flattened
# 3 is efficient, but requires know-how to use the shared checkpoint.

Expand Down

0 comments on commit 75bded5

Please sign in to comment.