-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Padding causes forward to produce different logits (Llama2-7b) #29029
Comments
Cc @younesbelkada too :) |
I believe this comment is relevant to this issue: #25420 (comment) |
On point @amyeroberts TLDR it's expected |
Thanks @amyeroberts @ArthurZucker I did a few more experiments based on the issue linked by @amyeroberts KV Cache I turned these off:
etc. but this did not lead to any change. My understanding is that KV caching is disabled by default for FP32 I loaded model weights in fp32. There is still a noticeable difference but the difference is smaller.
FP16 I loaded model weights in fp16. There is still a noticeable difference but the difference is smaller than for bf16 but larger than for fp32.
CPU I ran forward prop on CPU rather than GPU. The difference is now tiny.
Right Padding I changed padding to right padding on CPU. The error is now even smaller but still non-zero:
Thoughts? |
I found this bug too. You can test whether enable sqda or flash attention. When sqda is used, the result seems to be correct. I did not know why this bug happen. |
This was already answered, basically eager attention still attend to padding tokens (because the output of the softmax is never non zero) but with exact implementations / kernels, you have 0 for the padding tokens instead of a very tiny number. See #27050 |
I have done some experiments. If I use the eager attention with sdpa attention mask (version==4.37.2), the results are correct. However, with the eager mode attention mask, the results are wrong. This happens when using left padding for inference. The generated 4d attention mask looks like,
sdpa mode
|
I don't understand what is wrong? |
Please take a look at the discussions above. If left padding is used, the output of the model is wrong. I found that the attention mask can be generated with eager mode and sdpa mode. The difference is that if no element is attended the sdpa mode will set attention mask to zero. If sdpa mode attention mask generation is used, the output of the model is correct. I test with eager attention module and sdpa attention module. I am wondering why this happens. |
@SY-Xuan let's try to be clear when we say:
The reason why Now if you have a snippet with a reproducer, that will help. |
Thanks for your kind reply. I think I made a mistake by using different dtypes. I have fixed this by now. Sorry for the wasting of your time. |
No worries! 🤗 I'll close this as completed |
System Info
transformers
version: 4.36.2Who can help?
@ArthurZucker @yun
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I have noticed that running
forward
on a padded sequence and an unpadded sequence yields (slightly) different logits, even with an attention mask specified.Here is what I ran:
Notice that
inputs_1
is the same ascombined_inputs_1
, except with the left padding omitted and the attention mask altered to match.Upon close inspection, you'll see that this tensor is slightly different to
combined_logits_1
. We can show this more clearly:Is this meaningful? Well, if we look at the probabilities:
That's a pretty non-trivial probability!
Expected behavior
I would expect the attention mask to mask out the left padding, making the two sequences
inputs_1
andcombined_inputs_1
identical during forward prop, which should in turn mean that the logits produced are equivalent.I realise that there may be small errors arising from batched GPU computations, but this error doesn't seem very small...
The text was updated successfully, but these errors were encountered: