Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: StaticCache & inputs_embeds #32932

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,7 @@ def _prepare_cache_for_generation(
model_kwargs: Dict,
assistant_model: "PreTrainedModel",
batch_size: int,
max_cache_length: int,
device: torch.device,
) -> bool:
"""
Expand Down Expand Up @@ -1547,8 +1548,8 @@ def _prepare_cache_for_generation(
)
model_kwargs[cache_name] = self._get_cache(
cache_implementation=generation_config.cache_implementation,
batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
max_cache_len=generation_config.max_length,
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
max_cache_len=max_cache_length,
device=device,
model_kwargs=model_kwargs,
)
Expand Down Expand Up @@ -1888,7 +1889,16 @@ def generate(
# TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format)
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
user_defined_cache = model_kwargs.get(cache_name)
self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, device)
max_cache_length = generation_config.max_length
if (
inputs_tensor.shape[1] != input_ids_length
and model_input_name == "inputs_embeds"
and not self.config.is_encoder_decoder
):
max_cache_length += inputs_tensor.shape[1]
self._prepare_cache_for_generation(
generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
)

# 8. determine generation mode
generation_mode = generation_config.get_generation_mode(assistant_model)
Expand Down Expand Up @@ -1936,8 +1946,8 @@ def generate(
raise ValueError("assisted generate is only supported for batch_size = 1")
if not model_kwargs["use_cache"]:
raise ValueError("assisted generate requires `use_cache=True`")
if generation_config.cache_implementation == "static":
raise ValueError("assisted generate is not supported with `static_cache`")
if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]:
raise ValueError("assisted generate is not supported with Static cache classes`")
if self._is_stateful:
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
# which is not possible with stateful models (they can't reset to a previous subset of generated text)
Expand Down
56 changes: 56 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,9 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature):
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()

# no cache as some models require special cache classes to be init outside forward
model.generation_config.use_cache = False

# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
Expand Down Expand Up @@ -1593,6 +1596,59 @@ def test_generate_from_inputs_embeds_decoder_only(self):
outputs_from_embeds_wo_ids.tolist(),
)

@pytest.mark.generate
def test_generate_from_inputs_embeds_with_static_cache(self):
"""
Test that StaticCache can generate from inputs_embeds and calculates max_cache_length
correctly in `generate()`. We force the model to not stop generation until max-length is reached
to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache.
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(reason="This model does not support the static cache format")

config, input_ids, attention_mask = self._get_input_ids_and_config()
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")

model = model_class(config).to(torch_device).eval()
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
self.skipTest(reason="This model does not support `inputs_embeds` in generation")

model.config.use_cache = True
model.config.is_decoder = True
batch_size, seq_length = input_ids.shape
max_cache_len = 30

# here we force to not stop at eos and go until max-length
model.generation_config.eos_token_id = model.config.eos_token_id = -1
generation_kwargs = {
"max_length": max_cache_len,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}

head_dim = (
model.config.head_dim
if hasattr(model.config, "head_dim")
else model.config.hidden_size // model.config.num_attention_heads
)
num_key_value_heads = (
model.config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else model.config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers

inputs_embeds = model.get_input_embeddings()(input_ids)
outputs = model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)

# we should get `max_length` in shape, not `max_length - embeds_length`
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)

@pytest.mark.generate
def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
Expand Down
101 changes: 98 additions & 3 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

import unittest

from parameterized import parameterized
from pytest import mark

from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline
from transformers.testing_utils import (
require_flash_attn,
require_read_token,
Expand Down Expand Up @@ -59,7 +60,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
if is_torch_available()
else ()
)
all_generative_model_classes = ()
all_generative_model_classes = (Gemma2ForCausalLM,) if is_torch_available() else ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😱 good spot!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was removed because it was faiiling too many tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I skipped those that shouldn't be triggered due to model-specific cache and fixed other failing ones

pipeline_model_mapping = (
{
"feature-extraction": Gemma2Model,
Expand Down Expand Up @@ -89,6 +90,101 @@ def test_model_outputs_equivalence(self, **kwargs):
def test_eager_matches_sdpa_inference(self):
pass

@parameterized.expand([("random",), ("same",)])
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass

@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass

@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass

@unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass

@parameterized.expand([(1, False), (1, True), (4, False)])
@unittest.skip("Gemma2 has HybridCache and doesn't support old tuple format at all")
def test_new_cache_format(self, num_beams, do_sample):
pass

@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass

@unittest.skip("Gemma2 has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass

@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass

@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass

@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass

@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass

@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass

# overwrite because HybridCache has fixed length for key/values
def _check_attentions_for_generate(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the reason for the overwrite at the top of the fn as a comment, here an on the other functions that need an overwrite! That way, we immediately know why the function needs to exist :)

(I see that you added a few comments below, like HybridCache has fixed length for key/values, moving it to the top suffices)

self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)

for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
src_len = min_length + idx if not use_cache else max_length

expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)

# overwrite because HybridCache has fixed length for key/values
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
self.assertIsInstance(past_key_values, HybridCache)

# check shape key, value (batch, head, max_seq_length, head_features)
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers

# we should get `max_length` in shape, not `max_length - embeds_length`
# `+1` because the test in Mixin subtracts 1 which is needed for tuple cache
static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim)
static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean]
self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape)

@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_sdpa_equivalence(self):
pass
Expand Down Expand Up @@ -203,6 +299,5 @@ def test_model_9b_flash_attn(self):

output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
print(output_text)

self.assertEqual(output_text, EXPECTED_TEXTS)
Loading