From b5289e2d5f7845cab55564339d91591faac9683b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 28 Nov 2023 18:28:25 +0000 Subject: [PATCH 1/6] MVP --- src/transformers/generation/candidates.py | 319 ++++++++++++++++++++++ src/transformers/generation/utils.py | 273 +++++------------- 2 files changed, 390 insertions(+), 202 deletions(-) create mode 100644 src/transformers/generation/candidates.py diff --git a/src/transformers/generation/candidates.py b/src/transformers/generation/candidates.py new file mode 100644 index 000000000000..9cccc81fed06 --- /dev/null +++ b/src/transformers/generation/candidates.py @@ -0,0 +1,319 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +import warnings +from typing import TYPE_CHECKING, Any, Dict, Optional, Union, List + +import torch + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + from .logits_process import LogitsProcessorList + + +class CandidateGenerator: + """Abstract base class for all candidate generators that can be applied during assisted generation.""" + + def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: + """ + Fetches the candidates to be tried for the current input. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + + Return: + `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by + the model. + """ + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`." + ) + + def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + """ + Updates the candidate generation strategy based on the outcomes. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each vocabulary when not using + beam search or log softmax for each vocabulary token when using beam search + num_matches (`int`): + The number of matches between the candidate sequences and the model predictions. + """ + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can call " + "`update_candidate_strategy`." + ) + + +class AssistedCandidateGenerator(CandidateGenerator): + """ + `CandidateGenerator` class to be used for assisted generation. This class generates candidates through the use of + a smaller model. Read the following blog post for more information: https://huggingface.co/blog/assisted-generation + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + assistant_model (`PreTrainedModel`): + The model to be used for generating candidates. This model should be smaller than the main model. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + model_kwargs (`Dict`): + The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant + model as well. + inputs_tensor (`torch.Tensor`, *optional*): + The model input tensor. In encoder-decoder models, this is the encoder input. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + """ + + def __init__( + self, + input_ids: torch.LongTensor, + assistant_model: "PreTrainedModel", + logits_processor: "LogitsProcessorList", + model_kwargs: Dict, + inputs_tensor: Optional[torch.Tensor] = None, + eos_token_id: Optional[Union[int, List[int]]] = None + ): + + self.assistant_model = assistant_model + + # Prepare the number of candidate tokens + if hasattr(assistant_model, "num_assistant_tokens"): + warnings.warn( + "Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be " + "removed in v4.37. Make sure to set `num_assistant_tokens` via the generation_config instead.", + FutureWarning, + ) + self.num_assistant_tokens = assistant_model.num_assistant_tokens + else: + self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens + + # Prepare the kwargs for the assistant model + assistant_kwargs = copy.deepcopy(model_kwargs) + if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs: + inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs( + inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs + ) + assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, assistant_kwargs, model_input_name + ) + self.assistant_kwargs = assistant_kwargs + + # Prepare assistant model's keys of inputs + if assistant_model.config.is_encoder_decoder: + # both are encoder-decoder + self.input_ids_key = "decoder_input_ids" + self.attention_key = "decoder_attention_mask" + elif "encoder_outputs" in assistant_kwargs: + # special case for encoder-decoder with decoder-only assistant (like DistilWhisper) + self.input_ids_key = "input_ids" + self.attention_key = "attention_mask" + self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get( + "decoder_attention_mask", + torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long), + ) + else: + # both are decoder-only + self.input_ids_key = "input_ids" + self.attention_key = "attention_mask" + + # Prepare other attributes + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + self.eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + self.logits_processor = logits_processor + + def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: + """ + Fetches the candidates to be tried for the current input. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + + Return: + `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. + """ + # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length + # (which implicitly contains the number of accepted candidates from the previous round) + has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None + if has_past_key_values: + new_cur_len = input_ids.shape[-1] + + new_cache_size = new_cur_len - 1 + self.assistant_kwargs["past_key_values"] = _crop_past_key_values( + self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 + ) # the assistant does not have the token after the last match, hence the -1 + + self.assistant_kwargs = _prepare_attention_mask( + self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder + ) + self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) + + # 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()` + # call if we decide to add `past_key_values` as a possible output of generate, as we need access to the + # assistant cache to secure strong speedups. + candidate_input_ids = input_ids + for _ in range(int(self.num_assistant_tokens)): + # 2.1 prepare assistant model inputs + assistant_inputs = self.assistant_model.prepare_inputs_for_generation( + candidate_input_ids, + **self.assistant_kwargs, + ) + + # 2.2. check if the input ids length is correct + has_past_key_values = assistant_inputs.get("past_key_values", None) is not None + if has_past_key_values and assistant_inputs[self.input_ids_key].shape[-1] not in (1, 2): + raise ValueError("The length of the input ids in assistant inputs should be 1 or 2") + + # 2.3. use the assistant model to obtain the next candidate logits + assistant_model_outputs = self.assistant_model(**assistant_inputs) + + # 2.4. greedily select the next candidate token + if len(self.logits_processor) > 0: + assistant_model_outputs.logits[:, -1, :] = self.logits_processor( + candidate_input_ids, assistant_model_outputs.logits[:, -1, :] + ) + new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) + candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) + + # 2.5. update assistant model inputs + if self.assistant_kwargs.get(self.attention_key, None) is not None: + mask = self.assistant_kwargs[self.attention_key] + self.assistant_kwargs[self.attention_key] = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1) + self.assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values + + # 2.6. stop assistant generation on EOS + if self.eos_token_id_tensor is not None: + last_assistant_token_is_eos = new_token.tile(self.eos_token_id_tensor.shape[0], 1) + last_assistant_token_is_eos = ( + ~last_assistant_token_is_eos.ne(self.eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() + ) + if last_assistant_token_is_eos: + break + + return candidate_input_ids + + def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + """ + Updates the candidate generation strategy based on the outcomes. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each vocabulary when not using + beam search or log softmax for each vocabulary token when using beam search + num_matches (`int`): + The number of matches between the candidate sequences and the model predictions. + """ + # Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, + # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the + # cost of forecasting incorrect assistant tokens. + if self.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic": + if num_matches == int(self.num_assistant_tokens): + self.num_assistant_tokens += 2.0 + else: + self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) + + +def _crop_past_key_values(model, past_key_values, maximum_length): + """Crops the past key values up to a certain maximum length.""" + new_past = [] + if model.config.is_encoder_decoder: + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length, :], + past_key_values[idx][1][:, :, :maximum_length, :], + past_key_values[idx][2], + past_key_values[idx][3], + ) + ) + past_key_values = tuple(new_past) + # bloom is special + elif "bloom" in model.__class__.__name__.lower() or ( + model.config.architectures is not None and "bloom" in model.config.architectures[0].lower() + ): + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length], + past_key_values[idx][1][:, :maximum_length, :], + ) + ) + past_key_values = tuple(new_past) + # gptbigcode is too + elif "gptbigcode" in model.__class__.__name__.lower() or ( + model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower() + ): + if model.config.multi_query: + for idx in range(len(past_key_values)): + past_key_values[idx] = past_key_values[idx][:, :maximum_length, :] + else: + for idx in range(len(past_key_values)): + past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :] + else: + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length, :], + past_key_values[idx][1][:, :, :maximum_length, :], + ) + ) + past_key_values = tuple(new_past) + return past_key_values + + +def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]: + """Expands or crops the model's mask for decoding purposes, to the defined length""" + + mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask" + if mask_key not in model_kwargs: + return model_kwargs + + mask = model_kwargs[mask_key] + mask_length_diff = new_length - mask.shape[1] + + if mask_length_diff < 0: + model_kwargs[mask_key] = mask[:, :mask_length_diff] + elif mask_length_diff > 0: + model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) + return model_kwargs + + +def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: + """Expands or crops the model's token_type_ids for decoding purposes, to the defined length""" + if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: + return model_kwargs + + token_type_ids = model_kwargs["token_type_ids"] + final_token_type = token_type_ids[:, -1].unsqueeze(-1) + type_length_diff = new_length - token_type_ids.shape[1] + + if type_length_diff < 0: + token_type_ids = token_type_ids[:, :type_length_diff] + elif type_length_diff > 0: + token_type_copies = final_token_type.repeat(1, type_length_diff) + model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) + return model_kwargs diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1d413b3ab443..02e0fbc7bc99 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -37,6 +37,13 @@ from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer +from .candidates import ( + AssistedCandidateGenerator, + CandidateGenerator, + _crop_past_key_values, + _prepare_attention_mask, + _prepare_token_type_ids, +) from .configuration_utils import GenerationConfig from .logits_process import ( EncoderNoRepeatNGramLogitsProcessor, @@ -889,6 +896,29 @@ def _reorder_cache(self, past_key_values, beam_idx): f" enable beam search for {self.__class__}" ) + def _get_candidate_generator( + self, + generation_config: GenerationConfig, + input_ids: torch.LongTensor, + inputs_tensor: torch.Tensor, + assistant_model: "PreTrainedModel", + logits_processor: LogitsProcessorList, + model_kwargs: Dict, + eos_token_id: Union[int, List[int]], + ) -> CandidateGenerator: + """ + Returns the candidate generator to be used in `assisted_generation` + """ + candidate_generator = AssistedCandidateGenerator( + input_ids=input_ids, + assistant_model=assistant_model, + logits_processor=logits_processor, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + eos_token_id=eos_token_id, + ) + return candidate_generator + def _get_logits_warper( self, generation_config: GenerationConfig, @@ -1671,36 +1701,21 @@ def generate( if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") - assistant_accepts_encoder_outputs = "encoder_outputs" in set( - inspect.signature(assistant_model.forward).parameters.keys() + # 11. Get the candidate generator, given the parameterization + candidate_generator = self._get_candidate_generator( + generation_config=generation_config, + input_ids=input_ids, + inputs_tensor=inputs_tensor, + assistant_model=assistant_model, + logits_processor=logits_processor, + model_kwargs=model_kwargs, + eos_token_id=generation_config.eos_token_id, ) - # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs - if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs: - assistant_model_kwargs = copy.deepcopy(model_kwargs) - inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs( - inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs - ) - assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, assistant_model_kwargs, model_input_name - ) - model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"] - - if ( - not assistant_model.config.is_encoder_decoder - and assistant_accepts_encoder_outputs - and "encoder_outputs" in model_kwargs - ): - # some assistants might be assymetric (many more enc layers than dec layers) - # encoder-decoder models that share the exact same encoder as the teacher - # in this case the assistant only needs to load the light-weight decoder, - # but still requires `encoder_outputs` to be passed - model_kwargs["assistant_encoder_outputs"] = model_kwargs["encoder_outputs"] - # 12. run assisted generate return self.assisted_decoding( input_ids, - assistant_model=assistant_model, + candidate_generator=candidate_generator, do_sample=generation_config.do_sample, logits_processor=logits_processor, logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, @@ -4378,7 +4393,8 @@ def constrained_beam_search( def assisted_decoding( self, input_ids: torch.LongTensor, - assistant_model: "PreTrainedModel", + assistant_model: Optional["PreTrainedModel"] = None, + candidate_generator: Optional["CandidateGenerator"] = None, do_sample: bool = False, logits_processor: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None, @@ -4395,12 +4411,13 @@ def assisted_decoding( ): r""" Generates sequences of token ids for models with a language modeling head using **greedy decoding** or - **sample** (depending on `do_sample`), assisted by a smaller model. Can be used for text-decoder, text-to-text, - speech-to-text, and vision-to-text models. + **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a + candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text + models. - In most cases, you do not need to call [`~generation.GenerationMixin.assisted_decoding`] directly. Use + In most cases, you do not need to call [`~generation.GenerationMixin.candidate_decoding`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). @@ -4409,6 +4426,9 @@ def assisted_decoding( Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. + candidate_generator (`CandidateGenerator`, *optional*): + A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For + more information, the documentation of [`CandidateGenerator`] should be read. assistant_model (`PreTrainedModel`, *optional*): An assistant model that can be used to accelerate generation. The assistant model must have the exact same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model @@ -4491,15 +4511,19 @@ def assisted_decoding( >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" - # Assistant: initialize assistant-related variables - if hasattr(assistant_model, "num_assistant_tokens"): + # handling deprecated arguments + if (assistant_model is None) == (candidate_generator is None): + raise ValueError("One (and only one) of `assistant_model` and `candidate_generator` should be defined.") + + if assistant_model is not None: + candidate_generator = AssistedCandidateGenerator( + input_ids, assistant_model, logits_processor, model_kwargs, eos_token_id + ) warnings.warn( - "Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be removed in v.37. Make sure to set `num_assistant_tokens` via the generation_config instead.", + "Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. " + "Pass the `candidate_generator` argument instead.", FutureWarning, ) - num_assistant_tokens = assistant_model.num_assistant_tokens - else: - num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() @@ -4538,27 +4562,6 @@ def assisted_decoding( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) - # prepare assistant model's keys of inputs - assistant_kwargs = copy.copy(model_kwargs) - if assistant_model.config.is_encoder_decoder: - # both are encoder-decoder - input_ids_key = "decoder_input_ids" - attention_key = "decoder_attention_mask" - assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs") - elif "assistant_encoder_outputs" in assistant_kwargs: - # special case for encoder-decoder with decoder-only assistant (like DistilWhisper) - input_ids_key = "input_ids" - attention_key = "attention_mask" - assistant_kwargs["attention_mask"] = assistant_kwargs.get( - "decoder_attention_mask", - torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long), - ) - assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs") - else: - # both are decoder-only - input_ids_key = "input_ids" - attention_key = "attention_mask" - # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) @@ -4577,54 +4580,18 @@ def assisted_decoding( if this_peer_finished_flag.item() == 0.0: break - # Assistant: main logic start cur_len = input_ids.shape[-1] - # 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a - # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we - # need access to the assistant cache to secure strong speedups. - candidate_input_ids = input_ids - for _ in range(int(num_assistant_tokens)): - # 1.1 prepare assistant model inputs - assistant_inputs = assistant_model.prepare_inputs_for_generation( - candidate_input_ids, - **assistant_kwargs, - ) - - # 1.2. check if the input ids length is correct - has_past_key_values = assistant_inputs.get("past_key_values", None) is not None - if has_past_key_values and assistant_inputs[input_ids_key].shape[-1] not in (1, 2): - raise ValueError("The length of the input ids in assistant inputs should be 1 or 2") - - # 1.3. use the assistant model to obtain the next candidate logits - assistant_model_outputs = assistant_model(**assistant_inputs) - - # 1.4. greedily select the next candidate token - if len(logits_processor) > 0: - assistant_model_outputs.logits[:, -1, :] = logits_processor( - candidate_input_ids, assistant_model_outputs.logits[:, -1, :] - ) - new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) - candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) - - # 1.5. update assistant model inputs - if assistant_kwargs.get(attention_key, None) is not None: - mask = assistant_kwargs[attention_key] - assistant_kwargs[attention_key] = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1) - assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values - - # 1.6. stop assistant generation on EOS - if eos_token_id_tensor is not None: - last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1) - last_assistant_token_is_eos = ( - ~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() - ) - if last_assistant_token_is_eos: - break - else: - last_assistant_token_is_eos = False - + # 1. Fetch candidate sequences from a `CandidateGenerator` + candidate_input_ids = candidate_generator.get_candidates(input_ids) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] + last_assistant_token_is_eos = ( + ~candidate_input_ids[:, -1] + .tile(eos_token_id_tensor.shape[0], 1) + .ne(eos_token_id_tensor.unsqueeze(1)) + .prod(dim=0) + .bool() + ) # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, @@ -4687,20 +4654,10 @@ def assisted_decoding( # 5.3. Discard past key values relative to unused assistant tokens new_cache_size = new_cur_len - 1 outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) - assistant_kwargs["past_key_values"] = _crop_past_key_values( - assistant_model, assistant_kwargs["past_key_values"], new_cache_size - 1 - ) # the assistant does not have the token after the last match, hence the -1 - - # 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, - # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the - # cost of forecasting incorrect assistant tokens. - if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic": - if n_matches == int(num_assistant_tokens): - num_assistant_tokens += 2.0 - else: - num_assistant_tokens = max(1.0, num_assistant_tokens - 1.0) - # Assistant: main logic end + # 6. Update the candidate generation strategy if needed + candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) + if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need @@ -4749,12 +4706,6 @@ def assisted_decoding( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) - # Update assistant_kwargs for the assistant's next round of generations - assistant_kwargs = _prepare_attention_mask( - assistant_kwargs, new_cur_len, assistant_model.config.is_encoder_decoder - ) - assistant_kwargs = _prepare_token_type_ids(assistant_kwargs, new_cur_len) - # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( @@ -4802,54 +4753,6 @@ def assisted_decoding( return input_ids -def _crop_past_key_values(model, past_key_values, maximum_length): - """Crops the past key values up to a certain maximum length.""" - new_past = [] - if model.config.is_encoder_decoder: - for idx in range(len(past_key_values)): - new_past.append( - ( - past_key_values[idx][0][:, :, :maximum_length, :], - past_key_values[idx][1][:, :, :maximum_length, :], - past_key_values[idx][2], - past_key_values[idx][3], - ) - ) - past_key_values = tuple(new_past) - # bloom is special - elif "bloom" in model.__class__.__name__.lower() or ( - model.config.architectures is not None and "bloom" in model.config.architectures[0].lower() - ): - for idx in range(len(past_key_values)): - new_past.append( - ( - past_key_values[idx][0][:, :, :maximum_length], - past_key_values[idx][1][:, :maximum_length, :], - ) - ) - past_key_values = tuple(new_past) - # gptbigcode is too - elif "gptbigcode" in model.__class__.__name__.lower() or ( - model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower() - ): - if model.config.multi_query: - for idx in range(len(past_key_values)): - past_key_values[idx] = past_key_values[idx][:, :maximum_length, :] - else: - for idx in range(len(past_key_values)): - past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :] - else: - for idx in range(len(past_key_values)): - new_past.append( - ( - past_key_values[idx][0][:, :, :maximum_length, :], - past_key_values[idx][1][:, :, :maximum_length, :], - ) - ) - past_key_values = tuple(new_past) - return past_key_values - - def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): """ Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple @@ -4932,37 +4835,3 @@ def _ranking_fast( contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] _, selected_idx = contrastive_score.max(dim=-1) # [B] return selected_idx - - -def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]: - """Expands or crops the model's mask for decoding purposes, to the defined length""" - - mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask" - if mask_key not in model_kwargs: - return model_kwargs - - mask = model_kwargs[mask_key] - mask_length_diff = new_length - mask.shape[1] - - if mask_length_diff < 0: - model_kwargs[mask_key] = mask[:, :mask_length_diff] - elif mask_length_diff > 0: - model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) - return model_kwargs - - -def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: - """Expands or crops the model's token_type_ids for decoding purposes, to the defined length""" - if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: - return model_kwargs - - token_type_ids = model_kwargs["token_type_ids"] - final_token_type = token_type_ids[:, -1].unsqueeze(-1) - type_length_diff = new_length - token_type_ids.shape[1] - - if type_length_diff < 0: - token_type_ids = token_type_ids[:, :type_length_diff] - elif type_length_diff > 0: - token_type_copies = final_token_type.repeat(1, type_length_diff) - model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) - return model_kwargs From f676845b527c9f7194a91f25572b164c3d1dc14d Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 29 Nov 2023 09:53:07 +0000 Subject: [PATCH 2/6] fix ci --- src/transformers/generation/candidates.py | 25 ++++++++++++++++------- src/transformers/generation/utils.py | 6 +++++- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/candidates.py b/src/transformers/generation/candidates.py index 9cccc81fed06..573cac3b6a2e 100644 --- a/src/transformers/generation/candidates.py +++ b/src/transformers/generation/candidates.py @@ -14,12 +14,12 @@ # limitations under the License. import copy -import inspect import warnings -from typing import TYPE_CHECKING, Any, Dict, Optional, Union, List +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch + if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel from .logits_process import LogitsProcessorList @@ -92,9 +92,8 @@ def __init__( logits_processor: "LogitsProcessorList", model_kwargs: Dict, inputs_tensor: Optional[torch.Tensor] = None, - eos_token_id: Optional[Union[int, List[int]]] = None + eos_token_id: Optional[Union[int, List[int]]] = None, ): - self.assistant_model = assistant_model # Prepare the number of candidate tokens @@ -109,7 +108,15 @@ def __init__( self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens # Prepare the kwargs for the assistant model - assistant_kwargs = copy.deepcopy(model_kwargs) + assistant_kwargs = {} + for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads + if key != "encoder_outputs": + assistant_kwargs[key] = ( + value.clone().detach() if isinstance(value, torch.Tensor) else copy.deepcopy(value) + ) + if "encoder_outputs" in model_kwargs: + assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"] + if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs: inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs( inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs @@ -140,7 +147,9 @@ def __init__( # Prepare other attributes if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - self.eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + self.eos_token_id_tensor = ( + torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + ) self.logits_processor = logits_processor def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: @@ -200,7 +209,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: # 2.5. update assistant model inputs if self.assistant_kwargs.get(self.attention_key, None) is not None: mask = self.assistant_kwargs[self.attention_key] - self.assistant_kwargs[self.attention_key] = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1) + self.assistant_kwargs[self.attention_key] = torch.cat( + [mask, mask.new_ones((mask.shape[0], 1))], dim=-1 + ) self.assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values # 2.6. stop assistant generation on EOS diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 02e0fbc7bc99..7768bf89eda6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4517,7 +4517,11 @@ def assisted_decoding( if assistant_model is not None: candidate_generator = AssistedCandidateGenerator( - input_ids, assistant_model, logits_processor, model_kwargs, eos_token_id + input_ids=input_ids, + assistant_model=assistant_model, + logits_processor=logits_processor, + model_kwargs=model_kwargs, + eos_token_id=eos_token_id, ) warnings.warn( "Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. " From 44d08443145065de38e3a870a03e3928a27838a4 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 29 Nov 2023 10:16:49 +0000 Subject: [PATCH 3/6] more ci --- src/transformers/generation/candidates.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/candidates.py b/src/transformers/generation/candidates.py index 573cac3b6a2e..7cceac3364af 100644 --- a/src/transformers/generation/candidates.py +++ b/src/transformers/generation/candidates.py @@ -110,20 +110,20 @@ def __init__( # Prepare the kwargs for the assistant model assistant_kwargs = {} for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads - if key != "encoder_outputs": - assistant_kwargs[key] = ( - value.clone().detach() if isinstance(value, torch.Tensor) else copy.deepcopy(value) - ) - if "encoder_outputs" in model_kwargs: - assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"] + if key not in ("encoder_outputs", "assistant_encoder_outputs"): + assistant_kwargs[key] = value.detach() if isinstance(value, torch.Tensor) else copy.deepcopy(value) - if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs: + if "assistant_encoder_outputs" in model_kwargs: + assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] + elif assistant_model.config.is_encoder_decoder: inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs( inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs ) assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, assistant_kwargs, model_input_name ) + elif "encoder_outputs" in model_kwargs: + assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"] self.assistant_kwargs = assistant_kwargs # Prepare assistant model's keys of inputs From be7266721a76f9e62284163a13101cac3021b502 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 29 Nov 2023 10:35:13 +0000 Subject: [PATCH 4/6] remove redundant kwarg --- src/transformers/generation/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7768bf89eda6..0f2ac3323ecd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -904,7 +904,6 @@ def _get_candidate_generator( assistant_model: "PreTrainedModel", logits_processor: LogitsProcessorList, model_kwargs: Dict, - eos_token_id: Union[int, List[int]], ) -> CandidateGenerator: """ Returns the candidate generator to be used in `assisted_generation` @@ -915,7 +914,7 @@ def _get_candidate_generator( logits_processor=logits_processor, model_kwargs=model_kwargs, inputs_tensor=inputs_tensor, - eos_token_id=eos_token_id, + eos_token_id=generation_config.eos_token_id, ) return candidate_generator @@ -1709,7 +1708,6 @@ def generate( assistant_model=assistant_model, logits_processor=logits_processor, model_kwargs=model_kwargs, - eos_token_id=generation_config.eos_token_id, ) # 12. run assisted generate From 2808cdf15294a8b6cf2217b53cd4970bd4801e7f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 11 Dec 2023 18:39:15 +0000 Subject: [PATCH 5/6] Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0f2ac3323ecd..106cd6e1b329 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4426,7 +4426,7 @@ def assisted_decoding( The sequence used as a prompt for the generation. candidate_generator (`CandidateGenerator`, *optional*): A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For - more information, the documentation of [`CandidateGenerator`] should be read. + more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function. assistant_model (`PreTrainedModel`, *optional*): An assistant model that can be used to accelerate generation. The assistant model must have the exact same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model From a57367b0dc8e0db834d9bb7b065e1db3c90dbe3d Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 11 Dec 2023 18:45:25 +0000 Subject: [PATCH 6/6] rename file --- .../generation/{candidates.py => candidate_generator.py} | 0 src/transformers/generation/utils.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename src/transformers/generation/{candidates.py => candidate_generator.py} (100%) diff --git a/src/transformers/generation/candidates.py b/src/transformers/generation/candidate_generator.py similarity index 100% rename from src/transformers/generation/candidates.py rename to src/transformers/generation/candidate_generator.py diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 106cd6e1b329..d7510951b116 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -37,7 +37,7 @@ from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer -from .candidates import ( +from .candidate_generator import ( AssistedCandidateGenerator, CandidateGenerator, _crop_past_key_values,