-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: improve assisted generation tests #27540
Conversation
@@ -1524,62 +1529,49 @@ def test_assisted_decoding_matches_greedy_search(self): | |||
): | |||
self.skipTest("May fix in the future: need model-specific fixes") | |||
|
|||
# This for loop is a naive and temporary effort to make the test less flaky. | |||
failed = 0 | |||
for i in range(10): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was essentially the same as @is_flaky
, but (IMO) less elegant.
Now that we understand the cause for the mismatch (matmul with different shapes), and know that there is no workaround, it is safe to confirm that this test is indeed flaky :)
tests/generation/test_utils.py
Outdated
@@ -1520,66 +1525,53 @@ def test_assisted_decoding_matches_greedy_search(self): | |||
self.skipTest("Won't fix: old model with different cache format") | |||
if any( | |||
model_name in model_class.__name__.lower() | |||
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] | |||
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: seamlessm4t
was already in the skip list of test_assisted_decoding_sample
, probably for the same post mortem reasons
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
Great comments to provide context in the tests 🙏 Only comment is about having config.is_decoder
set for all these tests. Is the case when config.is_encoder_decoder
fully covered?
@@ -1599,18 +1609,27 @@ def test_assisted_decoding_sample(self): | |||
config.use_cache = True | |||
config.is_decoder = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also have a test for when config.is_encoder_decoder
to make sure any relevant logic is handled there?
@amyeroberts It is also not mutually exclusive with All tests that require caching, such as the assisted generation ones, have to set |
@gante Thanks for explaining! I thought they were mutually exclusive |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you 🙏
What does this PR do?
Strengthens the test suite for assisted generation. With these modifications, previously found API problems will be properly caught in advance.
Post mortem
Why weren't API problems caught before?
Assisted generation has two loops: the loop to obtain the candidate tokens from the assistant model (inner loop), and the loop to generate the final tokens from the main model (outer loop). Both loops are slightly different depending on whether the main model accepts the matches or not -- there are different code paths depending on whether
n_matches > 0
or not.The following cases were being tested and had no API issues:
n_matches == 0
n_matches > 0
, but we only run 1 iteration of the outer loop👉 We weren't explicitly testing the case where
n_matches > 0
AND we ran more than 1 outer loop iteration.If we weren't testing that case, why was the CI randomly red?
Each individual test had a ~97% chance of being green. The (random) assistant model was building the candidate sequence from the most likely tokens from its vocabulary (size = 99), and the main model was comparing the candidate sequence against sampling from its logits. Most of the times,
n_matches == 0
, so the test passed. However, sometimes we hadn_matches > 0
, but not to the point where it was enough to complete assisted generation in 1 outer loop.👉 There was a low chance (per test) of hitting the failing case, resulting in inconsistent CI failures