-
Notifications
You must be signed in to change notification settings - Fork 651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Triton fused dropout/bias #58
Conversation
9205e28
to
8aadd22
Compare
43f3ad5
to
7011c29
Compare
Codecov Report
@@ Coverage Diff @@
## main #58 +/- ##
==========================================
- Coverage 86.93% 86.65% -0.29%
==========================================
Files 49 49
Lines 2442 2473 +31
==========================================
+ Hits 2123 2143 +20
- Misses 319 330 +11
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
7011c29
to
1c4f08a
Compare
1c4f08a
to
159e3b2
Compare
535d0bd
to
5031876
Compare
- adding a proper backward - adding the plots
5031876
to
cf69993
Compare
e67c7c2
to
043459d
Compare
- [x] adjust benchmarks - [x] adjust unit test
043459d
to
7928275
Compare
@@ -42,52 +42,65 @@ Some examples, generated with `python3 xformers/benchmarks/benchmark_encoder.py | |||
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_softmax.py`. The units are GB/s. These results are for a nVidia V100, Triton 1.1 and PyTorch 1.9. | |||
Note that in the Triton case the slowdowns at extreme sizes are because of register spilling, A100s get much better performance. | |||
|
|||
 | |||
 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moving all the plots to subfolders to clean that up a little
0b1a990
to
1a31785
Compare
a178fda
to
17f715d
Compare
Full training checked on the microViT example, all good (slightly better results for some reason, 15-20% better speed) |
Apples to apples comparison: We're now even a tiny bit faster with the layernorm change which landed in between |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice improvement in GB/s! :)
units="GB/s", | ||
dash_key="pytorch", | ||
) | ||
|
||
|
||
for bw in [False, True]: | ||
bench_dropout(bw) | ||
for activation in [Activation.GeLU, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should others be benchmarked as well or its pretty similar for all of them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gelu is kind of the worst case for activations and Triton (it uses the real exp function and not the fastmath equivalent), and "None" negates the fusion (well, we fuse with nothingness :D), so these are kind of hard goals to test truthfully. It's a little slow to test with everything, but could be extended on a ad-hoc basis if people would like ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay, that makes sense! Guess its pretty straightforward to modify and add an activation if someone needs to
…ut (facebookresearch#58) * Baseline plots * Avoid materializing dense matrix in dropout * Add new version of SparseBMM Previous version was materializing a matrix which exeeded int32 on each dimension * Use matmul_with_mask in SparseBMM backward * Fix lint * Protect against in-place dropout * Add autograd tests for sparse_bmm
What does this PR do?
TODO:
--- Type: torch.float16 ---
--- Type: torch.float32 ---
--- Type: torch.float16 ---
--- Type: torch.float32 ---
--- Type: torch.float16 ---
--- Type: torch.float32 ---
--- Type: torch.float16 ---
--- Type: torch.float32 ---
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.