Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixture of Experts Training with Acceleration Library Plugin #69

Closed
wants to merge 21 commits into from

Conversation

fabianlim
Copy link
Contributor

@fabianlim fabianlim commented Aug 19, 2024

This PR adds a plug in for mixture of experts training, combining FSDP with expert parallel where the latter is borrowed from databricks megablocks

This implements the FSDP1 version of expert parallel from https://github.com/foundation-model-stack/moe-distributed-training

What is Expert Parallel?

Expert parallel is a form of model parallelism that applies to mixture-of-experts models.

  • it shards the experts (typically onto different devices) where they will run in-parallel, improving throughput.
  • requires a method to shuffle routed tokens to different devices where the corresponding experts reside (called the all-to-all primitive).
  • achieves significant speedups when the experts are balanced.
  • Within-GPU-expert parallel: On A100: async mode is not available within a GPU, so to use expert parallel within a GPU requires a dropless implementation to avoid padding or dropping tokens when managing multiple experts on a single device. Note that on H100, this can be avoided by using a new kernel for GEMM.

Diagram of Data Parallel (e.g., FSDP) vs Expert Parallel

image

Performance

Benchmark Results

Full-Finetuning

  • A100_80gb with single node (e.g., 8 devices)
  • effective batch size of 128 (per device batch_size of 1, using grad accum of 16)
  • Alpaca instruction formatted dataset, no packing
  • torch.bfloat16 training, no mixed-precision
Model Gpus TYPE mem_peak mem_alloc train_runtime mem improvement runtime improvement
Mixtral-8x7B-Instruct-v0.1 8 FSDP-only 54.8 G 44.0 G 4019 s baseline baseline
Mixtral-8x7B-Instruct-v0.1 8 our plugin 45.5 G 33.5G 996 s 33 % 4.00 x

NOTE: the train runtimes were collected with --skip_memory_metrics=True (huggingface default); setting this to False was only used to benchmark memory numbers, as is known to result in worser runtime measurements

NOTE: throughput numbers were 83 and 337 tokens per second, respectively.

Checkpoint Resumption

Checkpointing works as evidenced by correct training resumption behavior (see below):

  • This is made to work by fixing issues in accelerate.utils.fsdp_utils.
  • have to be updated to use the newer versions of the API to properly support DTensor sharding.
image

Next steps

  • Extend plug-in to enable multi-GPU MOE training using PEFT
  • Enable within-GPU expert parallel (for small MOE and/or single-GPU MOE training using PEFT) using dropletss implementation / sparse GEMM

Implementation Details

Comparison with DeepSpeed MoE (DS-MoE)

Deepspeed also has support for mixture-of-expert sharding. Noting down some points here:

  • DS-MoE does not support stage3 when also sharding MoE; this means that the non-MoE parameters cannot be parameter-sharded.
  • DS-MoE uses the similar all-to-all primitive to distributed tokens to experts.
  • DS-MoE requires custom parameter preperation by means of calling a DS function split_params_into_different_moe_groups_for_optimizer. This call is not integrated into accelerate's _prepare_deepspeed function.
 def create_moe_param_groups(model):
    from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer

    parameters = {'params': [p for p in model.parameters()], 'name': 'parameters'}

    return split_params_into_different_moe_groups_for_optimizer(parameters)

optimizer_grouped_parameters = create_moe_param_groups(opt_model)

Updates to benchmark.py

We now also

  • allow now to have an empty framework_config entry to allow the scenario to include the "no acceleration" case in the matrix:
  • added a slow tag, that if true, then we ignore the scenario in unfiltered runs.
  • added a accelerator-config.json to pass arguments to Accelerator, for example to set the sync_each_batch flag.
name: accelerated-moe-megablocks
    framework_config: 
        - # without acceleration. <- NEW
        - moe-megablocks
    slow: True # <- NEW: will be ignored in unfiltered runs
    arguments:
        learning_rate: 5e-5
        torch_dtype: bfloat16
        accelerator_config: scripts/benchmarks/accelerator-config.json
        gradient_accumulation_steps: 16
        logging_steps: 1
        packing: False
        adam_epsilon: 1e-8
        model_name_or_path: 
            - 'mistralai/Mixtral-8x7B-Instruct-v0.1'

Checklist of items covered

  • refactored and made easy to incorporate other HF MoE models
  • update the configurability to incorporate different expert parallel dimensions
  • updated the scenarios bench
  • generally works for single node, put in provisions for multi-node (but not tested)
  • handling loss balancing correctly
  • formatting and linting
  • checkpointing, ensuring that torch.distributed.dcp is operating correctly,.

Known Issues

torch.concat operation is dominating the load_sharded_experts_onto_device function.

  • it is very inefficient that in all the devices, we concat all the expert weights and then pass to torch.distributed to shard it
  • this may not scale well for larger number of experts in the future.
Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    3.171    3.171  588.308  588.308 fms-acceleration/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py:204(shard_moe)
       32    0.219    0.007  584.905   18.278 fms-acceleration/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py:155(load_sharded_experts_onto_device)
      128  567.666    4.435  567.666    4.435 {built-in method torch.concat}
       96    0.013    0.000   14.993    0.156 /workspace/mb/lib/python3.10/site-packages/torch/distributed/_tensor/api.py:507(distribute_tensor)

@fabianlim fabianlim marked this pull request as draft August 19, 2024 06:30
@fabianlim fabianlim force-pushed the moe branch 2 times, most recently from 736d3c1 to 078e737 Compare August 19, 2024 06:47
@fabianlim fabianlim changed the title Mixture-of-Experts for Megablocks Mixture-of-Experts with Expert-Parallel using Megablocks Aug 19, 2024
@fabianlim fabianlim force-pushed the moe branch 12 times, most recently from 51499ff to b036263 Compare August 21, 2024 13:31
@fabianlim fabianlim requested a review from achew010 August 21, 2024 13:36
@fabianlim fabianlim marked this pull request as ready for review August 22, 2024 00:58
@fabianlim fabianlim force-pushed the moe branch 4 times, most recently from 75bded5 to f945f59 Compare August 23, 2024 03:54
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
@fabianlim fabianlim changed the title Mixture-of-Experts with Expert-Parallel using Megablocks Mixture of Experts Training with Acceleration Library Plugin Aug 27, 2024
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
@fabianlim fabianlim force-pushed the moe branch 2 times, most recently from 1b6d7a7 to 01fa69a Compare August 29, 2024 07:41
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
@fabianlim
Copy link
Contributor Author

superceeded by #99

@fabianlim fabianlim closed this Nov 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant