-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix length related warnings in speculative decoding #29585
Fix length related warnings in speculative decoding #29585
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for fixing! 👍
@@ -1501,7 +1501,7 @@ def generate( | |||
) | |||
|
|||
# 12. run assisted generate | |||
result = self.assisted_decoding( | |||
result = self._assisted_decoding( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, this was surely throwing a deprecation warning 👍
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
A few questions about the intended behaviour - in particular why the config values are forcibly reset and instead a instance attribute is used.
Some tests to make sure min_new_tokens
has intended behaviour
@@ -3252,6 +3252,28 @@ def test_default_max_length_warning(self): | |||
model.generate(input_ids) | |||
self.assertEqual(len(warning_list), 0) | |||
|
|||
def test_length_warning_assisted_generation(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should also be a test that min_new_tokens
parameter behaves as expected, especially when max_new_tokens
is also set and when it's not set at all i.e. goes to default value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added one more test checking if the length is in range between min and max lengths
@@ -157,6 +157,13 @@ def __init__( | |||
self.generation_config.return_dict_in_generate = True | |||
self.generation_config.output_scores = True | |||
|
|||
# avoid unnecessary warnings that min_length is larger than max_new_tokens | |||
input_length = input_ids.shape[-1] | |||
min_new_tokens = self.generation_config.min_new_tokens if self.generation_config.min_new_tokens else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we guaranteed the self.generation_config
has this attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the GenerationConfig usually initiates those to 0 or None, if not indicated by the user. So, we check if it is not None for min_new_tokens
and then set min_length
to maximum (code says min, i'll fix it) between user-defined value or the default 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts AFAIK only Whisper (and perhaps other audio models?) uses attributes that may not exist in a generation_config
, it is a fairly regular object with everything initialized in __init__
:D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, and can we ever expect to have whisper generate configs being used here, or does the model always just use it's own custom generation code?
@@ -157,6 +157,13 @@ def __init__( | |||
self.generation_config.return_dict_in_generate = True | |||
self.generation_config.output_scores = True | |||
|
|||
# avoid unnecessary warnings that min_length is larger than max_new_tokens | |||
input_length = input_ids.shape[-1] | |||
min_new_tokens = self.generation_config.min_new_tokens if self.generation_config.min_new_tokens else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this checking for None-ness? Or can it be False
? Otherwise, defaulting to 0 if it's 0 is a superfluous check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Noneness, I'll specify it as not None
@@ -175,6 +182,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, | |||
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token. | |||
new_cur_len = input_ids.shape[-1] | |||
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) | |||
min_new_tokens = min(max_new_tokens, self.min_length - new_cur_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this result in negative values?
If I've generated more tokens than self.min_length
i.e. new_cur_len > self.min_length
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, I've missed it
@@ -175,6 +182,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, | |||
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token. | |||
new_cur_len = input_ids.shape[-1] | |||
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) | |||
min_new_tokens = min(max_new_tokens, self.min_length - new_cur_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use a class attribute and not the generation config, as for max_length
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, this is because for maximum length we deprecated generation_config.max_new_tokens
, so we can use the only possible attribute for max length. Yet, for minimum length we have to attributes, both of which are equally valid. That's why in init
we manually set the min_length
by checking both attributes, if they are set by user.
EDIT: I just remembered why I did that wat. We have to set generation_config's min_length
to 0 in init
, that's required to avoid unnecessary warnings. That's why I saved it as class attribute. Otherwise, the generation woul receive kwargs like below and throw warnings
{"min_length"=20, "min_new_tokens"=5, "max_new_tokens=5}
@gante , btw, don't you think we can also deprecate min_new_token
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zucchini-nlp it's the other way around, if anything we would want to deprecate the min_tokens
argument/config option :) max_new_tokens
and min_new_tokens
are much more predictable from a user point of view, as the user doesn't need to be concerned with the input length. In the past, before max_new_tokens
and min_new_tokens
were introduced, we would often get issues from confused users.
Inside generate, however, it is much easier to use the total length for control (no need to track the input length). We set max_length
from max_new_tokens
when needed (here), perhaps we should do the same with min_length
to simplify the procedure in this PR :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I decided to make a separate method for all length related corrections. Added one more test for min length, same as we have for max length.
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The recent changes look good to me. Added a nit :)
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
@zucchini-nlp It looks like a lot of tests are failing at the moment because of the lack of |
Oops, I accepted the last suggestion and did not fix naming in other places of the code. Now it should work, at least locally it was passing tests in "generation" |
@amyeroberts this one is ready to re-review 😃 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work iterating on this and improving our warnings!
@amyeroberts can you merge pls? Failing tests if TF seem to be unrelated |
As the failing tests are for generation, there's a very small chance there would be some interaction between the changes here and those tests (mainly because the test implementation isn't TF specific). They should now be resolved on |
What does this PR do?
Currently if we pass a
min_length
ormin_new_tokens
to speculative decoding, we get a bunch of warnings thatUserWarning: Unfeasible length constraints: min_new_tokens (34), when added to the prompt length (66), is larger than the maximum possible length (75)....
This PR adds a
min_new_tokens
argument in candidate's generate, which will default to 0 if nomin_length
as passed by the user.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@gante
fixes #29860