From 258da40efd0c148660a34821b05f995831cef9f9 Mon Sep 17 00:00:00 2001 From: Jonathan Mamou Date: Fri, 16 Feb 2024 13:44:58 +0200 Subject: [PATCH] fix num_assistant_tokens with heuristic schedule (#28759) * fix heuristic num_assistant_tokens_schedule * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante * Update utils.py check that candidate_generator.assistant_model exists since some some speculations (like ngram and PLD) don't have assistant_model attribute * Update src/transformers/generation/candidate_generator.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/generation/test_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * make fixup * merge conflict * fix docstring * make fixup --------- Co-authored-by: Joao Gante Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../generation/candidate_generator.py | 5 +- .../generation/configuration_utils.py | 5 +- src/transformers/generation/utils.py | 7 +++ tests/generation/test_utils.py | 46 +++++++++++++++++++ 4 files changed, 60 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 3bdd88300469..616afa193176 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -225,7 +225,10 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F # Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the # cost of forecasting incorrect assistant tokens. - if self.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic": + if self.assistant_model.generation_config.num_assistant_tokens_schedule in { + "heuristic", + "heuristic_transient", + }: if num_matches == int(self.num_assistant_tokens): self.num_assistant_tokens += 2.0 else: diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index ad8cfd796b4b..2af0232902bd 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -249,8 +249,9 @@ class GenerationConfig(PushToHubMixin): num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`): Defines the schedule at which max assistant tokens shall be changed during inference. - - `"_heuristic_`: When all _speculative_ tokens are correct, increase `num_assistant_tokens` by 2 else - reduce by 1 + - `"heuristic"`: When all speculative tokens are correct, increase `num_assistant_tokens` by 2 else + reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model. + - `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call. - `"constant"`: `num_assistant_tokens` stays unchanged during generation > Parameters specific to the caching mechanism: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f8fb086cba61..0c6740b32388 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4561,6 +4561,13 @@ def assisted_decoding( if streamer is not None: streamer.end() + if ( + hasattr(candidate_generator, "assistant_model") + and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic" + ): + candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( + candidate_generator.num_assistant_tokens + ) if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 18e7eb481fdb..b4e1a218a928 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3490,3 +3490,49 @@ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, encoder_outputs=encoder_outputs, ) self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + + def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self): + # This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly. + + prompt = "Alice and Bob" + checkpoint = "EleutherAI/pythia-160m-deduped" + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + inputs = tokenizer(prompt, return_tensors="pt") + + model = AutoModelForCausalLM.from_pretrained(checkpoint) + + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 5 + assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic" + generation_kwargs = { + "eos_token_id": -1, + "max_new_tokens": 5, + "do_sample": False, + "assistant_model": assistant_model, + } + model.generate(**inputs, **generation_kwargs) + # update_candidate_strategy is called only once and therefore, assistant_model.generation_config.num_assistant_tokens should be either 4 or 7 + self.assertTrue(assistant_model.generation_config.num_assistant_tokens in (4, 7)) + + def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule(self): + # This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly. + + prompt = "Alice and Bob" + checkpoint = "EleutherAI/pythia-160m-deduped" + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + inputs = tokenizer(prompt, return_tensors="pt") + + model = AutoModelForCausalLM.from_pretrained(checkpoint) + + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 5 + assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic_transient" + generation_kwargs = { + "eos_token_id": -1, + "max_new_tokens": 5, + "do_sample": False, + "assistant_model": assistant_model, + } + model.generate(**inputs, **generation_kwargs) + # update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5 + self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5)