Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: improve assisted generation tests #27540

Merged
merged 3 commits into from
Nov 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 86 additions & 67 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from transformers import is_torch_available, pipeline
from transformers.testing_utils import (
is_flaky,
require_accelerate,
require_torch,
require_torch_multi_accelerator,
Expand Down Expand Up @@ -1506,10 +1507,14 @@ def test_contrastive_generate_low_memory(self):
)
self.assertListEqual(low_output.tolist(), high_output.tolist())

@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
def test_assisted_decoding_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# It breaks the pattern in the tests above, for multiple reasons:
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
# shape differences -- and it may result in a different output. The input shape difference happens in the
# main model, that runs the forward pass with several candidates at once (as opposed to generating one token at
# a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
# NOTE (2): It breaks the pattern in the tests above, for multiple reasons:
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
# prepare the assistant encoder outputs in the main generate body);
# - assisted_decoding does not support `use_cache = False`
Expand All @@ -1520,77 +1525,82 @@ def test_assisted_decoding_matches_greedy_search(self):
self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
for model_name in [
"bigbirdpegasus",
"led",
"mega",
"speech2text",
"git",
"prophetnet",
"seamlessm4t",
"clvp",
]
):
self.skipTest("May fix in the future: need model-specific fixes")

# This for loop is a naive and temporary effort to make the test less flaky.
failed = 0
for i in range(10):
Copy link
Member Author

@gante gante Nov 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was essentially the same as @is_flaky, but (IMO) less elegant.

Now that we understand the cause for the mismatch (matmul with different shapes), and know that there is no workaround, it is safe to confirm that this test is indeed flaky :)

# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
# enable cache
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)

config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_greedy = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=1,
do_sample=False,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
# Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will
# be correct
output_assisted = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=1,
do_sample=False,
assistant_model=model,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")

try:
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# Sets assisted generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
# the assistant model is correct
# c) there are at least two forward passes in the main model, to ensure the input preparation of
# the main model is correct
generation_kwargs = {
"eos_token_id": -1, # see a)
"max_new_tokens": 4, # see c)
"num_beams": 1,
"do_sample": False,
"output_scores": True,
"output_hidden_states": True,
"output_attentions": True,
"return_dict_in_generate": True,
}
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)

for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)
except AssertionError:
failed += 1
if failed > 1:
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
generation_kwargs.update({"assistant_model": assistant_model})
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)

for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)
# The two outputs must match and their shape must be as expected
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)

@unittest.skip("Failing for a lot of models du to attention mask size missmatch. Works well when standalone.")
def test_assisted_decoding_sample(self):
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"]
for model_name in [
"bigbirdpegasus",
"led",
"mega",
"speech2text",
"git",
"prophetnet",
"seamlessm4t",
"clvp",
]
):
self.skipTest("May fix in the future: need model-specific fixes")

# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
Expand All @@ -1599,18 +1609,27 @@ def test_assisted_decoding_sample(self):
config.use_cache = True
config.is_decoder = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also have a test for when config.is_encoder_decoder to make sure any relevant logic is handled there?

model = model_class(config).to(torch_device).eval()
output_assisted = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=1,
do_sample=True,
assistant_model=model, # triggers assisted decoding
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
# Sets assisted generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
# the assistant model is correct
# c) there are at least two forward passes in the main model, to ensure the input preparation of
# the main model is correct
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
generation_kwargs = {
"eos_token_id": -1, # see a)
"max_new_tokens": 4, # see c)
"num_beams": 1,
"do_sample": True,
"assistant_model": assistant_model,
"output_scores": True,
"output_hidden_states": True,
"output_attentions": True,
"return_dict_in_generate": True,
}
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)

self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)

Expand Down