Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Triton fused dropout/bias #58

Merged
merged 9 commits into from
Nov 12, 2021
Merged

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Oct 29, 2021

What does this PR do?

TODO:

  • Implement fused FW
  • Implement BW
  • Fuse an activation (including BW)
  • Add a dedicated layer and hook it up to FusedMLP

--- Type: torch.float16 ---

Units: GB/s B=8, M=256, K=512 B=8, M=512, K=1024 B=4, M=1024, K=1024 B=2, M=2048, K=2048 B=2, M=4096, K=4096 B=1, M=2048, K=12288
pytorch - bias: True - fw+bw - act: gelu 37.9 60.2 60.2 67.4 75.4 74.4
triton - bias: True - fw+bw - act: gelu 45.5 70.6 70.6 80.5 92.1 90.5

--- Type: torch.float32 ---

Units: GB/s B=8, M=256, K=512 B=8, M=512, K=1024 B=4, M=1024, K=1024 B=2, M=2048, K=2048 B=2, M=4096, K=4096 B=1, M=2048, K=12288
pytorch - bias: True - fw+bw - act: gelu 67.7 77.5 77.5 82.9 88.3 87.7
triton - bias: True - fw+bw - act: gelu 76.6 106.7 106.7 119.8 130.0 130.8

--- Type: torch.float16 ---

Units: GB/s B=8, M=256, K=512 B=8, M=512, K=1024 B=4, M=1024, K=1024 B=2, M=2048, K=2048 B=2, M=4096, K=4096 B=1, M=2048, K=12288
pytorch - bias: False - fw+bw - act: gelu 54.6 74.5 74.1 82.3 90.5 89.4
triton - bias: False - fw+bw - act: gelu 52.5 78.0 78.0 88.3 98.8 97.4

--- Type: torch.float32 ---

Units: GB/s B=8, M=256, K=512 B=8, M=512, K=1024 B=4, M=1024, K=1024 B=2, M=2048, K=2048 B=2, M=4096, K=4096 B=1, M=2048, K=12288
pytorch - bias: False - fw+bw - act: gelu 81.1 94.2 94.4 100.4 105.5 104.9
triton - bias: False - fw+bw - act: gelu 91.0 121.8 122.3 134.8 142.4 143.6

--- Type: torch.float16 ---

Units: GB/s B=8, M=256, K=512 B=8, M=512, K=1024 B=4, M=1024, K=1024 B=2, M=2048, K=2048 B=2, M=4096, K=4096 B=1, M=2048, K=12288
pytorch - bias: True - fw - act: gelu 146.3 174.3 174.3 189.4 201.0 199.9
triton - bias: True - fw - act: gelu 178.1 292.6 292.6 334.4 369.2 364.2

--- Type: torch.float32 ---

Units: GB/s B=8, M=256, K=512 B=8, M=512, K=1024 B=4, M=1024, K=1024 B=2, M=2048, K=2048 B=2, M=4096, K=4096 B=1, M=2048, K=12288
pytorch - bias: True - fw - act: gelu 210.1 232.4 232.4 245.5 254.3 253.4
triton - bias: True - fw - act: gelu 341.4 565.0 565.0 649.0 725.2 717.7

--- Type: torch.float16 ---

Units: GB/s B=8, M=256, K=512 B=8, M=512, K=1024 B=4, M=1024, K=1024 B=2, M=2048, K=2048 B=2, M=4096, K=4096 B=1, M=2048, K=12288
pytorch - bias: False - fw - act: gelu 195.0 252.1 252.1 273.1 288.7 287.4
triton - bias: False - fw - act: gelu 178.1 292.6 297.9 334.4 371.3 366.8

--- Type: torch.float32 ---

Units: GB/s B=8, M=256, K=512 B=8, M=512, K=1024 B=4, M=1024, K=1024 B=2, M=2048, K=2048 B=2, M=4096, K=4096 B=1, M=2048, K=12288
pytorch - bias: False - fw - act: gelu 282.5 334.4 337.8 356.2 369.7 367.5
triton - bias: False - fw - act: gelu 341.3 574.9 574.9 662.0 736.4 728.2

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 29, 2021
@blefaudeux blefaudeux marked this pull request as draft October 29, 2021 22:36
@blefaudeux blefaudeux force-pushed the triton_dropout_fused_bias branch 2 times, most recently from 9205e28 to 8aadd22 Compare October 29, 2021 23:20
@blefaudeux blefaudeux changed the title [DRAFT][feat] Triton fused dropout/bias [feat] Triton fused dropout/bias Oct 29, 2021
@blefaudeux blefaudeux force-pushed the triton_dropout_fused_bias branch 2 times, most recently from 43f3ad5 to 7011c29 Compare October 30, 2021 02:58
@codecov-commenter
Copy link

codecov-commenter commented Oct 30, 2021

Codecov Report

Merging #58 (e633db9) into main (cc7c8ed) will decrease coverage by 0.28%.
The diff coverage is 66.03%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #58      +/-   ##
==========================================
- Coverage   86.93%   86.65%   -0.29%     
==========================================
  Files          49       49              
  Lines        2442     2473      +31     
