Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom attention bias #617

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open

Conversation

b-albar
Copy link

@b-albar b-albar commented Oct 19, 2023

This PR is an attempt to add custom (additive) attention biases. This is still very much a work in progress to say the least but I though to make my code available as there may be a lot of interest for this feature (#332, #342, #534, #606, #661).

For now, there is no gradients for the bias matrix (I'm working on this) and it has been through very limited testing. But the forward pass as well as dq, dv, dk looks ok (tested only on A100). Also there has been no time spent on optimizing stuff.

I'm not familiar working with kernels so any comments or help for this PR would be very much welcome !

Todo

  • Return gradients for the bias matrix
  • Support for linear biases like (1,1,1,N) ?
  • Test extensively (causal, head size, ...)
  • Benchmark the performance compared to vanilla FA
  • Optimize

@jacob-crux
Copy link

An error occurred during an experiment with this code.
It was confirmed that the sequence length (512, 1024, 2048, 4096) used in test_attn_bias.py operates normally. However, when I tested various sequence lengths other than multiples of 512, I found that RuntimeError: CUDA error: misaligned address occurred.
Can't this code be used in inference?

@b-albar
Copy link
Author

b-albar commented Oct 27, 2023

@khj94 Thanks for reporting this. As I said, testing was very limited for now. Correct me if I'm wrong, but I think the error came when the sequence length is not a multiple of 8. For q,k,v it looks that the tensor are padded (here) to avoid this problem. I guess I need do the same for the bias tensor. I'll look into it.

@jacob-crux
Copy link

jacob-crux commented Oct 31, 2023

@b-albar Thank you for responding quickly.
When I checked the newly committed code, I confirmed that forward was operating normally.
However, this time, the same RuntimeError: CUDA error: misaligned address occurs in backwards.
To fix the above error in the backward pass, you need a pad code in the bias tensor like the forward pass.

@b-albar
Copy link
Author

b-albar commented Nov 2, 2023

Actually, the way its implemented for qkv is that the padded tensors are returned in the forward and reinjected through the python interface in the backward. So no need to pad it twice. I added this in the last commit. Also I encountered some nan in dq,dk,dv and had to implement some condition for boundaries on row and columns similar to what was implemented in the alibi patch. But it look like I ran into some race conditions (due to branching ?), and I had to add a syncthread(). It looks ok now (hopefully) but I'll have to investigate more on that later.

@b-albar
Copy link
Author

b-albar commented Nov 3, 2023

I added the gradients for the bias. The gradient is computed only in case of requires_grad in the bias. This avoid to copy a large matrix if it's not needed (like for alibi style bias or masking) as it's one of the main feature of FA to not have to copy this matrix.

@b-albar b-albar marked this pull request as ready for review November 3, 2023 12:51
@b-albar b-albar changed the title Custom attention bias (WIP) Custom attention bias Nov 3, 2023
@jacob-crux
Copy link

jacob-crux commented Nov 8, 2023

@b-albar I want to check the last commit, but an error occurs during the build process.
After checking the code, there is no identifier in the two places below.
Could you please check?

flash-attention/csrc/flash_attn/src/flash_bwd_kernel.h(1618): error: identifier "Is_even_MN" is undefined
                  if (Is_even_MN || get<0>(tdScdS(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
                      ^
flash-attention/csrc/flash_attn/src/flash_bwd_kernel.h(1627): error: identifier "m_block_min" is undefined
              if (m_block > m_block_min) {
                            ^

@LyricZhao LyricZhao mentioned this pull request Nov 8, 2023
4 tasks
@b-albar
Copy link
Author

b-albar commented Nov 8, 2023

@khj94 My bad, should be good now!

@Taytay
Copy link

Taytay commented Feb 8, 2024

Just dropping in as an interested observer @b-albar to say: Thank you for your continued work on this. This PR is going to unlock some serious efficiency gains for T5 models, and I think it will lead to some great stuff. ❤️

* add support for shape (1,1,q,k)
* move bias and ds on the same smem
@c0nn3r
Copy link

c0nn3r commented Feb 15, 2024

Just echoing what others have said - extremely excited for this to land - let me know if there is anything I can do to help!

@c0nn3r
Copy link

c0nn3r commented Feb 18, 2024

I'll make a note here for anyone wanting to give this a try right now that all tests currently are failing.

@b-albar
Copy link
Author

b-albar commented Feb 20, 2024

@c0nn3r Thanks for noticing this. There is clearly a memory error somewhere. It is very odd that it didn't show up in my local version. I usually compile with alibi and some other stuff disabled as it is faster for debugging but this error show up only when all options are enabled ! This is very odd, I'll investigate what's going on. In the mean time specifying "-DFLASHATTENTION_DISABLE_ALIBI" for compilation seems to do the trick (at least for me) as a temporary workaround.

@c0nn3r
Copy link

c0nn3r commented Feb 21, 2024

Thanks for looking into it.

Hm, I just compiled it (7fe76b6) and I'm still seeing RuntimeError: CUDA error: misaligned address

If I just run: pytest -q -s tests/test_attn_bias.py then 18 tests pass, 270 fail. Then I see they are all failing with: RuntimeError: CUDA error: an illegal memory access was encountered

I also attempted to bisect the issue, but I've found it appears even before the refactor (and it takes a while to build even on 40 CPU cores).

@b-albar
Copy link
Author

b-albar commented Feb 21, 2024

I definitively didn't have a misaligned adress error previously. Could you post the pytest logs here ? I'd like to see which seqlen are failing. Also which GPU and cuda version do you use ?

@b-albar
Copy link
Author

b-albar commented Feb 21, 2024

Ok after some experiments, I think we are dealing with a compiler issue here. That would explain why it work after disabling some options and why I didn't found it before. It look like the kernel is quite big and the compiler doesn't like it too much (just compiling the kernel with debug infos crash the compilation). I was using CUDA 12.3 update 1 previously and I could reproduce the memory error even while disabling alibi. But disabling alibi, uneven_k and local worked fine without any change in the code ! Updating CUDA to 12.3 update 2 works better (the changelog mention some register spilling issue and potential miscompilation - I don't know if it could be this) and now just only disabling alibi works again. I'll try to recompile the code with all the options to see. For now, I'll disable all options on this branch until it is fully clear what happens, and at least it will speed up the compilation.

@c0nn3r did you use CUDA 12.3.1 to compile ?

@c0nn3r
Copy link

c0nn3r commented Feb 21, 2024

@b-albar I used CUDA 12.1 to compile. Let me update to CUDA 12.3 and try again.

@c0nn3r
Copy link

c0nn3r commented Feb 21, 2024

I get the same result as before using:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0

I don't think it is the compiler - both the flash attention and attention bias tests fail. My GPUs are all RTX A5000s

@wehos
Copy link

wehos commented Feb 29, 2024

Hi, there! I am willing to help with the testing.


RTX 4090
I tried on 4090 (built with CUDA12.1 and your latest branch). It ends up with 111 failed, 177 passed in 46.93s, from two errors RuntimeError: CUDA error: invalid argument and RuntimeError: FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800. With a previous commit, it ends up with 216 failed, 72 passed in 75.39s. Errors are still RuntimeError: CUDA error: invalid argument.

When I tried it with my app, it continues to throw RuntimeError: CUDA error: invalid argument.

More Specifically:

File "/opt/conda/lib/python3.10/site-packages/flash_attn-2.3.6-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 51, in _flash_attn_forward
   out, q, k, v, out_padded, attn_bias, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: CUDA error: invalid argument

Here is some warning happened during compiling which might be relevant (sorry I have no idea):

csrc/flash_attn/flash_api.cpp:52:11: warning: ‘void* memset(void*, int, size_t)’ clearing an object of non-trivial type ‘struct Flash_fwd_params’; use assignment or value-initialization instead [-Wclass-memaccess]
   52 |     memset(&params, 0, sizeof(params));
      |     ~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from csrc/flash_attn/flash_api.cpp:13:
/flash-attention/csrc/flash_attn/src/flash.h:51:8: note: ‘struct Flash_fwd_params’ declared here
   51 | struct Flash_fwd_params : public Qkv_params {

A100 PCIE
On A100 PCIE 40G (built with CUDA 12.1, both latest branch and this commit), it ends up with 3 failed, 285 passed. All fails are AssertionErrors.

However, when I deployed the model in my application (with latest branch), it fails with RuntimeError: CUDA error: an illegal memory access was encountered. It seems clear to me that the issue happens in continuous inference. It does not happen in the first few inference steps but is raised in the middle.

@sentialx
Copy link

sentialx commented Mar 6, 2024

Finally someone did it! Does it also support negative bias (i.e mask)?

@jlamprou
Copy link

jlamprou commented Mar 9, 2024

I am currently testing this, i have implemented it for switchtransformers, and an NVIDIA A100 40GB. Training at the MLM task with an LR=1.5e-4. Loss scaling is very unstable. Lowering the LR and increasing the warmup_ratio seems to help but at some point something explodes and the destroys loss scaling.
W B Chart 3_9_2024, 5 28 38 PM

@b-albar
Copy link
Author

b-albar commented Mar 15, 2024

Thank you all for testing.
I've fixed an annoying race condition in the latest commit that was appearing in the backward pass when using fp16 (apparently it was fine with bf16) . @jlamprou this may fix the unstabilities you observed.

For the CUDA error: an illegal memory access was encountered, I'm still puzzled by this. Memcheck point out this part of the code. From printing the memory addresses, the pointer of gdK is correct but the memory address of the partition (tdKgdK) is off. For now, I cannot make sense of this.

@wehos Illegal argument is a classical error when you ask for a shared memory size that is bigger than the one supported by the GPU. This is to be expected if you test on a 4090.

@sentialx Yes it should work for masking by using a large negative bias

@jlamprou
Copy link

@b-albar Sorry i didn't clarify, im using bf16, but the Switch architecture is tricky too with mixed-precision (they use selective mixed precision so some layers are bf16 and others are fp32), tried it again with the new commit on a subset of starcoder data(7B tokens) and reached 0.17 loss on epoch 0.32 which is a bit ridiculous on an MLM task

@b-albar
Copy link
Author

b-albar commented Mar 18, 2024

@b-albar Sorry i didn't clarify, im using bf16, but the Switch architecture is tricky too with mixed-precision (they use selective mixed precision so some layers are bf16 and others are fp32), tried it again with the new commit on a subset of starcoder data(7B tokens) and reached 0.17 loss on epoch 0.32 which is a bit ridiculous on an MLM task

Ok I think I see, you mean the loss going (too) quickly to zero ? In this case, maybe it's a causality issue. In this case it's better to force causality in the decoder using causal=True and not using the bias matrix for this. It look like the SwitchTransformers architecture is similar to T5, I'm working on a T5 with flash-attention, maybe it can help. The code is available here.

@jlamprou
Copy link

jlamprou commented Mar 18, 2024

@b-albar Sorry i didn't clarify, im using bf16, but the Switch architecture is tricky too with mixed-precision (they use selective mixed precision so some layers are bf16 and others are fp32), tried it again with the new commit on a subset of starcoder data(7B tokens) and reached 0.17 loss on epoch 0.32 which is a bit ridiculous on an MLM task

Ok I think I see, you mean the loss going (too) quickly to zero ? In this case, maybe it's a causality issue. In this case it's better to force causality in the decoder using causal=True and not using the bias matrix for this. It look like the SwitchTransformers architecture is similar to T5, I'm working on a T5 with flash-attention, maybe it can help. The code is available here.

Thanks a lot, i was actually looking for T5-like implementation for Flash, Switch are basically T5-MOE. I've also replaced the T5 RPE, with FIRE for bias which gave me better performance on different lengths.

@FSSRepo
Copy link

FSSRepo commented Mar 21, 2024

@b-albar Could this work with infinite negatives (custom attention mask), It seems like I have to reshape the array to (batch_size, num_heads, seq_len, seq_len). It would be good if broadcasting is performed along the num_heads dimension?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants