-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚨🚨 Fix beam score calculation issue for decoder-only models #27351
Conversation
@gante This commit is only for fixing |
@gante I think the current tests regarding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for jumping in to fix it!
Regarding the GPT-2 tests: agreed, they should be updated.
@VsonicV the "setup and quality" CI can be fixed by running |
6d8bde6
to
2ca1e4e
Compare
@gante Thanks for the suggestion! I have fixed the code quality issue using |
@VsonicV perfect! The failing CI is indeed unrelated (fix: #27353), the tests should pass after it gets merged. To keep the consistency of this fix throughout beam methods, I'd like to request you to:
After 1. and 2. is done, I'll tag a core maintainer to greenlight the merge! |
@gante Sure! I will work on 1 and 2 in the next 2 days. Will try to do 3 after that. |
The PR causing the CI to fail was merged, and I was informed that current PRs will need to be rebased to pass CI 🤗 |
2ca1e4e
to
c6e48a7
Compare
@gante Item 1 and 2 are done! I have applied the fix to all beam related methods: Regarding the remaining check failures, the recent merge only fixes the check failure caused by |
@VsonicV yes, we are still having some CI failures (unrelated to this PR) 😭 |
c6e48a7
to
464624c
Compare
@gante Tried rebasing once more, all the previous check failures are gone, but got one new CI failure caused by |
cd126fe
to
32c8293
Compare
32c8293
to
d0b64ad
Compare
@ArthurZucker I have rebased this PR with all your recently added test skips, the CI failures caused by |
Yeah I'll skip this test for everyone this is getting annoying! 😅 #27511 was merged |
d0b64ad
to
a4a6da9
Compare
@@ -633,7 +633,7 @@ def test_eos_token_id_int_and_list_beam_search(self): | |||
"do_sample": False, | |||
"num_beams": 3, | |||
} | |||
expectation = 13 | |||
expectation = 20 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expectation = 20 | |
if is_pt: | |
expectation = 20 | |
else: | |
# TODO (joao): fix me | |
expectation = 13 |
This test will likely fail on TF, since we haven't applied this upgrade there. Let's add a TODO for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch
Tagging @amyeroberts for a final check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
cur_len = ( | ||
input_ids.shape[-1] - decoder_prompt_len + 1 | ||
) # add up to the length which the next_scores is calculated on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
cur_len = ( | |
input_ids.shape[-1] - decoder_prompt_len + 1 | |
) # add up to the length which the next_scores is calculated on | |
# add up to the length which the next_scores is calculated on | |
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1 |
cur_len = ( | ||
input_ids.shape[-1] - decoder_prompt_len + 1 | ||
) # add up to the length which the next_scores is calculated on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
cur_len = ( | |
input_ids.shape[-1] - decoder_prompt_len + 1 | |
) # add up to the length which the next_scores is calculated on | |
# add up to the length which the next_scores is calculated on | |
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1 |
@@ -511,6 +520,7 @@ def process( | |||
pad_token_id: Optional[int] = None, | |||
eos_token_id: Optional[Union[int, List[int]]] = None, | |||
beam_indices: Optional[torch.LongTensor] = None, | |||
decoder_prompt_len: Optional[int] = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add this arg to the docstring below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, will do
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
@gante @amyeroberts All your suggested changes have been added and committed. All the tests have passed now (finally!). Should be ready for merge. |
@VsonicV Thank you for iterating with us and making And sorry for all the failing CI, you've caught an unfortunate series of failures 😬 |
@gante No problem! Regarding the fix of TF and JAX version, I have looked at the relevant codes briefly, and I think I can fix them. I will try to submit another PR fixing both TF and JAX later this week. |
…ace#27351) * Fix beam score calculation issue for decoder-only models * Update beam search test and fix code quality issue * Fix beam_sample, group_beam_search and constrained_beam_search * Split test for pytorch and TF, add documentation --------- Co-authored-by: Xin Qiu <xin.qiu@sentient.ai>
@gante Sorry about the delay in the next steps. I had a severe flu last week and just recovered. Will start working on the remaining fixes. |
@gante @amyeroberts All follow-up tasks have been completed, in three new PRs:
All three PRs have passed the CI checks. Ready for your review @gante . |
@gante Hi, I noticed that in the recent release notes of v4.36.0, only this PR is listed in "Beam score calculation for decoder-only models" section under "Breaking changes". Should we also add the 3 follow-up PRs ( #27808 #27814 #27816 ) under that section? It would be more clear for people to check all the changes relevant to this breaking change. Thanks. |
@VsonicV updated the release notes for future reference 👍 Thank you for your suggestion |
What does this PR do?
This PR fixes issue #26624 . In the original implementation of beam search, the beam score for decoder-only models is normalized by the total length of both prompt and generated sequence. However, the length of prompt should not be included in the normalization step. This issue would cause an unexpected bias towards generating shorter sequences.
This is a simple quick fix by adding an optional parameter
decoder_prompy_len
, which stores the length of prompt in decoder, toBeamSearchScorer.process()
,BeamSearchScorer.finalize()
andBeamHypotheses.add()
. Since the added new parameter is optional with a default value as 0, any existing calls of these functions without specifyingdecoder_prompy_len
would still work in the same way as before, avoiding any unexpected incompatibility. The corner case in which the very first generated token happens to be eos_token (empty generation) is considered and handled.Fixes #26624
Note: There are three follow-up PRs that complement this fix:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@gante