Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix length related warnings in speculative decoding #29585

Merged
merged 16 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ def __init__(
self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True

# avoid unnecessary warnings that min_length is more than max_new_tokens
input_length = input_ids.shape[-1]
min_new_tokens = self.generation_config.min_new_tokens if self.generation_config.min_new_tokens else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we guaranteed the self.generation_config has this attribute?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the GenerationConfig usually initiates those to 0 or None, if not indicated by the user. So, we check if it is not None for min_new_tokens and then set min_length to maximum (code says min, i'll fix it) between user-defined value or the default 0.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts AFAIK only Whisper (and perhaps other audio models?) uses attributes that may not exist in a generation_config, it is a fairly regular object with everything initialized in __init__ :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, and can we ever expect to have whisper generate configs being used here, or does the model always just use it's own custom generation code?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this checking for None-ness? Or can it be False? Otherwise, defaulting to 0 if it's 0 is a superfluous check

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Noneness, I'll specify it as not None

self.min_length = min(self.generation_config.min_length, input_length + min_new_tokens)
self.generation_config.min_length = 0
self.generation_config.min_new_tokens = None

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Expand All @@ -175,6 +182,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
min_new_tokens = min(max_new_tokens, self.min_length - new_cur_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this result in negative values?

If I've generated more tokens than self.min_length i.e. new_cur_len > self.min_length

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, I've missed it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use a class attribute and not the generation config, as for max_length

Copy link
Member Author

@zucchini-nlp zucchini-nlp Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is because for maximum length we deprecated generation_config.max_new_tokens, so we can use the only possible attribute for max length. Yet, for minimum length we have to attributes, both of which are equally valid. That's why in init we manually set the min_length by checking both attributes, if they are set by user.

EDIT: I just remembered why I did that wat. We have to set generation_config's min_length to 0 in init, that's required to avoid unnecessary warnings. That's why I saved it as class attribute. Otherwise, the generation woul receive kwargs like below and throw warnings
{"min_length"=20, "min_new_tokens"=5, "max_new_tokens=5}

@gante , btw, don't you think we can also deprecate min_new_token?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp it's the other way around, if anything we would want to deprecate the min_tokens argument/config option :) max_new_tokens and min_new_tokens are much more predictable from a user point of view, as the user doesn't need to be concerned with the input length. In the past, before max_new_tokens and min_new_tokens were introduced, we would often get issues from confused users.

Inside generate, however, it is much easier to use the total length for control (no need to track the input length). We set max_length from max_new_tokens when needed (here), perhaps we should do the same with min_length to simplify the procedure in this PR :)

Copy link
Member Author

@zucchini-nlp zucchini-nlp Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to make a separate method for all length related corrections. Added one more test for min length, same as we have for max length.

if max_new_tokens == 0:
return input_ids, None

Expand All @@ -195,6 +203,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,7 @@ def generate(
)

# 12. run assisted generate
result = self.assisted_decoding(
result = self._assisted_decoding(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, this was surely throwing a deprecation warning 👍

input_ids,
candidate_generator=candidate_generator,
do_sample=generation_config.do_sample,
Expand Down
22 changes: 22 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3252,6 +3252,28 @@ def test_default_max_length_warning(self):
model.generate(input_ids)
self.assertEqual(len(warning_list), 0)

def test_length_warning_assisted_generation(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should also be a test that min_new_tokens parameter behaves as expected, especially when max_new_tokens is also set and when it's not set at all i.e. goes to default value

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added one more test checking if the length is in range between min and max lengths

# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = tokenizer.eos_token_id
assistant.config.pad_token_id = tokenizer.eos_token_id

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)

# This should not raise any warning that min length is not feasible in candidate generation
with warnings.catch_warnings(record=True) as warning_list:
model.generate(
input_ids,
assistant_model=assistant,
min_new_tokens=10,
max_length=20,
)
self.assertEqual(len(warning_list), 0)

def test_model_kwarg_assisted_decoding_decoder_only(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
Expand Down
Loading