Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix length related warnings in speculative decoding #29585

Merged
merged 16 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ def __init__(
self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True

# avoid unnecessary warnings that min_length is larger than max_new_tokens
input_length = input_ids.shape[-1]
min_new_tokens = (
self.generation_config.min_new_tokens if self.generation_config.min_new_tokens is not None else 0
)
self.min_length = max(self.generation_config.min_length, input_length + min_new_tokens)
self.generation_config.min_length = 0
self.generation_config.min_new_tokens = None

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Expand All @@ -175,6 +184,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
min_new_tokens = min(max_new_tokens, max(self.min_length - new_cur_len, 0))
if max_new_tokens == 0:
return input_ids, None

Expand All @@ -195,6 +205,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,7 @@ def generate(
)

# 12. run assisted generate
result = self.assisted_decoding(
result = self._assisted_decoding(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, this was surely throwing a deprecation warning 👍

input_ids,
candidate_generator=candidate_generator,
do_sample=generation_config.do_sample,
Expand Down
50 changes: 50 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3252,6 +3252,56 @@ def test_default_max_length_warning(self):
model.generate(input_ids)
self.assertEqual(len(warning_list), 0)

def test_length_warning_assisted_generation(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should also be a test that min_new_tokens parameter behaves as expected, especially when max_new_tokens is also set and when it's not set at all i.e. goes to default value

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added one more test checking if the length is in range between min and max lengths

# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = tokenizer.eos_token_id
assistant.config.pad_token_id = tokenizer.eos_token_id

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)

# This should not raise any warning that min length is not feasible in candidate generation
with warnings.catch_warnings(record=True) as warning_list:
model.generate(
input_ids,
assistant_model=assistant,
min_new_tokens=10,
max_length=20,
)
self.assertEqual(len(warning_list), 0)

def test_generated_length_assisted_generation(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = tokenizer.eos_token_id
assistant.config.pad_token_id = tokenizer.eos_token_id

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
input_length = input_ids.shape[-1]

out = model.generate(
input_ids,
assistant_model=assistant,
min_new_tokens=10,
max_new_tokens=20,
)
self.assertTrue((10 + input_length) <= out.shape[-1] <= (20 + input_length))

out = model.generate(
input_ids,
assistant_model=assistant,
min_new_tokens=10,
)
self.assertTrue((input_length + 10) <= out.shape[-1] <= 20)

def test_model_kwarg_assisted_decoding_decoder_only(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
Expand Down
Loading