Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix length related warnings in speculative decoding #29585

Merged
merged 16 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ def __init__(
self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True

# avoid unnecessary warnings that min_length is larger than max_new_tokens
self.main_model_min_length = self.generation_config.min_length
self.generation_config.min_length = 0
self.generation_config.min_new_tokens = None

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Expand All @@ -166,6 +171,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
if max_new_tokens == 0:
return input_ids, None

Expand All @@ -186,6 +192,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
Expand Down
79 changes: 60 additions & 19 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,56 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de
UserWarning,
)

def _prepare_generated_length(
self,
generation_config,
has_default_max_length,
has_default_min_length,
model_input_name,
input_ids_length,
inputs_tensor,
):
"""Prepared max and min length in generaion configs to avoid clashes between similar attributes"""

if generation_config.max_new_tokens is not None:
if not has_default_max_length and generation_config.max_length is not None:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_length

# if both `inputs_embeds` and `input_ids` are passed, we do not correct the length
# otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length``
elif (
model_input_name == "inputs_embeds"
and input_ids_length != inputs_tensor.shape[1]
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]

# same for min length
if generation_config.min_new_tokens is not None:
if not has_default_min_length:
logger.warning(
f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(="
f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.min_length = generation_config.min_new_tokens + input_ids_length

elif (
model_input_name == "inputs_embeds"
and input_ids_length != inputs_tensor.shape[1]
and not self.config.is_encoder_decoder
):
generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)

return generation_config

def _prepare_generation_config(
self, generation_config: GenerationConfig, **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
Expand Down Expand Up @@ -1418,24 +1468,15 @@ def generate(
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if generation_config.max_new_tokens is not None:
if not has_default_max_length and generation_config.max_length is not None:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_length

# otherwise the total length [inputs-embeds-len + new-tokens-len] will go beyond indicated `max_length``
elif (
model_input_name == "inputs_embeds"
and inputs_tensor.shape[:-1] != input_ids.shape
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]
generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)

if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static":
Expand Down Expand Up @@ -1511,7 +1552,7 @@ def generate(
)

# 12. run assisted generate
result = self.assisted_decoding(
result = self._assisted_decoding(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, this was surely throwing a deprecation warning 👍

input_ids,
candidate_generator=candidate_generator,
do_sample=generation_config.do_sample,
Expand Down
64 changes: 64 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1977,6 +1977,20 @@ def test_max_length_if_input_embeds(self):
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_length=max_length)
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])

def test_min_length_if_input_embeds(self):
# PT-only test: TF doesn't have StoppingCriteria
article = "Today a dragon flew over Paris."
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
inputs_embeds = model.get_input_embeddings()(input_ids)

min_length = 10
input_len = input_ids.shape[-1]
out_gen = model.generate(input_ids=input_ids, min_length=min_length)
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, min_length=min_length)
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])

def test_custom_stopping_criteria_overload_error(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
Expand Down Expand Up @@ -2539,6 +2553,56 @@ def test_default_max_length_warning(self):
model.generate(input_ids)
self.assertEqual(len(warning_list), 0)

def test_length_warning_assisted_generation(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should also be a test that min_new_tokens parameter behaves as expected, especially when max_new_tokens is also set and when it's not set at all i.e. goes to default value

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added one more test checking if the length is in range between min and max lengths

# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = tokenizer.eos_token_id
assistant.config.pad_token_id = tokenizer.eos_token_id

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)

# This should not raise any warning that min length is not feasible in candidate generation
with warnings.catch_warnings(record=True) as warning_list:
model.generate(
input_ids,
assistant_model=assistant,
min_new_tokens=10,
max_length=20,
)
self.assertEqual(len(warning_list), 0)

def test_generated_length_assisted_generation(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = tokenizer.eos_token_id
assistant.config.pad_token_id = tokenizer.eos_token_id

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
input_length = input_ids.shape[-1]

out = model.generate(
input_ids,
assistant_model=assistant,
min_new_tokens=10,
max_new_tokens=20,
)
self.assertTrue((10 + input_length) <= out.shape[-1] <= (20 + input_length))

out = model.generate(
input_ids,
assistant_model=assistant,
min_new_tokens=10,
)
self.assertTrue((input_length + 10) <= out.shape[-1] <= 20)

def test_model_kwarg_assisted_decoding_decoder_only(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
Expand Down
Loading