From 6833cc6b22d27d777fe52e0bbb36b305e9a9e0f1 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 16 Aug 2024 09:18:16 +0000 Subject: [PATCH 01/21] drafting megablocks plugin Signed-off-by: Yu Chin Fabian Lim --- plugins/accelerated-moe/README.md | 13 + .../accelerated-moe/configs/megablocks.yaml | 9 + plugins/accelerated-moe/pyproject.toml | 29 ++ .../src/fms_acceleration_moe/__init__.py | 16 ++ .../framework_plugin_megablocks.py | 117 ++++++++ .../megablocks_utils/__init__.py | 13 + .../megablocks_utils/config_utils.py | 37 +++ .../megablocks_utils/shard_moe_utils.py | 270 ++++++++++++++++++ .../megablocks_utils/sparse_mlp2.py | 88 ++++++ 9 files changed, 592 insertions(+) create mode 100644 plugins/accelerated-moe/README.md create mode 100644 plugins/accelerated-moe/configs/megablocks.yaml create mode 100644 plugins/accelerated-moe/pyproject.toml create mode 100644 plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py create mode 100644 plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py create mode 100644 plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/__init__.py create mode 100644 plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py create mode 100644 plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py create mode 100644 plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md new file mode 100644 index 00000000..8ad79f9a --- /dev/null +++ b/plugins/accelerated-moe/README.md @@ -0,0 +1,13 @@ +# FMS Accelerattion for Mixture-of-Experts + +This library contains plugins to accelerate finetuning with the following optimizations: +1. Expert-Parallel MoE with Megablocks + + +## Megablocks Dependencies + +Currently databricks megablocks does not have a PyPi repository and does not have a proper release, so we have to install from the github repository as below. Please note that installing from github will require CUDA Toolkit to build. + +``` +pip install git+https://github.com/databricks/megablocks.git +``` \ No newline at end of file diff --git a/plugins/accelerated-moe/configs/megablocks.yaml b/plugins/accelerated-moe/configs/megablocks.yaml new file mode 100644 index 00000000..14465d96 --- /dev/null +++ b/plugins/accelerated-moe/configs/megablocks.yaml @@ -0,0 +1,9 @@ +training: + + # mixture-of-experts configurations + moe: + + # expert-parallel for MoE + megablocks: + + dummy: 1 diff --git a/plugins/accelerated-moe/pyproject.toml b/plugins/accelerated-moe/pyproject.toml new file mode 100644 index 00000000..b100bc1e --- /dev/null +++ b/plugins/accelerated-moe/pyproject.toml @@ -0,0 +1,29 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "fms-acceleration-moe" +version = '0.0.1' +description = "FMS Acceleration Plugin for Mixture-of-Experts" +authors = [ + {name = "Fabian Lim", email = "flim@sg.ibm.com"}, +] +license = {text = "Apache-2.0"} +readme = "README.md" +requires-python = "~=3.9" +keywords = ['fms-hf-tuning', 'acceleration', 'mixture-of-experts', 'megablocks'] +classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] + +[tool.hatch.build.targets.wheel] +only-include = ["src/fms_acceleration_moe"] + +[tool.hatch.build.targets.wheel.sources] +"src" = "" diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py new file mode 100644 index 00000000..eca7a2c5 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py @@ -0,0 +1,16 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .framework_plugin_megablocks import MegablocksMoEAccelerationPlugin \ No newline at end of file diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py new file mode 100644 index 00000000..aa9dda80 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py @@ -0,0 +1,117 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import Dict, Tuple +import torch +import warnings + +# Third Party +from fms_acceleration import AccelerationPlugin +from peft import LoraConfig +from transformers import TrainingArguments, AutoModelForCausalLM + + +class MegablocksMoEAccelerationPlugin(AccelerationPlugin): + + restricted_model_archs = {"MixtralForCausalLM"} + require_packages = {"megablocks"} + + def __init__(self, configurations: Dict[str, Dict]): + super().__init__(configurations) + + # args + self._dummy = self._check_config_and_maybe_check_values( + key="training.moe.megablocks", + values=["dummy"], + ) + + @property + def requires_custom_loading(self): + return True + + def model_loader(self, model_name: str, **kwargs): + # guarded + from .megablocks_utils.config_utils import update_mlp_registry + from megablocks_utils.shard_moe_utils import shard_moe, get_moe_kwargs + + # this one does a forward patching on MLP, but needs to be fixed + # properly as the load balancing loss is currently not properly + # handled + update_mlp_registry() + + # get additional parameters + torch_dtype = kwargs.get("torch_dtype", torch.float32) + + # load the model + model = AutoModelForCausalLM.from_pretrained( + model_name, **kwargs + ) + + rank, world_size = 0, 1 + if torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + # NOTE: or should we do a silent fallback + raise AssertionError( + "Megablocks expert parallel only works for distributed training." + ) + + # FIXME: have some way to search out the MOE block + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + # FIXME: the dtype checks below are too brittle + dp_mesh = shard_moe( + model, + MixtralSparseMoeBlock, + checkpoint_name_or_path=model_name, + rank=rank, + world_size=world_size, + ep_size=world_size, # FIXME: this can be passed in? + moe_kwargs=get_moe_kwargs( + model.config, + has_bias=False, # FIXME: is this true in general? + fp16=torch_dtype == torch.float16, + bf16=torch_dtype == torch.bfloat16, + ), + ) + + def get_callbacks_and_ready_for_train( + self, model: torch.nn.Module = None, accelerator=None + ): + + callbacks = [] + if ( + accelerator is not None + and getattr(accelerator.state, "fsdp_plugin", None) is not None + ): + # lora_adapters_switch_ddp_from_fsdp( + # [mod for mod in model.modules() if isinstance(mod, LoraLayer)], + # accelerator.state.fsdp_plugin, + # ) + # FIXME: should be + accelerator.state.fsdp_plugin.ignored_modules = [ + layer.block_sparse_moe for layer in model.model.layers + ] + + return callbacks + +# register +AccelerationPlugin.register_plugin( + MegablocksMoEAccelerationPlugin, + configuration_and_paths=[ + "training.moe.megablocks", + ], +) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/__init__.py new file mode 100644 index 00000000..38a9531e --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py new file mode 100644 index 00000000..9ef19eb6 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py @@ -0,0 +1,37 @@ +# utilities to update megablocks to register the MLP_v2 that +# handles gate, up, down projections + +from megablocks.layers.dmlp_registry import _REGISTRY + + +# from megablocks.layers import mlp +from .sparse_mlp2 import SparseMLPv2 +from megablocks.layers.moe import ParallelMLP +import torch + +def update_mlp_registry(): + # patch the registry to point to our v2 + _REGISTRY['mlp']['sparse'] = SparseMLPv2 + + def forward(self, x, scores, expert_weights, top_experts): + in_shape = x.size() + + # Compute the experts. + x, _ = self.forward_fn(x, expert_weights, top_experts) + + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + + # in this case we should be returning the router + # logits out of the MoeE forward. However, since + # the way the code is written now, it si difficult + # to extract these logits out, so at the moment, + # we return None as the placeholder. + return x, None + + # patch the forward function + ParallelMLP.forward = forward + \ No newline at end of file diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py new file mode 100644 index 00000000..f845c80f --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py @@ -0,0 +1,270 @@ + +import torch +from transformers import TrainingArguments, PretrainedConfig +from typing import Union, Dict, List, Type +from torch.distributed._tensor import Placement, Replicate, Shard, distribute_tensor +from torch.distributed._tensor.device_mesh import init_device_mesh +import os +from tqdm import tqdm + +from safetensors import safe_open +import json, re +from collections import defaultdict + +from accelerate import init_empty_weights + +from contextlib import ExitStack + +FILE_SAFETENSOR_INDEX = 'model.safetensors.index.json' +KEY_DATA_PARALLEL = 'data_parallel' +KEY_EXPERT_PARALLEL = 'expert_parallel' +DIM_EXPERT = 0 + +KEY_ROUTER = 'router.layer.weight' +KEY_EXPERTS = 'experts.mlp' + +def get_moe_kwargs( + config: PretrainedConfig, + has_bias: bool = False, # if the MOE has bias + fp16: bool = False, + bf16: bool = False, +): + return { + "hidden_size": config.hidden_size, + "ffn_hidden_size": config.intermediate_size, + "moe_num_experts": config.num_local_experts, + "moe_top_k": config.num_experts_per_tok, + "moe_expert_model_parallelism": True, + "memory_optimized_mlp": False, + "bias": has_bias, + "moe_normalize_expert_weights": True, + "fp16": fp16, + "bf16": bf16, + } + +# trick to get the resolved cache file to acccess the safetensor +# NOTE: this does not work if _dict_from_json_file, like GGUF files +def get_resolved_checkpoint_location(model_name_or_path: str): + + result = None + _old_func = PretrainedConfig._dict_from_json_file + def _dict_from_json_file(resolved_config_file): + nonlocal result + result = resolved_config_file + return _old_func(resolved_config_file) + + # make a hook and restrive + PretrainedConfig._dict_from_json_file = _dict_from_json_file + PretrainedConfig.from_pretrained(model_name_or_path) + PretrainedConfig._dict_from_json_file = _old_func + return os.path.dirname(result) + +# see https://github.com/mosaicml/llm-foundry/blob/main/tests/models/layers/test_dmoe.py +# for a basic example + +# this one is called for one layer +# e.g., 'model.layers.0, block_sparse_moe +def get_router_experts_sharded_safetensor( + weight_map: Dict, + prefix: str, # e.g., 'model.layers.0, + instance_name: str, # e.g., block_sparse_moe + router_name: str = 'gate', + expert_name: str = 'experts' +): + # insert in order + def _insert(L: List, i: int, v): + n = len(L) + if i < n: + L[i] = v + return + + n = i - n + 1 + while n > 0: + L.append(None) + n -= 1 + L[i] = v + + # state dict -> weights + # 'router.layer.weight': [(k, file),...] + # `experts.mlp.w1`: [...] + _map = defaultdict(list) + prefix = f"{prefix}.{instance_name}." + for k, stfile in weight_map.items(): + if not k.startswith(prefix): + continue + + # e.g. after replacement we get + # - gate.weight + # - experts.0.w1.weight + rel_k = k.replace(prefix, "") + m = re.match( + f'({router_name}|{expert_name})\.?(\d+)?\.?(\w+)?\.weight', + rel_k + ) + if m is None: + raise ValueError( + f"Unable to handle key '{k}' with provided router_name " + f"'{router_name}' or expert_name '{expert_name}'" + ) + if m.group(1) == router_name: + _map[KEY_ROUTER].append((k, stfile)) + elif m.group(1) == expert_name: + index = int(m.group(2)) + mod = m.group(3) + # expert_map[stfile].append((mod, index, k)) + _insert(_map[f'{KEY_EXPERTS}.{mod}'], index, (k, stfile)) + + if len(_map) == 0: + raise ValueError( + f"Could not get safetensor map for '{prefix}' and '{instance_name}'" + ) + + return _map + +# for megablocks.SparseMLPv2 +# assign dmoe with mlp_v2 +# settings is: +# experts.mlp.w1: [(k, file)] +def assign_mlp_v2_weights( + dmoe: torch.nn.Module, + directory: str, + settings: Dict, + device_mesh, + placements, +): + # typically they all should be same file + with ExitStack() as stack: + files = {} + for _, vs in settings.items(): + for _, fi in vs: + if fi not in files: + files[fi] = stack.enter_context( + safe_open(os.path.join(directory, fi), framework='pt', device='cpu') + ) + + # go by one weight + for weight_name, vs in settings.items(): + data = [] + for k, fi in vs: + T = files[fi].get_tensor(k) + if 'experts' in k: + if T.shape[1] > T.shape[0]: + T = T.t() + data.append(T) + + # concat on dim 0 and distribute + param = torch.concat(data, dim=DIM_EXPERT) + if KEY_ROUTER not in weight_name: + param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, placements) + ) + else: + param = torch.nn.Parameter( + param.to(torch.cuda.current_device()) + ) + name = weight_name.split('.') + path, name = ".".join(name[:-1]), name[-1] + mod = dmoe.get_submodule(path) + mod.register_parameter(name, param) + +def shard_moe( + model: torch.nn.Module, + moe_cls: Union[str, Type], + checkpoint_name_or_path: str, + rank: int, + world_size: int, + ep_size: int, + moe_kwargs: Dict, + device_type: str = 'cuda', + key_dp: str = KEY_DATA_PARALLEL, + key_ep: str = KEY_EXPERT_PARALLEL, +): + # guarded import + from megablocks.layers import dmoe, arguments, mpu + + assert ep_size > 1, "this function is used for sharding moe" + + # this function will shard the MOE on this rank + device = torch.device(f'cuda:{rank}') + dp_size = world_size // ep_size + + if dp_size == 1: + # in this case we will have a 1D mesh and collapse the + # expert parallel with data_parallel + + device_mesh = init_device_mesh( + device_type, + (ep_size,), + mesh_dim_names=(key_dp,), + ) + key_ep = key_dp + placements: List[Placement] = [Shard(DIM_EXPERT)] + else: + # in this case it will be a 2D mesh + device_mesh = init_device_mesh( + device_type, + (dp_size, ep_size), + mesh_dim_names=(key_dp, key_ep), + ) + placements: List[Placement] = [Replicate(), Shard(DIM_EXPERT)] + + mp_dmoe_args = arguments.Arguments( + **moe_kwargs, device=device, + expert_parallel_group=device_mesh[key_ep].get_group(0) + ) + + assert mp_dmoe_args.moe_num_experts % world_size == 0, \ + "number of moe experts not divisible by world_size" + + # for all the MoE related params, e.g., gate, experts + # get a dictc + # parent_mod: (child_instance_name, [list of fqdn keys]) + found = {} + for name, mod in model.named_modules(): + name = name.split('.') + parent, child = ".".join(name[:-1]), name[-1] + if isinstance(mod, moe_cls): + found[parent] = ( + child, + [ # all params, including childs' + f'{parent}.{child}.{n}' + for n, _ in mod.named_parameters() + ] + ) + + # NOTE: for now we only support sharded safetensors + # - most MOE models should be used using this checkpoint format + try: + loc = get_resolved_checkpoint_location(checkpoint_name_or_path) + with open(os.path.join(loc, FILE_SAFETENSOR_INDEX)) as f: + index = json.load(f) + + # e.g., prefix: 'model.layers.0', + # module_name: 'block_sparse_moe' + for prefix, (module_name, relevant_keys) in tqdm( + found.items(), + disable=torch.distributed.get_rank() > 0, + desc='Sharding MoE' + ): + settings = get_router_experts_sharded_safetensor( + index['weight_map'], prefix, module_name, + ) + with init_empty_weights(): + mp_dmoe = dmoe.dMoE(mp_dmoe_args) # drop in replacement for now + + assign_mlp_v2_weights( + mp_dmoe, loc, settings, + device_mesh, placements + ) + parent = model.get_submodule(prefix) + setattr(parent, module_name, mp_dmoe) + + except ValueError as e: + raise ValueError( + f"Unable to load checkpoint_path '{checkpoint_name_or_path}'. " + "Currently only support safetensor checkpoints. " + f": {e}" + ) + + + return device_mesh[key_dp] \ No newline at end of file diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py new file mode 100644 index 00000000..ac36245e --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py @@ -0,0 +1,88 @@ +import torch +from megablocks.layers import common +from megablocks.layers.arguments import Arguments +from megablocks.layers import mpu +from megablocks.layers.mlp import ( + create_dmoe_expert_weights, scale_gradient, + resolve_dtensor +) +import torch.nn.functional as F +import stk + +class SparseMLPv2(torch.nn.Module): + + def __init__(self, args : Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = ( + (mpu.experts_per_rank(args) * mpu.features_per_rank(args)) // + mpu.get_weight_parallel_world_size(args) + ) + + self.w1 = torch.nn.Parameter(torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args))) + self.w2 = torch.nn.Parameter(torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args))) + self.w3 = torch.nn.Parameter(torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args))) + + with torch.no_grad(): + self.w1.copy_(create_dmoe_expert_weights( + args, args.moe_num_experts, args.ffn_hidden_size, + args.hidden_size, args.init_method)) + self.w2.copy_(create_dmoe_expert_weights( + args, args.moe_num_experts, args.ffn_hidden_size, + args.hidden_size, args.output_layer_init_method)) + self.w3.copy_(create_dmoe_expert_weights( + args, args.moe_num_experts, args.ffn_hidden_size, + args.hidden_size, args.output_layer_init_method)) + + self._should_set_parallelism_attribute = ( + args.moe_expert_model_parallelism or args.moe_weight_parallelism) + mpu.set_expert_model_parallel_attributes( + self.w1, self._should_set_parallelism_attribute) + mpu.set_expert_model_parallel_attributes( + self.w2, self._should_set_parallelism_attribute) + mpu.set_expert_model_parallel_attributes( + self.w3, self._should_set_parallelism_attribute) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, hidden_states, topo): + w1, w2, w3 = ( + self.scale_grad(self.w1), self.scale_grad(self.w2), + self.scale_grad(self.w3) + ) + w1, w2, w3 = ( + resolve_dtensor(w1), resolve_dtensor(w2), resolve_dtensor(w3) + ) + + # Perform the expert computation + hidden_states = stk.Matrix( # type: ignore + topo.size(), + F.silu(stk.ops.sdd(hidden_states, w1.t(), topo).data) + * stk.ops.sdd(hidden_states, w3.t(), topo).data, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + return stk.ops.dsd(hidden_states, w2) \ No newline at end of file From 8729eae835b2f64c28ff1b64bca33d5e9422feec Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Mon, 19 Aug 2024 06:16:30 +0000 Subject: [PATCH 02/21] add configurations and minor updates Signed-off-by: Yu Chin Fabian Lim --- plugins/accelerated-moe/README.md | 4 ++-- .../framework_plugin_megablocks.py | 5 +++- .../megablocks_utils/config_utils.py | 14 +++++++++-- .../megablocks_utils/shard_moe_utils.py | 23 ++++++++++++++----- .../megablocks_utils/sparse_mlp2.py | 9 +++----- .../src/fms_acceleration/constants.py | 2 +- sample-configurations/CONTENTS.yaml | 7 +++++- .../moe-megablocks-sample-configuration.yaml | 14 +++++++++++ scripts/benchmarks/scenarios.yaml | 18 ++++----------- scripts/generate_sample_configurations.py | 3 +++ 10 files changed, 67 insertions(+), 32 deletions(-) create mode 100644 sample-configurations/moe-megablocks-sample-configuration.yaml diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index 8ad79f9a..b093931f 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -1,4 +1,4 @@ -# FMS Accelerattion for Mixture-of-Experts +# FMS Acceleration for Mixture-of-Experts This library contains plugins to accelerate finetuning with the following optimizations: 1. Expert-Parallel MoE with Megablocks @@ -9,5 +9,5 @@ This library contains plugins to accelerate finetuning with the following optimi Currently databricks megablocks does not have a PyPi repository and does not have a proper release, so we have to install from the github repository as below. Please note that installing from github will require CUDA Toolkit to build. ``` -pip install git+https://github.com/databricks/megablocks.git +pip install git+https://github.com/databricks/megablocks.git@bce5d7b2aaf5038bc93b36f76c2baf51c2939bd2 ``` \ No newline at end of file diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py index aa9dda80..83202dc9 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py @@ -44,7 +44,7 @@ def requires_custom_loading(self): def model_loader(self, model_name: str, **kwargs): # guarded from .megablocks_utils.config_utils import update_mlp_registry - from megablocks_utils.shard_moe_utils import shard_moe, get_moe_kwargs + from .megablocks_utils.shard_moe_utils import shard_moe, get_moe_kwargs # this one does a forward patching on MLP, but needs to be fixed # properly as the load balancing loss is currently not properly @@ -86,8 +86,11 @@ def model_loader(self, model_name: str, **kwargs): fp16=torch_dtype == torch.float16, bf16=torch_dtype == torch.bfloat16, ), + shared_mesh_dim=True, # FIXME: this can be passed in? ) + return model + def get_callbacks_and_ready_for_train( self, model: torch.nn.Module = None, accelerator=None ): diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py index 9ef19eb6..e72eda4e 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py @@ -9,8 +9,12 @@ from megablocks.layers.moe import ParallelMLP import torch +# this function ensures that the megablocks packaged is configured to use +# the correct SparseMLP implementation +# - at the moment not considering the GroupedMLP implementations def update_mlp_registry(): - # patch the registry to point to our v2 + + # replace the registry to point to the _REGISTRY['mlp']['sparse'] = SparseMLPv2 def forward(self, x, scores, expert_weights, top_experts): @@ -32,6 +36,12 @@ def forward(self, x, scores, expert_weights, top_experts): # we return None as the placeholder. return x, None - # patch the forward function + # patch the forward function. Willing to do this because ParallelMLP + # is only used here and not anywhere else, hence: + # 1. we do not care about reversing the patch + # 2. we have control on where this is called, and we know to call it + # before our code accesses this function. Hence, we view this as + # a hardcoded modification to the megablocks package more than a + # patch. ParallelMLP.forward = forward \ No newline at end of file diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py index f845c80f..cbf5662a 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py @@ -178,17 +178,20 @@ def shard_moe( device_type: str = 'cuda', key_dp: str = KEY_DATA_PARALLEL, key_ep: str = KEY_EXPERT_PARALLEL, + shared_mesh_dim: bool = True, ): # guarded import from megablocks.layers import dmoe, arguments, mpu - assert ep_size > 1, "this function is used for sharding moe" + assert ep_size > 1, "expert_parallel dimension must be set larger than 1" + assert world_size % ep_size == 0, ( + f"world_size ({world_size}) not divisible by ep_size ({ep_size})." + ) # this function will shard the MOE on this rank device = torch.device(f'cuda:{rank}') - dp_size = world_size // ep_size - if dp_size == 1: + if shared_mesh_dim: # in this case we will have a 1D mesh and collapse the # expert parallel with data_parallel @@ -200,12 +203,18 @@ def shard_moe( key_ep = key_dp placements: List[Placement] = [Shard(DIM_EXPERT)] else: - # in this case it will be a 2D mesh + # in this case it will distribute experts on a different + # mesh dimension than dp. + # - this will achieve the effect that the expert sharding can be + # hierachical (e.g., can be over a slower network plane since + # the communication overhead is less + dp_size = world_size // ep_size device_mesh = init_device_mesh( device_type, (dp_size, ep_size), mesh_dim_names=(key_dp, key_ep), ) + # - experts will replicate over the first dimension placements: List[Placement] = [Replicate(), Shard(DIM_EXPERT)] mp_dmoe_args = arguments.Arguments( @@ -213,8 +222,10 @@ def shard_moe( expert_parallel_group=device_mesh[key_ep].get_group(0) ) - assert mp_dmoe_args.moe_num_experts % world_size == 0, \ - "number of moe experts not divisible by world_size" + assert mp_dmoe_args.moe_num_experts % ep_size == 0, ( + f"number of moe experts ({mp_dmoe_args.moe_num_experts}) " + f"not divisible by ep_size ({ep_size})." + ) # for all the MoE related params, e.g., gate, experts # get a dictc diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py index ac36245e..1b8fa128 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py @@ -14,10 +14,8 @@ class SparseMLPv2(torch.nn.Module): def __init__(self, args : Arguments): super().__init__() self.args = args - self._num_rows_per_rank = ( - (mpu.experts_per_rank(args) * mpu.features_per_rank(args)) // - mpu.get_weight_parallel_world_size(args) - ) + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + self.w1 = torch.nn.Parameter(torch.empty( self._num_rows_per_rank, @@ -46,8 +44,7 @@ def __init__(self, args : Arguments): args, args.moe_num_experts, args.ffn_hidden_size, args.hidden_size, args.output_layer_init_method)) - self._should_set_parallelism_attribute = ( - args.moe_expert_model_parallelism or args.moe_weight_parallelism) + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism mpu.set_expert_model_parallel_attributes( self.w1, self._should_set_parallelism_attribute) mpu.set_expert_model_parallel_attributes( diff --git a/plugins/framework/src/fms_acceleration/constants.py b/plugins/framework/src/fms_acceleration/constants.py index 6a81d977..d568ec13 100644 --- a/plugins/framework/src/fms_acceleration/constants.py +++ b/plugins/framework/src/fms_acceleration/constants.py @@ -21,4 +21,4 @@ # and activated. # - hence the plugins that have model loaders should be on top of this list -PLUGINS = ["peft", "foak", "aadp"] +PLUGINS = ["peft", "foak", "aadp", "moe"] diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml index 09301193..5b7c1e65 100644 --- a/sample-configurations/CONTENTS.yaml +++ b/sample-configurations/CONTENTS.yaml @@ -67,4 +67,9 @@ framework_configs: - accelerated-peft - attention-and-distributed-packing - fused-ops-and-kernels - filename: accelerated-peft-autogptq-foak-padding-free-sample-configuration.yaml \ No newline at end of file + filename: accelerated-peft-autogptq-foak-padding-free-sample-configuration.yaml + + - shortname: moe-megablocks + plugins: + - accelerated-moe + filename: moe-megablocks-sample-configuration.yaml diff --git a/sample-configurations/moe-megablocks-sample-configuration.yaml b/sample-configurations/moe-megablocks-sample-configuration.yaml new file mode 100644 index 00000000..86ba08a0 --- /dev/null +++ b/sample-configurations/moe-megablocks-sample-configuration.yaml @@ -0,0 +1,14 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + training: + + # mixture-of-experts configurations + moe: + + # expert-parallel for MoE + megablocks: + + dummy: 1 diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml index 2eb22872..b3a9c300 100644 --- a/scripts/benchmarks/scenarios.yaml +++ b/scripts/benchmarks/scenarios.yaml @@ -94,20 +94,12 @@ scenarios: - 'mistralai/Mixtral-8x7B-Instruct-v0.1' - 'NousResearch/Llama-2-70b-hf' - - name: accelerated-peft-gptq + - name: accelerated-moe-megablocks framework_config: - - accelerated-peft-autogptq - - accelerated-peft-autogptq-foak + - moe-megablocks arguments: learning_rate: 2e-4 - fp16: True - torch_dtype: float16 - peft_method: lora - r: 16 - lora_alpha: 16 - lora_dropout: 0.1 - target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] + bf16: True + torch_dtype: bfloat16 model_name_or_path: - - 'TheBloke/Mistral-7B-v0.1-GPTQ' - - 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ' - - 'TheBloke/Llama-2-70B-GPTQ' + - 'mistralai/Mixtral-8x7B-Instruct-v0.1' diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py index c72c62eb..2a740342 100644 --- a/scripts/generate_sample_configurations.py +++ b/scripts/generate_sample_configurations.py @@ -147,6 +147,7 @@ def read_configuration(path: str) -> Dict: KEY_BNB_NF4_FOAK = "bnb-nf4-foak" KEY_AADP_PADDING_FREE = "aadp-padding-free" KEY_AADP_MULTIPACK = "aadp-multipack" +KEY_MEGABLOCKS = "moe-megablocks" CONFIGURATIONS = { KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml", @@ -171,6 +172,7 @@ def read_configuration(path: str) -> Dict: ), KEY_AADP_PADDING_FREE: "plugins/attention-and-distributed-packing/configs/padding_free.yaml", KEY_AADP_MULTIPACK: "plugins/attention-and-distributed-packing/configs/multipack.yaml", + KEY_MEGABLOCKS: "plugins/accelerated-moe/configs/megablocks.yaml", } # list of (tag, combi) tuples @@ -190,6 +192,7 @@ def read_configuration(path: str) -> Dict: ("accelerated-peft-autogptq-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)), ("accelerated-peft-bnb-nf4-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_BNB_NF4, KEY_BNB_NF4_FOAK)), ("aadp-padding-free-multipack", (KEY_AADP_PADDING_FREE, KEY_AADP_MULTIPACK)), + ("moe-megablocks", (KEY_MEGABLOCKS,)), ] From ff04022a95cf389b2a333390f2f8aa49d5fbb103 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Tue, 20 Aug 2024 07:08:54 +0000 Subject: [PATCH 03/21] address configurations Signed-off-by: Yu Chin Fabian Lim --- plugins/accelerated-moe/README.md | 5 ++++ .../framework_plugin_megablocks.py | 29 ++++++++++++++----- .../megablocks_utils/config_utils.py | 26 ++++++++++++----- .../megablocks_utils/shard_moe_utils.py | 15 +++++++++- .../moe-megablocks-sample-configuration.yaml | 13 ++++++++- 5 files changed, 71 insertions(+), 17 deletions(-) diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index b093931f..04cdc5a0 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -3,6 +3,11 @@ This library contains plugins to accelerate finetuning with the following optimizations: 1. Expert-Parallel MoE with Megablocks +## Known Issues with Megablocks + +Known Issues +- Currently we do not pass the data parallel `dp_mesh` to the `FSDP` constructor, so `FSDP` will always shard over the default process group (over world_size). + ## Megablocks Dependencies diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py index 83202dc9..ace0bdb6 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py @@ -13,14 +13,13 @@ # limitations under the License. # Standard -from typing import Dict, Tuple +from typing import Dict import torch import warnings # Third Party from fms_acceleration import AccelerationPlugin -from peft import LoraConfig -from transformers import TrainingArguments, AutoModelForCausalLM +from transformers import AutoModelForCausalLM class MegablocksMoEAccelerationPlugin(AccelerationPlugin): @@ -32,11 +31,21 @@ def __init__(self, configurations: Dict[str, Dict]): super().__init__(configurations) # args - self._dummy = self._check_config_and_maybe_check_values( - key="training.moe.megablocks", - values=["dummy"], + self._shard_along_dp = self._check_config_and_maybe_check_values( + key="training.moe.megablocks.shard_along_dp", + values=[True, False], + default=True, ) + # ep_size determines the expert parallel sharding + # - ep_size is ignored if _shard_along_dp=True + self._ep_size = None + if not self._shard_along_dp: + self._ep_size = self._check_config_and_maybe_check_values( + key="training.moe.megablocks.ep_size", + default=1, + ) + @property def requires_custom_loading(self): return True @@ -79,15 +88,19 @@ def model_loader(self, model_name: str, **kwargs): checkpoint_name_or_path=model_name, rank=rank, world_size=world_size, - ep_size=world_size, # FIXME: this can be passed in? + ep_size=self._ep_size, moe_kwargs=get_moe_kwargs( model.config, has_bias=False, # FIXME: is this true in general? fp16=torch_dtype == torch.float16, bf16=torch_dtype == torch.bfloat16, ), - shared_mesh_dim=True, # FIXME: this can be passed in? + shared_mesh_dim=self._shard_along_dp, ) + # NOTE: Currently, it is a bit troublesome to pass the device_mesh to + # the FSDP constructor, so we do not do that. + # - therefore FSDP will always shard on world_size over the default process + # group return model diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py index e72eda4e..5ff93b69 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py @@ -4,18 +4,30 @@ from megablocks.layers.dmlp_registry import _REGISTRY -# from megablocks.layers import mlp +from megablocks.layers.mlp import SparseMLP from .sparse_mlp2 import SparseMLPv2 from megablocks.layers.moe import ParallelMLP -import torch + +SPARSE_MLP_IMPL = { + "v1": SparseMLP, + "v2": SparseMLPv2, +} # this function ensures that the megablocks packaged is configured to use # the correct SparseMLP implementation # - at the moment not considering the GroupedMLP implementations -def update_mlp_registry(): - - # replace the registry to point to the - _REGISTRY['mlp']['sparse'] = SparseMLPv2 +def update_mlp_registry( + mlp_type: str = 'sparse', + mlp_version: str = 'v2', +): + + # replace the registry to point to the the correct sparse implementation + if mlp_type == 'sparse': + assert mlp_version in SPARSE_MLP_IMPL, \ + f"Megablocks only support sparse mlp versions: {','.join(SPARSE_MLP_IMPL.keys())}" + _REGISTRY['mlp']['sparse'] = SPARSE_MLP_IMPL[mlp_version] + else: + raise NotImplementedError("Currently only supports sparse MLP implementations.") def forward(self, x, scores, expert_weights, top_experts): in_shape = x.size() @@ -36,7 +48,7 @@ def forward(self, x, scores, expert_weights, top_experts): # we return None as the placeholder. return x, None - # patch the forward function. Willing to do this because ParallelMLP + # replace the forward function. Willing to do this because ParallelMLP # is only used here and not anywhere else, hence: # 1. we do not care about reversing the patch # 2. we have control on where this is called, and we know to call it diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py index cbf5662a..5c4a2835 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py @@ -173,16 +173,29 @@ def shard_moe( checkpoint_name_or_path: str, rank: int, world_size: int, - ep_size: int, moe_kwargs: Dict, device_type: str = 'cuda', key_dp: str = KEY_DATA_PARALLEL, key_ep: str = KEY_EXPERT_PARALLEL, shared_mesh_dim: bool = True, + ep_size: int = 1, ): # guarded import from megablocks.layers import dmoe, arguments, mpu + if shared_mesh_dim: + # if sharing mesh with dp, then the ep_size must be the world_size + # - in this case ep_shard_factor is ignored + ep_size = world_size + else: + + # - moe_kwargs is the constructed by get_moe_kwargs above + _num_experts = moe_kwargs['moe_num_experts'] + assert _num_experts % ep_size == 0, ( + f"ep_shard factor '{ep_size}' does not divide " + f"number of experts '{_num_experts}'." + ) + assert ep_size > 1, "expert_parallel dimension must be set larger than 1" assert world_size % ep_size == 0, ( f"world_size ({world_size}) not divisible by ep_size ({ep_size})." diff --git a/sample-configurations/moe-megablocks-sample-configuration.yaml b/sample-configurations/moe-megablocks-sample-configuration.yaml index 86ba08a0..c3c882e8 100644 --- a/sample-configurations/moe-megablocks-sample-configuration.yaml +++ b/sample-configurations/moe-megablocks-sample-configuration.yaml @@ -10,5 +10,16 @@ plugins: # expert-parallel for MoE megablocks: + + # if True, then we shard experts across data parallel dimension + # - only feasible if world_size divides the number of experts + shard_along_dp: True + + # to be specified only if shard_along_dp == False. This will influence + # the level of sharding, which indicates how many experts per device + # - the number of experts per device will be num_experts / ep_size + # - we disable the ability to set ep_size=1 since this means no sharding + # - NOTE: ep_size=1 does not mean shard_along_dp=True, which would otherwise + # be contradictory since ep_size suggests no expert parallel. + # ep_size: 2 - dummy: 1 From 03db9f577424c9a1c4f318dc64041640d22fef3c Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Tue, 20 Aug 2024 09:54:12 +0000 Subject: [PATCH 04/21] first attempt at abstract out model and add comments Signed-off-by: Yu Chin Fabian Lim --- plugins/accelerated-moe/README.md | 3 +- .../accelerated-moe/configs/megablocks.yaml | 14 +- .../framework_plugin_megablocks.py | 24 ++- .../megablocks_utils/shard_moe_utils.py | 151 +++++++++++++----- .../moe-megablocks-sample-configuration.yaml | 4 +- 5 files changed, 145 insertions(+), 51 deletions(-) diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index 04cdc5a0..fe4627c9 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -3,10 +3,11 @@ This library contains plugins to accelerate finetuning with the following optimizations: 1. Expert-Parallel MoE with Megablocks -## Known Issues with Megablocks +## Known Issues with Megablocks Plugin Implementation Known Issues - Currently we do not pass the data parallel `dp_mesh` to the `FSDP` constructor, so `FSDP` will always shard over the default process group (over world_size). +- Currently only supports loading `safetensor` MoE checkpoints. ## Megablocks Dependencies diff --git a/plugins/accelerated-moe/configs/megablocks.yaml b/plugins/accelerated-moe/configs/megablocks.yaml index 14465d96..5ca459bb 100644 --- a/plugins/accelerated-moe/configs/megablocks.yaml +++ b/plugins/accelerated-moe/configs/megablocks.yaml @@ -1,9 +1,19 @@ training: # mixture-of-experts configurations - moe: + moe: # expert-parallel for MoE megablocks: + + # if True, then we shard experts across data parallel dimension + # - only feasible if world_size divides the number of experts + shard_along_dp: true - dummy: 1 + # to be specified only if shard_along_dp == False. This will influence + # the level of sharding, which indicates how many experts per device + # - the number of experts per device will be num_experts / ep_size + # - we disable the ability to set ep_size=1 since this means no sharding + # - NOTE: ep_size=1 does not mean shard_along_dp=True, which would otherwise + # be contradictory since ep_size suggests no expert parallel. + # ep_size: 2 diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py index ace0bdb6..8d9fccb1 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py @@ -21,6 +21,17 @@ from fms_acceleration import AccelerationPlugin from transformers import AutoModelForCausalLM +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +# Different Models +# - MoE Class +# - has_bias +# - gate module name +MODEL_MEGABLOCKS = { + "MixtralForCausalLM": ( + MixtralSparseMoeBlock, False, "gate", "experts" + ) +} class MegablocksMoEAccelerationPlugin(AccelerationPlugin): @@ -78,24 +89,29 @@ def model_loader(self, model_name: str, **kwargs): "Megablocks expert parallel only works for distributed training." ) - # FIXME: have some way to search out the MOE block - from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + # - get model specific items + ( + moe_cls, has_bias, router_name, expert_name, + ) = MODEL_MEGABLOCKS[model.__class__.__name__] # FIXME: the dtype checks below are too brittle dp_mesh = shard_moe( model, - MixtralSparseMoeBlock, + moe_cls, checkpoint_name_or_path=model_name, rank=rank, world_size=world_size, ep_size=self._ep_size, moe_kwargs=get_moe_kwargs( model.config, - has_bias=False, # FIXME: is this true in general? + has_bias=has_bias, fp16=torch_dtype == torch.float16, bf16=torch_dtype == torch.bfloat16, ), shared_mesh_dim=self._shard_along_dp, + router_name=router_name, + expert_name=expert_name, + ) # NOTE: Currently, it is a bit troublesome to pass the device_mesh to # the FSDP constructor, so we do not do that. diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py index 5c4a2835..6561d8f2 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py @@ -1,9 +1,9 @@ import torch -from transformers import TrainingArguments, PretrainedConfig -from typing import Union, Dict, List, Type +from transformers import PretrainedConfig +from typing import Tuple, Dict, List, Type from torch.distributed._tensor import Placement, Replicate, Shard, distribute_tensor -from torch.distributed._tensor.device_mesh import init_device_mesh +from torch.distributed._tensor.device_mesh import init_device_mesh, DeviceMesh import os from tqdm import tqdm @@ -20,8 +20,9 @@ KEY_EXPERT_PARALLEL = 'expert_parallel' DIM_EXPERT = 0 -KEY_ROUTER = 'router.layer.weight' -KEY_EXPERTS = 'experts.mlp' +# these depend on the namings in the dMOE +KEY_DMOE_ROUTER = 'router.layer.weight' +KEY_DMOE_EXPERTS = 'experts.mlp' def get_moe_kwargs( config: PretrainedConfig, @@ -59,18 +60,44 @@ def _dict_from_json_file(resolved_config_file): PretrainedConfig._dict_from_json_file = _old_func return os.path.dirname(result) -# see https://github.com/mosaicml/llm-foundry/blob/main/tests/models/layers/test_dmoe.py -# for a basic example -# this one is called for one layer -# e.g., 'model.layers.0, block_sparse_moe -def get_router_experts_sharded_safetensor( +# This function creates a dictionary of keys and paths into the the sharded +# safetensors checkpoint file, that are relevant to the "prefix" and "instance_name" +# being pased in. +# - the keys point to modules found in megablocks.layers.dmoe.dMoE, the distributed +# expert module provided by megablocks. +# - the values are tuples pointing to the keys within the checkpoint file. +# +# Example: if prefix="module.layers.0" and instance_name="block_sparse_moe", then a dictionary +# of the following will be returned: +# { +# 'experts.mlp.w1': [ +# ( +# 'model.layers.0.block_sparse_moe.experts.0.w1.weight', +# 'model-00001-of-00019.safetensors' +# ), +# ( +# 'model.layers.0.block_sparse_moe.experts.1.w1.weight', +# 'model-00001-of-00019.safetensors' +# ), +# ... +# ] +# 'experts.mlp.w2': [...], +# 'experts.mlp.w3': [...], +# 'router.layer.weight': [ +# ( +# 'model.layers.0.block_sparse_moe.gate.weight', +# 'model-00001-of-00019.safetensors' +# ) +# ] +# } +def get_checkpoint_meta_from_sharded_safetensor( weight_map: Dict, prefix: str, # e.g., 'model.layers.0, instance_name: str, # e.g., block_sparse_moe - router_name: str = 'gate', - expert_name: str = 'experts' -): + router_name: str = 'gate', # e.g., named "gate" within block_sparse_moe + expert_name: str = 'experts' # e.g., named "experts" within block_sparse_moe +) -> Dict[str, List[Tuple]]: # insert in order def _insert(L: List, i: int, v): n = len(L) @@ -107,12 +134,11 @@ def _insert(L: List, i: int, v): f"'{router_name}' or expert_name '{expert_name}'" ) if m.group(1) == router_name: - _map[KEY_ROUTER].append((k, stfile)) + _map[KEY_DMOE_ROUTER].append((k, stfile)) elif m.group(1) == expert_name: index = int(m.group(2)) mod = m.group(3) - # expert_map[stfile].append((mod, index, k)) - _insert(_map[f'{KEY_EXPERTS}.{mod}'], index, (k, stfile)) + _insert(_map[f'{KEY_DMOE_EXPERTS}.{mod}'], index, (k, stfile)) if len(_map) == 0: raise ValueError( @@ -121,47 +147,53 @@ def _insert(L: List, i: int, v): return _map -# for megablocks.SparseMLPv2 -# assign dmoe with mlp_v2 -# settings is: -# experts.mlp.w1: [(k, file)] -def assign_mlp_v2_weights( +# this function will load the sharded experts onto the device. +# - this assumes that the "dmoe" module is the megablocks.layers.dmoe.dMoE distributed +# implementation of the mixture of experts. +def load_sharded_experts_onto_device( dmoe: torch.nn.Module, directory: str, - settings: Dict, - device_mesh, - placements, + checkpoint_metadata: Dict[str, List[Tuple]], + device_mesh: DeviceMesh, + placements: Placement, + expert_name: str = 'experts' # e.g., named "experts" within block_sparse_moe ): - # typically they all should be same file + # typically they all should be same file, but to play safe, load the checkpoint file onto + # cpu first since we may not need all weights in that file. with ExitStack() as stack: files = {} - for _, vs in settings.items(): + for _, vs in checkpoint_metadata.items(): for _, fi in vs: if fi not in files: files[fi] = stack.enter_context( safe_open(os.path.join(directory, fi), framework='pt', device='cpu') ) - # go by one weight - for weight_name, vs in settings.items(): + # go by one weight at a time. + # - weight_name: points to megablocks.dmoe + for weight_name, vs in checkpoint_metadata.items(): data = [] for k, fi in vs: T = files[fi].get_tensor(k) - if 'experts' in k: + if expert_name in k: if T.shape[1] > T.shape[0]: T = T.t() data.append(T) - # concat on dim 0 and distribute + # the megablocks dmoe experts the expert features to be on DIM_EXPERT. + # - concat on dim 0 and distribute param = torch.concat(data, dim=DIM_EXPERT) - if KEY_ROUTER not in weight_name: + if KEY_DMOE_ROUTER not in weight_name: param = torch.nn.Parameter( distribute_tensor(param, device_mesh, placements) ) else: + # - do not shard the router but load onto device as well param = torch.nn.Parameter( param.to(torch.cuda.current_device()) ) + + # register the sharded parameter onto the megablocks.dmoe name = weight_name.split('.') path, name = ".".join(name[:-1]), name[-1] mod = dmoe.get_submodule(path) @@ -169,7 +201,7 @@ def assign_mlp_v2_weights( def shard_moe( model: torch.nn.Module, - moe_cls: Union[str, Type], + moe_cls: Type, checkpoint_name_or_path: str, rank: int, world_size: int, @@ -177,11 +209,47 @@ def shard_moe( device_type: str = 'cuda', key_dp: str = KEY_DATA_PARALLEL, key_ep: str = KEY_EXPERT_PARALLEL, + router_name: str = 'gate', + expert_name: str = 'experts', shared_mesh_dim: bool = True, ep_size: int = 1, ): + """shard_moe takes a mixture-of-experts huggingface model and shards the experts + on the current device. All layers layers that have a MoE module will be sharded. + + The function requires "checkpoint_name_or_path" to point to the checkpoint that + the model has been loaded from, because model could have been loaded on the meta + device, and in which case would be missing the weights. This function will + instialize the sharded weights onto the device. + + The sharding has two modes, and depends on world_size and number_of_experts the model + has. This depends on the setting "shared_mesh_dim" to True or False: + - if True: then dp and ep will happen on the same device_mesh dimension. This is only possible + if world_size divides number_of_experts (which requires world_size < num_of_experts). + - if False: then dp and ep will be seperate device_mesh dimensions. The ep_size will be determined + by the argument passed in (which needs to be properly set ep_size > 1; the default + value will raise an assertion). + + Parameters: + + model (module): A valid mixture-of-experts Huggingface model. + moe_cls (type): A module class used to identify the MoE components. + checkpoint_name_or_path (str): name or path pointing to the weight checkpoint. + rank (int): rank of the current device. + world_size (int): total world size. + moe_kwargs (dict): kwargs to be passed to construct megablocks.layers.arguments for + constructing the megablocks.layer.dmoe.dMOE. + device_type (str): the current device to load the sharded model into. + key_dp (str): name of the data parallel mesh + key_ep (str): name of the expert parallel mesh (if initialized). + router_name (str): module name of the router in moe_cls (e.g., "gate"). + expert_name (str): module name of the experts in moe_cls (e.g., "experts"). + shared_mesh_dim (bool): for the sharding mode, see explanation above. + ep_size (int): for shard_mesh_dim=False only, see explanation above. + + """ # guarded import - from megablocks.layers import dmoe, arguments, mpu + from megablocks.layers import dmoe, arguments if shared_mesh_dim: # if sharing mesh with dp, then the ep_size must be the world_size @@ -266,19 +334,20 @@ def shard_moe( # e.g., prefix: 'model.layers.0', # module_name: 'block_sparse_moe' for prefix, (module_name, relevant_keys) in tqdm( - found.items(), - disable=torch.distributed.get_rank() > 0, - desc='Sharding MoE' + found.items(), disable=(rank > 0), desc='Sharding MoE' ): - settings = get_router_experts_sharded_safetensor( + checkpoint_metadata = get_checkpoint_meta_from_sharded_safetensor( index['weight_map'], prefix, module_name, + router_name, expert_name ) + + # - will replace the MoE module with the megablocks sharded dMoE with init_empty_weights(): mp_dmoe = dmoe.dMoE(mp_dmoe_args) # drop in replacement for now - assign_mlp_v2_weights( - mp_dmoe, loc, settings, - device_mesh, placements + load_sharded_experts_onto_device( + mp_dmoe, loc, checkpoint_metadata, + device_mesh, placements, expert_name ) parent = model.get_submodule(prefix) setattr(parent, module_name, mp_dmoe) @@ -286,7 +355,7 @@ def shard_moe( except ValueError as e: raise ValueError( f"Unable to load checkpoint_path '{checkpoint_name_or_path}'. " - "Currently only support safetensor checkpoints. " + "Currently only support non-GGUF safetensor checkpoints. " f": {e}" ) diff --git a/sample-configurations/moe-megablocks-sample-configuration.yaml b/sample-configurations/moe-megablocks-sample-configuration.yaml index c3c882e8..bb0e234b 100644 --- a/sample-configurations/moe-megablocks-sample-configuration.yaml +++ b/sample-configurations/moe-megablocks-sample-configuration.yaml @@ -10,10 +10,9 @@ plugins: # expert-parallel for MoE megablocks: - # if True, then we shard experts across data parallel dimension # - only feasible if world_size divides the number of experts - shard_along_dp: True + shard_along_dp: true # to be specified only if shard_along_dp == False. This will influence # the level of sharding, which indicates how many experts per device @@ -22,4 +21,3 @@ plugins: # - NOTE: ep_size=1 does not mean shard_along_dp=True, which would otherwise # be contradictory since ep_size suggests no expert parallel. # ep_size: 2 - From 2ce539ce3140d68f7c7258579066462596bf8998 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Wed, 21 Aug 2024 03:52:38 +0000 Subject: [PATCH 05/21] add model configuration Signed-off-by: Yu Chin Fabian Lim --- .../accelerated-moe/configs/megablocks.yaml | 18 +++++ .../framework_plugin_megablocks.py | 70 ++++++++++++------- .../megablocks_utils/shard_moe_utils.py | 58 ++++++++++----- .../megablocks_utils/sparse_mlp2.py | 10 ++- .../moe-megablocks-sample-configuration.yaml | 19 +++++ 5 files changed, 128 insertions(+), 47 deletions(-) diff --git a/plugins/accelerated-moe/configs/megablocks.yaml b/plugins/accelerated-moe/configs/megablocks.yaml index 5ca459bb..e52670d2 100644 --- a/plugins/accelerated-moe/configs/megablocks.yaml +++ b/plugins/accelerated-moe/configs/megablocks.yaml @@ -5,6 +5,20 @@ training: # expert-parallel for MoE megablocks: + + # The name of the mixture-of-experts class + moe_component_class: MixtralSparseMoeBlock + + # The module name of the router in moe_component_class above + moe_gate_module_name: gate + + # The module name of the experts in moe_component_class above + moe_experts_module_name: experts + + # the mlp version + # - for those with only up and down projs, use "v1" + # - for those with only up, down and gate projs, use "v2" + moe_mlp_impl: v2 # if True, then we shard experts across data parallel dimension # - only feasible if world_size divides the number of experts @@ -17,3 +31,7 @@ training: # - NOTE: ep_size=1 does not mean shard_along_dp=True, which would otherwise # be contradictory since ep_size suggests no expert parallel. # ep_size: 2 + + # the MoE dropless implementation. Currently we only support "dropless_sparse", but + # in the future we may support others + moe_implementation: dropless_sparse diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py index 8d9fccb1..eaa4e47b 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py @@ -23,25 +23,39 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock -# Different Models -# - MoE Class -# - has_bias -# - gate module name -MODEL_MEGABLOCKS = { - "MixtralForCausalLM": ( - MixtralSparseMoeBlock, False, "gate", "experts" - ) -} class MegablocksMoEAccelerationPlugin(AccelerationPlugin): - restricted_model_archs = {"MixtralForCausalLM"} require_packages = {"megablocks"} def __init__(self, configurations: Dict[str, Dict]): super().__init__(configurations) - # args + # arguments for configuring the mixture-of-experts model with defaults + # shown below for Mixtral 7x8b + # - 1. component class + self._moe_component_cls = self._check_config_and_maybe_check_values( + key="training.moe.megablocks.moe_component_class", + default="MixtralSparseMoeBlock", + ) + # - 2. gate_module_name + self._gate_module_name = self._check_config_and_maybe_check_values( + key="training.moe.megablocks.moe_gate_module_name", + default="gate" + ) + # - 3. experts_module_name + self._experts_module_name = self._check_config_and_maybe_check_values( + key="training.moe.megablocks.moe_experts_module_name", + default="experts" + ) + # - 4. mlp version + self._mlp_version = self._check_config_and_maybe_check_values( + key="training.moe.megablocks.moe_mlp_impl", + values=["v1", "v2"], + default="v2" + ) + + # for controlling the type of sharding self._shard_along_dp = self._check_config_and_maybe_check_values( key="training.moe.megablocks.shard_along_dp", values=[True, False], @@ -57,6 +71,14 @@ def __init__(self, configurations: Dict[str, Dict]): default=1, ) + # for the moe_implementation, currently we only use the megablocks + # dropless sparse implementation + self._shard_along_dp = self._check_config_and_maybe_check_values( + key="training.moe.megablocks.moe_implementation", + values=["dropless_sparse"], + default="dropless_sparse", + ) + @property def requires_custom_loading(self): return True @@ -89,28 +111,25 @@ def model_loader(self, model_name: str, **kwargs): "Megablocks expert parallel only works for distributed training." ) - # - get model specific items + # shard the MOE, and store products required for + # FSDP configuration ( - moe_cls, has_bias, router_name, expert_name, - ) = MODEL_MEGABLOCKS[model.__class__.__name__] - - # FIXME: the dtype checks below are too brittle - dp_mesh = shard_moe( + dp_mesh, self._moe_component_module_names + ) = shard_moe( model, - moe_cls, + self._moe_component_cls, checkpoint_name_or_path=model_name, rank=rank, world_size=world_size, ep_size=self._ep_size, moe_kwargs=get_moe_kwargs( model.config, - has_bias=has_bias, fp16=torch_dtype == torch.float16, bf16=torch_dtype == torch.bfloat16, ), shared_mesh_dim=self._shard_along_dp, - router_name=router_name, - expert_name=expert_name, + router_name=self._gate_module_name, + expert_name=self._experts_module_name, ) # NOTE: Currently, it is a bit troublesome to pass the device_mesh to @@ -129,13 +148,10 @@ def get_callbacks_and_ready_for_train( accelerator is not None and getattr(accelerator.state, "fsdp_plugin", None) is not None ): - # lora_adapters_switch_ddp_from_fsdp( - # [mod for mod in model.modules() if isinstance(mod, LoraLayer)], - # accelerator.state.fsdp_plugin, - # ) - # FIXME: should be accelerator.state.fsdp_plugin.ignored_modules = [ - layer.block_sparse_moe for layer in model.model.layers + getattr(layer, name) + for name in self._moe_component_module_names + for layer in model.model.layers ] return callbacks diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py index 6561d8f2..e5b6b9a0 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py @@ -1,10 +1,12 @@ import torch from transformers import PretrainedConfig -from typing import Tuple, Dict, List, Type +from transformers.activations import ACT2FN +from typing import Tuple, Dict, List, Type, Union from torch.distributed._tensor import Placement, Replicate, Shard, distribute_tensor from torch.distributed._tensor.device_mesh import init_device_mesh, DeviceMesh import os +from copy import copy from tqdm import tqdm from safetensors import safe_open @@ -26,7 +28,6 @@ def get_moe_kwargs( config: PretrainedConfig, - has_bias: bool = False, # if the MOE has bias fp16: bool = False, bf16: bool = False, ): @@ -37,8 +38,9 @@ def get_moe_kwargs( "moe_top_k": config.num_experts_per_tok, "moe_expert_model_parallelism": True, "memory_optimized_mlp": False, - "bias": has_bias, + "activation_fn": ACT2FN[config.hidden_act], "moe_normalize_expert_weights": True, + "return_bias": False, "fp16": fp16, "bf16": bf16, } @@ -175,7 +177,7 @@ def load_sharded_experts_onto_device( data = [] for k, fi in vs: T = files[fi].get_tensor(k) - if expert_name in k: + if expert_name in k and k.endswith("weight"): if T.shape[1] > T.shape[0]: T = T.t() data.append(T) @@ -201,7 +203,7 @@ def load_sharded_experts_onto_device( def shard_moe( model: torch.nn.Module, - moe_cls: Type, + moe_cls: Union[str,Type], checkpoint_name_or_path: str, rank: int, world_size: int, @@ -233,7 +235,7 @@ def shard_moe( Parameters: model (module): A valid mixture-of-experts Huggingface model. - moe_cls (type): A module class used to identify the MoE components. + moe_cls (str,type): A module class used to identify the MoE components. checkpoint_name_or_path (str): name or path pointing to the weight checkpoint. rank (int): rank of the current device. world_size (int): total world size. @@ -309,21 +311,37 @@ def shard_moe( ) # for all the MoE related params, e.g., gate, experts - # get a dictc + # get a dictionary # parent_mod: (child_instance_name, [list of fqdn keys]) found = {} for name, mod in model.named_modules(): name = name.split('.') parent, child = ".".join(name[:-1]), name[-1] - if isinstance(mod, moe_cls): - found[parent] = ( - child, - [ # all params, including childs' - f'{parent}.{child}.{n}' - for n, _ in mod.named_parameters() - ] + + # check the module depending if moe_cls is a str or class + if ( + mod.__class__.__name__ == moe_cls if + isinstance(moe_cls, str) else + isinstance(mod, moe_cls) + ): + fqdn_keys = [ # all params, including childs' + f'{parent}.{child}.{n}' + for n, _ in mod.named_parameters() + ] + + # check if there are any biases in any of the experts + # if there are biases + # Assumption: assume that if one expert has bias,then the others + # will have it to + has_bias = any( + expert_name in k and k.endswith('bias') + for k in fqdn_keys ) + found[parent] = (child, fqdn_keys, has_bias) + + moe_module_names = set() + # NOTE: for now we only support sharded safetensors # - most MOE models should be used using this checkpoint format try: @@ -333,7 +351,7 @@ def shard_moe( # e.g., prefix: 'model.layers.0', # module_name: 'block_sparse_moe' - for prefix, (module_name, relevant_keys) in tqdm( + for prefix, (module_name, relevant_keys, has_bias) in tqdm( found.items(), disable=(rank > 0), desc='Sharding MoE' ): checkpoint_metadata = get_checkpoint_meta_from_sharded_safetensor( @@ -341,9 +359,12 @@ def shard_moe( router_name, expert_name ) + _args = copy(mp_dmoe_args) + _args.bias = has_bias + # - will replace the MoE module with the megablocks sharded dMoE with init_empty_weights(): - mp_dmoe = dmoe.dMoE(mp_dmoe_args) # drop in replacement for now + mp_dmoe = dmoe.dMoE(_args) # drop in replacement for now load_sharded_experts_onto_device( mp_dmoe, loc, checkpoint_metadata, @@ -352,6 +373,9 @@ def shard_moe( parent = model.get_submodule(prefix) setattr(parent, module_name, mp_dmoe) + # - keep track of the name for returning + moe_module_names.add(module_name) + except ValueError as e: raise ValueError( f"Unable to load checkpoint_path '{checkpoint_name_or_path}'. " @@ -360,4 +384,4 @@ def shard_moe( ) - return device_mesh[key_dp] \ No newline at end of file + return device_mesh[key_dp], moe_module_names \ No newline at end of file diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py index 1b8fa128..77b2e42c 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py @@ -2,13 +2,15 @@ from megablocks.layers import common from megablocks.layers.arguments import Arguments from megablocks.layers import mpu +from megablocks.layers.activation_fn import act_fn from megablocks.layers.mlp import ( create_dmoe_expert_weights, scale_gradient, resolve_dtensor ) -import torch.nn.functional as F import stk +# This is the different MLP class used for models that have up_proj, down_proj +# and gate_proj like Mixtral class SparseMLPv2(torch.nn.Module): def __init__(self, args : Arguments): @@ -73,8 +75,10 @@ def forward(self, hidden_states, topo): # Perform the expert computation hidden_states = stk.Matrix( # type: ignore topo.size(), - F.silu(stk.ops.sdd(hidden_states, w1.t(), topo).data) - * stk.ops.sdd(hidden_states, w3.t(), topo).data, + act_fn( + stk.ops.sdd(hidden_states, w1.t(), topo), + self.args.activation_fn + ).data * stk.ops.sdd(hidden_states, w3.t(), topo).data, topo.row_indices, topo.column_indices, topo.offsets, diff --git a/sample-configurations/moe-megablocks-sample-configuration.yaml b/sample-configurations/moe-megablocks-sample-configuration.yaml index bb0e234b..12815d98 100644 --- a/sample-configurations/moe-megablocks-sample-configuration.yaml +++ b/sample-configurations/moe-megablocks-sample-configuration.yaml @@ -10,6 +10,21 @@ plugins: # expert-parallel for MoE megablocks: + + # The name of the mixture-of-experts class + moe_component_class: MixtralSparseMoeBlock + + # The module name of the router in moe_component_class above + moe_gate_module_name: gate + + # The module name of the experts in moe_component_class above + moe_experts_module_name: experts + + # the mlp version + # - for those with only up and down projs, use "v1" + # - for those with only up, down and gate projs, use "v2" + moe_mlp_impl: v2 + # if True, then we shard experts across data parallel dimension # - only feasible if world_size divides the number of experts shard_along_dp: true @@ -21,3 +36,7 @@ plugins: # - NOTE: ep_size=1 does not mean shard_along_dp=True, which would otherwise # be contradictory since ep_size suggests no expert parallel. # ep_size: 2 + + # the MoE dropless implementation. Currently we only support "dropless_sparse", but + # in the future we may support others + moe_implementation: dropless_sparse From 135bf5e292c7c10d1b267ef5714427af5cd38816 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Wed, 21 Aug 2024 07:21:10 +0000 Subject: [PATCH 06/21] fixes to bias and enabled load balancing loss Signed-off-by: Yu Chin Fabian Lim --- .../accelerated-moe/configs/megablocks.yaml | 3 ++ .../framework_plugin_megablocks.py | 36 ++++++++++++++++--- .../megablocks_utils/config_utils.py | 11 +++--- .../moe-megablocks-sample-configuration.yaml | 3 ++ 4 files changed, 44 insertions(+), 9 deletions(-) diff --git a/plugins/accelerated-moe/configs/megablocks.yaml b/plugins/accelerated-moe/configs/megablocks.yaml index e52670d2..cd63a21c 100644 --- a/plugins/accelerated-moe/configs/megablocks.yaml +++ b/plugins/accelerated-moe/configs/megablocks.yaml @@ -35,3 +35,6 @@ training: # the MoE dropless implementation. Currently we only support "dropless_sparse", but # in the future we may support others moe_implementation: dropless_sparse + + # for load_balancing_loss + load_balancing_loss: false diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py index eaa4e47b..56cdc3de 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py @@ -19,9 +19,7 @@ # Third Party from fms_acceleration import AccelerationPlugin -from transformers import AutoModelForCausalLM - -from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +from transformers import AutoConfig, AutoModelForCausalLM class MegablocksMoEAccelerationPlugin(AccelerationPlugin): @@ -73,11 +71,18 @@ def __init__(self, configurations: Dict[str, Dict]): # for the moe_implementation, currently we only use the megablocks # dropless sparse implementation - self._shard_along_dp = self._check_config_and_maybe_check_values( + self._moe_implementation = self._check_config_and_maybe_check_values( key="training.moe.megablocks.moe_implementation", values=["dropless_sparse"], default="dropless_sparse", ) + self._moe_implementation = self._moe_implementation.split("_")[1] + + self._load_balancing_loss = self._check_config_and_maybe_check_values( + key="training.moe.megablocks.load_balancing_loss", + values=[True, False], + default=False, + ) @property def requires_custom_loading(self): @@ -88,10 +93,27 @@ def model_loader(self, model_name: str, **kwargs): from .megablocks_utils.config_utils import update_mlp_registry from .megablocks_utils.shard_moe_utils import shard_moe, get_moe_kwargs + # - check the config + if self._load_balancing_loss and not hasattr( + AutoConfig.from_pretrained(model_name), + "output_router_logits" + ): + warnings.warn( + "load_balancing_loss=True but " + "the model '{model_name}' config not have 'output_router_logits' " + "in its config, hence it might not support load balancing and " + "fallback to load_balancing_loss=False." + ) + self._load_balancing_loss = False + # this one does a forward patching on MLP, but needs to be fixed # properly as the load balancing loss is currently not properly # handled - update_mlp_registry() + update_mlp_registry( + self._moe_implementation, + self._mlp_version, + self._load_balancing_loss + ) # get additional parameters torch_dtype = kwargs.get("torch_dtype", torch.float32) @@ -101,6 +123,10 @@ def model_loader(self, model_name: str, **kwargs): model_name, **kwargs ) + # set this in the config, which will be picked up by the forward + # function to go into the load_balancing loss + model.config.output_router_logits = self._load_balancing_loss + rank, world_size = 0, 1 if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py index 5ff93b69..72cae3ad 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py @@ -7,6 +7,7 @@ from megablocks.layers.mlp import SparseMLP from .sparse_mlp2 import SparseMLPv2 from megablocks.layers.moe import ParallelMLP +import torch SPARSE_MLP_IMPL = { "v1": SparseMLP, @@ -19,6 +20,7 @@ def update_mlp_registry( mlp_type: str = 'sparse', mlp_version: str = 'v2', + load_balancing_loss: bool = False, ): # replace the registry to point to the the correct sparse implementation @@ -42,10 +44,11 @@ def forward(self, x, scores, expert_weights, top_experts): return x + self.bias # in this case we should be returning the router - # logits out of the MoeE forward. However, since - # the way the code is written now, it si difficult - # to extract these logits out, so at the moment, - # we return None as the placeholder. + # logits out of the MoE forward. + if load_balancing_loss: + return x, torch.log(scores) + + # otherwise just return None return x, None # replace the forward function. Willing to do this because ParallelMLP diff --git a/sample-configurations/moe-megablocks-sample-configuration.yaml b/sample-configurations/moe-megablocks-sample-configuration.yaml index 12815d98..b1ee14f9 100644 --- a/sample-configurations/moe-megablocks-sample-configuration.yaml +++ b/sample-configurations/moe-megablocks-sample-configuration.yaml @@ -40,3 +40,6 @@ plugins: # the MoE dropless implementation. Currently we only support "dropless_sparse", but # in the future we may support others moe_implementation: dropless_sparse + + # for load_balancing_loss + load_balancing_loss: false From 0cb273af45063235713c9f850fe1374f767716a4 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Wed, 21 Aug 2024 10:06:26 +0000 Subject: [PATCH 07/21] fix the benches Signed-off-by: Yu Chin Fabian Lim --- plugins/accelerated-moe/README.md | 2 +- .../src/fms_acceleration_moe/requirements-mb.txt | 1 + scripts/benchmarks/accelerator-config.json | 5 +++++ scripts/benchmarks/scenarios.yaml | 9 +++++++-- tox.ini | 5 +++++ 5 files changed, 19 insertions(+), 3 deletions(-) create mode 100644 plugins/accelerated-moe/src/fms_acceleration_moe/requirements-mb.txt create mode 100644 scripts/benchmarks/accelerator-config.json diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index fe4627c9..9ca607a1 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -15,5 +15,5 @@ Known Issues Currently databricks megablocks does not have a PyPi repository and does not have a proper release, so we have to install from the github repository as below. Please note that installing from github will require CUDA Toolkit to build. ``` -pip install git+https://github.com/databricks/megablocks.git@bce5d7b2aaf5038bc93b36f76c2baf51c2939bd2 +pip install -r requirements_mb.txt ``` \ No newline at end of file diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/requirements-mb.txt b/plugins/accelerated-moe/src/fms_acceleration_moe/requirements-mb.txt new file mode 100644 index 00000000..fbe690ff --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/requirements-mb.txt @@ -0,0 +1 @@ +pip install git+https://github.com/databricks/megablocks.git@bce5d7b2aaf5038bc93b36f76c2baf51c2939bd2 \ No newline at end of file diff --git a/scripts/benchmarks/accelerator-config.json b/scripts/benchmarks/accelerator-config.json new file mode 100644 index 00000000..7f736f97 --- /dev/null +++ b/scripts/benchmarks/accelerator-config.json @@ -0,0 +1,5 @@ +{ + "gradient_accumulation_kwargs": { + "sync_each_batch": true + } +} diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml index b3a9c300..9cb59fef 100644 --- a/scripts/benchmarks/scenarios.yaml +++ b/scripts/benchmarks/scenarios.yaml @@ -98,8 +98,13 @@ scenarios: framework_config: - moe-megablocks arguments: - learning_rate: 2e-4 - bf16: True + learning_rate: 5e-5 torch_dtype: bfloat16 + accelerator_config: scripts/benchmarks/accelerator-config.json + gradient_accumulation_steps: 16 + logging_steps: 1 + packing: False + adam_epsilon: 1e-8 + model_name_or_path: - 'mistralai/Mixtral-8x7B-Instruct-v0.1' diff --git a/tox.ini b/tox.ini index 52f9bdb3..785f8271 100644 --- a/tox.ini +++ b/tox.ini @@ -39,6 +39,11 @@ commands = python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-peft python -m fms_acceleration.cli install -e {toxinidir}/plugins/fused-ops-and-kernels python -m fms_acceleration.cli install -e {toxinidir}/plugins/attention_and_distributed_packing + python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-moe + + # need to install some optional dependencies + # - the megablocks dependency + pip install -r {toxinidir}/plugins/accelerated-moe/requirements-mb.txt # run the benchmark script bash scripts/run_benchmarks.sh {posargs:"1 2" "4 8" benchmark_outputs} From 9c463cb88525194d9b0554c8ec1db45978c4c186 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 22 Aug 2024 01:20:29 +0000 Subject: [PATCH 08/21] update benchmark logic to have empty framework_config Signed-off-by: Yu Chin Fabian Lim --- scripts/benchmarks/benchmark.py | 12 ++++++++++-- scripts/benchmarks/scenarios.yaml | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index d21b2fbe..ce66381a 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -360,10 +360,18 @@ def __init__(self, scenario: Dict, acceleration_config_map: Dict = None) -> None if key == "framework_config": # if acceleration_config_map is None, then do not do mapping if acceleration_config_map: + + # - we allow k to be None to indicate we do not wish to + # set a config for that matrix entry. However, we do not + # check for multiple None's, so be careful. val = [ - acceleration_config_map[k] + ( + acceleration_config_map[k] + if k is not None + else None + ) for k in val - if k in acceleration_config_map + if k in acceleration_config_map or k is None ] setattr(self, key, val) diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml index 9cb59fef..3dd25e4d 100644 --- a/scripts/benchmarks/scenarios.yaml +++ b/scripts/benchmarks/scenarios.yaml @@ -96,6 +96,7 @@ scenarios: - name: accelerated-moe-megablocks framework_config: + - # without acceleration - moe-megablocks arguments: learning_rate: 5e-5 @@ -105,6 +106,5 @@ scenarios: logging_steps: 1 packing: False adam_epsilon: 1e-8 - model_name_or_path: - 'mistralai/Mixtral-8x7B-Instruct-v0.1' From 26ca16ef54e7a8e144cda67190de050b5f20378a Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 22 Aug 2024 04:23:23 +0000 Subject: [PATCH 09/21] properly handle dtype when sharding models Signed-off-by: Yu Chin Fabian Lim --- plugins/accelerated-moe/README.md | 18 ++++++++++++------ .../megablocks_utils/shard_moe_utils.py | 12 ++++++++---- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index 9ca607a1..a27d3880 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -3,17 +3,23 @@ This library contains plugins to accelerate finetuning with the following optimizations: 1. Expert-Parallel MoE with Megablocks -## Known Issues with Megablocks Plugin Implementation +## Expert-Parallel MoE with Megablocks -Known Issues -- Currently we do not pass the data parallel `dp_mesh` to the `FSDP` constructor, so `FSDP` will always shard over the default process group (over world_size). -- Currently only supports loading `safetensor` MoE checkpoints. +Not all of the features of `megablocks` are being incorporated; listing down some of the restrictions of the current integration: +- curretnly not passing the data parallel `dp_mesh` to the `FSDP` constructor, so `FSDP` will always shard over the default process group (over world_size). +- now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed. +- only supports the *dropless sparse* MLPs in the megablocks package; the other variations like non-dropless and grouped computes are not currently integrated. +- the `shard_moe` may not scale well with larger models as the current implementation `torch.concat` all the expert weights together before passing to `torch.distributed` to be sharded. This is redundently done in all devices, so it is inefficient. -## Megablocks Dependencies +### Megablocks Dependencies -Currently databricks megablocks does not have a PyPi repository and does not have a proper release, so we have to install from the github repository as below. Please note that installing from github will require CUDA Toolkit to build. +Currently databricks megablocks does not have a PyPi repository and no proper release, so we have to install directly from Github, refer to instructions below. +- This has to be a manual install as PyPI will complain if included as an official plugin dependency. +- Since this is not a binary install, please note that CUDA Toolkit will be required to build some of the kernels used by megablocks. ``` +# this will install the megablocks from Github +# megablocks requires CUDA Toolkit to build. pip install -r requirements_mb.txt ``` \ No newline at end of file diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py index e5b6b9a0..2cb5dc66 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py @@ -182,9 +182,16 @@ def load_sharded_experts_onto_device( T = T.t() data.append(T) + # get the module we want to shard + name = weight_name.split('.') + path, name = ".".join(name[:-1]), name[-1] + mod = dmoe.get_submodule(path) + mod_dtype = getattr(mod, name).dtype + # the megablocks dmoe experts the expert features to be on DIM_EXPERT. # - concat on dim 0 and distribute - param = torch.concat(data, dim=DIM_EXPERT) + # - cast to the correct dtype for the module + param = torch.concat(data, dim=DIM_EXPERT).to(mod_dtype) if KEY_DMOE_ROUTER not in weight_name: param = torch.nn.Parameter( distribute_tensor(param, device_mesh, placements) @@ -196,9 +203,6 @@ def load_sharded_experts_onto_device( ) # register the sharded parameter onto the megablocks.dmoe - name = weight_name.split('.') - path, name = ".".join(name[:-1]), name[-1] - mod = dmoe.get_submodule(path) mod.register_parameter(name, param) def shard_moe( From 9da5c48458a23edba4aa07ef6612c7a4dcb2c549 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 22 Aug 2024 04:33:10 +0000 Subject: [PATCH 10/21] add fmt + lint Signed-off-by: Yu Chin Fabian Lim --- .github/workflows/build-and-publish.yml | 1 + .github/workflows/format.yml | 1 + README.md | 2 +- plugins/accelerated-moe/.isort.cfg | 10 + plugins/accelerated-moe/.pylintrc | 649 ++++++++++++++++++ .../src/fms_acceleration_moe/__init__.py | 3 +- .../framework_plugin_megablocks.py | 49 +- .../megablocks_utils/config_utils.py | 49 +- .../megablocks_utils/shard_moe_utils.py | 183 +++-- .../megablocks_utils/sparse_mlp2.py | 197 +++--- plugins/accelerated-moe/tests/__init__.py | 13 + .../tests/test_megablocks_plugin.py | 34 + plugins/accelerated-moe/tox.ini | 48 ++ 13 files changed, 1018 insertions(+), 221 deletions(-) create mode 100644 plugins/accelerated-moe/.isort.cfg create mode 100644 plugins/accelerated-moe/.pylintrc create mode 100644 plugins/accelerated-moe/tests/__init__.py create mode 100644 plugins/accelerated-moe/tests/test_megablocks_plugin.py create mode 100644 plugins/accelerated-moe/tox.ini diff --git a/.github/workflows/build-and-publish.yml b/.github/workflows/build-and-publish.yml index 307ade0e..9592fcfb 100644 --- a/.github/workflows/build-and-publish.yml +++ b/.github/workflows/build-and-publish.yml @@ -15,6 +15,7 @@ jobs: - "accelerated-peft" - "fused-ops-and-kernels" - "attention-and-distributed-packing" + - "accelerated-moe" permissions: id-token: write # IMPORTANT: this permission is mandatory for trusted publishing diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 90f7210a..441a84cd 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -30,6 +30,7 @@ jobs: - "accelerated-peft" - "fused-ops-and-kernels" - "attention-and-distributed-packing" + - "accelerated-moe" steps: - uses: actions/checkout@v4 diff --git a/README.md b/README.md index 1158550c..8bc4b974 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Plugin | Description | Depends | License | Status [accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface
AutoGPTQ | Apache 2.0
MIT | Alpha [fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta [attention-and-distributed-packing](./plugins/attention-and-distributed-packing/README.md) | Padding-Free Flash Attention Computation | flash-attn | Apache 2.0 | Beta - MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon +[accelerated-moe](./plugins/accelerated-moe/README.md) | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Beta ## Usage with FMS HF Tuning diff --git a/plugins/accelerated-moe/.isort.cfg b/plugins/accelerated-moe/.isort.cfg new file mode 100644 index 00000000..7d3762ec --- /dev/null +++ b/plugins/accelerated-moe/.isort.cfg @@ -0,0 +1,10 @@ +[settings] +profile=black +from_first=true +import_heading_future=Future +import_heading_stdlib=Standard +import_heading_thirdparty=Third Party +import_heading_firstparty=First Party +import_heading_localfolder=Local +known_firstparty= +known_localfolder=tuning \ No newline at end of file diff --git a/plugins/accelerated-moe/.pylintrc b/plugins/accelerated-moe/.pylintrc new file mode 100644 index 00000000..14a7a572 --- /dev/null +++ b/plugins/accelerated-moe/.pylintrc @@ -0,0 +1,649 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS,protobufs + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +# ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.9 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1100 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + # Added messages + use-symbolic-message-instead, + invalid-name, + missing-class-docstring, + missing-module-docstring, + missing-function-docstring, + consider-using-f-string, + inconsistent-return-statements, + no-member, + too-many-arguments, + too-many-locals, + too-many-branches, + too-many-statements, + cyclic-import, + too-few-public-methods, + protected-access, + fixme, + logging-format-interpolation, + logging-too-many-args, + attribute-defined-outside-init, + abstract-method, + pointless-statement, + wrong-import-order, + duplicate-code, + unbalanced-tuple-unpacking, + unused-argument + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=yes + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the 'python-enchant' package. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io \ No newline at end of file diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py index eca7a2c5..7a459ffc 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. -from .framework_plugin_megablocks import MegablocksMoEAccelerationPlugin \ No newline at end of file +# Local +from .framework_plugin_megablocks import MegablocksMoEAccelerationPlugin diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py index 56cdc3de..4c5c85cb 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py @@ -14,14 +14,15 @@ # Standard from typing import Dict -import torch import warnings # Third Party from fms_acceleration import AccelerationPlugin from transformers import AutoConfig, AutoModelForCausalLM +import torch +# pylint: disable=too-many-instance-attributes class MegablocksMoEAccelerationPlugin(AccelerationPlugin): require_packages = {"megablocks"} @@ -31,26 +32,24 @@ def __init__(self, configurations: Dict[str, Dict]): # arguments for configuring the mixture-of-experts model with defaults # shown below for Mixtral 7x8b - # - 1. component class + # - 1. component class self._moe_component_cls = self._check_config_and_maybe_check_values( key="training.moe.megablocks.moe_component_class", default="MixtralSparseMoeBlock", ) # - 2. gate_module_name self._gate_module_name = self._check_config_and_maybe_check_values( - key="training.moe.megablocks.moe_gate_module_name", - default="gate" + key="training.moe.megablocks.moe_gate_module_name", default="gate" ) # - 3. experts_module_name self._experts_module_name = self._check_config_and_maybe_check_values( - key="training.moe.megablocks.moe_experts_module_name", - default="experts" + key="training.moe.megablocks.moe_experts_module_name", default="experts" ) # - 4. mlp version self._mlp_version = self._check_config_and_maybe_check_values( key="training.moe.megablocks.moe_mlp_impl", values=["v1", "v2"], - default="v2" + default="v2", ) # for controlling the type of sharding @@ -90,13 +89,14 @@ def requires_custom_loading(self): def model_loader(self, model_name: str, **kwargs): # guarded + # Local + # pylint: disable=import-outside-toplevel from .megablocks_utils.config_utils import update_mlp_registry - from .megablocks_utils.shard_moe_utils import shard_moe, get_moe_kwargs + from .megablocks_utils.shard_moe_utils import get_moe_kwargs, shard_moe # - check the config if self._load_balancing_loss and not hasattr( - AutoConfig.from_pretrained(model_name), - "output_router_logits" + AutoConfig.from_pretrained(model_name), "output_router_logits" ): warnings.warn( "load_balancing_loss=True but " @@ -110,18 +110,14 @@ def model_loader(self, model_name: str, **kwargs): # properly as the load balancing loss is currently not properly # handled update_mlp_registry( - self._moe_implementation, - self._mlp_version, - self._load_balancing_loss + self._moe_implementation, self._mlp_version, self._load_balancing_loss ) # get additional parameters torch_dtype = kwargs.get("torch_dtype", torch.float32) # load the model - model = AutoModelForCausalLM.from_pretrained( - model_name, **kwargs - ) + model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) # set this in the config, which will be picked up by the forward # function to go into the load_balancing loss @@ -137,28 +133,26 @@ def model_loader(self, model_name: str, **kwargs): "Megablocks expert parallel only works for distributed training." ) - # shard the MOE, and store products required for + # shard the MOE, and store products required for # FSDP configuration - ( - dp_mesh, self._moe_component_module_names - ) = shard_moe( - model, - self._moe_component_cls, + # pylint: disable=unused-variable + (dp_mesh, self._moe_component_module_names) = shard_moe( + model, + self._moe_component_cls, checkpoint_name_or_path=model_name, rank=rank, world_size=world_size, - ep_size=self._ep_size, + ep_size=self._ep_size, moe_kwargs=get_moe_kwargs( - model.config, + model.config, fp16=torch_dtype == torch.float16, bf16=torch_dtype == torch.bfloat16, ), shared_mesh_dim=self._shard_along_dp, - router_name=self._gate_module_name, + router_name=self._gate_module_name, expert_name=self._experts_module_name, - ) - # NOTE: Currently, it is a bit troublesome to pass the device_mesh to + # NOTE: Currently, it is a bit troublesome to pass the device_mesh to # the FSDP constructor, so we do not do that. # - therefore FSDP will always shard on world_size over the default process # group @@ -182,6 +176,7 @@ def get_callbacks_and_ready_for_train( return callbacks + # register AccelerationPlugin.register_plugin( MegablocksMoEAccelerationPlugin, diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py index 72cae3ad..31e73489 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py @@ -1,33 +1,39 @@ -# utilities to update megablocks to register the MLP_v2 that -# handles gate, up, down projections +# utilities to update megablocks to register various things +# e.g, the MLP_v2 that handles gate, up, down projections -from megablocks.layers.dmlp_registry import _REGISTRY - - -from megablocks.layers.mlp import SparseMLP -from .sparse_mlp2 import SparseMLPv2 -from megablocks.layers.moe import ParallelMLP +# Third Party import torch -SPARSE_MLP_IMPL = { - "v1": SparseMLP, - "v2": SparseMLPv2, -} # this function ensures that the megablocks packaged is configured to use # the correct SparseMLP implementation # - at the moment not considering the GroupedMLP implementations def update_mlp_registry( - mlp_type: str = 'sparse', - mlp_version: str = 'v2', + mlp_type: str = "sparse", + mlp_version: str = "v2", load_balancing_loss: bool = False, ): + # guarded + # Third Party + # pylint: disable=import-error,import-outside-toplevel + from megablocks.layers.dmlp_registry import _REGISTRY + from megablocks.layers.mlp import SparseMLP + from megablocks.layers.moe import ParallelMLP + + # Local + from .sparse_mlp2 import SparseMLPv2 + + SPARSE_MLP_IMPL = { + "v1": SparseMLP, + "v2": SparseMLPv2, + } # replace the registry to point to the the correct sparse implementation - if mlp_type == 'sparse': - assert mlp_version in SPARSE_MLP_IMPL, \ - f"Megablocks only support sparse mlp versions: {','.join(SPARSE_MLP_IMPL.keys())}" - _REGISTRY['mlp']['sparse'] = SPARSE_MLP_IMPL[mlp_version] + if mlp_type == "sparse": + assert ( + mlp_version in SPARSE_MLP_IMPL + ), f"Megablocks only support sparse mlp versions: {','.join(SPARSE_MLP_IMPL.keys())}" + _REGISTRY["mlp"]["sparse"] = SPARSE_MLP_IMPL[mlp_version] else: raise NotImplementedError("Currently only supports sparse MLP implementations.") @@ -44,10 +50,10 @@ def forward(self, x, scores, expert_weights, top_experts): return x + self.bias # in this case we should be returning the router - # logits out of the MoE forward. + # logits out of the MoE forward. if load_balancing_loss: return x, torch.log(scores) - + # otherwise just return None return x, None @@ -56,7 +62,6 @@ def forward(self, x, scores, expert_weights, top_experts): # 1. we do not care about reversing the patch # 2. we have control on where this is called, and we know to call it # before our code accesses this function. Hence, we view this as - # a hardcoded modification to the megablocks package more than a + # a hardcoded modification to the megablocks package more than a # patch. ParallelMLP.forward = forward - \ No newline at end of file diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py index 2cb5dc66..677bdabc 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py @@ -1,30 +1,31 @@ - -import torch -from transformers import PretrainedConfig -from transformers.activations import ACT2FN -from typing import Tuple, Dict, List, Type, Union -from torch.distributed._tensor import Placement, Replicate, Shard, distribute_tensor -from torch.distributed._tensor.device_mesh import init_device_mesh, DeviceMesh -import os -from copy import copy -from tqdm import tqdm - -from safetensors import safe_open -import json, re +# Standard from collections import defaultdict +from contextlib import ExitStack +from copy import copy +from typing import Dict, List, Tuple, Type, Union +import json +import os +import re +# Third Party from accelerate import init_empty_weights +from safetensors import safe_open +from torch.distributed._tensor import Placement, Replicate, Shard, distribute_tensor +from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh +from tqdm import tqdm +from transformers import PretrainedConfig +from transformers.activations import ACT2FN +import torch -from contextlib import ExitStack - -FILE_SAFETENSOR_INDEX = 'model.safetensors.index.json' -KEY_DATA_PARALLEL = 'data_parallel' -KEY_EXPERT_PARALLEL = 'expert_parallel' +FILE_SAFETENSOR_INDEX = "model.safetensors.index.json" +KEY_DATA_PARALLEL = "data_parallel" +KEY_EXPERT_PARALLEL = "expert_parallel" DIM_EXPERT = 0 # these depend on the namings in the dMOE -KEY_DMOE_ROUTER = 'router.layer.weight' -KEY_DMOE_EXPERTS = 'experts.mlp' +KEY_DMOE_ROUTER = "router.layer.weight" +KEY_DMOE_EXPERTS = "experts.mlp" + def get_moe_kwargs( config: PretrainedConfig, @@ -45,12 +46,14 @@ def get_moe_kwargs( "bf16": bf16, } + # trick to get the resolved cache file to acccess the safetensor # NOTE: this does not work if _dict_from_json_file, like GGUF files def get_resolved_checkpoint_location(model_name_or_path: str): result = None _old_func = PretrainedConfig._dict_from_json_file + def _dict_from_json_file(resolved_config_file): nonlocal result result = resolved_config_file @@ -69,43 +72,43 @@ def _dict_from_json_file(resolved_config_file): # - the keys point to modules found in megablocks.layers.dmoe.dMoE, the distributed # expert module provided by megablocks. # - the values are tuples pointing to the keys within the checkpoint file. -# +# # Example: if prefix="module.layers.0" and instance_name="block_sparse_moe", then a dictionary # of the following will be returned: # { # 'experts.mlp.w1': [ # ( -# 'model.layers.0.block_sparse_moe.experts.0.w1.weight', +# 'model.layers.0.block_sparse_moe.experts.0.w1.weight', # 'model-00001-of-00019.safetensors' -# ), +# ), # ( -# 'model.layers.0.block_sparse_moe.experts.1.w1.weight', +# 'model.layers.0.block_sparse_moe.experts.1.w1.weight', # 'model-00001-of-00019.safetensors' -# ), +# ), # ... # ] -# 'experts.mlp.w2': [...], +# 'experts.mlp.w2': [...], # 'experts.mlp.w3': [...], # 'router.layer.weight': [ # ( -# 'model.layers.0.block_sparse_moe.gate.weight', +# 'model.layers.0.block_sparse_moe.gate.weight', # 'model-00001-of-00019.safetensors' # ) # ] # } def get_checkpoint_meta_from_sharded_safetensor( weight_map: Dict, - prefix: str, # e.g., 'model.layers.0, - instance_name: str, # e.g., block_sparse_moe - router_name: str = 'gate', # e.g., named "gate" within block_sparse_moe - expert_name: str = 'experts' # e.g., named "experts" within block_sparse_moe + prefix: str, # e.g., 'model.layers.0, + instance_name: str, # e.g., block_sparse_moe + router_name: str = "gate", # e.g., named "gate" within block_sparse_moe + expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe ) -> Dict[str, List[Tuple]]: # insert in order def _insert(L: List, i: int, v): n = len(L) if i < n: L[i] = v - return + return n = i - n + 1 while n > 0: @@ -126,10 +129,8 @@ def _insert(L: List, i: int, v): # - gate.weight # - experts.0.w1.weight rel_k = k.replace(prefix, "") - m = re.match( - f'({router_name}|{expert_name})\.?(\d+)?\.?(\w+)?\.weight', - rel_k - ) + # pylint: disable=anomalous-backslash-in-string + m = re.match(f"({router_name}|{expert_name})\.?(\d+)?\.?(\w+)?\.weight", rel_k) if m is None: raise ValueError( f"Unable to handle key '{k}' with provided router_name " @@ -140,7 +141,7 @@ def _insert(L: List, i: int, v): elif m.group(1) == expert_name: index = int(m.group(2)) mod = m.group(3) - _insert(_map[f'{KEY_DMOE_EXPERTS}.{mod}'], index, (k, stfile)) + _insert(_map[f"{KEY_DMOE_EXPERTS}.{mod}"], index, (k, stfile)) if len(_map) == 0: raise ValueError( @@ -149,16 +150,17 @@ def _insert(L: List, i: int, v): return _map -# this function will load the sharded experts onto the device. + +# this function will load the sharded experts onto the device. # - this assumes that the "dmoe" module is the megablocks.layers.dmoe.dMoE distributed # implementation of the mixture of experts. def load_sharded_experts_onto_device( dmoe: torch.nn.Module, directory: str, checkpoint_metadata: Dict[str, List[Tuple]], - device_mesh: DeviceMesh, + device_mesh: DeviceMesh, placements: Placement, - expert_name: str = 'experts' # e.g., named "experts" within block_sparse_moe + expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe ): # typically they all should be same file, but to play safe, load the checkpoint file onto # cpu first since we may not need all weights in that file. @@ -168,9 +170,11 @@ def load_sharded_experts_onto_device( for _, fi in vs: if fi not in files: files[fi] = stack.enter_context( - safe_open(os.path.join(directory, fi), framework='pt', device='cpu') + safe_open( + os.path.join(directory, fi), framework="pt", device="cpu" + ) ) - + # go by one weight at a time. # - weight_name: points to megablocks.dmoe for weight_name, vs in checkpoint_metadata.items(): @@ -183,7 +187,7 @@ def load_sharded_experts_onto_device( data.append(T) # get the module we want to shard - name = weight_name.split('.') + name = weight_name.split(".") path, name = ".".join(name[:-1]), name[-1] mod = dmoe.get_submodule(path) mod_dtype = getattr(mod, name).dtype @@ -198,25 +202,24 @@ def load_sharded_experts_onto_device( ) else: # - do not shard the router but load onto device as well - param = torch.nn.Parameter( - param.to(torch.cuda.current_device()) - ) - + param = torch.nn.Parameter(param.to(torch.cuda.current_device())) + # register the sharded parameter onto the megablocks.dmoe mod.register_parameter(name, param) + def shard_moe( model: torch.nn.Module, - moe_cls: Union[str,Type], + moe_cls: Union[str, Type], checkpoint_name_or_path: str, - rank: int, + rank: int, world_size: int, moe_kwargs: Dict, - device_type: str = 'cuda', + device_type: str = "cuda", key_dp: str = KEY_DATA_PARALLEL, key_ep: str = KEY_EXPERT_PARALLEL, - router_name: str = 'gate', - expert_name: str = 'experts', + router_name: str = "gate", + expert_name: str = "experts", shared_mesh_dim: bool = True, ep_size: int = 1, ): @@ -230,11 +233,12 @@ def shard_moe( The sharding has two modes, and depends on world_size and number_of_experts the model has. This depends on the setting "shared_mesh_dim" to True or False: - - if True: then dp and ep will happen on the same device_mesh dimension. This is only possible - if world_size divides number_of_experts (which requires world_size < num_of_experts). - - if False: then dp and ep will be seperate device_mesh dimensions. The ep_size will be determined - by the argument passed in (which needs to be properly set ep_size > 1; the default - value will raise an assertion). + - if True: then dp and ep will happen on the same device_mesh dimension. + This is only possible if world_size divides number_of_experts + (which requires world_size < num_of_experts). + - if False: then dp and ep will be seperate device_mesh dimensions. The ep_size will be + determined by the argument passed in (which needs to be properly set ep_size > 1; + the default value will raise an assertion). Parameters: @@ -243,7 +247,7 @@ def shard_moe( checkpoint_name_or_path (str): name or path pointing to the weight checkpoint. rank (int): rank of the current device. world_size (int): total world size. - moe_kwargs (dict): kwargs to be passed to construct megablocks.layers.arguments for + moe_kwargs (dict): kwargs to be passed to construct megablocks.layers.arguments for constructing the megablocks.layer.dmoe.dMOE. device_type (str): the current device to load the sharded model into. key_dp (str): name of the data parallel mesh @@ -255,7 +259,9 @@ def shard_moe( """ # guarded import - from megablocks.layers import dmoe, arguments + # Third Party + # pylint: disable=import-error, import-outside-toplevel + from megablocks.layers import arguments, dmoe if shared_mesh_dim: # if sharing mesh with dp, then the ep_size must be the world_size @@ -264,22 +270,22 @@ def shard_moe( else: # - moe_kwargs is the constructed by get_moe_kwargs above - _num_experts = moe_kwargs['moe_num_experts'] + _num_experts = moe_kwargs["moe_num_experts"] assert _num_experts % ep_size == 0, ( f"ep_shard factor '{ep_size}' does not divide " f"number of experts '{_num_experts}'." ) - assert ep_size > 1, "expert_parallel dimension must be set larger than 1" - assert world_size % ep_size == 0, ( - f"world_size ({world_size}) not divisible by ep_size ({ep_size})." - ) + assert ep_size > 1, "expert_parallel dimension must be set larger than 1" + assert ( + world_size % ep_size == 0 + ), f"world_size ({world_size}) not divisible by ep_size ({ep_size})." # this function will shard the MOE on this rank - device = torch.device(f'cuda:{rank}') + device = torch.device(f"cuda:{rank}") if shared_mesh_dim: - # in this case we will have a 1D mesh and collapse the + # in this case we will have a 1D mesh and collapse the # expert parallel with data_parallel device_mesh = init_device_mesh( @@ -291,7 +297,7 @@ def shard_moe( placements: List[Placement] = [Shard(DIM_EXPERT)] else: # in this case it will distribute experts on a different - # mesh dimension than dp. + # mesh dimension than dp. # - this will achieve the effect that the expert sharding can be # hierachical (e.g., can be over a slower network plane since # the communication overhead is less @@ -305,8 +311,9 @@ def shard_moe( placements: List[Placement] = [Replicate(), Shard(DIM_EXPERT)] mp_dmoe_args = arguments.Arguments( - **moe_kwargs, device=device, - expert_parallel_group=device_mesh[key_ep].get_group(0) + **moe_kwargs, + device=device, + expert_parallel_group=device_mesh[key_ep].get_group(0), ) assert mp_dmoe_args.moe_num_experts % ep_size == 0, ( @@ -319,28 +326,24 @@ def shard_moe( # parent_mod: (child_instance_name, [list of fqdn keys]) found = {} for name, mod in model.named_modules(): - name = name.split('.') + name = name.split(".") parent, child = ".".join(name[:-1]), name[-1] # check the module depending if moe_cls is a str or class if ( - mod.__class__.__name__ == moe_cls if - isinstance(moe_cls, str) else - isinstance(mod, moe_cls) + mod.__class__.__name__ == moe_cls + if isinstance(moe_cls, str) + else isinstance(mod, moe_cls) ): - fqdn_keys = [ # all params, including childs' - f'{parent}.{child}.{n}' - for n, _ in mod.named_parameters() + fqdn_keys = [ # all params, including childs' + f"{parent}.{child}.{n}" for n, _ in mod.named_parameters() ] # check if there are any biases in any of the experts # if there are biases # Assumption: assume that if one expert has bias,then the others # will have it to - has_bias = any( - expert_name in k and k.endswith('bias') - for k in fqdn_keys - ) + has_bias = any(expert_name in k and k.endswith("bias") for k in fqdn_keys) found[parent] = (child, fqdn_keys, has_bias) @@ -350,17 +353,16 @@ def shard_moe( # - most MOE models should be used using this checkpoint format try: loc = get_resolved_checkpoint_location(checkpoint_name_or_path) - with open(os.path.join(loc, FILE_SAFETENSOR_INDEX)) as f: + with open(os.path.join(loc, FILE_SAFETENSOR_INDEX), encoding="utf-8") as f: index = json.load(f) # e.g., prefix: 'model.layers.0', # module_name: 'block_sparse_moe' - for prefix, (module_name, relevant_keys, has_bias) in tqdm( - found.items(), disable=(rank > 0), desc='Sharding MoE' + for prefix, (module_name, _, has_bias) in tqdm( + found.items(), disable=(rank > 0), desc="Sharding MoE" ): checkpoint_metadata = get_checkpoint_meta_from_sharded_safetensor( - index['weight_map'], prefix, module_name, - router_name, expert_name + index["weight_map"], prefix, module_name, router_name, expert_name ) _args = copy(mp_dmoe_args) @@ -368,11 +370,10 @@ def shard_moe( # - will replace the MoE module with the megablocks sharded dMoE with init_empty_weights(): - mp_dmoe = dmoe.dMoE(_args) # drop in replacement for now + mp_dmoe = dmoe.dMoE(_args) # drop in replacement for now load_sharded_experts_onto_device( - mp_dmoe, loc, checkpoint_metadata, - device_mesh, placements, expert_name + mp_dmoe, loc, checkpoint_metadata, device_mesh, placements, expert_name ) parent = model.get_submodule(prefix) setattr(parent, module_name, mp_dmoe) @@ -384,8 +385,6 @@ def shard_moe( raise ValueError( f"Unable to load checkpoint_path '{checkpoint_name_or_path}'. " "Currently only support non-GGUF safetensor checkpoints. " - f": {e}" - ) - + ) from e - return device_mesh[key_dp], moe_module_names \ No newline at end of file + return device_mesh[key_dp], moe_module_names diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py index 77b2e42c..8d3871a5 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py @@ -1,89 +1,130 @@ +# Third Party import torch -from megablocks.layers import common -from megablocks.layers.arguments import Arguments -from megablocks.layers import mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.mlp import ( - create_dmoe_expert_weights, scale_gradient, - resolve_dtensor -) -import stk -# This is the different MLP class used for models that have up_proj, down_proj -# and gate_proj like Mixtral -class SparseMLPv2(torch.nn.Module): +try: + # definition is guarded, intended only when + # megablocks is available - def __init__(self, args : Arguments): - super().__init__() - self.args = args - self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + # Third Party + # pylint: disable=import-error + from megablocks.layers import common, mpu + from megablocks.layers.activation_fn import act_fn + from megablocks.layers.arguments import Arguments + from megablocks.layers.mlp import ( + create_dmoe_expert_weights, + resolve_dtensor, + scale_gradient, + ) + import stk + # This is the different MLP class used for models that have up_proj, down_proj + # and gate_proj like Mixtral + class SparseMLPv2(torch.nn.Module): - self.w1 = torch.nn.Parameter(torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) - self.w2 = torch.nn.Parameter(torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) - self.w3 = torch.nn.Parameter(torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank( + args + ) * mpu.features_per_rank(args) - with torch.no_grad(): - self.w1.copy_(create_dmoe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.init_method)) - self.w2.copy_(create_dmoe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.output_layer_init_method)) - self.w3.copy_(create_dmoe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.output_layer_init_method)) + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + ) + self.w3 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + ) - self._should_set_parallelism_attribute = args.moe_expert_model_parallelism - mpu.set_expert_model_parallel_attributes( - self.w1, self._should_set_parallelism_attribute) - mpu.set_expert_model_parallel_attributes( - self.w2, self._should_set_parallelism_attribute) - mpu.set_expert_model_parallel_attributes( - self.w3, self._should_set_parallelism_attribute) + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ) + ) + self.w3.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ) + ) - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args) + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, self._should_set_parallelism_attribute + ) + mpu.set_expert_model_parallel_attributes( + self.w2, self._should_set_parallelism_attribute + ) + mpu.set_expert_model_parallel_attributes( + self.w3, self._should_set_parallelism_attribute + ) - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args) - def forward(self, hidden_states, topo): - w1, w2, w3 = ( - self.scale_grad(self.w1), self.scale_grad(self.w2), - self.scale_grad(self.w3) - ) - w1, w2, w3 = ( - resolve_dtensor(w1), resolve_dtensor(w2), resolve_dtensor(w3) - ) + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) - # Perform the expert computation - hidden_states = stk.Matrix( # type: ignore - topo.size(), - act_fn( - stk.ops.sdd(hidden_states, w1.t(), topo), - self.args.activation_fn - ).data * stk.ops.sdd(hidden_states, w3.t(), topo).data, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - return stk.ops.dsd(hidden_states, w2) \ No newline at end of file + def forward(self, hidden_states, topo): + w1, w2, w3 = ( + self.scale_grad(self.w1), + self.scale_grad(self.w2), + self.scale_grad(self.w3), + ) + w1, w2, w3 = (resolve_dtensor(w1), resolve_dtensor(w2), resolve_dtensor(w3)) + + # Perform the expert computation + hidden_states = stk.Matrix( # type: ignore + topo.size(), + act_fn( + stk.ops.sdd(hidden_states, w1.t(), topo), self.args.activation_fn + ).data + * stk.ops.sdd(hidden_states, w3.t(), topo).data, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + return stk.ops.dsd(hidden_states, w2) + +except ImportError: + pass diff --git a/plugins/accelerated-moe/tests/__init__.py b/plugins/accelerated-moe/tests/__init__.py new file mode 100644 index 00000000..38a9531e --- /dev/null +++ b/plugins/accelerated-moe/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/plugins/accelerated-moe/tests/test_megablocks_plugin.py b/plugins/accelerated-moe/tests/test_megablocks_plugin.py new file mode 100644 index 00000000..646e0a2b --- /dev/null +++ b/plugins/accelerated-moe/tests/test_megablocks_plugin.py @@ -0,0 +1,34 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import os + +# Third Party +from fms_acceleration.utils import instantiate_framework, read_configuration + +# First Party +from fms_acceleration_moe import MegablocksMoEAccelerationPlugin + +# configuration +DIRNAME = os.path.dirname(__file__) +CONFIG_PATH_MEGABLOCKS = os.path.join(DIRNAME, "../configs/megablocks.yaml") + + +def test_framework_installs_aadp_padding_free_plugin(): + with instantiate_framework( + read_configuration(CONFIG_PATH_MEGABLOCKS), require_packages_check=False + ) as framework: + for plugin in framework.active_plugins: + assert isinstance(plugin[1], MegablocksMoEAccelerationPlugin) diff --git a/plugins/accelerated-moe/tox.ini b/plugins/accelerated-moe/tox.ini new file mode 100644 index 00000000..811f1329 --- /dev/null +++ b/plugins/accelerated-moe/tox.ini @@ -0,0 +1,48 @@ +[tox] +envlist = py, lint + +[testenv] +deps = + pytest>=7 + -e {toxinidir} +skip_install = true +commands = + + # install the dependencies here to ensure + # the order + pip install -e {toxinidir}/../framework + pytest {posargs:tests} + +[testenv:lint] +description = run linters +skip_install = false +deps = + -e {toxinidir}/../framework + pylint>=2.16.2,<=3.1.0 +commands = + pylint src tests +allowlist_externals = pylint + +[testenv:fmt] +description = format +skip_install = true +deps = + black>=22.12 + isort>=5.11 +commands = + black {posargs:.} + isort {posargs:.} + +[testenv:build] +description = build wheel +deps = + build +commands = python -m build -w +skip_install = True + +[testenv:twinecheck] +description = check wheel +deps = + twine +commands = twine check dist/* +skip_install = True \ No newline at end of file From 36b698798a09ada965a4f2c604198cc51aa6a057 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 22 Aug 2024 14:24:51 +0000 Subject: [PATCH 11/21] fix replication on router and checkpointing Signed-off-by: Yu Chin Fabian Lim --- .../megablocks_utils/checkpoint_utils.py | 102 ++++++++++++++++++ .../megablocks_utils/config_utils.py | 68 +++++++++++- .../megablocks_utils/shard_moe_utils.py | 30 ++++-- .../megablocks_utils/sparse_mlp2.py | 14 +++ scripts/benchmarks/accelerate.yaml | 2 +- 5 files changed, 207 insertions(+), 9 deletions(-) create mode 100644 plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py new file mode 100644 index 00000000..036e083f --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py @@ -0,0 +1,102 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import os + +# Third Party +from accelerate.logging import get_logger +from accelerate.utils.constants import ( + FSDP_MODEL_NAME, + OPTIMIZER_NAME, +) +from torch.distributed.checkpoint.state_dict import ( + get_state_dict, + set_state_dict, +) +import torch.distributed.checkpoint as dcp + +logger = get_logger(__name__) + +MODEL_INDEX = None + +def save_fsdp_model( + fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False +): + + # pylint: disable=global-statement + global MODEL_INDEX + MODEL_INDEX = model_index + +def save_fsdp_optimizer( + fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0 +): + (model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer) + + ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + os.makedirs(ckpt_model, exist_ok=True) + logger.info(f"Saving model to {ckpt_model}") + dcp.save({"model": model_state_dict}, checkpoint_id=ckpt_model) + logger.info(f"Model saved to {ckpt_model}") + + ckpt_opt = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") + os.makedirs(ckpt_opt, exist_ok=True) + logger.info(f"Saving Optimizer state to {ckpt_opt}") + dcp.save({"optimizer": optimizer_state_dict}, checkpoint_id=ckpt_opt) + logger.info(f"Optimizer state saved in {ckpt_opt}") + + +# accelerate.utils.fsdp_utils.py +def load_fsdp_model( + fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False +): + # pylint: disable=global-statement + global MODEL_INDEX + MODEL_INDEX = model_index + + +# accelerate.utils.fsdp_utils.py +def load_fsdp_optimizer( + fsdp_plugin, + accelerator, + optimizer, + model, + input_dir, + optimizer_index=0, + adapter_only=False, +): + + model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) + ckpt_model = os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + dcp.load({"model": model_state_dict}, checkpoint_id=ckpt_model) + ckpt_opt = os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") + dcp.load({"optimizer": optimizer_state_dict}, checkpoint_id=ckpt_opt) + set_state_dict( + model, + optimizer, + model_state_dict=model_state_dict, + optim_state_dict=optimizer_state_dict, + ) + + # HACK for now + # - if seems that if params is empty, then the loading has someo + # problems + # - so for now, we just dump some random defaults + for group in optimizer.param_groups: + if len(group["params"]) == 0: + group["betas"] = (0.9, 0.999) + group["lr"] = 0.0 + group["initial_lr"] = 0.0 + group["eps"] = 1e-8 + group["weight_decay"] = 0.0 diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py index 31e73489..1234d383 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/config_utils.py @@ -1,8 +1,23 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # utilities to update megablocks to register various things # e.g, the MLP_v2 that handles gate, up, down projections # Third Party import torch +import torch.nn.functional as F # this function ensures that the megablocks packaged is configured to use @@ -17,8 +32,9 @@ def update_mlp_registry( # Third Party # pylint: disable=import-error,import-outside-toplevel from megablocks.layers.dmlp_registry import _REGISTRY - from megablocks.layers.mlp import SparseMLP + from megablocks.layers.mlp import SparseMLP, resolve_dtensor from megablocks.layers.moe import ParallelMLP + from megablocks.layers.router import LearnedRouter, _uniform_expert_assignment # Local from .sparse_mlp2 import SparseMLPv2 @@ -65,3 +81,53 @@ def forward(self, x, scores, expert_weights, top_experts): # a hardcoded modification to the megablocks package more than a # patch. ParallelMLP.forward = forward + + # for the router + # - need to resolve the dtensor since we had replicated the router + # weights + def forward_router(self, x): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + _weight = resolve_dtensor(self.layer.weight) + _bias = None if self.layer.bias is None else resolve_dtensor(self.layer.bias) + # pylint: disable=not-callable + scores = F.linear(x.view(-1, x.shape[-1]), _weight, _bias).softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) + if self.args.uniform_expert_assignment + else expert_indices + ) + return scores, expert_weights, expert_indices + + # replace the forward function in the router + # - same as above + LearnedRouter.forward = forward_router + + # Third Party + from fms_acceleration.model_patcher import patch_target_module + + # Local + from .checkpoint_utils import ( + load_fsdp_model, + load_fsdp_optimizer, + save_fsdp_model, + save_fsdp_optimizer, + ) + + patch_target_module("transformers.trainer.save_fsdp_model", save_fsdp_model) + patch_target_module("transformers.trainer.save_fsdp_optimizer", save_fsdp_optimizer) + patch_target_module("transformers.trainer.load_fsdp_model", load_fsdp_model) + patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py index 677bdabc..07971fc7 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py @@ -1,3 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Standard from collections import defaultdict from contextlib import ExitStack @@ -196,13 +210,15 @@ def load_sharded_experts_onto_device( # - concat on dim 0 and distribute # - cast to the correct dtype for the module param = torch.concat(data, dim=DIM_EXPERT).to(mod_dtype) - if KEY_DMOE_ROUTER not in weight_name: - param = torch.nn.Parameter( - distribute_tensor(param, device_mesh, placements) - ) - else: - # - do not shard the router but load onto device as well - param = torch.nn.Parameter(param.to(torch.cuda.current_device())) + + _placements = placements + if KEY_DMOE_ROUTER in weight_name: + # - the router needs to be replicated + _placements = [Replicate() for _ in range(len(placements))] + + param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, _placements) + ) # register the sharded parameter onto the megablocks.dmoe mod.register_parameter(name, param) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py index 8d3871a5..439eaaf0 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/sparse_mlp2.py @@ -1,3 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Third Party import torch diff --git a/scripts/benchmarks/accelerate.yaml b/scripts/benchmarks/accelerate.yaml index f70d74fa..f3908470 100644 --- a/scripts/benchmarks/accelerate.yaml +++ b/scripts/benchmarks/accelerate.yaml @@ -30,7 +30,7 @@ fsdp_config: # 3 is NO_SHARD, effectively disabling FSDP # 4, 5 are HYBRID_ modes for multi-node training only. - fsdp_state_dict_type: FULL_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3) + fsdp_state_dict_type: SHARDED_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3) # 2 is LOCAL_STATE_DICT where parameters are still flattened # 3 is efficient, but requires know-how to use the shared checkpoint. From 5be0948a92080298a064a6f5a11edf78795e832d Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 23 Aug 2024 09:54:40 +0000 Subject: [PATCH 12/21] more fixes on checkpointing Signed-off-by: Yu Chin Fabian Lim --- plugins/accelerated-moe/README.md | 1 + .../megablocks_utils/checkpoint_utils.py | 78 +++++++++++++++---- 2 files changed, 65 insertions(+), 14 deletions(-) diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index a27d3880..1836b7d0 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -10,6 +10,7 @@ Not all of the features of `megablocks` are being incorporated; listing down som - now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed. - only supports the *dropless sparse* MLPs in the megablocks package; the other variations like non-dropless and grouped computes are not currently integrated. - the `shard_moe` may not scale well with larger models as the current implementation `torch.concat` all the expert weights together before passing to `torch.distributed` to be sharded. This is redundently done in all devices, so it is inefficient. +- currently only supports `StateDictType.SHARDED_STATE_DICT` because the implementation uses `DTensors` which have limited support for full state dicts. However for efficiency considerations, sharded state dicts are the most efficient. ### Megablocks Dependencies diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py index 036e083f..d4363e7a 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/checkpoint_utils.py @@ -17,47 +17,75 @@ # Third Party from accelerate.logging import get_logger -from accelerate.utils.constants import ( - FSDP_MODEL_NAME, - OPTIMIZER_NAME, -) -from torch.distributed.checkpoint.state_dict import ( - get_state_dict, - set_state_dict, +from accelerate.utils.constants import FSDP_MODEL_NAME, OPTIMIZER_NAME +from torch.distributed.checkpoint.default_planner import ( + DefaultLoadPlanner, + DefaultSavePlanner, ) +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType import torch.distributed.checkpoint as dcp logger = get_logger(__name__) +# - variable to capture the model variable +# in the save/load model calls MODEL_INDEX = None +# Below are rewrite of functions for megablocks + + +# rewrite of func from accelerate.utils.fsdp_utils.py +# - empty function, as main logic is in the optimizer call +# save_fsdp_optimizer (see below). def save_fsdp_model( fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False ): - # pylint: disable=global-statement global MODEL_INDEX MODEL_INDEX = model_index + +# rewrite of func from accelerate.utils.fsdp_utils.py +# - saves both model and optimizer def save_fsdp_optimizer( fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0 ): + + if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT: + raise NotImplementedError( + "Checkpointing for megablocks only enabled for sharded state dict." + ) + + # get the state dicts for model and optimize (model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer) + # - save model ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") os.makedirs(ckpt_model, exist_ok=True) logger.info(f"Saving model to {ckpt_model}") - dcp.save({"model": model_state_dict}, checkpoint_id=ckpt_model) + dcp.save( + state_dict={"model": model_state_dict}, + storage_writer=dcp.FileSystemWriter(ckpt_model), + planner=DefaultSavePlanner(), + ) logger.info(f"Model saved to {ckpt_model}") + # - save optimizer ckpt_opt = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") os.makedirs(ckpt_opt, exist_ok=True) logger.info(f"Saving Optimizer state to {ckpt_opt}") - dcp.save({"optimizer": optimizer_state_dict}, checkpoint_id=ckpt_opt) + dcp.save( + state_dict={"optimizer": optimizer_state_dict}, + storage_writer=dcp.FileSystemWriter(ckpt_opt), + planner=DefaultSavePlanner(), + ) logger.info(f"Optimizer state saved in {ckpt_opt}") -# accelerate.utils.fsdp_utils.py +# rewrite of func from accelerate.utils.fsdp_utils.py +# - empty function, as main logic is in the optimizer call +# load_fsdp_optimizer (see below). def load_fsdp_model( fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False ): @@ -66,7 +94,8 @@ def load_fsdp_model( MODEL_INDEX = model_index -# accelerate.utils.fsdp_utils.py +# rewrite of func from accelerate.utils.fsdp_utils.py +# - loads both model and optimizer def load_fsdp_optimizer( fsdp_plugin, accelerator, @@ -77,11 +106,32 @@ def load_fsdp_optimizer( adapter_only=False, ): + accelerator.wait_for_everyone() + if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT: + raise NotImplementedError( + "Checkpointing for megablocks only enabled for sharded state dict." + ) + + # - get the state dicts model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) + + # - load the model state dict ckpt_model = os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") - dcp.load({"model": model_state_dict}, checkpoint_id=ckpt_model) + dcp.load( + state_dict={"model": model_state_dict}, + storage_reader=dcp.FileSystemReader(ckpt_model), + planner=DefaultLoadPlanner(), + ) + + # - load the optimizer state dict ckpt_opt = os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") - dcp.load({"optimizer": optimizer_state_dict}, checkpoint_id=ckpt_opt) + dcp.load( + state_dict={"optimizer": optimizer_state_dict}, + storage_reader=dcp.FileSystemReader(ckpt_opt), + planner=DefaultLoadPlanner(), + ) + + # - set the state dicts set_state_dict( model, optimizer, From e00fcd06cd1a3cfac2663715adfd467cf0abd29e Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 23 Aug 2024 10:05:19 +0000 Subject: [PATCH 13/21] update readme Signed-off-by: Yu Chin Fabian Lim --- plugins/accelerated-moe/README.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index 1836b7d0..c187cf20 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -3,6 +3,19 @@ This library contains plugins to accelerate finetuning with the following optimizations: 1. Expert-Parallel MoE with Megablocks +## Plugins + +Plugin | Description | Depends | Loading | Augmentation | Callbacks +--|--|--|--|--|-- +[megablocks](./src/fms_acceleration_moe/framework_plugin_megablocks.py) | MoE Expert Parallel with megablocks | megablocks | ✅ | | ✅ + + +## Running Benchmarks + +``` +tox -e run-benches -- 8 8 scenarios.yaml accelerated-moe-megablocks +``` + ## Expert-Parallel MoE with Megablocks Not all of the features of `megablocks` are being incorporated; listing down some of the restrictions of the current integration: @@ -12,7 +25,6 @@ Not all of the features of `megablocks` are being incorporated; listing down som - the `shard_moe` may not scale well with larger models as the current implementation `torch.concat` all the expert weights together before passing to `torch.distributed` to be sharded. This is redundently done in all devices, so it is inefficient. - currently only supports `StateDictType.SHARDED_STATE_DICT` because the implementation uses `DTensors` which have limited support for full state dicts. However for efficiency considerations, sharded state dicts are the most efficient. - ### Megablocks Dependencies Currently databricks megablocks does not have a PyPi repository and no proper release, so we have to install directly from Github, refer to instructions below. From cd9db22ea699b2278eb7069d95b546c3a1c4feef Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 23 Aug 2024 10:33:49 +0000 Subject: [PATCH 14/21] fixes to readme and tox Signed-off-by: Yu Chin Fabian Lim --- plugins/accelerated-moe/README.md | 29 +++++++++++++++++-- plugins/accelerated-moe/requirements-mb.txt | 3 ++ .../fms_acceleration_moe/requirements-mb.txt | 1 - tox.ini | 4 --- 4 files changed, 29 insertions(+), 8 deletions(-) create mode 100644 plugins/accelerated-moe/requirements-mb.txt delete mode 100644 plugins/accelerated-moe/src/fms_acceleration_moe/requirements-mb.txt diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index c187cf20..1e683c54 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -12,14 +12,37 @@ Plugin | Description | Depends | Loading | Augmentation | Callbacks ## Running Benchmarks +Run the below in the top-level directory of this repo: +- the `megablocks` dep is not included by default, so the `-x` switch installs it. + +``` +tox -e run-benches \ + -x testenv:run-benches.deps+="-r plugins/accelerated-moe/requirements-mb.txt" \ + -- \ + 8 8 benchmark_outputs scenarios.yaml accelerated-moe-megablocks + +``` + +NOTE: if `FileNotFoundError` is observed on the *triton cache*, similar to issues like these: +- https://github.com/triton-lang/triton/issues/2688 + +then somehow `tox` is causing problems with triton and multiprocessing (there is some race condition). +But the workaound is to first *activate the tox env* and +running in `bash`: ``` -tox -e run-benches -- 8 8 scenarios.yaml accelerated-moe-megablocks +# if FileNotFoundError in the triton cache is observed +# - then activate the env and run the script manually + +source .tox/run-benches/bin/activate +bash scripts/run_benchmarks.sh \ + 8 8 benchmark_outputs scenarios.yaml accelerated-moe-megablocks ``` + ## Expert-Parallel MoE with Megablocks Not all of the features of `megablocks` are being incorporated; listing down some of the restrictions of the current integration: -- curretnly not passing the data parallel `dp_mesh` to the `FSDP` constructor, so `FSDP` will always shard over the default process group (over world_size). +- currently not passing the data parallel `dp_mesh` to the `FSDP` constructor, so `FSDP` will always shard over the default process group (over world_size). - now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed. - only supports the *dropless sparse* MLPs in the megablocks package; the other variations like non-dropless and grouped computes are not currently integrated. - the `shard_moe` may not scale well with larger models as the current implementation `torch.concat` all the expert weights together before passing to `torch.distributed` to be sharded. This is redundently done in all devices, so it is inefficient. @@ -34,5 +57,5 @@ Currently databricks megablocks does not have a PyPi repository and no proper re ``` # this will install the megablocks from Github # megablocks requires CUDA Toolkit to build. -pip install -r requirements_mb.txt +pip install -r requirements-mb.txt ``` \ No newline at end of file diff --git a/plugins/accelerated-moe/requirements-mb.txt b/plugins/accelerated-moe/requirements-mb.txt new file mode 100644 index 00000000..7875fc98 --- /dev/null +++ b/plugins/accelerated-moe/requirements-mb.txt @@ -0,0 +1,3 @@ +megablocks @ git+https://github.com/databricks/megablocks.git@bce5d7b2aaf5038bc93b36f76c2baf51c2939bd2 + +# auto_gptq @ git+https://github.com/AutoGPTQ/AutoGPTQ.git@ea829c7bbe83561c2b1de26795b6592992373ef7 diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/requirements-mb.txt b/plugins/accelerated-moe/src/fms_acceleration_moe/requirements-mb.txt deleted file mode 100644 index fbe690ff..00000000 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/requirements-mb.txt +++ /dev/null @@ -1 +0,0 @@ -pip install git+https://github.com/databricks/megablocks.git@bce5d7b2aaf5038bc93b36f76c2baf51c2939bd2 \ No newline at end of file diff --git a/tox.ini b/tox.ini index 785f8271..cad75c25 100644 --- a/tox.ini +++ b/tox.ini @@ -41,10 +41,6 @@ commands = python -m fms_acceleration.cli install -e {toxinidir}/plugins/attention_and_distributed_packing python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-moe - # need to install some optional dependencies - # - the megablocks dependency - pip install -r {toxinidir}/plugins/accelerated-moe/requirements-mb.txt - # run the benchmark script bash scripts/run_benchmarks.sh {posargs:"1 2" "4 8" benchmark_outputs} From 0c6631a37221ef86c4c0001fe62694380a326643 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 23 Aug 2024 15:45:04 +0000 Subject: [PATCH 15/21] add benches Signed-off-by: Yu Chin Fabian Lim --- plugins/accelerated-moe/README.md | 3 + scripts/benchmarks/refs/a100_80gb_mb.csv | 3 + scripts/benchmarks/refs/requirements_mb.txt | 88 +++++++++++++++++++++ 3 files changed, 94 insertions(+) create mode 100644 scripts/benchmarks/refs/a100_80gb_mb.csv create mode 100644 scripts/benchmarks/refs/requirements_mb.txt diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index 1e683c54..0d1a3bc4 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -12,6 +12,9 @@ Plugin | Description | Depends | Loading | Augmentation | Callbacks ## Running Benchmarks +See the benchmarks [a100_80gb_mb.csv](../../scripts/benchmarks/refs/a100_80gb_mb.csv) + + Run the below in the top-level directory of this repo: - the `megablocks` dep is not included by default, so the `-x` switch installs it. diff --git a/scripts/benchmarks/refs/a100_80gb_mb.csv b/scripts/benchmarks/refs/a100_80gb_mb.csv new file mode 100644 index 00000000..b5bdc6c5 --- /dev/null +++ b/scripts/benchmarks/refs/a100_80gb_mb.csv @@ -0,0 +1,3 @@ +framework_config,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second +none,65598.5,58936741888,47259259904,bfloat16,0.859970542192459,4170.0391,3.07,0.024,80.575 +moe-megablocks,52284.0,48874301952,35987686400,bfloat16,0.8570401281118393,1404.3938,9.114,0.071,239.249 diff --git a/scripts/benchmarks/refs/requirements_mb.txt b/scripts/benchmarks/refs/requirements_mb.txt new file mode 100644 index 00000000..679b20f8 --- /dev/null +++ b/scripts/benchmarks/refs/requirements_mb.txt @@ -0,0 +1,88 @@ +accelerate==0.33.0 +aiohappyeyeballs==2.4.0 +aiohttp==3.10.5 +aiosignal==1.3.1 +async-timeout==4.0.3 +attrs==24.2.0 +bitsandbytes==0.43.3 +certifi==2024.7.4 +charset-normalizer==3.3.2 +contourpy==1.2.1 +cycler==0.12.1 +datasets==2.21.0 +dill==0.3.8 +docstring_parser==0.16 +einops==0.8.0 +filelock==3.15.4 +fire==0.6.0 +flash-attn==2.6.3 +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@9cf70081f3bfc00e84331102e7d13b333f17ee26#egg=fms_acceleration&subdirectory=plugins/framework +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@9cf70081f3bfc00e84331102e7d13b333f17ee26#egg=fms_acceleration_foak&subdirectory=plugins/fused-ops-and-kernels +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@9cf70081f3bfc00e84331102e7d13b333f17ee26#egg=fms_acceleration_moe&subdirectory=plugins/accelerated-moe +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@9cf70081f3bfc00e84331102e7d13b333f17ee26#egg=fms_acceleration_peft&subdirectory=plugins/accelerated-peft +fms-hf-tuning @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@daca5510ab76cc8ecf0283fd31fc220697a75040 +fonttools==4.53.1 +frozenlist==1.4.1 +fsspec==2024.6.1 +huggingface-hub==0.24.6 +idna==3.7 +Jinja2==3.1.4 +kiwisolver==1.4.5 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.9.2 +mdurl==0.1.2 +megablocks @ git+https://github.com/databricks/megablocks.git@bce5d7b2aaf5038bc93b36f76c2baf51c2939bd2 +mpmath==1.3.0 +multidict==6.0.5 +multiprocess==0.70.16 +networkx==3.3 +numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.6.20 +nvidia-nvtx-cu12==12.1.105 +packaging==24.1 +pandas==2.2.2 +peft==0.12.0 +pillow==10.4.0 +protobuf==5.27.3 +psutil==6.0.0 +pyarrow==17.0.0 +Pygments==2.18.0 +pyparsing==3.1.2 +python-dateutil==2.9.0.post0 +pytz==2024.1 +PyYAML==6.0.2 +regex==2024.7.24 +requests==2.32.3 +rich==13.7.1 +safetensors==0.4.4 +sentencepiece==0.2.0 +shtab==1.7.1 +simpleeval==0.9.13 +six==1.16.0 +stanford-stk @ git+https://git@github.com/stanford-futuredata/stk.git@a1ddf98466730b88a2988860a9d8000fd1833301 +sympy==1.13.2 +termcolor==2.4.0 +threadpoolctl==3.5.0 +tokenizers==0.19.1 +torch==2.3.1 +tqdm==4.66.5 +transformers==4.44.2 +triton==2.3.1 +trl==0.9.6 +typing_extensions==4.12.2 +tyro==0.8.8 +tzdata==2024.1 +urllib3==2.2.2 +xxhash==3.5.0 +yarl==1.9.4 From d686e960facafce87bf056eca42e7c436b574f56 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sat, 24 Aug 2024 07:59:54 +0000 Subject: [PATCH 16/21] fix ignore modules Signed-off-by: Yu Chin Fabian Lim --- .../framework_plugin_megablocks.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py index 4c5c85cb..9a1c0a57 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py @@ -168,12 +168,21 @@ def get_callbacks_and_ready_for_train( accelerator is not None and getattr(accelerator.state, "fsdp_plugin", None) is not None ): + # - use an internal function call to get the no split + # module names, which are typically layers + _layers = model._get_no_split_modules('') accelerator.state.fsdp_plugin.ignored_modules = [ getattr(layer, name) - for name in self._moe_component_module_names - for layer in model.model.layers + for name in moe_component_module_names + for layer in model.modules() + if layer.__class__.__name__ in _layers ] +FSDP( + model, + ignored_modules=ignored_modules, +) + return callbacks From 329c4d040f50890a695f55dc5fa5f47a58807248 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sat, 24 Aug 2024 08:10:36 +0000 Subject: [PATCH 17/21] skip slow scenarios in unflitered runs Signed-off-by: Yu Chin Fabian Lim --- scripts/benchmarks/benchmark.py | 17 +++++++++++++++++ scripts/benchmarks/scenarios.yaml | 1 + 2 files changed, 18 insertions(+) diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index ce66381a..3a85adf3 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -356,6 +356,12 @@ class ScenarioMatrix: def __init__(self, scenario: Dict, acceleration_config_map: Dict = None) -> None: assert "arguments" in scenario.keys(), "Missing `arguments` key in `scenario`" + + # "slow" is a special key that indicates this scenario + # takes resources to run + # - "slow" scenarios are not run if not specified by a filter + self.slow = False + for key, val in scenario.items(): if key == "framework_config": # if acceleration_config_map is None, then do not do mapping @@ -687,7 +693,18 @@ def prepare_arguments(args, benchmark_dataset: BenchmarkDataset): if args.run_only_scenarios and _scn_name not in args.run_only_scenarios: print(f"Skipping scenario '{_scn_name}'") continue + + # build scenario matrix scenario = ScenarioMatrix(scenario_config, acceleration_config_map) + + if ( + not args.run_only_scenarios and + and scenarios.slow + ): + # unfiltered runs omit all "slow" marked scenarios + print(f"Skipping slow scenario '{_scn_name}' beacuse run_only_scenarios=None.") + continue + scenario_matrices, scenario_constants = ( scenario.get_scenario_matrices_and_defaults() ) diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml index 3dd25e4d..0e8b6954 100644 --- a/scripts/benchmarks/scenarios.yaml +++ b/scripts/benchmarks/scenarios.yaml @@ -98,6 +98,7 @@ scenarios: framework_config: - # without acceleration - moe-megablocks + slow: True arguments: learning_rate: 5e-5 torch_dtype: bfloat16 From 59907e321e50612df6057857694dffd5087c13cd Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sun, 25 Aug 2024 12:02:21 +0000 Subject: [PATCH 18/21] fix Signed-off-by: Yu Chin Fabian Lim --- .../fms_acceleration_moe/framework_plugin_megablocks.py | 9 ++------- scripts/benchmarks/benchmark.py | 2 +- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py index 9a1c0a57..ea9f527e 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py @@ -168,21 +168,16 @@ def get_callbacks_and_ready_for_train( accelerator is not None and getattr(accelerator.state, "fsdp_plugin", None) is not None ): - # - use an internal function call to get the no split + # - use an internal function call to get the no split # module names, which are typically layers _layers = model._get_no_split_modules('') accelerator.state.fsdp_plugin.ignored_modules = [ getattr(layer, name) - for name in moe_component_module_names + for name in self._moe_component_module_names for layer in model.modules() if layer.__class__.__name__ in _layers ] -FSDP( - model, - ignored_modules=ignored_modules, -) - return callbacks diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index 3a85adf3..19487652 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -698,7 +698,7 @@ def prepare_arguments(args, benchmark_dataset: BenchmarkDataset): scenario = ScenarioMatrix(scenario_config, acceleration_config_map) if ( - not args.run_only_scenarios and + not args.run_only_scenarios and scenarios.slow ): # unfiltered runs omit all "slow" marked scenarios From 12d8619808fbc12d4bfbcd01908723e5d156d500 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Tue, 27 Aug 2024 13:41:19 +0000 Subject: [PATCH 19/21] update readme with note on mixed precision Signed-off-by: Yu Chin Fabian Lim --- plugins/accelerated-moe/README.md | 1 + .../src/fms_acceleration_moe/framework_plugin_megablocks.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index 0d1a3bc4..54064381 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -50,6 +50,7 @@ Not all of the features of `megablocks` are being incorporated; listing down som - only supports the *dropless sparse* MLPs in the megablocks package; the other variations like non-dropless and grouped computes are not currently integrated. - the `shard_moe` may not scale well with larger models as the current implementation `torch.concat` all the expert weights together before passing to `torch.distributed` to be sharded. This is redundently done in all devices, so it is inefficient. - currently only supports `StateDictType.SHARDED_STATE_DICT` because the implementation uses `DTensors` which have limited support for full state dicts. However for efficiency considerations, sharded state dicts are the most efficient. +- currently may not support *mixed precision* properly; need to ascertain more clearly how the sharded `DTensors` are upcasted in the optimizer (if at all). ### Megablocks Dependencies diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py index ea9f527e..d7204335 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py @@ -170,7 +170,7 @@ def get_callbacks_and_ready_for_train( ): # - use an internal function call to get the no split # module names, which are typically layers - _layers = model._get_no_split_modules('') + _layers = model._get_no_split_modules("") accelerator.state.fsdp_plugin.ignored_modules = [ getattr(layer, name) for name in self._moe_component_module_names From feaeaa53795e9014b748051ad33d36ca39851843 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Wed, 28 Aug 2024 11:20:01 +0000 Subject: [PATCH 20/21] partially address mixed precision Signed-off-by: Yu Chin Fabian Lim --- plugins/accelerated-moe/README.md | 5 +++- .../framework_plugin_megablocks.py | 6 +++++ .../megablocks_utils/shard_moe_utils.py | 25 +++++++++++++++++-- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index 54064381..3f938096 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -44,13 +44,16 @@ bash scripts/run_benchmarks.sh \ ## Expert-Parallel MoE with Megablocks +Currently supports *mixed precision*. Will upcast the router and the sharded experts if turned on. +- However this is hard-coded to off at the moment. +- The FSDP mixed precision works independenly of the MoE one. + Not all of the features of `megablocks` are being incorporated; listing down some of the restrictions of the current integration: - currently not passing the data parallel `dp_mesh` to the `FSDP` constructor, so `FSDP` will always shard over the default process group (over world_size). - now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed. - only supports the *dropless sparse* MLPs in the megablocks package; the other variations like non-dropless and grouped computes are not currently integrated. - the `shard_moe` may not scale well with larger models as the current implementation `torch.concat` all the expert weights together before passing to `torch.distributed` to be sharded. This is redundently done in all devices, so it is inefficient. - currently only supports `StateDictType.SHARDED_STATE_DICT` because the implementation uses `DTensors` which have limited support for full state dicts. However for efficiency considerations, sharded state dicts are the most efficient. -- currently may not support *mixed precision* properly; need to ascertain more clearly how the sharded `DTensors` are upcasted in the optimizer (if at all). ### Megablocks Dependencies diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py index d7204335..046efb73 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py @@ -151,7 +151,13 @@ def model_loader(self, model_name: str, **kwargs): shared_mesh_dim=self._shard_along_dp, router_name=self._gate_module_name, expert_name=self._experts_module_name, + mixed_precision=False, # Currently this is hardcoded to OFF ) + + # NOTE: there is currently no good way to get the mixed precision + # flag from train_args. It will be better to handle this if + # when we move the sharding to augmentation. + # NOTE: Currently, it is a bit troublesome to pass the device_mesh to # the FSDP constructor, so we do not do that. # - therefore FSDP will always shard on world_size over the default process diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py index 07971fc7..a5f9fd0a 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py @@ -20,6 +20,7 @@ import json import os import re +import warnings # Third Party from accelerate import init_empty_weights @@ -175,6 +176,7 @@ def load_sharded_experts_onto_device( device_mesh: DeviceMesh, placements: Placement, expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe + mixed_precision: bool = False, ): # typically they all should be same file, but to play safe, load the checkpoint file onto # cpu first since we may not need all weights in that file. @@ -191,6 +193,7 @@ def load_sharded_experts_onto_device( # go by one weight at a time. # - weight_name: points to megablocks.dmoe + upcasted = set() for weight_name, vs in checkpoint_metadata.items(): data = [] for k, fi in vs: @@ -204,11 +207,18 @@ def load_sharded_experts_onto_device( name = weight_name.split(".") path, name = ".".join(name[:-1]), name[-1] mod = dmoe.get_submodule(path) - mod_dtype = getattr(mod, name).dtype + + # if mixed_precision and KEY_DMOE_ROUTER not in weight_name: + if mixed_precision: + mod_dtype = torch.float32 + upcasted.add(weight_name) + else: + mod_dtype = getattr(mod, name).dtype # the megablocks dmoe experts the expert features to be on DIM_EXPERT. # - concat on dim 0 and distribute # - cast to the correct dtype for the module + # - if mixed precision is enabled, then sharded params are cased param = torch.concat(data, dim=DIM_EXPERT).to(mod_dtype) _placements = placements @@ -223,6 +233,9 @@ def load_sharded_experts_onto_device( # register the sharded parameter onto the megablocks.dmoe mod.register_parameter(name, param) + upcasted = ", ".join(sorted(upcasted)) + warnings.warn(f"Mixed precision turned on, upcasted MoE parameters: {upcasted}") + def shard_moe( model: torch.nn.Module, @@ -238,6 +251,7 @@ def shard_moe( expert_name: str = "experts", shared_mesh_dim: bool = True, ep_size: int = 1, + mixed_precision: bool = False, ): """shard_moe takes a mixture-of-experts huggingface model and shards the experts on the current device. All layers layers that have a MoE module will be sharded. @@ -272,6 +286,7 @@ def shard_moe( expert_name (str): module name of the experts in moe_cls (e.g., "experts"). shared_mesh_dim (bool): for the sharding mode, see explanation above. ep_size (int): for shard_mesh_dim=False only, see explanation above. + mixed_precision (bool): activate mixed precision and upcasts sharded params """ # guarded import @@ -389,7 +404,13 @@ def shard_moe( mp_dmoe = dmoe.dMoE(_args) # drop in replacement for now load_sharded_experts_onto_device( - mp_dmoe, loc, checkpoint_metadata, device_mesh, placements, expert_name + mp_dmoe, + loc, + checkpoint_metadata, + device_mesh, + placements, + expert_name, + mixed_precision, ) parent = model.get_submodule(prefix) setattr(parent, module_name, mp_dmoe) From 7a6cdd85da9879caca048b76b18554527d7b371d Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 29 Aug 2024 07:18:28 +0000 Subject: [PATCH 21/21] handle requires_grad in shard_moe Signed-off-by: Yu Chin Fabian Lim --- .../framework_plugin_megablocks.py | 2 +- .../megablocks_utils/shard_moe_utils.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py index 046efb73..f9dcc60a 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_megablocks.py @@ -151,7 +151,7 @@ def model_loader(self, model_name: str, **kwargs): shared_mesh_dim=self._shard_along_dp, router_name=self._gate_module_name, expert_name=self._experts_module_name, - mixed_precision=False, # Currently this is hardcoded to OFF + mixed_precision=False, # Currently this is hardcoded to OFF ) # NOTE: there is currently no good way to get the mixed precision diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py index a5f9fd0a..8ef95fa2 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py @@ -215,6 +215,8 @@ def load_sharded_experts_onto_device( else: mod_dtype = getattr(mod, name).dtype + requires_grad = getattr(mod, name).requires_grad + # the megablocks dmoe experts the expert features to be on DIM_EXPERT. # - concat on dim 0 and distribute # - cast to the correct dtype for the module @@ -227,14 +229,16 @@ def load_sharded_experts_onto_device( _placements = [Replicate() for _ in range(len(placements))] param = torch.nn.Parameter( - distribute_tensor(param, device_mesh, _placements) + distribute_tensor(param, device_mesh, _placements), + requires_grad=requires_grad, ) # register the sharded parameter onto the megablocks.dmoe mod.register_parameter(name, param) - upcasted = ", ".join(sorted(upcasted)) - warnings.warn(f"Mixed precision turned on, upcasted MoE parameters: {upcasted}") + if mixed_precision: + upcasted = ", ".join(sorted(upcasted)) + warnings.warn(f"Mixed precision turned on, upcasted MoE parameters: {upcasted}") def shard_moe(