Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AttentionMaskConverter] ]Fix-mask-inf #27114

Merged
merged 8 commits into from
Nov 10, 2023
Merged

[AttentionMaskConverter] ]Fix-mask-inf #27114

merged 8 commits into from
Nov 10, 2023

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Oct 27, 2023

What does this PR do?

Fixes the -inf appearing in the padding mask from the way we create them. Adds @dataclass decorator to AttentionMaskConverter as well as en example in the doc.

The pad tokens are still attended to for some specific cases, which produce different outputs for flash attention / non flash attention. Might fix #27050, but also related to other Llama issue.

FYI @gante for visibility 😉

Basically instead of

  -inf, 0, 0, 0 .... 0                             -inf, 0, 0, 0 .... 0           
  -inf, 0, 0, 0 .... 0            +                -inf, -inf, 0, 0, ....
  -inf, 0, 0, 0 .... 0                             -inf, -inf, -inf, 0 .... 0           
  -inf, 0, 0, 0 .... 0                             -inf, -inf, -inf, -inf .... 0                

we just mask fill the second with the first. This way we are sure that the mask does not overflow.
Before:

>>> import torch
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
>>> converter =  AttentionMaskConverter(True)
>>> converter.to_4d(torch.tensor([[0,0,0,1,1]]), 5, 5)
tensor([[[[-3.4028e+38,        -inf,        -inf, -3.4028e+38, -3.4028e+38],
          [-3.4028e+38, -3.4028e+38,        -inf, -3.4028e+38, -3.4028e+38],
          [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00, -3.4028e+38],
          [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00,  0.0000e+00]]]])

after:

>>> import torch
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
>>> converter =  AttentionMaskConverter(True)
>>> converter.to_4d(torch.tensor([[0,0,0,1,1]]), 5, 5)
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00, -3.4028e+38],
          [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00,  0.0000e+00]]]])

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 27, 2023

The documentation is not available anymore as the PR was closed or merged.

@ArthurZucker ArthurZucker marked this pull request as ready for review November 8, 2023 16:48
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks for the offline explanation

@ArthurZucker ArthurZucker merged commit 68afca3 into main Nov 10, 2023
3 checks passed
@ArthurZucker ArthurZucker deleted the fix-mask-inf branch November 10, 2023 14:22
@toritospartan
Copy link

Hello!

We have pulled this PR from the main branch yesterday. We were having NaN issues with ppo_trainer and llama2-7b-chat. After investigation, we found that the NaN can be reproduced just by generating the 1st token for a batch of 4 sentences and it depends on how we form the batch (i.e. the sentences that fail depend on the size of the batch and which sentences we include in the batch). Including different sentences in the batch changes the padding structure and it seems that the moment you get padding, your risk of NaN increases. Nevertheless, we have seen also NaN with batch=1 (no padding) and float16, so it seems that padding is not the only root of the problem

We have observed that the NaN appear in the 31st layer and subsequently in the logits, not in earlier layers. The input_ids and attention mask that generate get seem correct. The example code uses bfloat16 because it seems to alleviate the issue, which is more frequent with float16.

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch

bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )

n_gpus = torch.cuda.device_count()
max_memory = f'{40960}MB' # TODO: what if we have more memory?


sft_folder = "/raid/general/llm/llama2/hf_versions/7b-chat-withpad/"

sft_model = AutoModelForCausalLM.from_pretrained(
    sft_folder,
    quantization_config=bnb_config,
    device_map="auto", # dispatch efficiently the model on the available ressources
    max_memory = {i: max_memory for i in range(n_gpus)},
    torch_dtype=torch.bfloat16, 
)   
tokenizer = AutoTokenizer.from_pretrained(sft_folder, model_max_length=2048)

# Depending on how we select the batch here, different sentences fail
bb = tokenized_dataset['train']['input_ids'][:2]
batch_mask = [torch.ones_like(element) for element in bb]
inputs = {"input_ids": bb, "attention_mask": batch_mask}

tokenizer.padding_side = "left"
padded_inputs = tokenizer.pad(
    inputs,
    padding=True,
    max_length=None,
    pad_to_multiple_of=None,
    return_tensors="pt",
)

generation_kwargs = {
    "top_k": 0.0,
    "top_p": 0.0,
    "temperature": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id,
    "max_new_tokens": 1
}

response_tensors = sft_model.generate(**padded_inputs, **generation_kwargs)

@gante
Copy link
Member

gante commented Nov 15, 2023

Hi @toritospartan -- any chance you could reproduce the issue with an open-access model OR privately share your model with us? Otherwise, it will be very challenging for us to nail the cause :)

@toritospartan
Copy link