==========================================
+ Hits         2123     2143      +20     
- Misses        319      330      +11     
Flag Coverage Δ
Python 86.65% <66.03%> (-0.29%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/triton/activations.py 54.54% <0.00%> (-4.28%) ⬇️
xformers/triton/fused_linear_layer.py 92.85% <ø> (-0.48%) ⬇️
xformers/triton/dropout.py 75.00% <72.72%> (+0.71%) ⬆️
xformers/components/feedforward/fused_mlp.py 91.30% <100.00%> (+0.39%) ⬆️
xformers/triton/__init__.py 69.23% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update cc7c8ed...e633db9. Read the comment docs.

@blefaudeux blefaudeux force-pushed the triton_dropout_fused_bias branch from 7011c29 to 1c4f08a Compare November 1, 2021 18:17
@blefaudeux blefaudeux requested a review from fmassa November 1, 2021 18:21
@blefaudeux blefaudeux marked this pull request as ready for review November 1, 2021 18:21
@blefaudeux blefaudeux force-pushed the triton_dropout_fused_bias branch from 1c4f08a to 159e3b2 Compare November 1, 2021 19:30
@blefaudeux blefaudeux marked this pull request as draft November 1, 2021 20:39
@blefaudeux blefaudeux force-pushed the triton_dropout_fused_bias branch 3 times, most recently from 535d0bd to 5031876 Compare November 5, 2021 22:18
- adding a proper backward

- adding the plots
@blefaudeux blefaudeux force-pushed the triton_dropout_fused_bias branch from 5031876 to cf69993 Compare November 9, 2021 17:42
@blefaudeux blefaudeux changed the title [feat] Triton fused dropout/bias [DRAFT][feat] Triton fused dropout/bias Nov 9, 2021
@blefaudeux blefaudeux force-pushed the triton_dropout_fused_bias branch 5 times, most recently from e67c7c2 to 043459d Compare November 10, 2021 00:01
- [x] adjust benchmarks
- [x] adjust unit test
@blefaudeux blefaudeux force-pushed the triton_dropout_fused_bias branch from 043459d to 7928275 Compare November 10, 2021 00:17
@blefaudeux blefaudeux changed the title [DRAFT][feat] Triton fused dropout/bias [feat] Triton fused dropout/bias Nov 10, 2021
@blefaudeux blefaudeux marked this pull request as ready for review November 10, 2021 00:20
@blefaudeux blefaudeux marked this pull request as draft November 10, 2021 04:03
@blefaudeux blefaudeux marked this pull request as ready for review November 10, 2021 18:15
@@ -42,52 +42,65 @@ Some examples, generated with `python3 xformers/benchmarks/benchmark_encoder.py
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_softmax.py`. The units are GB/s. These results are for a nVidia V100, Triton 1.1 and PyTorch 1.9.
Note that in the Triton case the slowdowns at extreme sizes are because of register spilling, A100s get much better performance.

![Softmax throughput in fp16 - inference](docs/plots/Softmax_Bandwidth_FW_fp16.png)
![Softmax throughput in fp16 - inference](docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moving all the plots to subfolders to clean that up a little

@blefaudeux blefaudeux force-pushed the triton_dropout_fused_bias branch from 0b1a990 to 1a31785 Compare November 10, 2021 18:16
@blefaudeux blefaudeux force-pushed the triton_dropout_fused_bias branch from a178fda to 17f715d Compare November 10, 2021 19:52
@blefaudeux
Copy link
Contributor Author

Full training checked on the microViT example, all good (slightly better results for some reason, 15-20% better speed)

@blefaudeux
Copy link
Contributor Author

ping review @jieru-hu @dianaml0 , if you don't mind ? This should be ready

Copy link
Contributor

@dianaml0 dianaml0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvement in GB/s! :)

units="GB/s",
dash_key="pytorch",
)


for bw in [False, True]:
bench_dropout(bw)
for activation in [Activation.GeLU, None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should others be benchmarked as well or its pretty similar for all of them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gelu is kind of the worst case for activations and Triton (it uses the real exp function and not the fastmath equivalent), and "None" negates the fusion (well, we fuse with nothingness :D), so these are kind of hard goals to test truthfully. It's a little slow to test with everything, but could be extended on a ad-hoc basis if people would like ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay, that makes sense! Guess its pretty straightforward to modify and add an activation if someone needs to

@blefaudeux blefaudeux merged commit eb80810 into main Nov 12, 2021
@blefaudeux blefaudeux deleted the triton_dropout_fused_bias branch November 12, 2021 22:54
xwhan pushed a commit to xwhan/xformers that referenced this pull request Feb 8, 2022
…ut (facebookresearch#58)

* Baseline plots

* Avoid materializing dense matrix in dropout

* Add new version of SparseBMM

Previous version was materializing a matrix which exeeded int32 on each dimension

* Use matmul_with_mask in SparseBMM backward

* Fix lint

* Protect against in-place dropout

* Add autograd tests for sparse_bmm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants