Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Decoding of LMs will cause different outputs with different batch size #25921

Open
2 of 4 tasks
wenhuchen opened this issue Sep 2, 2023 · 7 comments
Open
2 of 4 tasks
Assignees
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@wenhuchen
Copy link

System Info

Transformers=4.31
Torch=2.01
Cuda=11.8
Python=3.10

A100 GPU 80GB

Who can help?

@ArthurZucker , @younesbelkada , @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Running the following examples will produce different outputs for the first input.

from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import GenerationConfig
import torch

if __name__ == '__main__':
    name = 'yahma/llama-7b-hf'
    tokenizer = LlamaTokenizer.from_pretrained(
        name, 
        padding_side="left", 
        trust_remote_code=True)
    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id

    model = LlamaForCausalLM.from_pretrained(
        name, 
        device_map="auto", 
        torch_dtype=torch.bfloat16,
        trust_remote_code=True)

    question = [
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        'Where am I supposed to eat dinner',
        'How hard is it to find a doctor in Canada',
        'What is the best price of vegatables',
        'How can somehow be so mean',
        'How can we get the severance pay',
        'What type of president is this?'
        'How is the weather today?'
    ]

    batch = tokenizer(
        question,
        padding=True,
        return_tensors="pt",
    )
    with torch.no_grad():
        output_ids = model.generate(
            batch.input_ids.to(model.device),
            attention_mask=batch.attention_mask.to(model.device),
            pad_token_id=tokenizer.pad_token_id,
            generation_config=GenerationConfig(do_sample=False, max_new_tokens=50, trust_remote_code=True)
        )

    output_strs = []
    for output_id in output_ids.tolist()[:4]:
        tmp = tokenizer.decode(output_id[batch.input_ids.shape[-1]:], skip_special_tokens=True)
        output_strs.append(tmp)
        print(tmp)
        print('----------------------------------------------------')


    print('############### Now we decrease the batch size #############################')

    question = [
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        'Where am I supposed to eat dinner',
        'How hard is it to find a doctor in Canada',
        'What is the best price of vegatables',
    ]

    batch = tokenizer(
        question,
        padding=True,
        return_tensors="pt",
    )
    with torch.no_grad():
        output_ids = model.generate(
            batch.input_ids.to(model.device),
            attention_mask=batch.attention_mask.to(model.device),
            pad_token_id=tokenizer.pad_token_id,
            generation_config=GenerationConfig(do_sample=False, max_new_tokens=50, trust_remote_code=True)
        )

    output_strs = []
    for output_id in output_ids.tolist():
        tmp = tokenizer.decode(output_id[batch.input_ids.shape[-1]:], skip_special_tokens=True)
        output_strs.append(tmp)
        print(tmp)
        print('----------------------------------------------------')

Expected behavior

The produced outputs are supposed to be the same and should not be affected by the batch size.

@da03
Copy link

da03 commented Sep 3, 2023

I can confirm that it's not due to left padding, since even with same-length inputs in the batch, the same issue persists:

from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import GenerationConfig
import torch

if __name__ == '__main__':
    name = 'yahma/llama-7b-hf'
    tokenizer = LlamaTokenizer.from_pretrained(
        name, 
        padding_side="left", 
        trust_remote_code=True)
    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id

    model = LlamaForCausalLM.from_pretrained(
        name, 
        device_map="auto", 
        torch_dtype=torch.bfloat16,
        trust_remote_code=True)

    question = [
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        #'Where am I supposed to eat dinner',
        #'How hard is it to find a doctor in Canada',
        #'What is the best price of vegatables',
        #'How can somehow be so mean',
        #'How can we get the severance pay',
        #'What type of president is this?'
        #'How is the weather today?'
    ]

    batch = tokenizer(
        question,
        padding=True,
        return_tensors="pt",
    )
    with torch.no_grad():
        output_ids = model.generate(
            batch.input_ids.to(model.device),
            attention_mask=batch.attention_mask.to(model.device),
            pad_token_id=tokenizer.pad_token_id,
            generation_config=GenerationConfig(do_sample=False, max_new_tokens=50, trust_remote_code=True)
        )

    output_strs = []
    for output_id in output_ids.tolist()[:4]:
        tmp = tokenizer.decode(output_id[batch.input_ids.shape[-1]:], skip_special_tokens=True)
        output_strs.append(tmp)
        print(tmp)
        print('----------------------------------------------------')


    print('############### Now we decrease the batch size #############################')

    question = [
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        'Can you explain to me what is the concept of deep learning and how it can be applied to NLP?',
        #'Where am I supposed to eat dinner',
        #'How hard is it to find a doctor in Canada',
        #'What is the best price of vegatables',
    ]

    batch = tokenizer(
        question,
        padding=True,
        return_tensors="pt",
    )
    with torch.no_grad():
        output_ids = model.generate(
            batch.input_ids.to(model.device),
            attention_mask=batch.attention_mask.to(model.device),
            pad_token_id=tokenizer.pad_token_id,
            generation_config=GenerationConfig(do_sample=False, max_new_tokens=50, trust_remote_code=True)
        )

    output_strs = []
    for output_id in output_ids.tolist():
        tmp = tokenizer.decode(output_id[batch.input_ids.shape[-1]:], skip_special_tokens=True)
        output_strs.append(tmp)
        print(tmp)
        print('----------------------------------------------------')

The output I got is:

Deep learning is a machine learning technique that uses multiple layers of artificial neural networks to learn representations of data.
Deep learning is a machine learning technique that uses multiple layers of artificial neural networks to learn representations of data. The idea is that
----------------------------------------------------

Deep learning is a machine learning technique that uses multiple layers of artificial neural networks to learn representations of data.
Deep learning is a machine learning technique that uses multiple layers of artificial neural networks to learn representations of data. The idea is that
----------------------------------------------------

