Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Padding causes forward to produce different logits (Llama2-7b) #29029

Closed
2 of 4 tasks
c3ianwu opened this issue Feb 15, 2024 · 12 comments
Closed
2 of 4 tasks

Padding causes forward to produce different logits (Llama2-7b) #29029

c3ianwu opened this issue Feb 15, 2024 · 12 comments

Comments

@c3ianwu
Copy link

c3ianwu commented Feb 15, 2024

System Info

  • transformers version: 4.36.2
  • Platform: Linux-5.15.107+-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • Huggingface_hub version: 0.20.1
  • Safetensors version: 0.4.1
  • Accelerate version: 0.22.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2+cu118 (True)

Who can help?

@ArthurZucker @yun

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I have noticed that running forward on a padded sequence and an unpadded sequence yields (slightly) different logits, even with an attention mask specified.

Here is what I ran:

from transformers import AutoTokenizer, AutoModelForCausalLM
from torch import tensor
import torch

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16)
model.to(2)
model.eval()
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

prompt_1 = "I am so so confused."
prompt_2 = "I have never been more lost in my life!"
combined_prompt = [prompt_1, prompt_2]

combined_inputs = tokenizer(combined_prompt, padding=True, return_tensors="pt").to(2) # batch size 2
combined_inputs
>>> {'input_ids': tensor([[    2,     2,     2,     2,     1,   306,   626,   577,   577,  9613,
         29889], [    1,   306,   505,  2360,  1063,   901,  5714,   297,   590,  2834,
         29991]], device='cuda:2'), 'attention_mask': tensor([[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:2')}
combined_inputs_1 = {"input_ids": combined_inputs["input_ids"][0].unsqueeze(0), "attention_mask": combined_inputs["attention_mask"][0].unsqueeze(0)} # extracting just the first item in the batch
combined_inputs_1
>>> {'input_ids': tensor([[    2,     2,     2,     2,     1,   306,   626,   577,   577,  9613,
          29889]], device='cuda:2'),
 'attention_mask': tensor([[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]], device='cuda:2')}
# running forward prop and then visualising the last 7 logits
with torch.no_grad():
    combined_outputs_1 = model(**combined_inputs_1)
combined_logits_1 = combined_outputs_1.logits[0, 4:, :]
combined_logits_1
>>> tensor([[-12.5000,  -7.0625,  -0.6406,  ...,  -6.6250,  -7.9062,  -7.2812],
        [ -9.5000, -12.1875,  -1.1172,  ...,  -5.0625,  -8.9375,  -3.6250],
        [ -7.0312,  -4.4688,   2.1875,  ...,  -1.8438,  -5.6562,  -1.8984],
        ...,
        [ -6.9375,  -7.4062,   4.3438,  ...,  -2.8594,  -3.1875,  -3.1875],
        [ -2.4219,  -2.0000,  11.0625,  ...,  -0.6914,  -0.1133,  -1.4141],
        [-11.8750, -10.8750,   8.3750,  ...,  -4.8125,  -4.3750,  -3.6094]],
       device='cuda:2')
inputs_1 = tokenizer(prompt_1, padding=True, return_tensors="pt").to(2) # batch size 1
inputs_1
>>> {'input_ids': tensor([[    1,   306,   626,   577,   577,  9613, 29889]], device='cuda:2'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]], device='cuda:2')}

Notice that inputs_1 is the same as combined_inputs_1, except with the left padding omitted and the attention mask altered to match.

# running forward prop again
with torch.no_grad():
    outputs_1 = model(**inputs_1)
    logits_1 = outputs_1.logits[0, :, :]
logits_1
>>> tensor([[-12.5000,  -7.0625,  -0.6406,  ...,  -6.6250,  -7.9062,  -7.2812],
        [ -9.5000, -12.1875,  -1.1016,  ...,  -5.0312,  -8.9375,  -3.6250],
        [ -7.0625,  -4.4375,   2.2188,  ...,  -1.8750,  -5.7188,  -1.9219],
        ...,
        [ -6.9062,  -7.3125,   4.3438,  ...,  -2.8594,  -3.1875,  -3.1406],
        [ -2.4219,  -2.0000,  11.0625,  ...,  -0.6680,  -0.1445,  -1.4062],
        [-11.8125, -10.8125,   8.3750,  ...,  -4.7812,  -4.3125,  -3.5938]],
       device='cuda:2')

Upon close inspection, you'll see that this tensor is slightly different to combined_logits_1. We can show this more clearly:

torch.sum(torch.abs(logits_1 - combined_logits_1))
>>> tensor(3722.9448, device='cuda:2')

Is this meaningful? Well, if we look at the probabilities:

torch.max(torch.abs(torch.nn.Softmax()(logits_1) - torch.nn.Softmax()(combined_logits_1)))
>>> tensor(0.0053, device='cuda:2')

That's a pretty non-trivial probability!

Expected behavior

I would expect the attention mask to mask out the left padding, making the two sequences inputs_1 and combined_inputs_1 identical during forward prop, which should in turn mean that the logits produced are equivalent.

I realise that there may be small errors arising from batched GPU computations, but this error doesn't seem very small...

@amyeroberts
Copy link
Collaborator

Cc @younesbelkada too :)

@amyeroberts
Copy link
Collaborator

I believe this comment is relevant to this issue: #25420 (comment)

@ArthurZucker
Copy link
Collaborator

On point @amyeroberts TLDR it's expected

@c3ianwu
Copy link
Author

c3ianwu commented Feb 15, 2024

Thanks @amyeroberts @ArthurZucker

I did a few more experiments based on the issue linked by @amyeroberts

KV Cache

I turned these off:

with torch.no_grad():
    combined_outputs_1 = model(**combined_inputs_1, use_cache=False)

etc. but this did not lead to any change. My understanding is that KV caching is disabled by default for forward so I'm not surprised.

FP32

I loaded model weights in fp32. There is still a noticeable difference but the difference is smaller.

torch.sum(torch.abs(logits_1 - combined_logits_1))
>>> tensor(169.4670, device='cuda:2')
torch.max(torch.abs(torch.nn.Softmax()(logits_1) - torch.nn.Softmax()(combined_logits_1)))
>>> tensor(0.0002, device='cuda:2')

FP16

I loaded model weights in fp16. There is still a noticeable difference but the difference is smaller than for bf16 but larger than for fp32.

torch.sum(torch.abs(logits_1 - combined_logits_1))
>>> tensor(510.8704, device='cuda:2')
torch.max(torch.abs(torch.nn.Softmax()(logits_1) - torch.nn.Softmax()(combined_logits_1)))
>>> tensor(0.0006, device='cuda:2')

CPU

I ran forward prop on CPU rather than GPU. The difference is now tiny.

torch.sum(torch.abs(logits_1 - combined_logits_1))
>>> tensor(0.3935)
torch.max(torch.abs(torch.nn.Softmax()(logits_1) - torch.nn.Softmax()(combined_logits_1)))
>>> tensor(6.5565e-07)

Right Padding

I changed padding to right padding on CPU. The error is now even smaller but still non-zero:

torch.sum(torch.abs(logits_1 - combined_logits_1))
>>> tensor(0.2899)
torch.max(torch.abs(torch.nn.Softmax()(logits_1) - torch.nn.Softmax()(combined_logits_1)))
>>> tensor(1.1325e-06)

Thoughts?

@SY-Xuan
Copy link

SY-Xuan commented Feb 17, 2024

I found this bug too. You can test whether enable sqda or flash attention. When sqda is used, the result seems to be correct. I did not know why this bug happen.

@ArthurZucker
Copy link
Collaborator

This was already answered, basically eager attention still attend to padding tokens (because the output of the softmax is never non zero) but with exact implementations / kernels, you have 0 for the padding tokens instead of a very tiny number. See #27050

@SY-Xuan
Copy link

SY-Xuan commented Feb 18, 2024

I have done some experiments. If I use the eager attention with sdpa attention mask (version==4.37.2), the results are correct. However, with the eager mode attention mask, the results are wrong. This happens when using left padding for inference.

The generated 4d attention mask looks like,
eager mode

[[[[     0., -65504., -65504.,  ..., -65504., -65504., -65504.],
          [     0.,      0., -65504.,  ..., -65504., -65504., -65504.],
          [     0.,      0.,      0.,  ..., -65504., -65504., -65504.],
          ...,
          [     0.,      0.,      0.,  ...,      0., -65504., -65504.],
          [     0.,      0.,      0.,  ...,      0.,      0., -65504.],
          [     0.,      0.,      0.,  ...,      0.,      0.,      0.]]],
        [[[-65504., -65504., -65504.,  ..., -65504., -65504., -65504.],
          [-65504.,      0., -65504.,  ..., -65504., -65504., -65504.],
          [-65504.,      0.,      0.,  ..., -65504., -65504., -65504.],
          ...,
          [-65504.,      0.,      0.,  ...,      0., -65504., -65504.],
          [-65504.,      0.,      0.,  ...,      0.,      0., -65504.],
          [-65504.,      0.,      0.,  ...,      0.,      0.,      0.]]]]

sdpa mode

tensor([[[[     0., -65504., -65504.,  ..., -65504., -65504., -65504.],
          [     0.,      0., -65504.,  ..., -65504., -65504., -65504.],
          [     0.,      0.,      0.,  ..., -65504., -65504., -65504.],
          ...,
          [     0.,      0.,      0.,  ...,      0., -65504., -65504.],
          [     0.,      0.,      0.,  ...,      0.,      0., -65504.],
          [     0.,      0.,      0.,  ...,      0.,      0.,      0.]]],


        [[[0, 0, 0,  ..., 0, 0, 0],
          [-65504.,      0., -65504.,  ..., -65504., -65504., -65504.],
          [-65504.,      0.,      0.,  ..., -65504., -65504., -65504.],
          ...,
          [-65504.,      0.,      0.,  ...,      0., -65504., -65504.],
          [-65504.,      0.,      0.,  ...,      0.,      0., -65504.],
          [-65504.,      0.,      0.,  ...,      0.,      0.,      0.]]]],

@ArthurZucker
Copy link
Collaborator

I don't understand what is wrong?

@SY-Xuan
Copy link

SY-Xuan commented Feb 19, 2024

Please take a look at the discussions above. If left padding is used, the output of the model is wrong. I found that the attention mask can be generated with eager mode and sdpa mode. The difference is that if no element is attended the sdpa mode will set attention mask to zero. If sdpa mode attention mask generation is used, the output of the model is correct. I test with eager attention module and sdpa attention module. I am wondering why this happens.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Feb 20, 2024

@SY-Xuan let's try to be clear when we say:

  • results are correct / wrong: what is wrong for you? You did not share generation, nor did you provide a snippet. Should I assume you are talking about @c3ianwu's results? Do you have the same setup as he does?
  • I used x and y: there are many different combination, sdpa attention, eager attention etc. Providing a small snippet of what you tests will help us understand what you mean by if no element is attended the sdpa mode will set attention mask to zero. If sdpa mode attention mask generation is used, the output of the model is correct.
  • outputs are wrong: are you talking about logits? About generation? Are you doing greedy decoding ? Sampling? etc etc

The reason why sdpa uses 0 attention is because there is a bug with sdpa that does not support un-attended lines. 0 in the causal mask means that it will be attended.

Now if you have a snippet with a reproducer, that will help.

@SY-Xuan
Copy link

SY-Xuan commented Feb 22, 2024

Thanks for your kind reply. I think I made a mistake by using different dtypes. I have fixed this by now. Sorry for the wasting of your time.

@ArthurZucker
Copy link
Collaborator

No worries! 🤗 I'll close this as completed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants