-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: improve assisted generation tests #27540
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
|
||
from transformers import is_torch_available, pipeline | ||
from transformers.testing_utils import ( | ||
is_flaky, | ||
require_accelerate, | ||
require_torch, | ||
require_torch_multi_accelerator, | ||
|
@@ -1506,10 +1507,14 @@ def test_contrastive_generate_low_memory(self): | |
) | ||
self.assertListEqual(low_output.tolist(), high_output.tolist()) | ||
|
||
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%. | ||
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail. | ||
def test_assisted_decoding_matches_greedy_search(self): | ||
# This test ensures that the assisted generation does not introduce output changes over greedy search. | ||
# It breaks the pattern in the tests above, for multiple reasons: | ||
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul | ||
# shape differences -- and it may result in a different output. The input shape difference happens in the | ||
# main model, that runs the forward pass with several candidates at once (as opposed to generating one token at | ||
# a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info. | ||
# NOTE (2): It breaks the pattern in the tests above, for multiple reasons: | ||
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to | ||
# prepare the assistant encoder outputs in the main generate body); | ||
# - assisted_decoding does not support `use_cache = False` | ||
|
@@ -1520,77 +1525,82 @@ def test_assisted_decoding_matches_greedy_search(self): | |
self.skipTest("Won't fix: old model with different cache format") | ||
if any( | ||
model_name in model_class.__name__.lower() | ||
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] | ||
for model_name in [ | ||
"bigbirdpegasus", | ||
"led", | ||
"mega", | ||
"speech2text", | ||
"git", | ||
"prophetnet", | ||
"seamlessm4t", | ||
"clvp", | ||
] | ||
): | ||
self.skipTest("May fix in the future: need model-specific fixes") | ||
|
||
# This for loop is a naive and temporary effort to make the test less flaky. | ||
failed = 0 | ||
for i in range(10): | ||
# enable cache | ||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) | ||
|
||
# NOTE: assisted generation only works with cache on at the moment. | ||
if not hasattr(config, "use_cache"): | ||
self.skipTest("This model doesn't support caching") | ||
# enable cache | ||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1) | ||
|
||
config.use_cache = True | ||
config.is_decoder = True | ||
model = model_class(config).to(torch_device).eval() | ||
output_greedy = model.generate( | ||
input_ids, | ||
attention_mask=attention_mask, | ||
max_length=max_length, | ||
num_beams=1, | ||
do_sample=False, | ||
output_scores=True, | ||
output_hidden_states=True, | ||
output_attentions=True, | ||
return_dict_in_generate=True, | ||
) | ||
# Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will | ||
# be correct | ||
output_assisted = model.generate( | ||
input_ids, | ||
attention_mask=attention_mask, | ||
max_length=max_length, | ||
num_beams=1, | ||
do_sample=False, | ||
assistant_model=model, | ||
output_scores=True, | ||
output_hidden_states=True, | ||
output_attentions=True, | ||
return_dict_in_generate=True, | ||
) | ||
# NOTE: assisted generation only works with cache on at the moment. | ||
if not hasattr(config, "use_cache"): | ||
self.skipTest("This model doesn't support caching") | ||
|
||
try: | ||
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) | ||
config.use_cache = True | ||
config.is_decoder = True | ||
model = model_class(config).to(torch_device).eval() | ||
# Sets assisted generation arguments such that: | ||
# a) no EOS is generated, to ensure generation doesn't break early | ||
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of | ||
# the assistant model is correct | ||
# c) there are at least two forward passes in the main model, to ensure the input preparation of | ||
# the main model is correct | ||
generation_kwargs = { | ||
"eos_token_id": -1, # see a) | ||
"max_new_tokens": 4, # see c) | ||
"num_beams": 1, | ||
"do_sample": False, | ||
"output_scores": True, | ||
"output_hidden_states": True, | ||
"output_attentions": True, | ||
"return_dict_in_generate": True, | ||
} | ||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) | ||
|
||
for output in (output_greedy, output_assisted): | ||
self._check_outputs(output, input_ids, model.config, use_cache=True) | ||
except AssertionError: | ||
failed += 1 | ||
if failed > 1: | ||
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) | ||
assistant_model = model | ||
assistant_model.generation_config.num_assistant_tokens = 2 # see b) | ||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) | ||
generation_kwargs.update({"assistant_model": assistant_model}) | ||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) | ||
|
||
for output in (output_greedy, output_assisted): | ||
self._check_outputs(output, input_ids, model.config, use_cache=True) | ||
# The two outputs must match and their shape must be as expected | ||
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) | ||
for output in (output_greedy, output_assisted): | ||
self._check_outputs(output, input_ids, model.config, use_cache=True) | ||
|
||
@unittest.skip("Failing for a lot of models du to attention mask size missmatch. Works well when standalone.") | ||
def test_assisted_decoding_sample(self): | ||
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the | ||
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking). | ||
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not | ||
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with | ||
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535). | ||
for model_class in self.all_generative_model_classes: | ||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): | ||
self.skipTest("Won't fix: old model with different cache format") | ||
if any( | ||
model_name in model_class.__name__.lower() | ||
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"] | ||
for model_name in [ | ||
"bigbirdpegasus", | ||
"led", | ||
"mega", | ||
"speech2text", | ||
"git", | ||
"prophetnet", | ||
"seamlessm4t", | ||
"clvp", | ||
] | ||
): | ||
self.skipTest("May fix in the future: need model-specific fixes") | ||
|
||
# enable cache | ||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) | ||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1) | ||
|
||
# NOTE: assisted generation only works with cache on at the moment. | ||
if not hasattr(config, "use_cache"): | ||
|
@@ -1599,18 +1609,27 @@ def test_assisted_decoding_sample(self): | |
config.use_cache = True | ||
config.is_decoder = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we also have a test for when |
||
model = model_class(config).to(torch_device).eval() | ||
output_assisted = model.generate( | ||
input_ids, | ||
attention_mask=attention_mask, | ||
max_length=max_length, | ||
num_beams=1, | ||
do_sample=True, | ||
assistant_model=model, # triggers assisted decoding | ||
output_scores=True, | ||
output_hidden_states=True, | ||
output_attentions=True, | ||
return_dict_in_generate=True, | ||
) | ||
# Sets assisted generation arguments such that: | ||
# a) no EOS is generated, to ensure generation doesn't break early | ||
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of | ||
# the assistant model is correct | ||
# c) there are at least two forward passes in the main model, to ensure the input preparation of | ||
# the main model is correct | ||
assistant_model = model | ||
assistant_model.generation_config.num_assistant_tokens = 2 # see b) | ||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) | ||
generation_kwargs = { | ||
"eos_token_id": -1, # see a) | ||
"max_new_tokens": 4, # see c) | ||
"num_beams": 1, | ||
"do_sample": True, | ||
"assistant_model": assistant_model, | ||
"output_scores": True, | ||
"output_hidden_states": True, | ||
"output_attentions": True, | ||
"return_dict_in_generate": True, | ||
} | ||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) | ||
|
||
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True) | ||
|
||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was essentially the same as
@is_flaky
, but (IMO) less elegant.Now that we understand the cause for the mismatch (matmul with different shapes), and know that there is no workaround, it is safe to confirm that this test is indeed flaky :)