Deep learning is a machine learning technique that uses multiple layers of artificial neural networks to learn representations of data.
Deep learning is a machine learning technique that uses multiple layers of artificial neural networks to learn representations of data. The idea is that
----------------------------------------------------

Deep learning is a machine learning technique that uses multiple layers of artificial neural networks to learn representations of data.
Deep learning is a machine learning technique that uses multiple layers of artificial neural networks to learn representations of data. The idea is that
----------------------------------------------------
############### Now we decrease the batch size #############################

Deep learning is a machine learning technique that is based on the idea of neural networks. Neural networks are a type of machine learning algorithm that is inspired by the human brain. The human brain is a very complex system that is able to learn
----------------------------------------------------

Deep learning is a machine learning technique that is based on the idea of neural networks. Neural networks are a type of machine learning algorithm that is inspired by the human brain. The human brain is a very complex system that is able to learn
----------------------------------------------------

Deep learning is a machine learning technique that is based on the idea of neural networks. Neural networks are a type of machine learning algorithm that is inspired by the human brain. The human brain is a very complex system that is able to learn
----------------------------------------------------

Deep learning is a machine learning technique that is based on the idea of neural networks. Neural networks are a type of machine learning algorithm that is inspired by the human brain. The human brain is a very complex system that is able to learn
----------------------------------------------------

@csarron
Copy link
Contributor

csarron commented Sep 3, 2023

In my environment, even the same examples in a single batch sometimes give different outputs for bfloat models. I'm not totally sure yet, but I suspect the issue is that the precision conversion is non-deterministic, see RMSNorm. When a bfloat16 number is converted to fp32 format, the fraction part of the converted fp32 number might not be the same. Same for the softmax operation. There might be other places where the precision conversion happens.

FYI, this might also be related to #25420

@gante
Copy link
Member

gante commented Sep 5, 2023

Hi @wenhuchen @da03 @csarron 👋 Thank you for raising this issue.

We are aware of this phenomenon on all (or nearly all) models that contain rotary position embeddings (Llama, Llama2, Falcon, GPTNeoX, ...). Running things in fp32 helps avoid this problem, but that is far from a good solution.

We have to dive deep to find the root cause, but our bandwidth is limited and we can't provide a time estimate. I'll keep this issue open -- however, if there are volunteers to explore the issue, let me know!

@gante gante added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Sep 5, 2023
@gante gante self-assigned this Sep 5, 2023
@wenhuchen
Copy link
Author

@xiangyue9607, please take a look at this.

@wenhuchen
Copy link
Author

Hi @wenhuchen @da03 @csarron 👋 Thank you for raising this issue.

We are aware of this phenomenon on all (or nearly all) models that contain rotary position embeddings (Llama, Llama2, Falcon, GPTNeoX, ...). Running things in fp32 helps avoid this problem, but that is far from a good solution.

We have to dive deep to find the root cause, but our bandwidth is limited and we can't provide a time estimate. I'll keep this issue open -- however, if there are volunteers to explore the issue, let me know!

@gante, thanks for letting us know. We are using fp32 at this point. But we notice that fp32 normally leads to compromised results than bf16. Anyway, looking forward to your PR to fix this issue.

@prompteus
Copy link

prompteus commented Mar 2, 2024

Any update on this issue?

My T5 model produces different outputs (with greedy decoding) for the same prompt depending on batch size, even if I create a batch by copying the same prompt. It occurs even on cpu with float32 but is more common on cuda with bfloat16.

A self-contained example is below. Seeding and making torch use deterministic algorithms does not help, but I'm adding it here for completeness.

# make torch deterministic
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
import torch
torch.use_deterministic_algorithms(True)


import random
import transformers
import numpy as np

# seed everything
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
transformers.set_seed(0)


model_id = "MU-NLPC/calcformer-instruct-flan-xl_step-128k"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_fast=False)
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_id).to("cuda").to(torch.bfloat16).eval()

question = 'In order to help the victims of the earthquake in Sichuan, the factory rushed to make a batch of disaster relief tents. The first workshop completed (1/5) of this batch of tents, the second workshop completed (1/4) of this batch of tents, and the remaining batch of tents What percentage of it is not completed?'

inputs = tokenizer([question], return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=False)
decoded_bs1 = tokenizer.decode(outputs[0], skip_special_tokens=True, spaces_between_special_tokens=False)

inputs = tokenizer([question] * 4, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=False)
decoded_bs4 = tokenizer.decode(outputs[0], skip_special_tokens=True, spaces_between_special_tokens=False)

decoded_bs1 and decoded_bs4 contain different output sequences.

decoded_bs1 == The first workshop completed (1/5) of the tents, so the remaining tents are 1 - (1/5) - (1/4) = 3/5 of the tents. The second workshop completed (1/4) of the tents, so the remaining tents are 3/5 - (1/4) = 7/10 of the tents. Since 7/10 is equal to 70%, then the remaining batch of tents is not completed for a percentage of (7/10)*100% = 75%.<result>75</result>

decoded_bs4 == The first workshop completed (1/5) of the tents, so the remaining tents are 1 - (1/5) - (1/4) = 3/5 of the tents. The second workshop completed (1/4) of the tents, so the remaining tents are 3/5 - (1/4) = 7/10 of the tents. Since 7/10 is equal to 70%, then the remaining batch of tents is not completed at all.<result>70</result>

My environment:

torch==2.2.0
transformers==4.36.2

torch.version.cuda = '12.1', the gpu is Nvidia A40, but the same issue occurs (for some inputs) on cpu with float32 as well.

@gante
Copy link
Member

gante commented Mar 4, 2024

Hi @prompteus 👋 Have a look at this comment -- #25420 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests

5 participants