Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong ouput of Gemma-2 models using flash_attention_2 #32309

Closed
tanliboy opened this issue Jul 30, 2024 · 19 comments
Closed

Wrong ouput of Gemma-2 models using flash_attention_2 #32309

tanliboy opened this issue Jul 30, 2024 · 19 comments
Labels

Comments

@tanliboy
Copy link

tanliboy commented Jul 30, 2024

I remember that the soft-capping issue was resolved for forward pass in flash_attn. However, I am still seeing poor model outputs when I enable use_flash_attention_2 in Transformers, even for inference:

Did I miss something? Or is it a recent regression?

Who can help?

@ArthurZucker

Reproduction

  1. Turn on use_flash_attention_2 to load Gemma-2 7B IT model

python

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    use_flash_attention_2=True     # It generates non-sense if I set it to be true
)
input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))
  1. Observe the non-sense output and compare it with the situation when use_flash_attention_2=False. This can be consistently reproduced.

Expected behavior

See the difference below:
Screenshot 2024-07-29 at 6 52 53 PM
Screenshot 2024-07-29 at 6 53 24 PM

@tanliboy tanliboy added the bug label Jul 30, 2024
@tanliboy tanliboy changed the title Wrong model ouput of Gemma-2 models using flash_attention_2 Wrong ouput of Gemma-2 models using flash_attention_2 Jul 30, 2024
@zucchini-nlp
Copy link
Member

Yes, there's a PR to fix it (#32188)

@tanliboy
Copy link
Author

Thank you, @zucchini-nlp !

After the fix, will we be able to use flash_attention_2 for both forward (inference) and backward (training) paths of Gemma2 models in transformers?

Since FlashAttention currently doesn't support a static cache, do you think this issue will also impact other libraries (e.g., vLLM and other frameworks) when using flash_attention_2 with the Gemma2 model? If so, do you think we can address this within the FlashAttention library?

@zucchini-nlp
Copy link
Member

@tanliboy FA2 should now work for transformer, in forward and backward.

