Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Benchmark] HF Trainer on RTX-3090 #14608

Open
stas00 opened this issue Dec 3, 2021 · 13 comments
Open

[Benchmark] HF Trainer on RTX-3090 #14608

stas00 opened this issue Dec 3, 2021 · 13 comments
Assignees
Labels
Benchmarks Issues related to Memory regressions in tests and scripts WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@stas00
Copy link
Contributor

stas00 commented Dec 3, 2021

🖥 Benchmarking transformers w/ HF Trainer on RTX-3090

We are going to use a special benchmarking tool that will do all the work for us. #14934

This is the index post and specific benchmarks are in their own posts below:

  1. fp16 vs bf16 vs tf32 vs fp32
  2. gradient accumulation steps
  3. gradient checkpointing
  4. batch size
  5. optimizers
  6. combining winning strategies ~2x speed improvement!
  7. RTX-3090 vs A100

See also the same benchmarks for A100

TODO:

  • other suggestions?

Note that each benchmark was run only once, so multiple runs and averaging is probably going to give slightly different results. The purpose here though is to see relative differences roughly and not try to give an exact number.

@stas00 stas00 changed the title [Benchmark] tf32 vs fp16 vs fp32 [Benchmark] tf32 vs fp16 vs bf16 vs fp32 Dec 3, 2021
@stas00
Copy link
Contributor Author

stas00 commented Dec 26, 2021

Diagnostics of T5 slowness under TF32

By Eddie Yan (re-pasted from torch slack)

Here is update/summary after investigating the T5 model further:

The lack of speedup in TF32 can be attributed to bottlenecks in non-GEMM ops (e.g., pointwise ops in custom unfused LayerNorm and custom AdamW optimizer). Without optimization, we see that the custom LayerNorm is comparable in wall clock time to GEMMs due to this bottleneck (first image attachment).

Zoomed in profile on A6000 showing custom “T5LayerNorm” being very expensive compared to cuBLAS GEMM (in green):

T5-A6000-T5LayerNorm

When we zoom out in the profile, the ratio of pointwise ops in the optimizer to compute is further exacerbated by the small batch size of the large model on the 3090; this small batch size means that the GEMM compute intensity is low for the number of pointwise ops incurred by the optimizer, which will update every parameter in the model along with running statistics for a relatively number small of training examples. The second image attachment shows this issue, where the optimizer step wall clock time is comparable to an entire backward step (220+ms vs. 260 ms)!

Zoomed out profile on A6000 showing expensive optimizer step (AdamW) relative to backward pass:

T5-A6000-AdamW

On A6000, a GPU comparable to 3090 in terms of architecture, we've done a study to incrementally gauge the performance of optimizations for TF32 vs. fp32. First, we replaced the custom LayerNorm implementation with PyTorch's native implementation (which despite being different should be good for a rough estimate). While the native implementation is far from optimal, this change yields 38.3 samples/s with TF32 vs. 34.3 with fp32, a ~10% speedup. Turning on gradient accumulation improves performance dramatically as the optimizer to forward-backward compute ratio is abated, but more importantly TF32 is now ~20% faster than fp32 at 90.5 samples/s to 75.1 samples/s for fp32.

Additionally, replacing the custom AdamW with a fused implementation from apex (thanks to @kevin Stephano for suggesting this) yields another small improvement to 92.4 samples/s with TF32 to 75.3 for fp32.As we've seen that the lack of improvement can be attributed to a low ratio of GEMMs to pointwise ops (especially in the optimizer), another way to improve this ratio is to increase the batch size vs. gradient accumulation. This approach really shines on A100, which even with 40GiB of memory allows the batch size to be increased to 64. Perhaps due to higher TF32 throughput, we see that the speedup here is dramatic: 207.4 samples/s to 73.1 samples/s for fp32, an over 2x speedup.


For reference here is the custom "T5LayerNorm" which also ignores the input data types and casts to float32:

