diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 63dcf38240f9..b4308f4f7767 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -1076,10 +1076,14 @@ def prepare_inputs_for_generation( # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if ( + isinstance(past_key_values, HybridCache) + and attention_mask.ndim == 2 + and not self.config._attn_implementation == "flash_attention_2" + ): + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: batch_size, sequence_length = model_inputs["input_ids"].shape device = model_inputs["input_ids"].device