Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨🚨 Fix beam score calculation issue for decoder-only models #27351

Merged
merged 4 commits into from
Nov 15, 2023

Conversation

VsonicV
Copy link
Contributor

@VsonicV VsonicV commented Nov 7, 2023

What does this PR do?

This PR fixes issue #26624 . In the original implementation of beam search, the beam score for decoder-only models is normalized by the total length of both prompt and generated sequence. However, the length of prompt should not be included in the normalization step. This issue would cause an unexpected bias towards generating shorter sequences.

This is a simple quick fix by adding an optional parameter decoder_prompy_len, which stores the length of prompt in decoder, to BeamSearchScorer.process(), BeamSearchScorer.finalize() and BeamHypotheses.add(). Since the added new parameter is optional with a default value as 0, any existing calls of these functions without specifying decoder_prompy_len would still work in the same way as before, avoiding any unexpected incompatibility. The corner case in which the very first generated token happens to be eos_token (empty generation) is considered and handled.

Fixes #26624

Note: There are three follow-up PRs that complement this fix:

  1. Fix remaining issues in beam score calculation #27808 further fixes some remaining issues in the Pytorch version.
  2. Fix beam score calculation issue for Tensorflow version #27814 fixes the Tensorflow version.
  3. Fix beam score calculation issue for JAX version #27816 fixes the JAX version.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante

@VsonicV
Copy link
Contributor Author

VsonicV commented Nov 7, 2023

@gante This commit is only for fixing beam_search. If you think the fix is good to go, I can also apply the same fix to beam_sample, group_beam_search and constrained_beam_search.

@VsonicV
Copy link
Contributor Author

VsonicV commented Nov 7, 2023

@gante I think the current tests regarding beam_search are using the results generated by previous "buggy" version, so the new beam_search cannot pass the test test_eos_token_id_int_and_list_beam_search, which uses the decoder-only GPT-2. We need to update the relevant tests as well.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for jumping in to fix it!

Regarding the GPT-2 tests: agreed, they should be updated.

@gante
Copy link
Member

gante commented Nov 7, 2023

@VsonicV the "setup and quality" CI can be fixed by running make fixup on your local transformers root folder and committing the changes!

@VsonicV
Copy link
Contributor Author

VsonicV commented Nov 8, 2023

@gante Thanks for the suggestion! I have fixed the code quality issue using make fixup, and updated the relevant test test_eos_token_id_int_and_list_beam_search with new expectation value. Both checks pass now. However, there is still one check failure caused by test_run_image_classification_no_trainer and test_run_ner_no_trainer, which should be irrelevant to these commits regarding beam search. Do you have any clue about how to fix it?

@gante
Copy link
Member

gante commented Nov 8, 2023

@VsonicV perfect! The failing CI is indeed unrelated (fix: #27353), the tests should pass after it gets merged.

To keep the consistency of this fix throughout beam methods, I'd like to request you to:

  1. Also apply this change to other beam methods in Pytorch :)
  2. Add 🚨🚨 to the PR title, as this is a (correct but also) breaking change
  3. (optional, only if you're comfortable with it, as the fix is slightly different) Apply this change to beam methods in TF and JAX

After 1. and 2. is done, I'll tag a core maintainer to greenlight the merge!

@VsonicV
Copy link
Contributor Author

VsonicV commented Nov 8, 2023

@gante Sure! I will work on 1 and 2 in the next 2 days. Will try to do 3 after that.

@VsonicV VsonicV changed the title Fix beam score calculation issue for decoder-only models 🚨🚨Fix beam score calculation issue for decoder-only models Nov 8, 2023
@VsonicV VsonicV changed the title 🚨🚨Fix beam score calculation issue for decoder-only models 🚨🚨 Fix beam score calculation issue for decoder-only models Nov 8, 2023
@gante
Copy link
Member

gante commented Nov 8, 2023

The PR causing the CI to fail was merged, and I was informed that current PRs will need to be rebased to pass CI 🤗

@VsonicV
Copy link
Contributor Author

VsonicV commented Nov 9, 2023

@gante Item 1 and 2 are done! I have applied the fix to all beam related methods: beam_sample, group_beam_search and constrained_beam_search. I have rebased the PR and all relevant tests have passed.

Regarding the remaining check failures, the recent merge only fixes the check failure caused by test_run_image_classification_no_trainer, but not for test_run_ner_no_trainer. According to the error message AssertionError: 0.5109489440917969 not less than 0.5, the checking threshold for self.assertLess(result["train_loss"], 0.5) in test_run_ner_no_trainer needs to be adjusted as well. Moreover, one new check failure is caused by test_cached_model_has_minimum_calls_to_head and test_cached_tokenizer_has_minimum_calls_to_head, which are unrelated to the commits in this PR (we only see this after the most recent rebase).

@gante
Copy link
Member

gante commented Nov 9, 2023

@VsonicV yes, we are still having some CI failures (unrelated to this PR) 😭

@VsonicV
Copy link
Contributor Author

VsonicV commented Nov 12, 2023

@gante Tried rebasing once more, all the previous check failures are gone, but got one new CI failure caused by test_assisted_decoding_sample, which should again be unrelated to this PR.

@VsonicV
Copy link
Contributor Author

VsonicV commented Nov 15, 2023

@ArthurZucker I have rebased this PR with all your recently added test skips, the CI failures caused by test_assisted_decoding_sample still persist for blenderbot, same failures also happened for pegasus and umt5 in my previous tries, would you mind adding skips of test_assisted_decoding_sample for blenderbot, pegasus and umt5 as well? Thank you!

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Nov 15, 2023

Yeah I'll skip this test for everyone this is getting annoying! 😅 #27511 was merged

@@ -633,7 +633,7 @@ def test_eos_token_id_int_and_list_beam_search(self):
"do_sample": False,
"num_beams": 3,
}
expectation = 13
expectation = 20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
expectation = 20
if is_pt:
expectation = 20
else:
# TODO (joao): fix me
expectation = 13

