Skip to content

Commit

Permalink
fix num_assistant_tokens with heuristic schedule (#28759)
Browse files Browse the repository at this point in the history
* fix heuristic num_assistant_tokens_schedule

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update utils.py

check that candidate_generator.assistant_model exists since some some speculations (like ngram and PLD) don't have assistant_model attribute

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update tests/generation/test_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* make fixup

* merge conflict

* fix docstring

* make fixup

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
4 people authored Feb 16, 2024
1 parent 0eb4085 commit 258da40
Showing 4 changed files with 60 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
@@ -225,7 +225,10 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
# Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# cost of forecasting incorrect assistant tokens.
if self.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic":
if self.assistant_model.generation_config.num_assistant_tokens_schedule in {
"heuristic",
"heuristic_transient",
}:
if num_matches == int(self.num_assistant_tokens):
self.num_assistant_tokens += 2.0
else:
5 changes: 3 additions & 2 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
@@ -249,8 +249,9 @@ class GenerationConfig(PushToHubMixin):
num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`):
Defines the schedule at which max assistant tokens shall be changed during inference.
- `"_heuristic_`: When all _speculative_ tokens are correct, increase `num_assistant_tokens` by 2 else
reduce by 1
- `"heuristic"`: When all speculative tokens are correct, increase `num_assistant_tokens` by 2 else
reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model.
- `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call.
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
> Parameters specific to the caching mechanism:
7 changes: 7 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
@@ -4561,6 +4561,13 @@ def assisted_decoding(
if streamer is not None:
streamer.end()

if (
hasattr(candidate_generator, "assistant_model")
and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic"
):
candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
candidate_generator.num_assistant_tokens
)
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
46 changes: 46 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
@@ -3490,3 +3490,49 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None,
encoder_outputs=encoder_outputs,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())

def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self):
# This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.

prompt = "Alice and Bob"
checkpoint = "EleutherAI/pythia-160m-deduped"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt")

model = AutoModelForCausalLM.from_pretrained(checkpoint)

assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 5
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"
generation_kwargs = {
"eos_token_id": -1,
"max_new_tokens": 5,
"do_sample": False,
"assistant_model": assistant_model,
}
model.generate(**inputs, **generation_kwargs)
# update_candidate_strategy is called only once and therefore, assistant_model.generation_config.num_assistant_tokens should be either 4 or 7
self.assertTrue(assistant_model.generation_config.num_assistant_tokens in (4, 7))

def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule(self):
# This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.

prompt = "Alice and Bob"
checkpoint = "EleutherAI/pythia-160m-deduped"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt")

model = AutoModelForCausalLM.from_pretrained(checkpoint)

assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 5
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic_transient"
generation_kwargs = {
"eos_token_id": -1,
"max_new_tokens": 5,
"do_sample": False,
"assistant_model": assistant_model,
}
model.generate(**inputs, **generation_kwargs)
# update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5
self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5)

0 comments on commit 258da40

Please sign in to comment.