-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix length related warnings in speculative decoding #29585
Changes from 3 commits
7859c43
4b41249
7c35d05
0cb8c1d
f2b820e
dc8235c
2f48285
cc8a47c
0683f2d
24ed5d8
31ad6c0
2a2ec3d
081e1b9
08ff6a9
bb14b36
3b628e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -157,6 +157,13 @@ def __init__( | |
self.generation_config.return_dict_in_generate = True | ||
self.generation_config.output_scores = True | ||
|
||
# avoid unnecessary warnings that min_length is larger than max_new_tokens | ||
input_length = input_ids.shape[-1] | ||
min_new_tokens = self.generation_config.min_new_tokens if self.generation_config.min_new_tokens else 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this checking for None-ness? Or can it be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For Noneness, I'll specify it as |
||
self.min_length = min(self.generation_config.min_length, input_length + min_new_tokens) | ||
self.generation_config.min_length = 0 | ||
self.generation_config.min_new_tokens = None | ||
|
||
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: | ||
""" | ||
Fetches the candidates to be tried for the current input. | ||
|
@@ -175,6 +182,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, | |
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token. | ||
new_cur_len = input_ids.shape[-1] | ||
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) | ||
min_new_tokens = min(max_new_tokens, self.min_length - new_cur_len) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't this result in negative values? If I've generated more tokens than There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point, I've missed it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why use a class attribute and not the generation config, as for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, this is because for maximum length we deprecated EDIT: I just remembered why I did that wat. We have to set generation_config's @gante , btw, don't you think we can also deprecate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zucchini-nlp it's the other way around, if anything we would want to deprecate the Inside generate, however, it is much easier to use the total length for control (no need to track the input length). We set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I decided to make a separate method for all length related corrections. Added one more test for min length, same as we have for max length. |
||
if max_new_tokens == 0: | ||
return input_ids, None | ||
|
||
|
@@ -195,6 +203,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, | |
# 2. Forecast next N tokens using the assistant model. | ||
assistant_generation_kwargs = { | ||
self.input_ids_key: input_ids, | ||
"min_new_tokens": min_new_tokens, | ||
"max_new_tokens": max_new_tokens, | ||
"generation_config": self.generation_config, | ||
"logits_processor": self.logits_processor, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1501,7 +1501,7 @@ def generate( | |
) | ||
|
||
# 12. run assisted generate | ||
result = self.assisted_decoding( | ||
result = self._assisted_decoding( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch, this was surely throwing a deprecation warning 👍 |
||
input_ids, | ||
candidate_generator=candidate_generator, | ||
do_sample=generation_config.do_sample, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3252,6 +3252,28 @@ def test_default_max_length_warning(self): | |
model.generate(input_ids) | ||
self.assertEqual(len(warning_list), 0) | ||
|
||
def test_length_warning_assisted_generation(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should also be a test that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added one more test checking if the length is in range between min and max lengths |
||
# PT-only test: TF doesn't support assisted decoding yet. | ||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) | ||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) | ||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") | ||
model.config.pad_token_id = tokenizer.eos_token_id | ||
assistant.config.pad_token_id = tokenizer.eos_token_id | ||
|
||
text = "Hello world" | ||
tokenized_inputs = tokenizer([text], return_tensors="pt") | ||
input_ids = tokenized_inputs.input_ids.to(torch_device) | ||
|
||
# This should not raise any warning that min length is not feasible in candidate generation | ||
with warnings.catch_warnings(record=True) as warning_list: | ||
model.generate( | ||
input_ids, | ||
assistant_model=assistant, | ||
min_new_tokens=10, | ||
max_length=20, | ||
) | ||
self.assertEqual(len(warning_list), 0) | ||
|
||
def test_model_kwarg_assisted_decoding_decoder_only(self): | ||
# PT-only test: TF doesn't support assisted decoding yet. | ||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we guaranteed the
self.generation_config
has this attribute?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the GenerationConfig usually initiates those to 0 or None, if not indicated by the user. So, we check if it is not None for
min_new_tokens
and then setmin_length
to maximum (code says min, i'll fix it) between user-defined value or the default 0.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts AFAIK only Whisper (and perhaps other audio models?) uses attributes that may not exist in a
generation_config
, it is a fairly regular object with everything initialized in__init__
:DThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, and can we ever expect to have whisper generate configs being used here, or does the model always just use it's own custom generation code?