Skip to content

Commit

Permalink
handle requires_grad in shard_moe
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Aug 29, 2024
1 parent feaeaa5 commit 1b6d7a7
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def load_sharded_experts_onto_device(
else:
mod_dtype = getattr(mod, name).dtype

requires_grad = getattr(mod, name).requires_grad

# the megablocks dmoe experts the expert features to be on DIM_EXPERT.
# - concat on dim 0 and distribute
# - cast to the correct dtype for the module
Expand All @@ -227,7 +229,8 @@ def load_sharded_experts_onto_device(
_placements = [Replicate() for _ in range(len(placements))]

param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, _placements)
distribute_tensor(param, device_mesh, _placements),
requires_grad=requires_grad
)

# register the sharded parameter onto the megablocks.dmoe
Expand Down

0 comments on commit 1b6d7a7

Please sign in to comment.