From 5a2ee0e89239af48005636ce9bec9cccc476976a Mon Sep 17 00:00:00 2001 From: Akihiro Takahashi Date: Wed, 23 Oct 2024 11:42:50 -0700 Subject: [PATCH] Enable flash attention and reuse_cache for gemma Add missing flag handling to gemma --reuse_cache --use_flash_attention --flash_attention_recompute --flash_attention_causal_mask --- .../models/gemma/modeling_gemma.py | 62 +++++++++++-------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/optimum/habana/transformers/models/gemma/modeling_gemma.py b/optimum/habana/transformers/models/gemma/modeling_gemma.py index 1c270b62f6..b66df5c2a0 100755 --- a/optimum/habana/transformers/models/gemma/modeling_gemma.py +++ b/optimum/habana/transformers/models/gemma/modeling_gemma.py @@ -20,6 +20,7 @@ """PyTorch Gemma model.""" import math +import os from typing import List, Optional, Tuple, Union import torch @@ -214,7 +215,7 @@ def pre_attn_forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, - cache_idx: int = None, + cache_idx: Optional[int] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -289,7 +290,8 @@ def pre_attn_forward( if q_len == 1: # next token - with ht.sdp_kernel(enable_recompute=False): + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): attn_output = FusedSDPA.apply( query_states, key_states, value_states, attention_mask, 0.0, False, None ) @@ -407,23 +409,23 @@ def pre_attn( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, - cache_idx: int = None, + cache_idx: Optional[int] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - cache_position, - token_idx, - attn_softmax_bf16, - reuse_cache, - use_flash_attention, - flash_attention_recompute, - flash_attention_causal_mask, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, ) return hidden_states, attn_weights, present_key_value @@ -443,7 +445,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, - cache_idx: int = None, + cache_idx: Optional[int] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Copied from GemmaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py @@ -453,16 +455,16 @@ def forward( residual = hidden_states hidden_states, self_attn_weights, present_key_value = self.pre_attn( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - cache_position, - token_idx, - attn_softmax_bf16, - reuse_cache, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, @@ -717,6 +719,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + reuse_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -746,6 +749,7 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, + reuse_cache=reuse_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -859,9 +863,13 @@ def prepare_inputs_for_generation( "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, + "reuse_cache": kwargs.get("reuse_cache"), "attention_mask": attention_mask, "num_logits_to_keep": num_logits_to_keep, "token_idx": token_idx, + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), } ) return model_inputs