This test will likely fail on TF, since we haven't applied this upgrade there. Let's add a TODO for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

@gante gante requested a review from amyeroberts November 15, 2023 11:28
@gante
Copy link
Member

gante commented Nov 15, 2023

Tagging @amyeroberts for a final check

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Comment on lines 227 to 229
cur_len = (
input_ids.shape[-1] - decoder_prompt_len + 1
) # add up to the length which the next_scores is calculated on
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
cur_len = (
input_ids.shape[-1] - decoder_prompt_len + 1
) # add up to the length which the next_scores is calculated on
# add up to the length which the next_scores is calculated on
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1

Comment on lines 563 to 565
cur_len = (
input_ids.shape[-1] - decoder_prompt_len + 1
) # add up to the length which the next_scores is calculated on
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
cur_len = (
input_ids.shape[-1] - decoder_prompt_len + 1
) # add up to the length which the next_scores is calculated on
# add up to the length which the next_scores is calculated on
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1

@@ -511,6 +520,7 @@ def process(
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this arg to the docstring below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, will do

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@VsonicV
Copy link
Contributor Author

VsonicV commented Nov 15, 2023

@gante @amyeroberts All your suggested changes have been added and committed. All the tests have passed now (finally!). Should be ready for merge.

@gante gante merged commit 453079c into huggingface:main Nov 15, 2023
2 checks passed
@gante
Copy link
Member

gante commented Nov 15, 2023

@VsonicV Thank you for iterating with us and making transformers better 💛

And sorry for all the failing CI, you've caught an unfortunate series of failures 😬

@VsonicV
Copy link
Contributor Author

VsonicV commented Nov 15, 2023

@gante No problem! Regarding the fix of TF and JAX version, I have looked at the relevant codes briefly, and I think I can fix them. I will try to submit another PR fixing both TF and JAX later this week.

@VsonicV VsonicV deleted the fix_beam_score branch November 15, 2023 13:06
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
…ace#27351)

* Fix beam score calculation issue for decoder-only models

* Update beam search test and fix code quality issue

* Fix beam_sample, group_beam_search and constrained_beam_search

* Split test for pytorch and TF, add documentation

---------

Co-authored-by: Xin Qiu <xin.qiu@sentient.ai>
@VsonicV
Copy link
Contributor Author

VsonicV commented Dec 1, 2023

@gante No problem! Regarding the fix of TF and JAX version, I have looked at the relevant codes briefly, and I think I can fix them. I will try to submit another PR fixing both TF and JAX later this week.

@gante Sorry about the delay in the next steps. I had a severe flu last week and just recovered. Will start working on the remaining fixes.

@VsonicV
Copy link
Contributor Author

VsonicV commented Dec 4, 2023

@gante @amyeroberts All follow-up tasks have been completed, in three new PRs:

  1. Fix remaining issues in beam score calculation #27808 further fixes some remaining issues in the Pytorch version.
  2. Fix beam score calculation issue for Tensorflow version #27814 fixes the Tensorflow version.
  3. Fix beam score calculation issue for JAX version #27816 fixes the JAX version.

All three PRs have passed the CI checks. Ready for your review @gante .

@VsonicV
Copy link
Contributor Author

VsonicV commented Dec 14, 2023

@gante Hi, I noticed that in the recent release notes of v4.36.0, only this PR is listed in "Beam score calculation for decoder-only models" section under "Breaking changes". Should we also add the 3 follow-up PRs ( #27808 #27814 #27816 ) under that section? It would be more clear for people to check all the changes relevant to this breaking change. Thanks.

@gante
Copy link
Member

gante commented Jan 9, 2024

@VsonicV updated the release notes for future reference 👍 Thank you for your suggestion

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Beam search calculates mean logprobs wrong?
5 participants