diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py new file mode 100644 index 00000000..036e083f --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py @@ -0,0 +1,102 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import os + +# Third Party +from accelerate.logging import get_logger +from accelerate.utils.constants import ( + FSDP_MODEL_NAME, + OPTIMIZER_NAME, +) +from torch.distributed.checkpoint.state_dict import ( + get_state_dict, + set_state_dict, +) +import torch.distributed.checkpoint as dcp + +logger = get_logger(__name__) + +MODEL_INDEX = None + +def save_fsdp_model( + fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False +): + + # pylint: disable=global-statement + global MODEL_INDEX + MODEL_INDEX = model_index + +def save_fsdp_optimizer( + fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0 +): + (model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer) + + ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + os.makedirs(ckpt_model, exist_ok=True) + logger.info(f"Saving model to {ckpt_model}") + dcp.save({"model": model_state_dict}, checkpoint_id=ckpt_model) + logger.info(f"Model saved to {ckpt_model}") + + ckpt_opt = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") + os.makedirs(ckpt_opt, exist_ok=True) + logger.info(f"Saving Optimizer state to {ckpt_opt}") + dcp.save({"optimizer": optimizer_state_dict}, checkpoint_id=ckpt_opt) + logger.info(f"Optimizer state saved in {ckpt_opt}") + + +# accelerate.utils.fsdp_utils.py +def load_fsdp_model( + fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False +): + # pylint: disable=global-statement + global MODEL_INDEX + MODEL_INDEX = model_index + + +# accelerate.utils.fsdp_utils.py +def load_fsdp_optimizer( + fsdp_plugin, + accelerator, + optimizer, + model, + input_dir, + optimizer_index=0, + adapter_only=False, +): + + model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) + ckpt_model = os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + dcp.load({"model": model_state_dict}, checkpoint_id=ckpt_model) + ckpt_opt = os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") + dcp.load({"optimizer": optimizer_state_dict}, checkpoint_id=ckpt_opt) + set_state_dict( + model, + optimizer, + model_state_dict=model_state_dict, + optim_state_dict=optimizer_state_dict, + ) + + # HACK for now + # - if seems that if params is empty, then the loading has someo + # problems + # - so for now, we just dump some random defaults + for group in optimizer.param_groups: + if len(group["params"]) == 0: + group["betas"] = (0.9, 0.999) + group["lr"] = 0.0 + group["initial_lr"] = 0.0 + group["eps"] = 1e-8 + group["weight_decay"] = 0.0 diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py index 31e73489..1234d383 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py @@ -1,8 +1,23 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # utilities to update megablocks to register various things # e.g, the MLP_v2 that handles gate, up, down projections # Third Party import torch +import torch.nn.functional as F # this function ensures that the megablocks packaged is configured to use @@ -17,8 +32,9 @@ def update_mlp_registry( # Third Party # pylint: disable=import-error,import-outside-toplevel from megablocks.layers.dmlp_registry import _REGISTRY - from megablocks.layers.mlp import SparseMLP + from megablocks.layers.mlp import SparseMLP, resolve_dtensor from megablocks.layers.moe import ParallelMLP + from megablocks.layers.router import LearnedRouter, _uniform_expert_assignment # Local from .sparse_mlp2 import SparseMLPv2 @@ -65,3 +81,53 @@ def forward(self, x, scores, expert_weights, top_experts): # a hardcoded modification to the megablocks package more than a # patch. ParallelMLP.forward = forward + + # for the router + # - need to resolve the dtensor since we had replicated the router + # weights + def forward_router(self, x): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + _weight = resolve_dtensor(self.layer.weight) + _bias = None if self.layer.bias is None else resolve_dtensor(self.layer.bias) + # pylint: disable=not-callable + scores = F.linear(x.view(-1, x.shape[-1]), _weight, _bias).softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) + if self.args.uniform_expert_assignment + else expert_indices + ) + return scores, expert_weights, expert_indices + + # replace the forward function in the router + # - same as above + LearnedRouter.forward = forward_router + + # Third Party + from fms_acceleration.model_patcher import patch_target_module + + # Local + from .checkpoint_utils import ( + load_fsdp_model, + load_fsdp_optimizer, + save_fsdp_model, + save_fsdp_optimizer, + ) + + patch_target_module("transformers.trainer.save_fsdp_model", save_fsdp_model) + patch_target_module("transformers.trainer.save_fsdp_optimizer", save_fsdp_optimizer) + patch_target_module("transformers.trainer.load_fsdp_model", load_fsdp_model) + patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py index 677bdabc..07971fc7 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py @@ -1,3 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Standard from collections import defaultdict from contextlib import ExitStack @@ -196,13 +210,15 @@ def load_sharded_experts_onto_device( # - concat on dim 0 and distribute # - cast to the correct dtype for the module param = torch.concat(data, dim=DIM_EXPERT).to(mod_dtype) - if KEY_DMOE_ROUTER not in weight_name: - param = torch.nn.Parameter( - distribute_tensor(param, device_mesh, placements) - ) - else: - # - do not shard the router but load onto device as well - param = torch.nn.Parameter(param.to(torch.cuda.current_device())) + + _placements = placements + if KEY_DMOE_ROUTER in weight_name: + # - the router needs to be replicated + _placements = [Replicate() for _ in range(len(placements))] + + param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, _placements) + ) # register the sharded parameter onto the megablocks.dmoe mod.register_parameter(name, param) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py index 8d3871a5..439eaaf0 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py @@ -1,3 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Third Party import torch diff --git a/scripts/benchmarks/accelerate.yaml b/scripts/benchmarks/accelerate.yaml index f70d74fa..f3908470 100644 --- a/scripts/benchmarks/accelerate.yaml +++ b/scripts/benchmarks/accelerate.yaml @@ -30,7 +30,7 @@ fsdp_config: # 3 is NO_SHARD, effectively disabling FSDP # 4, 5 are HYBRID_ modes for multi-node training only. - fsdp_state_dict_type: FULL_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3) + fsdp_state_dict_type: SHARDED_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3) # 2 is LOCAL_STATE_DICT where parameters are still flattened # 3 is efficient, but requires know-how to use the shared checkpoint.