Skip to content

Commit

Permalink
Generate: improve assisted generation tests (#27540)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Nov 16, 2023
1 parent 651408a commit 12b50c6
Showing 1 changed file with 86 additions and 67 deletions.
153 changes: 86 additions & 67 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from transformers import is_torch_available, pipeline
from transformers.testing_utils import (
is_flaky,
require_accelerate,
require_torch,
require_torch_multi_accelerator,
Expand Down Expand Up @@ -1506,10 +1507,14 @@ def test_contrastive_generate_low_memory(self):
)
self.assertListEqual(low_output.tolist(), high_output.tolist())

@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
def test_assisted_decoding_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# It breaks the pattern in the tests above, for multiple reasons:
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
# shape differences -- and it may result in a different output. The input shape difference happens in the
# main model, that runs the forward pass with several candidates at once (as opposed to generating one token at
# a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
# NOTE (2): It breaks the pattern in the tests above, for multiple reasons:
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
# prepare the assistant encoder outputs in the main generate body);
# - assisted_decoding does not support `use_cache = False`
Expand All @@ -1520,77 +1525,82 @@ def test_assisted_decoding_matches_greedy_search(self):
self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
for model_name in [
"bigbirdpegasus",
"led",
"mega",
"speech2text",
"git",
"prophetnet",
"seamlessm4t",
"clvp",
]
):
self.skipTest("May fix in the future: need model-specific fixes")

# This for loop is a naive and temporary effort to make the test less flaky.
failed = 0
for i in range(10):
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
# enable cache
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)

config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_greedy = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=1,
do_sample=False,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
# Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will
# be correct
output_assisted = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=1,
do_sample=False,
assistant_model=model,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")

try:
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# Sets assisted generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
# the assistant model is correct
# c) there are at least two forward passes in the main model, to ensure the input preparation of
# the main model is correct
generation_kwargs = {
"eos_token_id": -1, # see a)
"max_new_tokens": 4, # see c)
"num_beams": 1,
"do_sample": False,
"output_scores": True,
"output_hidden_states": True,
"output_attentions": True,
"return_dict_in_generate": True,
}
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)

for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)
except AssertionError:
failed += 1
if failed > 1:
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
generation_kwargs.update({"assistant_model": assistant_model})
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)

for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)
# The two outputs must match and their shape must be as expected
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)

@unittest.skip("Failing for a lot of models du to attention mask size missmatch. Works well when standalone.")
def test_assisted_decoding_sample(self):
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"]
for model_name in [
"bigbirdpegasus",
"led",
"mega",
"speech2text",
"git",
"prophetnet",
"seamlessm4t",
"clvp",
]
):
self.skipTest("May fix in the future: need model-specific fixes")

# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
Expand All @@ -1599,18 +1609,27 @@ def test_assisted_decoding_sample(self):
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_assisted = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=1,
do_sample=True,
assistant_model=model, # triggers assisted decoding
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
# Sets assisted generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
# the assistant model is correct
# c) there are at least two forward passes in the main model, to ensure the input preparation of
# the main model is correct
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
generation_kwargs = {
"eos_token_id": -1, # see a)
"max_new_tokens": 4, # see c)
"num_beams": 1,
"do_sample": True,
"assistant_model": assistant_model,
"output_scores": True,
"output_hidden_states": True,
"output_attentions": True,
"return_dict_in_generate": True,
}
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)

self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)

Expand Down

0 comments on commit 12b50c6

Please sign in to comment.