-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Benchmark] HF Trainer on RTX-3090 #14608
Comments
Diagnostics of T5 slowness under TF32By Eddie Yan (re-pasted from torch slack) Here is update/summary after investigating the T5 model further: The lack of speedup in TF32 can be attributed to bottlenecks in non-GEMM ops (e.g., pointwise ops in custom unfused LayerNorm and custom AdamW optimizer). Without optimization, we see that the custom LayerNorm is comparable in wall clock time to GEMMs due to this bottleneck (first image attachment). Zoomed in profile on A6000 showing custom “T5LayerNorm” being very expensive compared to cuBLAS GEMM (in green): When we zoom out in the profile, the ratio of pointwise ops in the optimizer to compute is further exacerbated by the small batch size of the large model on the 3090; this small batch size means that the GEMM compute intensity is low for the number of pointwise ops incurred by the optimizer, which will update every parameter in the model along with running statistics for a relatively number small of training examples. The second image attachment shows this issue, where the optimizer step wall clock time is comparable to an entire backward step (220+ms vs. 260 ms)! Zoomed out profile on A6000 showing expensive optimizer step (AdamW) relative to backward pass: On A6000, a GPU comparable to 3090 in terms of architecture, we've done a study to incrementally gauge the performance of optimizations for TF32 vs. fp32. First, we replaced the custom LayerNorm implementation with PyTorch's native implementation (which despite being different should be good for a rough estimate). While the native implementation is far from optimal, this change yields 38.3 samples/s with TF32 vs. 34.3 with fp32, a ~10% speedup. Turning on gradient accumulation improves performance dramatically as the optimizer to forward-backward compute ratio is abated, but more importantly TF32 is now ~20% faster than fp32 at 90.5 samples/s to 75.1 samples/s for fp32. Additionally, replacing the custom AdamW with a fused implementation from apex (thanks to @kevin Stephano for suggesting this) yields another small improvement to 92.4 samples/s with TF32 to 75.3 for fp32.As we've seen that the lack of improvement can be attributed to a low ratio of GEMMs to pointwise ops (especially in the optimizer), another way to improve this ratio is to increase the batch size vs. gradient accumulation. This approach really shines on A100, which even with 40GiB of memory allows the batch size to be increased to 64. Perhaps due to higher TF32 throughput, we see that the speedup here is dramatic: 207.4 samples/s to 73.1 samples/s for fp32, an over 2x speedup. For reference here is the custom "T5LayerNorm" which also ignores the input data types and casts to float32:
Here is another analysis of t5 but for fp16 NVIDIA/apex#1271 Also trying to dealing with it by trying to manifest an RMSNorm fused kernel here: NVIDIA/apex#1271 |
Will probably move the following elsewhere, but will save here for now: Notes on Ampere cards performance comparisonFrom Christian Sarofeen (re-pasted from torch slack) Whitepapers on Ampere chips: The following numbers are TFLOPs RTX-3080:
RTX-3090: (~1.17 more powerful than 3080)
A100:
Performance comparison: SM = Streaming multi-processors on the GPU. If you just look at the SM count it's probably the easiest way to scale. So RTX 3080 has 68 SMs, and 3090 has 82 SMs. Then the clock speeds it's 1440 (1710 with boost) vs 1395 (1695). So the ratio of their compute is 68 * 1440 : 82 * 1395 if we just use the base clocks. 3090 should be more SMs, slightly slower clock. When you have a significantly bigger chip it's common to reduce the clock speed so the overall power consumption isn't over some set budget. They're both GA102 chips, so more SMs equates pretty trivial to more compute. 3070 on the other hand is a GA104, so comparison to that isn't as straight forward. You can't straight forwardly compare RTX with A100 like you can within the same chip family. So wikipedia was fine to go from 3080 -> 3090, because they're based on the same chip GA102. A100 is a GA100 so you can't do a simple comparison like that. |
precision: fp16 vs bf16 vs tf32 vs fp32Main interest: benchmarking the new --bf16 and --tf32 on Ampere/RTX-3090, comparatively to fp16 and fp32 modes.
Note: to get the best performance make sure you have 2 independent 12V PCIe rails plugged into the card and not 2 splits of the same rail. BenchmarkThe benchmark uses 3 different t5 models, and at the end of the section also gpt2. For t5 the main script is:
and now adding one of:
But we are going to use a special benchmarking tool that will do all the work for us. #14934 Important notes:
Benchmark 1: t5-small
Conclusions:
Benchmark 2: t5-base
Conclusions:
Benchmark 3: t5-large
Conclusions:
Benchmark 4: gpt2Let's try a different architecture.
Conclusions:
Benchmark 5: gpt2-mediumand now with
Conclusions:
|
gradient accumulation stepsLet's choose Let's measure *** Results:
Let's filter out just one subset so that it's easier to compare the gradient accumulation differences alone, so re-running with just bf16 enabled (
Conclusions:
|
gradient checkpointingLet's choose Let's benchmark enabling
Conclusions:
Let's look at memory:
We can clearly see that peak GPU memory is ~2/3 less. note: I had to half BS in the 2nd benchmark as I was getting OOM. Plus memory metrics slow things down.
|
batch size
Conclusions:
|
optimizersLet's do fp32 first:
Observations:
fp16:
bf16:
Observations:
|
combining winning strategiesNow let's combine the winning strategies from each individual benchmark above and compare with the baseline:
Getting an almost 2x improvement in speed!
|
What CPU and DDR memory was used in the test with RTX 3090? |
It's a good point that I didn't publish that, but my desktop has been upgraded since then, so I don't have that information any longer. It should be trivial for anybody to redo these benchmarks on the hw they have and know the exact setup. |
I am accumulating evidence that CPU single core speed is critical for GPU performance. It would be helpful if you could post the CPU and memory specs when possible. |
This will not happen as I explained above since I no longer have that hardware. |
I meant in future benchmarks) Thank you |
🖥 Benchmarking
transformers
w/ HF Trainer on RTX-3090We are going to use a special benchmarking tool that will do all the work for us. #14934
This is the index post and specific benchmarks are in their own posts below:
See also the same benchmarks for A100
TODO:
Note that each benchmark was run only once, so multiple runs and averaging is probably going to give slightly different results. The purpose here though is to see relative differences roughly and not try to give an exact number.
The text was updated successfully, but these errors were encountered: