Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow BFloat16 for blocksparse #354

Closed

Conversation

SeanNaren
Copy link
Contributor

@SeanNaren SeanNaren commented Jul 21, 2022

What does this PR do?

Fixes #353

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 21, 2022
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I'm ok with the direction and would be fine merging the PR as is, but what do you think if we instead don't perform any casting whatsoever in the code? I understand that Triton won't work with fp32 (or at least it won't be efficient), but should we implicitly perform this cast here, instead of raising an error?

cc @blefaudeux for thoughts

@blefaudeux
Copy link
Contributor

Thanks for the PR!

I'm ok with the direction and would be fine merging the PR as is, but what do you think if we instead don't perform any casting whatsoever in the code? I understand that Triton won't work with fp32 (or at least it won't be efficient), but should we implicitly perform this cast here, instead of raising an error?

cc @blefaudeux for thoughts

thanks for the PR @SeanNaren to begin with ! The cast (ours) is not super nice to begin with, it's mentioned in the doc in that originally blocksparse would crash on fp32, so we officially just supported fp16 and added that. I think that the fp32 case is fixed by now, and this nuked bf16 altogether which is just nicer than fp16 most of the time, so we should probably revisit that ! I can submit a PR on your branch @SeanNaren to try to remove the cast altogether, and adjust the unit tests to make sure that we cover parity testing for all types, what do you think ?

@blefaudeux
Copy link
Contributor

I've pushed a small addition to https://github.com/blefaudeux/xformers/tree/fix/blocksparse_bfloat16 (can be PRd onto your branch @SeanNaren ), the issue right now is that some parity tests are not passing anymore and I didn´t have time to investigate. I think that it could be related to different defaults nowadays with respect to the accumulators, triton defaults to TF32 and I think that pytorch reverted that to fp32

@SeanNaren
Copy link
Contributor Author

Thanks @fmassa @blefaudeux! getting rid of the check would be the optimal solution, @blefaudeux you should be able to push directly to my branch, feel free to do so!

@blefaudeux blefaudeux force-pushed the fix/blocksparse_bfloat16 branch from 1a7e082 to 14dd397 Compare July 30, 2022 13:05
@blefaudeux
Copy link
Contributor

I just pushed @SeanNaren, underwater these days with a new job, really sorry for the delay ! I just had to adapt a few tests (some pytorch codepaths are not bfl16 compatible), but nothing super big. I took the liberty to rebase your branch also, to remove a merge conflict, you would have to git reset --hard origin/yourbranch if pulling on your end

@blefaudeux blefaudeux force-pushed the fix/blocksparse_bfloat16 branch from 14dd397 to 54de7bd Compare July 30, 2022 13:07
@SeanNaren
Copy link
Contributor Author

I just pushed @SeanNaren, underwater these days with a new job, really sorry for the delay ! I just had to adapt a few tests (some pytorch codepaths are not bfl16 compatible), but nothing super big. I took the liberty to rebase your branch also, to remove a merge conflict, you would have to git reset --hard origin/yourbranch if pulling on your end

Appreciate it man! Without a dev machine (for now) so can't make progress here, so appreciate the assist!

@blefaudeux
Copy link
Contributor

Seems good except for two tests where the results are markedly different in between triton and torch, not too nice.. I'll try to update triton pip package, could be that this was fixed in betwee

@blefaudeux
Copy link
Contributor

I tried a couple of obvious takes, to no avail, the TL: DR is that parity with pytorch does not seem to be respected with the triton operator with bf16, with or without the tf32 accumulation and with or without the triton.testing tools. I'm not sure how much of this is xformers business, since this part is just a helper to use Triton's blocksparse attention, but since it's exposed here it makes sense that parity is tested. @ptillet sorry for the ping, but is that expected ? Are there some specific pip packages which we should stick to ?

fmassa pushed a commit that referenced this pull request Aug 10, 2022
* Add support for f16 with tensorcores

* sm75 minimum for tensorcores

* Run tests with CUDA_LAUNCH_BLOCKING=1

* Support sm70 properly

