-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: StaticCache & inputs_embeds
#32932
Changes from all commits
fce9e7e
926eaa0
0378197
862fddc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,9 +16,10 @@ | |
|
||
import unittest | ||
|
||
from parameterized import parameterized | ||
from pytest import mark | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline | ||
from transformers.testing_utils import ( | ||
require_flash_attn, | ||
require_read_token, | ||
|
@@ -59,7 +60,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): | |
if is_torch_available() | ||
else () | ||
) | ||
all_generative_model_classes = () | ||
all_generative_model_classes = (Gemma2ForCausalLM,) if is_torch_available() else () | ||
pipeline_model_mapping = ( | ||
{ | ||
"feature-extraction": Gemma2Model, | ||
|
@@ -89,6 +90,101 @@ def test_model_outputs_equivalence(self, **kwargs): | |
def test_eager_matches_sdpa_inference(self): | ||
pass | ||
|
||
@parameterized.expand([("random",), ("same",)]) | ||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") | ||
def test_assisted_decoding_matches_greedy_search(self, assistant_type): | ||
pass | ||
|
||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") | ||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): | ||
pass | ||
|
||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") | ||
def test_assisted_decoding_sample(self): | ||
pass | ||
|
||
@unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding") | ||
def test_dola_decoding_sample(self): | ||
pass | ||
|
||
@parameterized.expand([(1, False), (1, True), (4, False)]) | ||
@unittest.skip("Gemma2 has HybridCache and doesn't support old tuple format at all") | ||
def test_new_cache_format(self, num_beams, do_sample): | ||
pass | ||
|
||
@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv") | ||
def test_generate_continue_from_past_key_values(self): | ||
pass | ||
|
||
@unittest.skip("Gemma2 has HybridCache and doesn't support low_memory generation") | ||
def test_beam_search_low_memory(self): | ||
pass | ||
|
||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation") | ||
def test_contrastive_generate(self): | ||
pass | ||
|
||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation") | ||
def test_contrastive_generate_dict_outputs_use_cache(self): | ||
pass | ||
|
||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation") | ||
def test_contrastive_generate_low_memory(self): | ||
pass | ||
|
||
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") | ||
def test_generate_with_static_cache(self): | ||
pass | ||
|
||
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") | ||
def test_generate_from_inputs_embeds_with_static_cache(self): | ||
pass | ||
|
||
# overwrite because HybridCache has fixed length for key/values | ||
def _check_attentions_for_generate( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add the reason for the overwrite at the top of the fn as a comment, here an on the other functions that need an overwrite! That way, we immediately know why the function needs to exist :) (I see that you added a few comments below, like |
||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 | ||
): | ||
self.assertIsInstance(attentions, tuple) | ||
self.assertListEqual( | ||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) | ||
) | ||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) | ||
|
||
for idx, iter_attentions in enumerate(attentions): | ||
tgt_len = min_length + idx if not use_cache else 1 | ||
src_len = min_length + idx if not use_cache else max_length | ||
|
||
expected_shape = ( | ||
batch_size * num_beam_groups, | ||
config.num_attention_heads, | ||
tgt_len, | ||
src_len, | ||
) | ||
# check attn size | ||
self.assertListEqual( | ||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) | ||
) | ||
|
||
# overwrite because HybridCache has fixed length for key/values | ||
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1): | ||
self.assertIsInstance(past_key_values, HybridCache) | ||
|
||
# check shape key, value (batch, head, max_seq_length, head_features) | ||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads | ||
num_key_value_heads = ( | ||
config.num_attention_heads | ||
if getattr(config, "num_key_value_heads", None) is None | ||
else config.num_key_value_heads | ||
) | ||
num_hidden_layers = config.num_hidden_layers | ||
|
||
# we should get `max_length` in shape, not `max_length - embeds_length` | ||
# `+1` because the test in Mixin subtracts 1 which is needed for tuple cache | ||
static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim) | ||
static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean] | ||
self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers) | ||
self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape) | ||
|
||
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different") | ||
def test_sdpa_equivalence(self): | ||
pass | ||
|
@@ -203,6 +299,5 @@ def test_model_9b_flash_attn(self): | |
|
||
output = model.generate(**inputs, max_new_tokens=100, do_sample=False) | ||
output_text = tokenizer.batch_decode(output, skip_special_tokens=False) | ||
print(output_text) | ||
|
||
self.assertEqual(output_text, EXPECTED_TEXTS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😱 good spot!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was removed because it was faiiling too many tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I skipped those that shouldn't be triggered due to model-specific cache and fixed other failing ones