Skip to content

Commit

Permalink
fix replication on router and checkpointing
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Aug 23, 2024
1 parent 6ec816e commit f945f59
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
import os

# Third Party
from accelerate.logging import get_logger
from accelerate.utils.constants import (
FSDP_MODEL_NAME,
OPTIMIZER_NAME,
)
from torch.distributed.checkpoint.state_dict import (
get_state_dict,
set_state_dict,
)
import torch.distributed.checkpoint as dcp

logger = get_logger(__name__)

MODEL_INDEX = None

def save_fsdp_model(
fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False
):

# pylint: disable=global-statement
global MODEL_INDEX
MODEL_INDEX = model_index

def save_fsdp_optimizer(
fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0
):
(model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer)

ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
os.makedirs(ckpt_model, exist_ok=True)
logger.info(f"Saving model to {ckpt_model}")
dcp.save({"model": model_state_dict}, checkpoint_id=ckpt_model)
logger.info(f"Model saved to {ckpt_model}")

ckpt_opt = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
os.makedirs(ckpt_opt, exist_ok=True)
logger.info(f"Saving Optimizer state to {ckpt_opt}")
dcp.save({"optimizer": optimizer_state_dict}, checkpoint_id=ckpt_opt)
logger.info(f"Optimizer state saved in {ckpt_opt}")


# accelerate.utils.fsdp_utils.py
def load_fsdp_model(
fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False
):
# pylint: disable=global-statement
global MODEL_INDEX
MODEL_INDEX = model_index


# accelerate.utils.fsdp_utils.py
def load_fsdp_optimizer(
fsdp_plugin,
accelerator,
optimizer,
model,
input_dir,
optimizer_index=0,
adapter_only=False,
):

model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
ckpt_model = os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
dcp.load({"model": model_state_dict}, checkpoint_id=ckpt_model)
ckpt_opt = os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
dcp.load({"optimizer": optimizer_state_dict}, checkpoint_id=ckpt_opt)
set_state_dict(
model,
optimizer,
model_state_dict=model_state_dict,
optim_state_dict=optimizer_state_dict,
)

# HACK for now
# - if seems that if params is empty, then the loading has someo
# problems
# - so for now, we just dump some random defaults
for group in optimizer.param_groups:
if len(group["params"]) == 0:
group["betas"] = (0.9, 0.999)
group["lr"] = 0.0
group["initial_lr"] = 0.0
group["eps"] = 1e-8
group["weight_decay"] = 0.0
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# utilities to update megablocks to register various things
# e.g, the MLP_v2 that handles gate, up, down projections

# Third Party
import torch
import torch.nn.functional as F


# this function ensures that the megablocks packaged is configured to use
Expand All @@ -17,8 +32,9 @@ def update_mlp_registry(
# Third Party
# pylint: disable=import-error,import-outside-toplevel
from megablocks.layers.dmlp_registry import _REGISTRY
from megablocks.layers.mlp import SparseMLP
from megablocks.layers.mlp import SparseMLP, resolve_dtensor
from megablocks.layers.moe import ParallelMLP
from megablocks.layers.router import LearnedRouter, _uniform_expert_assignment

# Local
from .sparse_mlp2 import SparseMLPv2
Expand Down Expand Up @@ -65,3 +81,53 @@ def forward(self, x, scores, expert_weights, top_experts):
# a hardcoded modification to the megablocks package more than a
# patch.
ParallelMLP.forward = forward

# for the router
# - need to resolve the dtensor since we had replicated the router
# weights
def forward_router(self, x):
if self.training and self.args.moe_jitter_eps is not None:
x = x * self.jitter(x)

_weight = resolve_dtensor(self.layer.weight)
_bias = None if self.layer.bias is None else resolve_dtensor(self.layer.bias)
# pylint: disable=not-callable
scores = F.linear(x.view(-1, x.shape[-1]), _weight, _bias).softmax(dim=-1)
expert_weights, expert_indices = self._top_k(scores)
if self.args.moe_normalize_expert_weights:
expert_weights = expert_weights / torch.norm(
expert_weights,
p=self.args.moe_normalize_expert_weights,
dim=-1,
keepdim=True,
)

expert_indices = (
_uniform_expert_assignment(
expert_indices,
self.args.moe_num_experts,
)
if self.args.uniform_expert_assignment
else expert_indices
)
return scores, expert_weights, expert_indices

# replace the forward function in the router
# - same as above
LearnedRouter.forward = forward_router

# Third Party
from fms_acceleration.model_patcher import patch_target_module

# Local
from .checkpoint_utils import (
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
save_fsdp_optimizer,
)

patch_target_module("transformers.trainer.save_fsdp_model", save_fsdp_model)
patch_target_module("transformers.trainer.save_fsdp_optimizer", save_fsdp_optimizer)
patch_target_module("transformers.trainer.load_fsdp_model", load_fsdp_model)
patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer)
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from collections import defaultdict
from contextlib import ExitStack
Expand Down Expand Up @@ -196,13 +210,15 @@ def load_sharded_experts_onto_device(
# - concat on dim 0 and distribute
# - cast to the correct dtype for the module
param = torch.concat(data, dim=DIM_EXPERT).to(mod_dtype)
if KEY_DMOE_ROUTER not in weight_name:
param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, placements)
)
else:
# - do not shard the router but load onto device as well
param = torch.nn.Parameter(param.to(torch.cuda.current_device()))

_placements = placements
if KEY_DMOE_ROUTER in weight_name:
# - the router needs to be replicated
_placements = [Replicate() for _ in range(len(placements))]

param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, _placements)
)

# register the sharded parameter onto the megablocks.dmoe
mod.register_parameter(name, param)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Third Party
import torch

Expand Down
2 changes: 1 addition & 1 deletion scripts/benchmarks/accelerate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fsdp_config:
# 3 is NO_SHARD, effectively disabling FSDP
# 4, 5 are HYBRID_ modes for multi-node training only.

fsdp_state_dict_type: FULL_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3)
fsdp_state_dict_type: SHARDED_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3)
# 2 is LOCAL_STATE_DICT where parameters are still flattened
# 3 is efficient, but requires know-how to use the shared checkpoint.

Expand Down

0 comments on commit f945f59

Please sign in to comment.