Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] InplaceNorm, affine-less LayerNorm #53

Merged
merged 2 commits into from
Oct 29, 2021

Conversation

ClashLuke
Copy link
Contributor

This should resolve #50

  • InplaceNorm
  • affine-less LayerNorm
  • GLU

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 29, 2021
@ClashLuke
Copy link
Contributor Author

Considering that the fused matmul kernel already supports a joint compilation with the activation function, I'm not sure how useful it'd be to add the activation part of GLU to normalization.
I could implement something like norm_and_glu(f(x), gelu_and_g(x)) where f() and gelu_and_g() are already handled in k_fused_matmul.py.
Does that sound good?

@ClashLuke
Copy link
Contributor Author

Here are the benchmarks for LayerNorm compared to affine-less LayerNorm

Upstream:

 --- Type: torch.float16 --- 
| Units: GB/s                                    |B=8, M=256, K=512   |B=8, M=512, K=1024  |B=4, M=1024, K=1024 |B=2, M=2048, K=2048 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw                                    |25.0                |40.1                |40.0                |55.1                |67.6                |74.4                |
|triton - fw                                     |143.9               |158.3               |158.8               |161.9               |163.4               |162.8               |


 --- Type: torch.float32 --- 
| Units: GB/s                                    |B=8, M=256, K=512   |B=8, M=512, K=1024  |B=4, M=1024, K=1024 |B=2, M=2048, K=2048 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw                                    |44.5                |65.4                |65.3                |83.0                |93.7                |95.9                |
|triton - fw                                     |156.7               |162.2               |162.2               |162.7               |162.3               |146.3               |


 --- Type: torch.float16 --- 
| Units: GB/s                                    |B=8, M=256, K=512   |B=8, M=512, K=1024  |B=4, M=1024, K=1024 |B=2, M=2048, K=2048 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|----------------	----|
|pytorch - fw+bw                                 |10.1                |14.1                |14.1                |16.4                |17.5                |17.3                |
|triton - fw+bw                                  |22.5                |21.2                |21.2                |21.6                |21.9                |21.8                |


 --- Type: torch.float32 --- 
| Units: GB/s                                    |B=8, M=256, K=512   |B=8, M=512, K=1024  |B=4, M=1024, K=1024 |B=2, M=2048, K=2048 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw+bw                                 |15.2                |18.3                |18.3                |19.6                |20.3                |20.1                |
|triton - fw+bw                                  |21.1                |21.6                |21.7                |21.8                |22.0                |20.6                |


=================

Proposed LayerNorm:
 
 --- Type: torch.float16 --- 
| Units: GB/s                                    |B=8, M=256, K=512   |B=8, M=512, K=1024  |B=4, M=1024, K=1024 |B=2, M=2048, K=2048 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw                                    |24.9                |39.8                |39.9                |55.2                |67.8                |74.4                |
|triton - fw                                     |143.6               |158.3               |158.8               |161.9               |163.4               |162.8               |


 --- Type: torch.float32 --- 
| Units: GB/s                                    |B=8, M=256, K=512   |B=8, M=512, K=1024  |B=4, M=1024, K=1024 |B=2, M=2048, K=2048 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw                                    |44.5                |65.5                |65.4                |83.2                |93.9                |95.9                |
|triton - fw                                     |156.6               |162.2               |162.2               |162.8               |162.3               |145.6               |


 --- Type: torch.float16 --- 
| Units: GB/s                                    |B=8, M=256, K=512   |B=8, M=512, K=1024  |B=4, M=1024, K=1024 |B=2, M=2048, K=2048 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw+bw                                 |10.1                |14.1                |14.1                |16.4                |17.5                |17.3                |
|triton - fw+bw                                  |22.5                |21.2                |21.2                |21.6                |22.0                |17.3                |


 --- Type: torch.float32 --- 
| Units: GB/s                                    |B=8, M=256, K=512   |B=8, M=512, K=1024  |B=4, M=1024, K=1024 |B=2, M=2048, K=2048 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw+bw                                 |15.2                |18.3                |18.3                |19.6                |20.3                |20.0                |
|triton - fw+bw                                  |21.1                |21.7                |21.7                |21.9                |22.0                |18.2                |


=================

Proposed affine-less LayerNorm:


 --- Type: torch.float16 --- 
| Units: GB/s                                    |B=8, M=256, K=512   |B=8, M=512, K=1024  |B=4, M=1024, K=1024 |B=2, M=2048, K=2048 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw                                    |24.9                |39.8                |39.9                |55.1                |67.8                |74.4                |
|triton - fw                                     |146.3               |159.9               |160.2               |162.3               |164.3               |163.6               |


 --- Type: torch.float32 --- 