def forward(self, hidden_states):
(this is RMS norm and not a normal layer norm, hence we don't have any fast kernels at the moment)


Here is another analysis of t5 but for fp16 NVIDIA/apex#1271


Also trying to dealing with it by trying to manifest an RMSNorm fused kernel here: NVIDIA/apex#1271

@stas00
Copy link
Contributor Author

stas00 commented Dec 27, 2021

Will probably move the following elsewhere, but will save here for now:

Notes on Ampere cards performance comparison

From Christian Sarofeen (re-pasted from torch slack)

Whitepapers on Ampere chips:

  • GA100 (A100)
  • GA102 (RTX-3080, RTX-3090) (consumer grade)
  • GA104 (RTX-3070)

The following numbers are TFLOPs

RTX-3080:

  • FP32: 29.8
  • TF32: 29.8
  • FP16: 59.5 (forget sparsity at the moment)
  • BF16: 59.5

RTX-3090: (~1.17 more powerful than 3080)

  • FP32: 34.87 (29.8*1.17)
  • TF32: 34.87 (29.8*1.17)
  • FP16: 69.61 (59.5*1.17)
  • BF16: 69.61 (59.5*1.17)

A100:

  • FP32: 19.5
  • TF32: 156
  • FP16: 312
  • BF16: 312

Performance comparison:

SM = Streaming multi-processors on the GPU.

GeForce 30 series

If you just look at the SM count it's probably the easiest way to scale. So RTX 3080 has 68 SMs, and 3090 has 82 SMs. Then the clock speeds it's 1440 (1710 with boost) vs 1395 (1695). So the ratio of their compute is 68 * 1440 : 82 * 1395 if we just use the base clocks.

3090 should be more SMs, slightly slower clock. When you have a significantly bigger chip it's common to reduce the clock speed so the overall power consumption isn't over some set budget.

They're both GA102 chips, so more SMs equates pretty trivial to more compute. 3070 on the other hand is a GA104, so comparison to that isn't as straight forward.

You can't straight forwardly compare RTX with A100 like you can within the same chip family. So wikipedia was fine to go from 3080 -> 3090, because they're based on the same chip GA102. A100 is a GA100 so you can't do a simple comparison like that.

@stas00 stas00 changed the title [Benchmark] tf32 vs fp16 vs bf16 vs fp32 [Benchmark] [Ampere] tf32 vs fp16 vs bf16 vs fp32 Dec 27, 2021
@stas00 stas00 changed the title [Benchmark] [Ampere] tf32 vs fp16 vs bf16 vs fp32 [Benchmark] [Ampere] fp16 vs bf16 vs tf32 vs fp32 Dec 30, 2021
@stas00
Copy link
Contributor Author

stas00 commented Jan 3, 2022

precision: fp16 vs bf16 vs tf32 vs fp32

Main interest: benchmarking the new --bf16 and --tf32 on Ampere/RTX-3090, comparatively to fp16 and fp32 modes.

  • bf16 is autocast(dtype=torch.bfloat16)
  • tf32 is torch.backends.cuda.matmul.allow_tf32 = True
Datetime    : 2021-12-29 16:37:16

Software:
transformers: 4.16.0.dev0
torch       : 1.10.1
cuda        : 11.3
python      : 3.8.11

Hardware:
1 GPUs      : NVIDIA GeForce RTX 3090, 23.70GB

Note: to get the best performance make sure you have 2 independent 12V PCIe rails plugged into the card and not 2 splits of the same rail.

Benchmark

The benchmark uses 3 different t5 models, and at the end of the section also gpt2. For t5 the main script is:

CUDA_VISIBLE_DEVICES=0 python \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-small \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --per_device_train_batch_size 32 --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 20000 --dataloader_num_workers 2

and now adding one of:

--tf32 0 # fp32
--tf32 0 --fp16
--tf32 0 --bf16
--tf32 1
--tf32 1 --fp16
--tf32 1 --bf16

But we are going to use a special benchmarking tool that will do all the work for us. #14934

Important notes:

  1. --tf32 0 --fp16 0 combo is just fp32 (which is the default mode - we don't have this option per se)
  2. I changed --per_device_train_batch_size in the base command from 32 (t5-small) to 16 (t5-base) to 8 (t5-large) to be able to fit into the GPU memory while keeping it as occupied as possible.
  3. I changed --max_train_samples in the base command from 20k (t5-small) to 10k (t5-base) to 5k (t5-large) to give each run about 1-3min of run time so that the benchmark doesn't take too too long, but is long enough to put strain on the card.

Benchmark 1: t5-small

CUDA_VISIBLE_DEVICES=0 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/translation/run_translation.py --model_name_or_path t5-small \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --per_device_train_batch_size 32 --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 20000 --dataloader_num_workers 2 ' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'|--fp16|--bf16' '--tf32 0|--tf32 1' --report-metric-keys train_loss \
--repeat-times 1 --base-variation '--tf32 0'
Variation Train
samples
per
second
Diff
%
Train
loss
--tf32 0 286.07 0 2.51
--tf32 1 342.82 20 2.51
--fp16 --tf32 0 422.07 48 2.51
--fp16 --tf32 1 423.18 48 2.51
--bf16 --tf32 0 415.93 45 2.52
--bf16 --tf32 1 418.51 46 2.52

Conclusions:

  • bf16 is 2-3% slower than fp16
  • tf32 makes 0% impact on bf16 and fp16 modes
  • tf32 is 20% faster than fp32, but otherwise doesn't help much with performance

Benchmark 2: t5-base

CUDA_VISIBLE_DEVICES=0 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/translation/run_translation.py --model_name_or_path t5-base \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --per_device_train_batch_size 16 --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 10000 --dataloader_num_workers 2 ' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'|--fp16|--bf16' '--tf32 0|--tf32 1' --report-metric-keys train_loss \
--repeat-times 1 --base-variation '--tf32 0'
Variation Train
samples
per
second
Diff
%
Train
loss
--tf32 0 95.69 0 2.20
--tf32 1 116.58 22 2.20
--fp16 --tf32 0 131.98 38 2.20
--fp16 --tf32 1 132.84 39 2.20
--bf16 --tf32 0 135.47 42 2.21
--bf16 --tf32 1 135.86 42 2.21

Conclusions:

  • similar to t5-small
  • but bf16 is 2-3% faster than fp16!

Benchmark 3: t5-large

CUDA_VISIBLE_DEVICES=0 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/translation/run_translation.py --model_name_or_path t5-large \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --per_device_train_batch_size 8 --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 5000 --dataloader_num_workers 2 ' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'|--fp16|--bf16' '--tf32 0|--tf32 1' --report-metric-keys train_loss \
--repeat-times 1 --base-variation '--tf32 0'
Variation Train
samples
per
second
Diff
%
Train
loss
--tf32 0 31.88 0 2.03
--tf32 1 35.66 12 2.03
--fp16 --tf32 0 47.34 49 0.00
--fp16 --tf32 1 48.08 51 0.00
--bf16 --tf32 0 35.07 10 2.04
--bf16 --tf32 1 35.13 10 2.04

Conclusions:

  • fp16 overflows here (loss=0). I originally wasn't printing the loss and thus missed this and was getting much faster outcome under fp16! But it was totally wrong. (And this is a very well known issue with many bf16-pretrained models that are being attempted to be finetuned in fp16).
  • tf32 makes 0% impact on bf16 mode
  • tf32 is only 12% faster than fp32

Benchmark 4: gpt2

Let's try a different architecture.

CUDA_VISIBLE_DEVICES=0 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/language-modeling/run_clm.py --model_name_or_path gpt2 \
--dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \
--logging_strategy no --save_strategy no --do_train --max_train_samples 2500 \
--per_device_train_batch_size 8 --num_train_epochs 1 --warmup_steps 8 \
--block_size 512 --report_to none ' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'|--fp16|--bf16' '--tf32 0|--tf32 1' --report-metric-keys train_loss \
--repeat-times 1 --base-variation '--tf32 0'
Variation Train
samples
per
second
Diff
%
Train
loss
--tf32 0 26.96 0 3.36
--tf32 1 33.43 24 3.36
--fp16 --tf32 0 42.46 58 3.36
--fp16 --tf32 1 42.43 57 3.36
--bf16 --tf32 0 42.43 57 3.37
--bf16 --tf32 1 42.42 57 3.37

Conclusions:

  • tf32 still far from suggested huge speedups - only 24%
  • as before tf32 makes no difference for fp16/bf16
  • fp16/bf16 perform on par here and are 57% faster than fp32

Benchmark 5: gpt2-medium

and now with gpt-medium (~3x larger than gpt):

CUDA_VISIBLE_DEVICES=0 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/language-modeling/run_clm.py --model_name_or_path gpt2-medium \
--dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \
--logging_strategy no --save_strategy no --do_train --max_train_samples 1200 \
--per_device_train_batch_size 4 --num_train_epochs 1 --warmup_steps 8 \
--block_size 512 --report_to none ' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'|--fp16|--bf16' '--tf32 0|--tf32 1' --report-metric-keys train_loss \
--repeat-times 1 --base-variation '--tf32 0'
Variation Train
samples
per
second
Diff
%
Train
loss
--tf32 0 9.23 0 3.02
--tf32 1 11.48 24 3.01
--fp16 --tf32 0 14.50 57 3.02
--fp16 --tf32 1 14.52 57 3.02
--bf16 --tf32 0 14.56 58 3.02
--bf16 --tf32 1 14.55 58 3.02

Conclusions:

  • % diff is same as the smaller gpt2 model

@stas00
Copy link
Contributor Author

stas00 commented Jan 3, 2022

gradient accumulation steps

Let's choose t5-base model to test with as it's pretty large yet doesn't overflow like t5-large.

Let's measure --gradient_accumulation_steps 1,2,4,8,16 with different precision configurations.

*** Results:

Variation Train
samples
per
second
Diff
%
Train
loss
--gradient_accumulation_steps 1 --tf32 0 96.17 0 2.20
--gradient_accumulation_steps 1 --tf32 1 116.57 21 2.20
--gradient_accumulation_steps 1 --tf32 0 --fp16 132.64 38 2.20
--gradient_accumulation_steps 1 --tf32 0 --bf16 136.35 42 2.21
--gradient_accumulation_steps 2 --tf32 0 103.83 8 2.28
--gradient_accumulation_steps 2 --tf32 1 130.11 35 2.28
--gradient_accumulation_steps 2 --tf32 0 --fp16 153.09 59 2.28
--gradient_accumulation_steps 2 --tf32 0 --bf16 156.70 63 2.29
--gradient_accumulation_steps 4 --tf32 0 108.48 13 2.39
--gradient_accumulation_steps 4 --tf32 1 137.75 43 2.40
--gradient_accumulation_steps 4 --tf32 0 --fp16 164.48 71 2.40
--gradient_accumulation_steps 4 --tf32 0 --bf16 170.01 77 2.42
--gradient_accumulation_steps 8 --tf32 0 111.14 16 2.57
--gradient_accumulation_steps 8 --tf32 1 141.59 47 2.57
--gradient_accumulation_steps 8 --tf32 0 --fp16 170.77 78 2.57
--gradient_accumulation_steps 8 --tf32 0 --bf16 177.59 85 2.62
--gradient_accumulation_steps 16 --tf32 0 112.65 17 2.81
--gradient_accumulation_steps 16 --tf32 1 143.89 50 2.81
--gradient_accumulation_steps 16 --tf32 0 --fp16 173.69 81 2.81
--gradient_accumulation_steps 16 --tf32 0 --bf16 181.04 88 2.86

Let's filter out just one subset so that it's easier to compare the gradient accumulation differences alone, so re-running with just bf16 enabled ( --tf32 0 --bf16):

Variation Train
samples
per
second
Diff
%
Train
loss
--gradient_accumulation_steps 1 135.85 0 2.21
--gradient_accumulation_steps 2 156.95 16 2.29
--gradient_accumulation_steps 4 167.65 23 2.42
--gradient_accumulation_steps 8 175.02 29 2.62
--gradient_accumulation_steps 16 179.15 32 2.86

Conclusions:

  • that's a significant speed up for even 4 steps
  • notice that the loss gets much bigger with the higher accumulation steps - my benchmark is very short and with less steps to take when the batches are larger, the model simply doesn't have a chance to step down far enough. The same can be observed with just normal batch size changes.
    Non-zero lr warm up too plays a role here since it's a very short run.
*** Setup:


Datetime    : 2022-01-03 14:53:02

Software:
transformers: 4.16.0.dev0
torch       : 1.10.1
cuda        : 11.3
python      : 3.8.11

Hardware:
1 GPUs      : NVIDIA GeForce RTX 3090, 23.70GB


*** The benchmark command line was:

CUDA_VISIBLE_DEVICES=0 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/translation/run_translation.py --model_name_or_path t5-base \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --per_device_train_batch_size 16 --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 10000 --dataloader_num_workers 2 ' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'--gradient_accumulation_steps 1|--gradient_accumulation_steps 2|--gradient_accumulation_steps 4|--gradient_accumulation_steps 8|--gradient_accumulation_steps 16' \
'--tf32 0|--tf32 1|--tf32 0 --fp16|--tf32 0 --bf16' --report-metric-keys \
train_loss --repeat-times 1

@stas00
Copy link
Contributor Author

stas00 commented Jan 3, 2022

gradient checkpointing

Let's choose t5-base model to test with as it's pretty large yet doesn't overflow like t5-large.

Let's benchmark enabling --gradient_checkpointing

Variation Train
samples
per
second
Diff
%
Train
loss
--gradient_checkpointing 0 135.82 24 2.21
--gradient_checkpointing 1 109.24 0 2.21

Conclusions:

  • as expected since gradient checkpointing recalculates forward activations it should be slower - we get a 24% slowdown here.

Let's look at memory:

Variation Train
samples
per
second
Diff
%
Train
loss
Train
mem
gpu
alloc
delta
Train
mem
gpu
peaked
delta
--gradient_checkpointing 0 63.42 32 2.17 2684MB 3340MB
--gradient_checkpointing 1 47.96 0 2.17 2676MB 1245MB

We can clearly see that peak GPU memory is ~2/3 less.

note: I had to half BS in the 2nd benchmark as I was getting OOM. Plus memory metrics slow things down.

*** The benchmark command lines were:

1.

CUDA_VISIBLE_DEVICES=0 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/translation/run_translation.py --model_name_or_path t5-base \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --per_device_train_batch_size 16 --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 10000 --dataloader_num_workers 2 --bf16' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'--gradient_checkpointing 0|--gradient_checkpointing 1' --report-metric-keys \
train_loss --repeat-times 1

2.

CUDA_VISIBLE_DEVICES=0 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/translation/run_translation.py --model_name_or_path t5-base \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --per_device_train_batch_size 8 --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 5000 --dataloader_num_workers 2 --bf16 --skip_memory_metrics 0' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'--gradient_checkpointing 0|--gradient_checkpointing 1' \
--report-metric-keys 'train_loss train_mem_gpu_alloc_delta train_mem_gpu_peaked_delta' \
--repeat-times 1

@stas00
Copy link
Contributor Author

stas00 commented Jan 4, 2022

batch size

Variation Train
samples
per
second
Diff
%
Train
loss
--per_device_train_batch_size 1 10.04 0 1.90
--per_device_train_batch_size 2 19.39 93 2.01
--per_device_train_batch_size 4 38.66 285 2.09
--per_device_train_batch_size 8 77.52 672 2.17
--per_device_train_batch_size 16 144.12 1335 2.26

Conclusions:

  • No surprise here, the speed here is directly proportional to the gpu capacity utilization. In this particular configuration BS=16 is the highest BS we can fit. So when we use BS=1 we greatly underutilize the GPU. The speed up is linear and almost directly proportional to the batch-size.
  • as with gradient accumulation steps lm loss gets worse with the increase in the batch size because my benchmark is very short and with less steps to take when the batches are larger, the model simply doesn't have a chance to step down far enough.

*** Setup:


Datetime    : 2022-01-03 17:10:28

Software:
transformers: 4.16.0.dev0
torch       : 1.10.1
cuda        : 11.3
python      : 3.8.11

Hardware:
1 GPUs      : NVIDIA GeForce RTX 3090, 23.70GB


*** The benchmark command line was:

CUDA_VISIBLE_DEVICES=0 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/translation/run_translation.py --model_name_or_path t5-base \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 5000 --dataloader_num_workers 2 --bf16' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'--per_device_train_batch_size 1|--per_device_train_batch_size 2|--per_device_train_batch_size 4|--per_device_train_batch_size 8|--per_device_train_batch_size 16' \
--report-metric-keys train_loss --repeat-times 1

@stas00 stas00 changed the title [Benchmark] [Ampere] fp16 vs bf16 vs tf32 vs fp32 [Benchmark] HF Trainer on RTX-3090 Jan 4, 2022
@stas00 stas00 added the Benchmarks Issues related to Memory regressions in tests and scripts label Jan 4, 2022
@stas00 stas00 self-assigned this Jan 4, 2022
@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Jan 4, 2022
@stas00
Copy link
Contributor Author

stas00 commented Jan 4, 2022

optimizers

Let's do fp32 first:

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_hf 116.95 4 2.20
--optim adamw_torch 112.60 0 2.20
--optim adafactor 90.55 -20 2.20
--optim adamw_apex_fused 126.38 12 2.20

Observations:

  • apex's FusedAdam is the fastest.

fp16:

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_hf 132.49 4 2.20
--optim adamw_torch 126.84 0 2.20
--optim adafactor 101.91 -20 2.20
--optim adamw_apex_fused 144.54 14 2.20

bf16:

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_hf 136.49 4 2.21
--optim adamw_torch 130.66 0 2.21
--optim adafactor 104.65 -20 2.22
--optim adamw_apex_fused 148.51 14 2.21

Observations:

  • The relative speed up is the same

# fp32
CUDA_VISIBLE_DEVICES=0 python \
/hf/transformers-trainer-benchmark/scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-base --output_dir output_dir \
--do_train --label_smoothing 0.1 --logging_strategy no --save_strategy no --per_device_train_batch_size 16 \
--max_source_length 512 --max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: "  --warmup_steps 50 \
--max_train_samples 10000 --dataloader_num_workers 2 \
' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'--optim adamw_hf|--optim adamw_torch|--optim adafactor|--optim adamw_apex_fused' \
--report-metric-keys train_loss --base-variation '--optim adamw_torch'

# fp16 - just add --fp16 to base-cmd


# bf16 - just add --bf16 to base-cmd

@stas00
Copy link
Contributor Author

stas00 commented Jan 4, 2022

combining winning strategies

Now let's combine the winning strategies from each individual benchmark above and compare with the baseline:

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_torch --gradient_accumulation_steps 1 --tf32 0 93.40 0 2.20
--optim adamw_apex_fused --gradient_accumulation_steps 8 --tf32 --bf16 178.90 92 2.62

Getting an almost 2x improvement in speed!

CUDA_VISIBLE_DEVICES=0 python \
/hf/transformers-trainer-benchmark/scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-base --output_dir output_dir \
--do_train --label_smoothing 0.1 --logging_strategy no --save_strategy no --per_device_train_batch_size 16 \
--max_source_length 512 --max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: "  --warmup_steps 50 \
--max_train_samples 10000 --dataloader_num_workers 2 \
' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'--optim adamw_torch --gradient_accumulation_steps 1 --tf32 0|--optim adamw_apex_fused --gradient_accumulation_steps 8 --tf32 --bf16' \
--report-metric-keys train_loss

@miteigi-nemoto
Copy link

What CPU and DDR memory was used in the test with RTX 3090?

@stas00
Copy link
Contributor Author

stas00 commented May 30, 2024

It's a good point that I didn't publish that, but my desktop has been upgraded since then, so I don't have that information any longer.

It should be trivial for anybody to redo these benchmarks on the hw they have and know the exact setup.

@miteigi-nemoto
Copy link

I am accumulating evidence that CPU single core speed is critical for GPU performance. It would be helpful if you could post the CPU and memory specs when possible.

@stas00
Copy link
Contributor Author

stas00 commented May 30, 2024

This will not happen as I explained above since I no longer have that hardware.

@miteigi-nemoto
Copy link

I meant in future benchmarks) Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Benchmarks Issues related to Memory regressions in tests and scripts WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests

2 participants