Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: improve assisted generation tests #27540

Merged
merged 3 commits into from
Nov 16, 2023

Conversation

gante
Copy link
Member

@gante gante commented Nov 16, 2023

What does this PR do?

Strengthens the test suite for assisted generation. With these modifications, previously found API problems will be properly caught in advance.

Post mortem

Why weren't API problems caught before?

Assisted generation has two loops: the loop to obtain the candidate tokens from the assistant model (inner loop), and the loop to generate the final tokens from the main model (outer loop). Both loops are slightly different depending on whether the main model accepts the matches or not -- there are different code paths depending on whether n_matches > 0 or not.

The following cases were being tested and had no API issues:

  1. n_matches == 0
  2. n_matches > 0, but we only run 1 iteration of the outer loop

👉 We weren't explicitly testing the case where n_matches > 0 AND we ran more than 1 outer loop iteration.

If we weren't testing that case, why was the CI randomly red?

Each individual test had a ~97% chance of being green. The (random) assistant model was building the candidate sequence from the most likely tokens from its vocabulary (size = 99), and the main model was comparing the candidate sequence against sampling from its logits. Most of the times, n_matches == 0, so the test passed. However, sometimes we had n_matches > 0, but not to the point where it was enough to complete assisted generation in 1 outer loop.

👉 There was a low chance (per test) of hitting the failing case, resulting in inconsistent CI failures

@gante gante requested a review from amyeroberts November 16, 2023 14:52
@@ -1524,62 +1529,49 @@ def test_assisted_decoding_matches_greedy_search(self):
):
self.skipTest("May fix in the future: need model-specific fixes")

# This for loop is a naive and temporary effort to make the test less flaky.
failed = 0
for i in range(10):
Copy link
Member Author

@gante gante Nov 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was essentially the same as @is_flaky, but (IMO) less elegant.

Now that we understand the cause for the mismatch (matmul with different shapes), and know that there is no workaround, it is safe to confirm that this test is indeed flaky :)

@@ -1520,66 +1525,53 @@ def test_assisted_decoding_matches_greedy_search(self):
self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: seamlessm4t was already in the skip list of test_assisted_decoding_sample, probably for the same post mortem reasons

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Great comments to provide context in the tests 🙏 Only comment is about having config.is_decoder set for all these tests. Is the case when config.is_encoder_decoder fully covered?

@@ -1599,18 +1609,27 @@ def test_assisted_decoding_sample(self):
config.use_cache = True
config.is_decoder = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also have a test for when config.is_encoder_decoder to make sure any relevant logic is handled there?

@gante
Copy link
Member Author

gante commented Nov 16, 2023

@amyeroberts is_decoder is a poorly named flag 😅 contrarily to is_encoder_decoder, which controls many aspects in generation, is_decoder only controls one thing AFAIK -- whether to enable use_cache (example) and pipe the cache around in encoders with a LM Head.

It is also not mutually exclusive with is_encoder_decoder (it should be IMO 👀)

All tests that require caching, such as the assisted generation ones, have to set model.config.is_decoder = True. Otherwise, the tests will fail in the encoder with LM Heads (see image below)
Screenshot 2023-11-16 at 18 07 27

@amyeroberts
Copy link
Collaborator

@gante Thanks for explaining! I thought they were mutually exclusive

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you 🙏

@gante gante merged commit 12b50c6 into huggingface:main Nov 16, 2023
3 checks passed
@gante gante deleted the assisted_harder_tests branch November 16, 2023 18:54
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants