Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test_finetune_bert2bert #25984

Merged
merged 1 commit into from
Sep 13, 2023
Merged

Fix test_finetune_bert2bert #25984

merged 1 commit into from
Sep 13, 2023

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Sep 5, 2023

What does this PR do?

Fix the CI error. See the comment in the change of this PR.

tests/trainer/test_trainer_seq2seq.py::Seq2seqTrainerTester::test_finetune_bert2bert
(line 162)  ValueError: Make sure to set the pad_token_id attribute of the model's configuration.

Comment on lines +274 to +283
generation_inputs = inputs.copy()
# If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
# (otherwise, it would continue generating from the padded `decoder_input_ids`)
if (
"labels" in inputs
and "decoder_input_ids" in inputs
and inputs["labels"].shape == inputs["decoder_input_ids"].shape
"labels" in generation_inputs
and "decoder_input_ids" in generation_inputs
and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
):
inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"}
generated_tokens = self.model.generate(**inputs, **gen_kwargs)
generation_inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"}
generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failed test started failing since d979cf6e (months ago 😅 ).
The modification of inputs here will be used in the next block related to training.
Since labels is in inputs, and we moved out decoder_input_ids, it tries to create decoder_input_ids from labels again in EncoderDecoder but failed as there is no padding token.

This PR just use generation_inputs (without decoder_input_ids) for generation but keep inputs untouched for training.

@ydshieh ydshieh requested a review from gante September 5, 2023 09:32
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 5, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix @ydshieh 🙏

@gante gante merged commit 95a9041 into main Sep 13, 2023
@gante gante deleted the fix_seq_2_seq_ci branch September 13, 2023 15:53
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants