Skip to content

Commit

Permalink
Add FSDP and DeepSpeed training example (#10981)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhaAndroid authored Oct 11, 2023
1 parent 4d77feb commit 0230867
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 0 deletions.
75 changes: 75 additions & 0 deletions projects/example_largemodel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Vision Large Model Example

The project is used to explore how to successfully train relatively large visual models on consumer-level graphics cards.

Although the visual model does not have such an exaggerated number of parameters as LLM, even the commonly used models with Swin Large as the backbone need to be trained successfully on A100, which undoubtedly hinders users' exploration and experiments on visual large models. Therefore, this project will explore how to train visual large models on 3090 and even smaller graphics cards with 24G or less memory.

The project mainly involves training technologies such as `FSDP`, `DeepSpeed` and `ColossalAI` commonly used in large model training.

The project will be continuously updated and improved. If you have better exploration and suggestions, you are also welcome to submit a PR

## requirements

```text
mmengine >=0.9.0 # Example 1
deepspeed # Example 2
fairscale # Example 2
```

## Example 1: Train `dino-5scale_swin-l_fsdp_8xb2-12e_coco.py` with 8 24G 3090 GPUs and FSDP

```bash
cd mmdetection
./tools/dist_train.sh projects/example_largemodel/dino-5scale_swin-l_fsdp_8xb2-12e_coco.py 8
./tools/dist_train.sh projects/example_largemodel/dino-5scale_swin-l_fsdp_8xb2-12e_coco.py 8 --amp
```

| ID | AMP | GC of Backbone | GC of Encoder | FSDP | Peak Mem (GB) | Iter Time (s) |
| :-: | :-: | :------------: | :-----------: | :--: | :-----------: | :-----------: |
| 1 | | | | | 49 (A100) | 0.9 |
| 2 || | | | 39 (A100) | 1.2 |
| 3 | || | | 33 (A100) | 1.1 |
| 4 ||| | | 25 (A100) | 1.3 |
| 5 | ||| | 18 | 2.2 |
| 6 |||| | 13 | 1.6 |
| 7 | |||| 14 | 2.9 |
| 8 ||||| 8.5 | 2.4 |

- AMP: Automatic Mixed Precision
- GC: Gradient/Activation checkpointing
- FSDP: ZeRO-3 with Activation Checkpointing ZeRO-3
- Iter Time: Total training time for one iteration

From the above analysis, it can be seen that:

1. By combining FSDP with AMP and GC techniques, the initial 49GB of GPU memory can be reduced to 8.5GB, but it comes at the cost of a 1.7x increase in training time.
2. In object detection visual models, the largest memory consumption is due to activation values, rather than optimizer states, which is different from LLM. Therefore, users should prefer gradient checkpoints over FSDP.
3. If gradient checkpoints are not enabled and only FSDP is used, out-of-memory (OOM) errors can still occur, even with more fine-grained parameter splitting strategies.
4. While AMP can significantly reduce memory usage, some algorithms may experience a decrease in precision when using AMP, whereas FSDP does not exhibit this issue.

## Example 2: Train `dino-5scale_swin-l_deepspeed_8xb2-12e_coco.py` with 8 24G 3090 GPUs and DeepSpeed

```bash
cd mmdetection
./tools/dist_train.sh projects/example_largemodel/dino-5scale_swin-l_deepspeed_8xb2-12e_coco.py 8
```

It is a pity that this is still a failed case so far, because the gradient will always overflow, resulting in very low accuracy.

| ID | AMP | GC of Backbone | GC of Encoder | DeepSpeed | Peak Mem (GB) | Iter Time (s) |
| :-: | :-: | :------------: | :-----------: | :-------: | :-----------: | :-----------: |
| 1 | | | | | 49 (A100) | 0.9 |
| 2 || | | | 39 (A100) | 1.2 |
| 3 ||| | | 25 (A100) | 1.3 |
| 4 ||| || 10.5 | 1.5 |
| 5 |||| | 13 | 1.6 |
| 6 ||||| 5.0 | 1.4 |

From the above analysis, it can be seen that:

1. DeepSpeed has greatly improved usability compared to FSDP. Gradient checkpointing can be done using the native torch functionality without the need for custom modifications, and there is no need for the `auto_wrap_policy` parameter that needs to be set by the user.
2. The DeepSpeed ZeRO series requires the use of FP16 mode and utilizes NVIDIA's Apex package. It uses Apex's AMP O2 mode, which requires code modifications. However, the O2 mode uses a significant amount of FP16 computation, which prevents DINO algorithm from training properly. But this mode can significantly save GPU memory and provides more thorough type conversion compared to torch's official AMP.

From the above analysis, it can be concluded that if DeepSpeed can successfully train the DINO model without reduce performance, it will have a significant advantage over FSDP. If you have a deep understanding of DeepSpeed and Apex and are interested in troubleshooting accuracy issues, your feedback or PR is welcome.

As mentioned earlier, due to the specific nature of Apex AMP O2, the current version of MMDetection cannot train the DINO model. Considering this as a failed case, the modified code has been placed in the [dino_deepspeed branch](https://github.com/hhaAndroid/mmdetection/tree/dino_deepspeed). The corresponding modifications can be seen in this [commit](https://github.com/hhaAndroid/mmdetection/commit/0c825ae38e2cee3d11a20c5c4adf24ee682d0a55). If you are interested, you can pull this branch and experiment with it.
75 changes: 75 additions & 0 deletions projects/example_largemodel/README_zh-CN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# 视觉大模型实践案例

本工程用于探索如何在消费级显卡上成功训练相对大的视觉模型。

虽然视觉模型并没有像 LLM 那样有极其夸张的参数量,但是即使常用的以 Swin Large 为 backbone 的模型,都需要在 A100 上才能成功训练,这无疑阻碍了用户在视觉大模型上的探索和实验。因此本工程将探索在 3090 等 24G 甚至更小显存的消级显卡上如何训练视觉大模型。

本工程主要涉及到的训练技术有 `FSDP``DeepSpeed``ColossalAI` 等常用大模型训练技术。

本工程将不断更新完善,如果你有比较好的探索和意见,也非常欢迎提 PR

## 依赖

```text
mmengine >=0.9.0 # 案例 1
deepspeed # 案例 2
fairscale # 案例 2
```

## 案例 1: 采用 8 张 24G 3090 显卡结合 FSDP 训练 `dino-5scale_swin-l_fsdp_8xb2-12e_coco.py`

```bash
cd mmdetection
./tools/dist_train.sh projects/example_largemodel/dino-5scale_swin-l_fsdp_8xb2-12e_coco.py 8
./tools/dist_train.sh projects/example_largemodel/dino-5scale_swin-l_fsdp_8xb2-12e_coco.py 8 --amp
```

| ID | AMP | GC of Backbone | GC of Encoder | FSDP | Peak Mem (GB) | Iter Time (s) |
| :-: | :-: | :------------: | :-----------: | :--: | :-----------: | :-----------: |
| 1 | | | | | 49 (A100) | 0.9 |
| 2 || | | | 39 (A100) | 1.2 |
| 3 | || | | 33 (A100) | 1.1 |
| 4 ||| | | 25 (A100) | 1.3 |
| 5 | ||| | 18 | 2.2 |
| 6 |||| | 13 | 1.6 |
| 7 | |||| 14 | 2.9 |
| 8 ||||| 8.5 | 2.4 |

- AMP: 混合精度训练
- GC: 梯度/激活值检查点
- FSDP: ZeRO-3 结合梯度检查点
- Iter Time: 一次迭代训练总时间

从上表可以看出:

1. 采用 FSDP 结合 AMP 和 GC 技术,可以将最初的 49G 显存降低为 8.5G,但是会增加 1.7 倍训练时间
2. 在目标检测视觉模型中,占据最大显存的是激活值,而不是优化器状态,这和 LLM 不同,因此用户应该首选梯度检查点,而不是 FSDP
3. 如果不开启梯度检查点,仅开启 FSDP 的话依然会 OOM,即使尝试了更加细致的参数切分策略
4. 虽然 AMP 可以减少不少显存,但是有些算法使用 AMP 会导致精度下降而 FSDP 不会

## 案例 2: 采用 8 张 24G 3090 显卡结合 DeepSpeed 训练 `dino-5scale_swin-l_deepspeed_8xb2-12e_coco.py`

```bash
cd mmdetection
./tools/dist_train.sh projects/example_largemodel/dino-5scale_swin-l_deepspeed_8xb2-12e_coco.py 8
```

很遗憾,到目前为止这依然是一个失败的案例,因为梯度始终会溢出导致精度很低。

| ID | AMP | GC of Backbone | GC of Encoder | DeepSpeed | Peak Mem (GB) | Iter Time (s) |
| :-: | :-: | :------------: | :-----------: | :-------: | :-----------: | :-----------: |
| 1 | | | | | 49 (A100) | 0.9 |
| 2 || | | | 39 (A100) | 1.2 |
| 3 ||| | | 25 (A100) | 1.3 |
| 4 ||| || 10.5 | 1.5 |
| 5 |||| | 13 | 1.6 |
| 6 ||||| 5.0 | 1.4 |

从上表可以看出:

1. DeepSpeed 易用性上相比于 FSDP 有很大提升,因为梯度检查点可以用 torch 原生的而不需要修改特殊定制,同时也没有 `auto_wrap_policy` 这个需要用户自行设置的参数
2. DeepSpeed ZeRO 系列必须要采用 FP16 模式,其底层是采用了 NVIDIA’s Apex package, 其使用 Apex 的 AMP O2 模式,这导致需要修改代码,并且 O2 模式采用大量 FP16 计算导致 DINO 算法无法正常训练,但是它的这种模式可以显著节省显存,相比于 torch 官方的 AMP,类型转换更加彻底

从上述分析可知,如果 DeepSpeed 能够在不降低性能情况下成功训练 DINO 模型,那么其将比 FSDP 具备比较大的优势。如果您对 DeepSpeed 和 Apex 有比较深入的了解同时有兴趣排查精度问题,欢迎反馈或者提 PR

前面说过由于 Apex AMP O2 的特殊性,目前的 MMDetection 无法训练 DINO 模型,考虑到这是一个失败的案例,因此将修改的代码放在了 https://github.com/hhaAndroid/mmdetection/tree/dino_deepspeed 分支,其对应修改见 [commit](https://github.com/hhaAndroid/mmdetection/commit/0c825ae38e2cee3d11a20c5c4adf24ee682d0a55)。如果您有兴趣尝试,可以拉取该分支进行试验。
3 changes: 3 additions & 0 deletions projects/example_largemodel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .fsdp_utils import checkpoint_check_fn, layer_auto_wrap_policy

__all__ = ['checkpoint_check_fn', 'layer_auto_wrap_policy']
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from mmengine.config import read_base

with read_base():
from mmdet.configs.dino.dino_5scale_swin_l_8xb2_12e_coco import * # noqa

model.update(dict(encoder=dict(num_cp=6))) # noqa

runner_type = 'FlexibleRunner'
strategy = dict(
type='DeepSpeedStrategy',
gradient_clipping=0.1,
fp16=dict(
enabled=True,
fp16_master_weights_and_grads=False,
loss_scale=0,
loss_scale_window=500,
hysteresis=2,
min_loss_scale=1,
initial_scale_power=15,
),
inputs_to_half=['inputs'],
zero_optimization=dict(
stage=3,
allgather_partitions=True,
reduce_scatter=True,
allgather_bucket_size=50000000,
reduce_bucket_size=50000000,
overlap_comm=True,
contiguous_gradients=True,
cpu_offload=False),
)

optim_wrapper = dict(
type='DeepSpeedOptimWrapper',
optimizer=dict(
type='AdamW',
lr=0.0001, # 0.0002 for DeformDETR
weight_decay=0.0001),
# clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)}))

# To debug
default_hooks.update(dict(logger=dict(interval=1))) # noqa
log_processor.update(dict(window_size=1)) # noqa
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from mmengine.config import read_base

with read_base():
from mmdet.configs.dino.dino_5scale_swin_l_8xb2_12e_coco import * # noqa

from projects.example_largemodel import (checkpoint_check_fn,
layer_auto_wrap_policy)

# The checkpoint needs to be controlled by the checkpoint_check_fn.
model.update(dict(backbone=dict(with_cp=False))) # noqa

# TODO: The new version of configs does not support passing a module list,
# so for now, it can only be hard-coded. We will fix this issue in the future.
runner_type = 'FlexibleRunner'
strategy = dict(
type='FSDPStrategy',
activation_checkpointing=dict(check_fn=checkpoint_check_fn),
model_wrapper=dict(auto_wrap_policy=dict(type=layer_auto_wrap_policy)))
38 changes: 38 additions & 0 deletions projects/example_largemodel/fsdp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Sequence, Union

import torch.nn as nn

from mmdet.models.backbones.swin import SwinBlock
from mmdet.models.layers.transformer.deformable_detr_layers import \
DeformableDetrTransformerEncoderLayer


# TODO: The new version of configs does not support passing a module list,
# so for now, it can only be hard-coded. We will fix this issue in the future.
def layer_auto_wrap_policy(
module,
recurse: bool,
nonwrapped_numel: int,
layer_cls: Union[nn.Module, Sequence[nn.Module]] = (
SwinBlock, DeformableDetrTransformerEncoderLayer),
) -> bool:
if recurse:
# always recurse
return True
else:
# if not recursing, decide whether we should wrap for
# the leaf node or reminder
return isinstance(module, tuple(layer_cls))


def checkpoint_check_fn(submodule,
layer_cls: Union[nn.Module, Sequence[nn.Module]] = (
SwinBlock, DeformableDetrTransformerEncoderLayer)):
return isinstance(submodule, tuple(layer_cls))


# non_reentrant_wrapper = partial(
# checkpoint_wrapper,
# offload_to_cpu=False,
# checkpoint_impl=CheckpointImpl.NO_REENTRANT,
# )

0 comments on commit 0230867

Please sign in to comment.