Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Benchmark] Deepspeed +fp16/bf16 on a 8xA100 node #14913

Closed
aqred1 opened this issue Dec 23, 2021 · 4 comments
Closed

[Benchmark] Deepspeed +fp16/bf16 on a 8xA100 node #14913

aqred1 opened this issue Dec 23, 2021 · 4 comments

Comments

@aqred1
Copy link

aqred1 commented Dec 23, 2021

🖥 Benchmarking transformers

Benchmark

Which part of transformers did you benchmark?
Deepspeed with template Zero 1, 2 and 3 configurations using fp16 and bf16.

  • I am by no means an expert on this, I'm trying to find the fastest configuration for my setup. So if you see better ways to do this, please let me know.
  • I have access to more nodes, but somehow when running on multinode deepspeed does not report percentages of completion nor times estimations. If there is a way to do this, please let me know and I'll extend it to 4 x (8xA100)

Set-up

What did you run your benchmarks on? Please include details, such as: CPU, GPU? If using multiple GPUs, which parallelization did you use?

My system:

torch: 1.10.0+cu113
transformers: 4.14.1
deepspeed: 0.5.8

The command is always:
deepspeed 5.run_clm-post.py --model_name_or_path /path/to/gpt2-large/ --train_file sample.txt --tokenizer_name embeddings--do_train --do_eval --output_dir ./output --evaluation_strategy steps --eval_steps 1000 --save_steps 1000 --num_train_epochs 12 --per_device_train_batch_size 8 --cache_dir .cache2/ --save_total_limit 2 --dataloader_drop_last True --learning_rate 1e-06

And then I add:
--deepspeed config1.json --fp16
--deepspeed config2.json --fp16
--deepspeed config3.json --fp16

--deepspeed config_2.json --fp16

Where the config files are:

config1.json:

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto"}

config2.json:

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },

    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 100,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

config3.json:

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },

    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_fp16_weights_on_model_save": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

Then config_2.json is the same as the above config2 but replacing the fp16 part with:

    "bfloat16": {
         "enabled": true
    }

Results

----------- fp16 bf16
deepspeed 1 2.28it/s -
deepspeed 2 4.59 s/it 4.90 s/it
deepspeed 3 5.02 s/it -

Somehow the units in the fp16 -deepspeed 1 case are returned in it/s, so for the sake of comparison that would translate to 0.43 s/it. I am puzzled by the results, because I'd expect zero 2 and 3 to work faster, but zero 1 turned to be around 10 times faster. So let me know if I am doing anything wrong. Also, let me know how could I extend to multi-node -if it is interesting for somebody else-

Thanks

@aqred1
Copy link
Author

aqred1 commented Dec 23, 2021

oh, and tagging @stas00 because is a deepspeed "issue".

@aqred1
Copy link
Author

aqred1 commented Dec 23, 2021

Information about the cards:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05    Driver Version: 495.29.05    CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:0E:00.0 Off |                    0 |
| N/A   34C    P0    56W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  On   | 00000000:13:00.0 Off |                    0 |
| N/A   33C    P0    54W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM...  On   | 00000000:49:00.0 Off |                    0 |
| N/A   31C    P0    53W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM...  On   | 00000000:4F:00.0 Off |                    0 |
| N/A   34C    P0    54W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   4  NVIDIA A100-SXM...  On   | 00000000:90:00.0 Off |                    0 |
| N/A   34C    P0    57W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   5  NVIDIA A100-SXM...  On   | 00000000:96:00.0 Off |                    0 |
| N/A   31C    P0    52W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   6  NVIDIA A100-SXM...  On   | 00000000:CC:00.0 Off |                    0 |
| N/A   33C    P0    56W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   7  NVIDIA A100-SXM...  On   | 00000000:D1:00.0 Off |                    0 |
| N/A   32C    P0    56W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
nvidia-smi topo -m
	GPU0	GPU1	GPU2	GPU3	GPU4	GPU5	GPU6	GPU7	mlx5_0	mlx5_1	mlx5_2	mlx5_3	CPU Affinity	NUMA Affinity
GPU0	 X 	NV12	NV12	NV12	NV12	NV12	NV12	NV12	PXB	NODE	NODE	SYS	0-63	0
GPU1	NV12	 X 	NV12	NV12	NV12	NV12	NV12	NV12	PXB	NODE	NODE	SYS	0-63	0
GPU2	NV12	NV12	 X 	NV12	NV12	NV12	NV12	NV12	NODE	PXB	PXB	SYS	0-63	0
GPU3	NV12	NV12	NV12	 X 	NV12	NV12	NV12	NV12	NODE	PXB	PXB	SYS	0-63	0
GPU4	NV12	NV12	NV12	NV12	 X 	NV12	NV12	NV12	SYS	SYS	SYS	NODE	64-127	1
GPU5	NV12	NV12	NV12	NV12	NV12	 X 	NV12	NV12	SYS	SYS	SYS	NODE	64-127	1
GPU6	NV12	NV12	NV12	NV12	NV12	NV12	 X 	NV12	SYS	SYS	SYS	PXB	64-127	1
GPU7	NV12	NV12	NV12	NV12	NV12	NV12	NV12	 X 	SYS	SYS	SYS	PXB	64-127	1
mlx5_0	PXB	PXB	NODE	NODE	SYS	SYS	SYS	SYS	 X 	NODE	NODE	SYS		
mlx5_1	NODE	NODE	PXB	PXB	SYS	SYS	SYS	SYS	NODE	 X 	PIX	SYS		
mlx5_2	NODE	NODE	PXB	PXB	SYS	SYS	SYS	SYS	NODE	PIX	 X 	SYS		
mlx5_3	SYS	SYS	SYS	SYS	NODE	NODE	PXB	PXB	SYS	SYS	SYS	 X 		


@stas00
Copy link
Contributor

stas00 commented Dec 23, 2021

You need to understand how ZeRO stages work and their relative to each other speed:

Z1: fastest - only shards optim states
Z2: fast - shards optim states + gradients
Z3: slowest - as it has to shard optim states + gradients + params

i.e, the more sharding it has to do the slower it becomes as it has to communicate a lot more data between processes.

and of course:

Z0: super fast - no ZeRO, no sharding fastest of all of them.

You choose which stage to use depending on your model's size. If you can fit it with a desirable BS on Z0 use that, if you can't next try Z1, then Z2, and only if Z2 is not enough you use Z3.

again Z0 - is no deepspeed.

and in reverse Z3 -> Z2 -> Z1 -> Z0 your memory requirements grow, see:
https://deepspeed.readthedocs.io/en/stable/memory.html
for other options that further save memory beyond Z3.

so it's a trade-off between memory and speed.


Somehow the units in the fp16 -deepspeed 1 case are returned in it/s

I'm not sure what you mean, perhaps paste the metrics you're referring to?

e.g. a sample output from HF Trainer:

***** train metrics *****
  epoch                    =        1.0
  train_loss               =      2.418
  train_runtime            = 0:01:20.80
  train_samples            =       2500
  train_samples_per_second =      30.94
  train_steps_per_second   =      3.874

For benchmark I think samples/sec is the most interesting and consistent, but of course others are fine as well.

e.g. see #14608

Also, let me know how could I extend to multi-node

what do you mean how you could extend this to multi-node, it should just work. And if it doesn't please let us know what specifically doesn't work.

additionally for multi-node benchmark reports please specify the type of inter-connects - Infiniband, OPA, etc., as these make a big difference.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Feb 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants