Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix length related warnings in speculative decoding #29585

Merged
merged 16 commits into from
Apr 10, 2024

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Mar 11, 2024

What does this PR do?

Currently if we pass a min_length or min_new_tokens to speculative decoding, we get a bunch of warnings that
UserWarning: Unfeasible length constraints: min_new_tokens (34), when added to the prompt length (66), is larger than the maximum possible length (75)....

This PR adds a min_new_tokens argument in candidate's generate, which will default to 0 if no min_length as passed by the user.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@gante

fixes #29860

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for fixing! 👍

@@ -1501,7 +1501,7 @@ def generate(
)

# 12. run assisted generate
result = self.assisted_decoding(
result = self._assisted_decoding(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, this was surely throwing a deprecation warning 👍

@gante gante requested a review from amyeroberts March 11, 2024 15:51
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

A few questions about the intended behaviour - in particular why the config values are forcibly reset and instead a instance attribute is used.

Some tests to make sure min_new_tokens has intended behaviour

@@ -3252,6 +3252,28 @@ def test_default_max_length_warning(self):
model.generate(input_ids)
self.assertEqual(len(warning_list), 0)

def test_length_warning_assisted_generation(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should also be a test that min_new_tokens parameter behaves as expected, especially when max_new_tokens is also set and when it's not set at all i.e. goes to default value

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added one more test checking if the length is in range between min and max lengths

@@ -157,6 +157,13 @@ def __init__(
self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True

# avoid unnecessary warnings that min_length is larger than max_new_tokens
input_length = input_ids.shape[-1]
min_new_tokens = self.generation_config.min_new_tokens if self.generation_config.min_new_tokens else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we guaranteed the self.generation_config has this attribute?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the GenerationConfig usually initiates those to 0 or None, if not indicated by the user. So, we check if it is not None for min_new_tokens and then set min_length to maximum (code says min, i'll fix it) between user-defined value or the default 0.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts AFAIK only Whisper (and perhaps other audio models?) uses attributes that may not exist in a generation_config, it is a fairly regular object with everything initialized in __init__ :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, and can we ever expect to have whisper generate configs being used here, or does the model always just use it's own custom generation code?

@@ -157,6 +157,13 @@ def __init__(
self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True

# avoid unnecessary warnings that min_length is larger than max_new_tokens
input_length = input_ids.shape[-1]
min_new_tokens = self.generation_config.min_new_tokens if self.generation_config.min_new_tokens else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this checking for None-ness? Or can it be False? Otherwise, defaulting to 0 if it's 0 is a superfluous check

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Noneness, I'll specify it as not None

@@ -175,6 +182,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
min_new_tokens = min(max_new_tokens, self.min_length - new_cur_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this result in negative values?

If I've generated more tokens than self.min_length i.e. new_cur_len > self.min_length

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, I've missed it

@@ -175,6 +182,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
min_new_tokens = min(max_new_tokens, self.min_length - new_cur_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use a class attribute and not the generation config, as for max_length

Copy link
Member Author

@zucchini-nlp zucchini-nlp Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is because for maximum length we deprecated generation_config.max_new_tokens, so we can use the only possible attribute for max length. Yet, for minimum length we have to attributes, both of which are equally valid. That's why in init we manually set the min_length by checking both attributes, if they are set by user.

EDIT: I just remembered why I did that wat. We have to set generation_config's min_length to 0 in init, that's required to avoid unnecessary warnings. That's why I saved it as class attribute. Otherwise, the generation woul receive kwargs like below and throw warnings
{"min_length"=20, "min_new_tokens"=5, "max_new_tokens=5}

@gante , btw, don't you think we can also deprecate min_new_token?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp it's the other way around, if anything we would want to deprecate the min_tokens argument/config option :) max_new_tokens and min_new_tokens are much more predictable from a user point of view, as the user doesn't need to be concerned with the input length. In the past, before max_new_tokens and min_new_tokens were introduced, we would often get issues from confused users.

Inside generate, however, it is much easier to use the total length for control (no need to track the input length). We set max_length from max_new_tokens when needed (here), perhaps we should do the same with min_length to simplify the procedure in this PR :)

Copy link
Member Author

@zucchini-nlp zucchini-nlp Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to make a separate method for all length related corrections. Added one more test for min length, same as we have for max length.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recent changes look good to me. Added a nit :)

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
@amyeroberts
Copy link
Collaborator

@zucchini-nlp It looks like a lot of tests are failing at the moment because of the lack of min_length attribute

@zucchini-nlp
Copy link
Member Author

Oops, I accepted the last suggestion and did not fix naming in other places of the code. Now it should work, at least locally it was passing tests in "generation"

@zucchini-nlp
Copy link
Member Author

@amyeroberts this one is ready to re-review 😃

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work iterating on this and improving our warnings!

@zucchini-nlp
Copy link
Member Author

@amyeroberts can you merge pls? Failing tests if TF seem to be unrelated

@amyeroberts
Copy link
Collaborator

As the failing tests are for generation, there's a very small chance there would be some interaction between the changes here and those tests (mainly because the test implementation isn't TF specific).

They should now be resolved on main and resolve with a quick rebase :)

@zucchini-nlp zucchini-nlp merged commit 4157976 into huggingface:main Apr 10, 2024
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

assisted_decoding called directly inside generate triggering warning to use when it shouldn't
4 participants