| Units: GB/s                                    |B=8, M=256, K=512   |B=8, M=512, K=1024  |B=4, M=1024, K=1024 |B=2, M=2048, K=2048 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw                                    |44.5                |65.5                |65.4                |83.1                |94.0                |95.9                |
|triton - fw                                     |157.7               |162.7               |162.8               |163.8               |163.8               |163.7               |


 --- Type: torch.float16 --- 
| Units: GB/s                                    |B=8, M=256, K=512   |B=8, M=512, K=1024  |B=4, M=1024, K=1024 |B=2, M=2048, K=2048 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw+bw                                 |10.1                |14.1                |14.1                |16.4                |17.5                |17.3                |
|triton - fw+bw                                  |25.9                |29.0                |29.0                |29.6                |30.2                |30.0                |


 --- Type: torch.float32 --- 
| Units: GB/s                                    |B=8, M=256, K=512   |B=8, M=512, K=1024  |B=4, M=1024, K=1024 |B=2, M=2048, K=2048 |B=2, M=4096, K=4096 |B=1, M=2048, K=12288|
|------------------------------------------------|--------------------|--------------------|--------------------|--------------------|--------------------|--------------------|
|pytorch - fw+bw                                 |15.2                |18.3                |18.3                |19.6                |20.3                |20.0                |
|triton - fw+bw                                  |28.1                |29.7                |29.7                |30.0                |30.2                |30.1                |

@codecov-commenter
Copy link

Codecov Report

❗ No coverage uploaded for pull request base (main@aeeedeb). Click here to learn what that means.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##             main      #53   +/-   ##
=======================================
  Coverage        ?   54.48%           
=======================================
  Files           ?       70           
  Lines           ?     4076           
  Branches        ?        0           
=======================================
  Hits            ?     2221           
  Misses          ?     1855           
  Partials        ?        0           
Flag Coverage Δ
Python 54.48% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.


Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update aeeedeb...d1f95c4. Read the comment docs.

@ClashLuke
Copy link
Contributor Author

I just noticed that this PR contains parts of #45. Is that going to be an issue? Should I pick it apart?

@blefaudeux
Copy link
Contributor

I just noticed that this PR contains parts of #45. Is that going to be an issue? Should I pick it apart?
two options:

  • either in github, on the PR, you edit it and show that it's agains the other branch, and not main -> you'll see only your changes
  • or on your side you can rebase this onto main to remove the dependency

it's easier for the review and landing if we keep things orthogonal, if that's ok with you. ?

@blefaudeux
Copy link
Contributor

Else the results are pretty great @ClashLuke, and composing the transforms while the data is in shared memory is the right thing to do I think

"Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm"

# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB

# fmt: off
meta = {"BLOCK_SIZE_N": ctx.BLOCK_SIZE_N,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably saw that, but this is the same as passing it directly down when calling the kernel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could integrate it into the two calls below. It just felt cleaner to not repeat shared arguments.

Copy link
Contributor

@blefaudeux blefaudeux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provided you rebase out of the other PR, this looks great to me, thanks a lot @ClashLuke. I'd like to homogenize the coding style on all the kernels so that this part is easier to grasp, but I'll do that after this lands

@blefaudeux blefaudeux changed the title WIP: InplaceNorm, affine-less LayerNorm, GLU [feat] InplaceNorm, affine-less LayerNorm Oct 29, 2021
@blefaudeux
Copy link
Contributor

blefaudeux commented Oct 29, 2021

@ClashLuke ok to ?

  • rebase out of the reversible branch ? (you can 'git rebase --onto main {hash of the commit just before the stack you want to move}')
  • implement the GELU path in another PR ? could be that I can homogenize the kernels before that

@ClashLuke
Copy link
Contributor Author

It doesn't look like rebasing solved the problem as I merged inside of the graph, so I just squashed the changes into a single commit.

@blefaudeux
Copy link
Contributor

@ClashLuke alright, looks good to me, you would just need to rebase on main now :) (sorry about all that, the doc build problem was fixed yesterday)

@blefaudeux
Copy link
Contributor

Looks good, I'll land as soon as the final test is done, thank you so much @ClashLuke ! Adding you to the contributors in a future README update

@blefaudeux blefaudeux merged commit 23389bd into facebookresearch:main Oct 29, 2021
xwhan pushed a commit to xwhan/xformers that referenced this pull request Feb 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

feat(triton): InplaceNorm + InstanceNorm
4 participants