Skip to content

Commit

Permalink
handle requires_grad in shard_moe
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Aug 29, 2024
1 parent feaeaa5 commit 7a6cdd8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def model_loader(self, model_name: str, **kwargs):
shared_mesh_dim=self._shard_along_dp,
router_name=self._gate_module_name,
expert_name=self._experts_module_name,
mixed_precision=False, # Currently this is hardcoded to OFF
mixed_precision=False, # Currently this is hardcoded to OFF
)

# NOTE: there is currently no good way to get the mixed precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def load_sharded_experts_onto_device(
else:
mod_dtype = getattr(mod, name).dtype

requires_grad = getattr(mod, name).requires_grad

# the megablocks dmoe experts the expert features to be on DIM_EXPERT.
# - concat on dim 0 and distribute
# - cast to the correct dtype for the module
Expand All @@ -227,14 +229,16 @@ def load_sharded_experts_onto_device(
_placements = [Replicate() for _ in range(len(placements))]

param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, _placements)
distribute_tensor(param, device_mesh, _placements),
requires_grad=requires_grad,
)

# register the sharded parameter onto the megablocks.dmoe
mod.register_parameter(name, param)

upcasted = ", ".join(sorted(upcasted))
warnings.warn(f"Mixed precision turned on, upcasted MoE parameters: {upcasted}")
if mixed_precision:
upcasted = ", ".join(sorted(upcasted))
warnings.warn(f"Mixed precision turned on, upcasted MoE parameters: {upcasted}")


def shard_moe(
Expand Down

0 comments on commit 7a6cdd8

Please sign in to comment.