-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix length related warnings in speculative decoding #29585
Changes from all commits
7859c43
4b41249
7c35d05
0cb8c1d
f2b820e
dc8235c
2f48285
cc8a47c
0683f2d
24ed5d8
31ad6c0
2a2ec3d
081e1b9
08ff6a9
bb14b36
3b628e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1977,6 +1977,20 @@ def test_max_length_if_input_embeds(self): | |
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_length=max_length) | ||
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1]) | ||
|
||
def test_min_length_if_input_embeds(self): | ||
# PT-only test: TF doesn't have StoppingCriteria | ||
article = "Today a dragon flew over Paris." | ||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) | ||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") | ||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) | ||
inputs_embeds = model.get_input_embeddings()(input_ids) | ||
|
||
min_length = 10 | ||
input_len = input_ids.shape[-1] | ||
out_gen = model.generate(input_ids=input_ids, min_length=min_length) | ||
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, min_length=min_length) | ||
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1]) | ||
|
||
def test_custom_stopping_criteria_overload_error(self): | ||
# PT-only test: TF doesn't have StoppingCriteria | ||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" | ||
|
@@ -2539,6 +2553,56 @@ def test_default_max_length_warning(self): | |
model.generate(input_ids) | ||
self.assertEqual(len(warning_list), 0) | ||
|
||
def test_length_warning_assisted_generation(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should also be a test that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added one more test checking if the length is in range between min and max lengths |
||
# PT-only test: TF doesn't support assisted decoding yet. | ||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) | ||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) | ||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") | ||
model.config.pad_token_id = tokenizer.eos_token_id | ||
assistant.config.pad_token_id = tokenizer.eos_token_id | ||
|
||
text = "Hello world" | ||
tokenized_inputs = tokenizer([text], return_tensors="pt") | ||
input_ids = tokenized_inputs.input_ids.to(torch_device) | ||
|
||
# This should not raise any warning that min length is not feasible in candidate generation | ||
with warnings.catch_warnings(record=True) as warning_list: | ||
model.generate( | ||
input_ids, | ||
assistant_model=assistant, | ||
min_new_tokens=10, | ||
max_length=20, | ||
) | ||
self.assertEqual(len(warning_list), 0) | ||
|
||
def test_generated_length_assisted_generation(self): | ||
# PT-only test: TF doesn't support assisted decoding yet. | ||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) | ||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) | ||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") | ||
model.config.pad_token_id = tokenizer.eos_token_id | ||
assistant.config.pad_token_id = tokenizer.eos_token_id | ||
|
||
text = "Hello world" | ||
tokenized_inputs = tokenizer([text], return_tensors="pt") | ||
input_ids = tokenized_inputs.input_ids.to(torch_device) | ||
input_length = input_ids.shape[-1] | ||
|
||
out = model.generate( | ||
input_ids, | ||
assistant_model=assistant, | ||
min_new_tokens=10, | ||
max_new_tokens=20, | ||
) | ||
self.assertTrue((10 + input_length) <= out.shape[-1] <= (20 + input_length)) | ||
|
||
out = model.generate( | ||
input_ids, | ||
assistant_model=assistant, | ||
min_new_tokens=10, | ||
) | ||
self.assertTrue((input_length + 10) <= out.shape[-1] <= 20) | ||
|
||
def test_model_kwarg_assisted_decoding_decoder_only(self): | ||
# PT-only test: TF doesn't support assisted decoding yet. | ||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, this was surely throwing a deprecation warning 👍