We have been able to reproduce the issue with just public data. The reason of it seems to be that, due to a lack of padding token in llama2, we added our own pad token (we added 128 tokens to keep the model efficient as warning said), thinking this token should be ignored anyway. However this seems to produce those NaN in some occasions. We checked the range of the embedding of token 0 (maybe this is the pad token Meta used even if it is not clear in their code or in the export scrip from HF?). The std of this embedding is 0 with mean 0. Our pad token embedding had the std of the _init_weights of the Transformer model (this is expected). Thing is that it is this range that seems to make llama overflow. We have generated a script that makes this happen very often via generating that weight with an exaggerated std, Clear advise on how to manage this situation will make people be less confused because a further question is (are we creating a bias in the model because of which pad token we use?)

from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
import torch
import pandas as pd
from datasets import Dataset

# Model to load
model_folder = "/raid/general/llm/llama2/hf_versions/7b-chat/"

# Quantization
bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )


# Load model
n_gpus = torch.cuda.device_count()
max_memory = f'{40960}MB'
model = AutoModelForCausalLM.from_pretrained(
    model_folder,
    quantization_config=bnb_config,
    device_map="auto", # dispatch efficiently the model on the available ressources
    max_memory = {i: max_memory for i in range(n_gpus)},
    torch_dtype=torch.float16, 
)   
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_folder, model_max_length=2048)

# 
tokenizer.add_special_tokens({'pad_token': '[PAD]'}, replace_additional_special_tokens=False)
tokenizer.add_special_tokens({'additional_special_tokens': [f'[unused_{i}]' for i in range(0,127)]}, replace_additional_special_tokens=False)
tokenizer.pad_token = '[PAD]'
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.special_tokens_map['pad_token'])

old_range = model.config.initializer_range
model.config.initializer_range = 10000
model.resize_token_embeddings(len(tokenizer))
model.config.initializer_range = old_range

# Generate dataset
prompts = [
    "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.",
    "Ut enim ad minim veniam",
]

prompt_df = pd.DataFrame(prompts, columns=["prompt"])
prompt_dataset = Dataset.from_pandas(prompt_df)

# Tokenize dataset
def tokenize_dataset(dataset, tokenizer):
    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["prompt"])
        return sample
    dataset = dataset.map(tokenize, batched=False)
    dataset.set_format(type="torch")
    return dataset
tokenized_dataset = tokenize_dataset(prompt_dataset, tokenizer)

batch_input_ids = tokenized_dataset['input_ids']
batch_mask = [torch.ones_like(element) for element in batch_input_ids]
inputs = {"input_ids": batch_input_ids, "attention_mask": batch_mask}

tokenizer.padding_side = "left"
padded_inputs = tokenizer.pad(
    inputs,
    padding=True,
    max_length=None,
    pad_to_multiple_of=None,
    return_tensors="pt",
)

generation_kwargs = {
    "top_k": 0.0,
    "top_p": 0.0,
    "temperature": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id,
    "max_new_tokens": 1
}

response_tensors = model.generate(**padded_inputs, **generation_kwargs)

@gante
Copy link
Member

gante commented Nov 17, 2023

(cc @ArthurZucker as I have no idea how adding extra tokens works internally :D)

EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* fix?

* actual fix

* fixups

* add dataclass to the attention mask converter

* refine testing suite

* make sure there are no overflows

* update the test
@ArthurZucker ArthurZucker mentioned this pull request Nov 20, 2023
4 tasks
@ArthurZucker
Copy link
Collaborator Author

Hey! Thanks both, when adding a new token it is recommended to initialize it's embedding to an average of all the embedding of the embedding layer! This explains it best: https://nlp.stanford.edu/~johnhew/vocab-expansion.html.
Would you mind trying this! 🤗

@artsobolev
Copy link

artsobolev commented Dec 9, 2023

@toritospartan, the LLaMA models are unaffected by this PR as they do masking by hand instead of relying on AttentionMaskConverter.to_4d. So do mistral, gpt2, falcon, t5 and probably many others.

As an alternative solution you can do while waiting for existing models to be fixed I'd suggest adding the following after the line 425 of modeling_llama.py (before the masking):

attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]

@xuyaoxun
Copy link

xuyaoxun commented Dec 14, 2023

@toritospartan, the LLaMA models are unaffected by this PR as they do masking by hand instead of relying on AttentionMaskConverter.to_4d. So do mistral, gpt2, falcon, t5 and probably many others.

As an alternative solution you can do while waiting for existing models to be fixed I'd suggest adding the following after the line 425 of modeling_llama.py (before the masking):

attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]

@artsobolev Hi,Since the code for modeling_llama has changed, and I'm not sure exactly where you're referring to, I've put in if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
before attn_weights = attn_weights + attention_mask and didn't get work, still got nan. I'm not sure if is the right place.

@ArthurZucker
Copy link
Collaborator Author

ArthurZucker commented Dec 14, 2023

Llama, mistral and so on do use _prepare_4d_causal_attention_mask which uses to_4d if the mask is provided, to_causal_4d otherwise. No the nan values do not arise from the mask anymore, Llama always had instabilities, this PR fixes the ones related to attention mask overflow. Not sure what code you are looking at @artsobolev ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Difference in LlamaAttention & LlamaFlashAttention2 attn_output
7 participants