* Disable tensorcore when not correctly aligned - and use 32bit accessors

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
fmassa added a commit that referenced this pull request Aug 25, 2022
* Enable masking in memory-efficient attention (#333)

* Add attention bias in memory-efficient attention

* Add gradient for attn_mask support

* Add CPU implementation

* clang-format

* Add benchmark scripts

* Add extra loop in benchmarks

* Move zeros array out of helper function

* clang-format

* Enable dropout in memory-efficient attention (#334)

* Merge compute_scaling_coeffs and update_scaling_coeffs into a single function

It wasn't needed to break it in two functions to begin with

* Add CUDA implementation for dropout

* clang-format

* Make p be drop probability

* Only CUDA supports dropout

* Add benchmarks

* Remove unused variables

* Fix test

* Cleanups and comments

* Fix masking corner case when full block is masked (#339)

* Add cutlass 2.9 - 858c735856a7f17bd33fe438ec76d3c9f0234e7f

* Option to load from shared memory for PredicatedTileIterator

* Add cutlass include dir

* Ignore files in third-party for flake8/coverage

* third-party -> third_party

* Address comments

* Revert some un-needed mods

* Add attention_forward_generic.cu

* Add tests

* Fix duplicate calculations on baseline for mem efficient transformers

* Always run all linters in CI

* clang-format attention_forward_generic.cu

* Benchmark: Add possibility to compare benchmarks

* [isort] Ignore third_party

* black autoformat

* Black again + ignore third_party properly

* black

* Fix memory leak between the 2 benchmarks in backward

* Exclude third_party/ without using pyproject.toml as it imposes isolated build which is a pain

* Remove progress bar when finished

* mypy

* flake8

* Save results to shared folder in home location

* run black

* clang-format with 'run-clang-format.py'

* Fix cutlass build for arch>=75

* Set tests precision for gradient more accurately

* Fix precision margin

* Revert changes to black

* [feat] Fix importing xformers when not built (#351)

authored-by: danthe3rd <danthe3rd@users.noreply.github.com>

* Update black to 22.3.0

* Tweak precision for mem_eff_attention test

* mem-efficient impl for f16 (#352)

Co-authored-by: danthe3rd <danthe3rd>

* Add support for f16 with tensorcores [sm70/sm75/sm80] (#354)

* Add support for f16 with tensorcores

* sm75 minimum for tensorcores

* Run tests with CUDA_LAUNCH_BLOCKING=1

* Support sm70 properly

* Disable tensorcore when not correctly aligned - and use 32bit accessors

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>

* Optimize backward of memory-efficient attention by ~20% (#355)

* Optimize backward by 15% by using equivalent formulation

* Unify everything into single kernel

* Remove unused implementation

* clang-format

* Remove unused tensor

* Display results as we progress during benchmark (#357)

Co-authored-by: danthe3rd <danthe3rd>

* RFC: Ops dispatch (#356)

* Ops dispatch

* CI: Fix doc build

* memory_efficient_attention raises when no implementation is available

* type: ignore

* Fix torch.device/str comparison

* Make mypy happy

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: danthe3rd <danthe3rd>

* [A100/f32] Use TensorCores for Q.K_t matmul with FastF32 (#358)

* Use TensorCores for MM0 on Float as well

* Use MultiStage MMA when available - change to FastF32 rather than FastF16

* Better alignment calculation

* Just use regular f32, no fastf32

* Hackfix to handle alignment

* HeuristicsMM0 -> GemmTypeQK

* No longer use f16 for matmul

* Add some doc

* Typo

* Fix build <sm80

* Alignment check based on current device compute capability

* Use TORCH_INTERNAL_ASSERT

Co-authored-by: danthe3rd <danthe3rd>

* FlashAttention implem and dispatch (#360)

* FlashAttention implem WIP

* Fix flashattention forward+backward

* Fix forward/backward for FlashAttention

* Enable tests (more permissive) for f16 backward

* Fix CI

* flashattn only supports Sm75 and above

* Fix CI2

* Disable K=128 when below sm80 for flashattn

Co-authored-by: danthe3rd <danthe3rd>

* Misc performance improvements for generic mem-efficient attention (#361)

* 3% speedup by calculating mi from registers

* Also compute m_prime/s_prime and exponentiate from registers

* Support for Simt tiles

* Fix TensorOp for V100

* Fix for A100

* Fix Simt alignment calculation

* clang-format

* WarpReduction before atomic call for Simt

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>

* Update flashattention to support bf16 (#363)

* Update flashattention to support bf16

* bfloat16 only on sm80 and above

Co-authored-by: danthe3rd <danthe3rd>

* Flashattn causal (#364)

* Implement causal memory-efficient attention with FlashAttention

* Update benchmarks

* Fix mypy

Co-authored-by: danthe3rd <danthe3rd>

* Option to disable flashattention (long to build) (#362)

* Option to disable flashattention (long to build)

* Update setup.py

Co-authored-by: danthe3rd <danthe3rd>

* Remove code duplicate in attention_scaling_coefs_updater.h (#367)

Co-authored-by: danthe3rd <danthe3rd>

* Update .gitmodules (#366)

* MemoryEff attention forward: Properly fuse matmul and enable TensorCores on the second matmul (#368)

* Generic backwards

* Guard backward to sm75 only

* bounds checking for gradV

* clang-format

* Fused gemm working for Sm80/Sm75 f16/f32

* WIP

* Volta TensorOp for f16

* Working on A100 again

* SIMT working

* Code cleanup 1

* Code cleanup2

* BUGFIX for shared memory limit

* Remove code

* clang-format

* Remove code again

* Remove draft of backward

* Enforce alignment for fp16

* Fix tests

* Fix constraint on seq length when not using tensorcores

* Fix alignment requirements for V100/tensorcores

* Clang-format

* Update xformers/components/attention/csrc/cuda/attention_forward_generic.cu

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>

* Address comments from fmassa

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: Francisco Massa <fvsmassa@gmail.com>

* Update install instructions with submodule (#365)

* Generic backward implem with cutlass (#371)

* Old bw code

* P100: gradV working

* gk/gq working (at least for small values of M, and on P100/f16)

* Further restrict supported values for bw

* Fix storage into smem for Simt

* More tooling for pruint/debug

* Remove tests we dont need for now

* Tests pass on P100 :D

* 4 warps per block

* Restraint on q length

* Use tensorcores on V100 for f16

* Support dynamic smem for bw

* Handle alignment and different dtype/arch

* Fix NaNS by initializing shared memory

* bw.py

* Fix launch bounds

* Faster 'computeDi'

* minus_lse can operate on arrays

* Output number of regs used etc...

* Code cleanup

* Hackfix for alignment check during forward

* zFill to avoid nans in Sm80 + fix launch bounds

* COde cleanup1

* clang-format

* Fix tests

* Add benchmark for K=64

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: danthe3rd <danthe3rd>

* Cutlass as submodule (#375)

* Make cutlass be back at 858c735856a7f17bd33fe438ec76d3c9f0234e7f

* Remove cutlass

* Update submodules

* Add submodule (properly)

* spaces / tab

* Make submodule init be recursive

* Fix bad rebase

* Bump tolerance for backward (#377)

* Add verbose flag to CI builds (#376)

* Add verbose flag to CI builds

* Spurious change to rebuild cache

* Add ninja

* Ninja wasn't visible before, install through conda

* Debugging

* Source env

* One more try

* Forgot to uncomment a line

* Another try

* Cleanup

* Fix for FlashAttention dispatch

It requires device capability >= 7.5

* Remove generated file

* Address some reviewer feedback

Remove unused function and typo fix

* Perf improvement on backward (#378)

* Fast again on V100

* Fix correctness - missing syncthreads

* Get rid of AttentionInfo

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com>
@blefaudeux
Copy link
Contributor

fixed by #528 ! Thanks @SeanNaren for putting this one up, and sorry that it was so slow to close the case, it required upgrading all the triton code here and I was a bit too busy for that..

@blefaudeux blefaudeux closed this Nov 24, 2022
@SeanNaren SeanNaren deleted the fix/blocksparse_bfloat16 branch November 24, 2022 23:00
bertmaher pushed a commit to bertmaher/xformers that referenced this pull request Dec 20, 2024
…ch#354)

* Add support for f16 with tensorcores

* sm75 minimum for tensorcores

* Run tests with CUDA_LAUNCH_BLOCKING=1

* Support sm70 properly

* Disable tensorcore when not correctly aligned - and use 32bit accessors

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Blocksparse forced to use FP16 when BF16 is supported (for triton v2)
4 participants