For other libraries, I am not super familiar with all of them but for vllm Gemma2 should work same way as other models because they do not use the same ``StaticCache` we do. Also note that currently vllm doesn't do sliding window in every second attn block, as per the comment I see here

@tanliboy
Copy link
Author

Thank you for the details, @zucchini-nlp !

@xenova
Copy link
Contributor

xenova commented Aug 1, 2024

@tanliboy Glad to see it's fixed! Let me know if I can close the issue 😇

@zucchini-nlp
Copy link
Member

sure, the PR is merged already, closing the issue :)

@HuangBugWei
Copy link

@zucchini-nlp, very thank you to fix the issue about it.
Since the PR #32188 is merged in 5 days ago, I guess the latest released v4.43.3 Patch deepspeed does not contain this branch update right?
We should install from source by pip install git+https://github.com/huggingface/transformers to adopt that feature right?

@zucchini-nlp
Copy link
Member

@HuangBugWei correct! We might have a release soon, but until then it should be installed from source

@tanliboy
Copy link
Author

tanliboy commented Aug 9, 2024

@zucchini-nlp thanks for the fix!

I installed the latest release but ran into the below error while using flash_attention_2 (it is fine without using flash_attention_2).

../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [12,0,0], thread: [31,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 12
      5 streamer = TextStreamer(tokenizer)
      6 terminators = (
      7     [
      8         tokenizer.eos_token_id,
      9         tokenizer.convert_tokens_to_ids("<end_of_turn>"),
     10     ]
     11 )
---> 12 _ = model.generate(**input_ids, streamer=streamer, eos_token_id=terminators, max_new_tokens=2048)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/utils/_contextlib.py:116](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/utils/_contextlib.py#line=115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/generation/utils.py:2024](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/generation/utils.py#line=2023), in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2016     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2017         input_ids=input_ids,
   2018         expand_size=generation_config.num_return_sequences,
   2019         is_encoder_decoder=self.config.is_encoder_decoder,
   2020         **model_kwargs,
   2021     )
   2023     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2024     result = self._sample(
   2025         input_ids,
   2026         logits_processor=prepared_logits_processor,
   2027         logits_warper=prepared_logits_warper,
   2028         stopping_criteria=prepared_stopping_criteria,
   2029         generation_config=generation_config,
   2030         synced_gpus=synced_gpus,
   2031         streamer=streamer,
   2032         **model_kwargs,
   2033     )
   2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2036     # 11. prepare logits warper
   2037     prepared_logits_warper = (
   2038         self._get_logits_warper(generation_config, device=input_ids.device)
   2039         if generation_config.do_sample
   2040         else None
   2041     )

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/generation/utils.py:2982](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/generation/utils.py#line=2981), in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   2979 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   2981 # forward pass to get next token
-> 2982 outputs = self(**model_inputs, return_dict=True)
   2984 if synced_gpus and this_peer_finished:
   2985     continue  # don't waste resources running the code we don't need

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1553](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1562](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/accelerate/hooks.py:166](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/accelerate/hooks.py#line=165), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:999](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py#line=998), in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    996 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    998 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 999 outputs = self.model(
   1000     input_ids=input_ids,
   1001     attention_mask=attention_mask,
   1002     position_ids=position_ids,
   1003     past_key_values=past_key_values,
   1004     inputs_embeds=inputs_embeds,
   1005     use_cache=use_cache,
   1006     output_attentions=output_attentions,
   1007     output_hidden_states=output_hidden_states,
   1008     return_dict=return_dict,
   1009     cache_position=cache_position,
   1010 )
   1012 hidden_states = outputs[0]
   1013 logits = self.lm_head(hidden_states)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1553](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1562](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:847](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py#line=846), in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    836     layer_outputs = self._gradient_checkpointing_func(
    837         decoder_layer.__call__,
    838         hidden_states,
   (...)
    844         cache_position,
    845     )
    846 else:
--> 847     layer_outputs = decoder_layer(
    848         hidden_states,
    849         attention_mask=causal_mask,
    850         position_ids=position_ids,
    851         past_key_value=past_key_values,
    852         output_attentions=output_attentions,
    853         use_cache=use_cache,
    854         cache_position=cache_position,
    855     )
    857 hidden_states = layer_outputs[0]
    859 if output_attentions:

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1553](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1562](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/accelerate/hooks.py:166](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/accelerate/hooks.py#line=165), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:590](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py#line=589), in Gemma2DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    587 hidden_states = self.input_layernorm(hidden_states)
    589 # Self Attention
--> 590 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    591     hidden_states=hidden_states,
    592     attention_mask=attention_mask,
    593     position_ids=position_ids,
    594     past_key_value=past_key_value,
    595     output_attentions=output_attentions,
    596     use_cache=use_cache,
    597     cache_position=cache_position,
    598 )
    599 hidden_states = self.post_attention_layernorm(hidden_states)
    600 hidden_states = residual + hidden_states

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1553](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py:1562](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/accelerate/hooks.py:166](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/accelerate/hooks.py#line=165), in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:423](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py#line=422), in Gemma2FlashAttention2.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    420     key_states = key_states.to(target_dtype)
    421     value_states = value_states.to(target_dtype)
--> 423 attn_output = _flash_attention_forward(
    424     query_states,
    425     key_states,
    426     value_states,
    427     attention_mask,
    428     q_len,
    429     dropout=dropout_rate,
    430     softmax_scale=self.scaling,
    431     is_causal=self.is_causal,
    432     use_top_left_mask=self._flash_attn_uses_top_left_mask,
    433     softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
    434 )
    436 attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
    437 attn_output = self.o_proj(attn_output)

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py:246](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py#line=245), in _flash_attention_forward(query_states, key_states, value_states, attention_mask, query_length, is_causal, dropout, position_ids, softmax_scale, sliding_window, use_top_left_mask, softcap, deterministic)
    244 if attention_mask is not None:
    245     batch_size = query_states.shape[0]
--> 246     query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
    247         query_states, key_states, value_states, attention_mask, query_length
    248     )
    249     cu_seqlens_q, cu_seqlens_k = cu_seq_lens
    250     max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py:121](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py#line=120), in _upad_input(query_layer, key_layer, value_layer, attention_mask, query_length)
    118 else:
    119     # The -q_len: slice assumes left padding.
    120     attention_mask = attention_mask[:, -query_length:]
--> 121     query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
    123 return (
    124     query_layer,
    125     key_layer,
   (...)
    129     (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
    130 )

File [/opt/conda/envs/handbook/lib/python3.10/site-packages/flash_attn/bert_padding.py:110](http://localhost:8181/opt/conda/envs/handbook/lib/python3.10/site-packages/flash_attn/bert_padding.py#line=109), in unpad_input(hidden_states, attention_mask)
     99 """
    100 Arguments:
    101     hidden_states: (batch, seqlen, ...)
   (...)
    107     max_seqlen_in_batch: int
    108 """
    109 seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
--> 110 indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    111 max_seqlen_in_batch = seqlens_in_batch.max().item()
    112 cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Here is the testing code to repro:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    use_flash_attention_2=True
)
input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

from transformers import TextStreamer
streamer = TextStreamer(tokenizer)
terminators = (
    [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<end_of_turn>"),
    ]
)
_ = model.generate(**input_ids, streamer=streamer, eos_token_id=terminators, max_new_tokens=2048)

Did I miss something?

@zucchini-nlp
Copy link
Member

@tanliboy yeah, seems like there were some other changes in how attn mask is prepared, which broke FA2 again... Will open a new PR

@tanliboy
Copy link
Author

tanliboy commented Aug 9, 2024

Thank you, @zucchini-nlp !

@tanliboy
Copy link
Author

I tested the fix, and it worked well. Thank you!

I also had a side-by-side comparison during fine-tuning with and without flash_attention_2. Surprisingly, the fine-tuning with flash_attention_2 showed only a marginal improvement over the eager mode on my A100x8 setup.

The "GPU Time Spent Accessing Memory" was around 40%, which is lower than the ~47% observed with eager, but still higher than other models during fine-tuning (~32%). The "Process GPU Memory" is ~91% with flash_attention_2, compared with ~97% with eager.

With flash_attention_2:
Screenshot 2024-08-09 at 5 00 10 PM

With eager:
Screenshot 2024-08-09 at 5 00 31 PM

@tanliboy
Copy link
Author

@zucchini-nlp , is this warning still true? Or should we remove it given the fix?

It is strongly recommended to train Gemma2 models with the eager attention implementation instead of flash_attention_2. Use eager with AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager').

@zucchini-nlp
Copy link
Member

Yes, I believe it still holds true as it wasn't related to FA2 not being supported, but rather due to small numerical precision differences between eager and non-eager attn

@ArthurZucker
Copy link
Collaborator

No it’s not longer true as flash attention soft capping is supported. Will remove

@zucchini-nlp
Copy link
Member

I guess SDPA is not yet supported?

@ArthurZucker
Copy link
Collaborator

Yes, we need to integrate flex attention for that!

@51616
Copy link

51616 commented Jan 26, 2025

No it’s not longer true as flash attention soft capping is supported. Will remove

The warning is still there in 4.48.1. Can you confirm that we can safely ignore this warning?

@ArthurZucker
Copy link
Collaborator

If you don't have the correct version of Flash then it's expected, but otherwise yes, can be ignored!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

6 participants