diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py new file mode 100644 index 000000000000..7cceac3364af --- /dev/null +++ b/src/transformers/generation/candidate_generator.py @@ -0,0 +1,330 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import warnings +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import torch + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + from .logits_process import LogitsProcessorList + + +class CandidateGenerator: + """Abstract base class for all candidate generators that can be applied during assisted generation.""" + + def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: + """ + Fetches the candidates to be tried for the current input. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + + Return: + `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by + the model. + """ + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`." + ) + + def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + """ + Updates the candidate generation strategy based on the outcomes. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each vocabulary when not using + beam search or log softmax for each vocabulary token when using beam search + num_matches (`int`): + The number of matches between the candidate sequences and the model predictions. + """ + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can call " + "`update_candidate_strategy`." + ) + + +class AssistedCandidateGenerator(CandidateGenerator): + """ + `CandidateGenerator` class to be used for assisted generation. This class generates candidates through the use of + a smaller model. Read the following blog post for more information: https://huggingface.co/blog/assisted-generation + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + assistant_model (`PreTrainedModel`): + The model to be used for generating candidates. This model should be smaller than the main model. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + model_kwargs (`Dict`): + The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant + model as well. + inputs_tensor (`torch.Tensor`, *optional*): + The model input tensor. In encoder-decoder models, this is the encoder input. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + """ + + def __init__( + self, + input_ids: torch.LongTensor, + assistant_model: "PreTrainedModel", + logits_processor: "LogitsProcessorList", + model_kwargs: Dict, + inputs_tensor: Optional[torch.Tensor] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + ): + self.assistant_model = assistant_model + + # Prepare the number of candidate tokens + if hasattr(assistant_model, "num_assistant_tokens"): + warnings.warn( + "Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be " + "removed in v4.37. Make sure to set `num_assistant_tokens` via the generation_config instead.", + FutureWarning, + ) + self.num_assistant_tokens = assistant_model.num_assistant_tokens + else: + self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens + + # Prepare the kwargs for the assistant model + assistant_kwargs = {} + for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads + if key not in ("encoder_outputs", "assistant_encoder_outputs"): + assistant_kwargs[key] = value.detach() if isinstance(value, torch.Tensor) else copy.deepcopy(value) + + if "assistant_encoder_outputs" in model_kwargs: + assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] + elif assistant_model.config.is_encoder_decoder: + inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs( + inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs + ) + assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, assistant_kwargs, model_input_name + ) + elif "encoder_outputs" in model_kwargs: + assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"] + self.assistant_kwargs = assistant_kwargs + + # Prepare assistant model's keys of inputs + if assistant_model.config.is_encoder_decoder: + # both are encoder-decoder + self.input_ids_key = "decoder_input_ids" + self.attention_key = "decoder_attention_mask" + elif "encoder_outputs" in assistant_kwargs: + # special case for encoder-decoder with decoder-only assistant (like DistilWhisper) + self.input_ids_key = "input_ids" + self.attention_key = "attention_mask" + self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get( + "decoder_attention_mask", + torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long), + ) + else: + # both are decoder-only + self.input_ids_key = "input_ids" + self.attention_key = "attention_mask" + + # Prepare other attributes + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + self.eos_token_id_tensor = ( + torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + ) + self.logits_processor = logits_processor + + def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: + """ + Fetches the candidates to be tried for the current input. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + + Return: + `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. + """ + # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length + # (which implicitly contains the number of accepted candidates from the previous round) + has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None + if has_past_key_values: + new_cur_len = input_ids.shape[-1] + + new_cache_size = new_cur_len - 1 + self.assistant_kwargs["past_key_values"] = _crop_past_key_values( + self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 + ) # the assistant does not have the token after the last match, hence the -1 + + self.assistant_kwargs = _prepare_attention_mask( + self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder + ) + self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) + + # 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()` + # call if we decide to add `past_key_values` as a possible output of generate, as we need access to the + # assistant cache to secure strong speedups. + candidate_input_ids = input_ids + for _ in range(int(self.num_assistant_tokens)): + # 2.1 prepare assistant model inputs + assistant_inputs = self.assistant_model.prepare_inputs_for_generation( + candidate_input_ids, + **self.assistant_kwargs, + ) + + # 2.2. check if the input ids length is correct + has_past_key_values = assistant_inputs.get("past_key_values", None) is not None + if has_past_key_values and assistant_inputs[self.input_ids_key].shape[-1] not in (1, 2): + raise ValueError("The length of the input ids in assistant inputs should be 1 or 2") + + # 2.3. use the assistant model to obtain the next candidate logits + assistant_model_outputs = self.assistant_model(**assistant_inputs) + + # 2.4. greedily select the next candidate token + if len(self.logits_processor) > 0: + assistant_model_outputs.logits[:, -1, :] = self.logits_processor( + candidate_input_ids, assistant_model_outputs.logits[:, -1, :] + ) + new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) + candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) + + # 2.5. update assistant model inputs + if self.assistant_kwargs.get(self.attention_key, None) is not None: + mask = self.assistant_kwargs[self.attention_key] + self.assistant_kwargs[self.attention_key] = torch.cat( + [mask, mask.new_ones((mask.shape[0], 1))], dim=-1 + ) + self.assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values + + # 2.6. stop assistant generation on EOS + if self.eos_token_id_tensor is not None: + last_assistant_token_is_eos = new_token.tile(self.eos_token_id_tensor.shape[0], 1) + last_assistant_token_is_eos = ( + ~last_assistant_token_is_eos.ne(self.eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() + ) + if last_assistant_token_is_eos: + break + + return candidate_input_ids + + def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + """ + Updates the candidate generation strategy based on the outcomes. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each vocabulary when not using + beam search or log softmax for each vocabulary token when using beam search + num_matches (`int`): + The number of matches between the candidate sequences and the model predictions. + """ + # Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, + # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the + # cost of forecasting incorrect assistant tokens. + if self.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic": + if num_matches == int(self.num_assistant_tokens): + self.num_assistant_tokens += 2.0 + else: + self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) + + +def _crop_past_key_values(model, past_key_values, maximum_length): + """Crops the past key values up to a certain maximum length.""" + new_past = [] + if model.config.is_encoder_decoder: + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length, :], + past_key_values[idx][1][:, :, :maximum_length, :], + past_key_values[idx][2], + past_key_values[idx][3], + ) + ) + past_key_values = tuple(new_past) + # bloom is special + elif "bloom" in model.__class__.__name__.lower() or ( + model.config.architectures is not None and "bloom" in model.config.architectures[0].lower() + ): + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length], + past_key_values[idx][1][:, :maximum_length, :], + ) + ) + past_key_values = tuple(new_past) + # gptbigcode is too + elif "gptbigcode" in model.__class__.__name__.lower() or ( + model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower() + ): + if model.config.multi_query: + for idx in range(len(past_key_values)): + past_key_values[idx] = past_key_values[idx][:, :maximum_length, :] + else: + for idx in range(len(past_key_values)): + past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :] + else: + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length, :], + past_key_values[idx][1][:, :, :maximum_length, :], + ) + ) + past_key_values = tuple(new_past) + return past_key_values + + +def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]: + """Expands or crops the model's mask for decoding purposes, to the defined length""" + + mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask" + if mask_key not in model_kwargs: + return model_kwargs + + mask = model_kwargs[mask_key] + mask_length_diff = new_length - mask.shape[1] + + if mask_length_diff < 0: + model_kwargs[mask_key] = mask[:, :mask_length_diff] + elif mask_length_diff > 0: + model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) + return model_kwargs + + +def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: + """Expands or crops the model's token_type_ids for decoding purposes, to the defined length""" + if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: + return model_kwargs + + token_type_ids = model_kwargs["token_type_ids"] + final_token_type = token_type_ids[:, -1].unsqueeze(-1) + type_length_diff = new_length - token_type_ids.shape[1] + + if type_length_diff < 0: + token_type_ids = token_type_ids[:, :type_length_diff] + elif type_length_diff > 0: + token_type_copies = final_token_type.repeat(1, type_length_diff) + model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) + return model_kwargs diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1d413b3ab443..d7510951b116 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -37,6 +37,13 @@ from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer +from .candidate_generator import ( + AssistedCandidateGenerator, + CandidateGenerator, + _crop_past_key_values, + _prepare_attention_mask, + _prepare_token_type_ids, +) from .configuration_utils import GenerationConfig from .logits_process import ( EncoderNoRepeatNGramLogitsProcessor, @@ -889,6 +896,28 @@ def _reorder_cache(self, past_key_values, beam_idx): f" enable beam search for {self.__class__}" ) + def _get_candidate_generator( + self, + generation_config: GenerationConfig, + input_ids: torch.LongTensor, + inputs_tensor: torch.Tensor, + assistant_model: "PreTrainedModel", + logits_processor: LogitsProcessorList, + model_kwargs: Dict, + ) -> CandidateGenerator: + """ + Returns the candidate generator to be used in `assisted_generation` + """ + candidate_generator = AssistedCandidateGenerator( + input_ids=input_ids, + assistant_model=assistant_model, + logits_processor=logits_processor, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + eos_token_id=generation_config.eos_token_id, + ) + return candidate_generator + def _get_logits_warper( self, generation_config: GenerationConfig, @@ -1671,36 +1700,20 @@ def generate( if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") - assistant_accepts_encoder_outputs = "encoder_outputs" in set( - inspect.signature(assistant_model.forward).parameters.keys() + # 11. Get the candidate generator, given the parameterization + candidate_generator = self._get_candidate_generator( + generation_config=generation_config, + input_ids=input_ids, + inputs_tensor=inputs_tensor, + assistant_model=assistant_model, + logits_processor=logits_processor, + model_kwargs=model_kwargs, ) - # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs - if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs: - assistant_model_kwargs = copy.deepcopy(model_kwargs) - inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs( - inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs - ) - assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, assistant_model_kwargs, model_input_name - ) - model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"] - - if ( - not assistant_model.config.is_encoder_decoder - and assistant_accepts_encoder_outputs - and "encoder_outputs" in model_kwargs - ): - # some assistants might be assymetric (many more enc layers than dec layers) - # encoder-decoder models that share the exact same encoder as the teacher - # in this case the assistant only needs to load the light-weight decoder, - # but still requires `encoder_outputs` to be passed - model_kwargs["assistant_encoder_outputs"] = model_kwargs["encoder_outputs"] - # 12. run assisted generate return self.assisted_decoding( input_ids, - assistant_model=assistant_model, + candidate_generator=candidate_generator, do_sample=generation_config.do_sample, logits_processor=logits_processor, logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, @@ -4378,7 +4391,8 @@ def constrained_beam_search( def assisted_decoding( self, input_ids: torch.LongTensor, - assistant_model: "PreTrainedModel", + assistant_model: Optional["PreTrainedModel"] = None, + candidate_generator: Optional["CandidateGenerator"] = None, do_sample: bool = False, logits_processor: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None, @@ -4395,12 +4409,13 @@ def assisted_decoding( ): r""" Generates sequences of token ids for models with a language modeling head using **greedy decoding** or - **sample** (depending on `do_sample`), assisted by a smaller model. Can be used for text-decoder, text-to-text, - speech-to-text, and vision-to-text models. + **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a + candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text + models. - In most cases, you do not need to call [`~generation.GenerationMixin.assisted_decoding`] directly. Use + In most cases, you do not need to call [`~generation.GenerationMixin.candidate_decoding`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). @@ -4409,6 +4424,9 @@ def assisted_decoding( Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. + candidate_generator (`CandidateGenerator`, *optional*): + A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For + more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function. assistant_model (`PreTrainedModel`, *optional*): An assistant model that can be used to accelerate generation. The assistant model must have the exact same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model @@ -4491,15 +4509,23 @@ def assisted_decoding( >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" - # Assistant: initialize assistant-related variables - if hasattr(assistant_model, "num_assistant_tokens"): + # handling deprecated arguments + if (assistant_model is None) == (candidate_generator is None): + raise ValueError("One (and only one) of `assistant_model` and `candidate_generator` should be defined.") + + if assistant_model is not None: + candidate_generator = AssistedCandidateGenerator( + input_ids=input_ids, + assistant_model=assistant_model, + logits_processor=logits_processor, + model_kwargs=model_kwargs, + eos_token_id=eos_token_id, + ) warnings.warn( - "Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be removed in v.37. Make sure to set `num_assistant_tokens` via the generation_config instead.", + "Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. " + "Pass the `candidate_generator` argument instead.", FutureWarning, ) - num_assistant_tokens = assistant_model.num_assistant_tokens - else: - num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() @@ -4538,27 +4564,6 @@ def assisted_decoding( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) - # prepare assistant model's keys of inputs - assistant_kwargs = copy.copy(model_kwargs) - if assistant_model.config.is_encoder_decoder: - # both are encoder-decoder - input_ids_key = "decoder_input_ids" - attention_key = "decoder_attention_mask" - assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs") - elif "assistant_encoder_outputs" in assistant_kwargs: - # special case for encoder-decoder with decoder-only assistant (like DistilWhisper) - input_ids_key = "input_ids" - attention_key = "attention_mask" - assistant_kwargs["attention_mask"] = assistant_kwargs.get( - "decoder_attention_mask", - torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long), - ) - assistant_kwargs["encoder_outputs"] = assistant_kwargs.pop("assistant_encoder_outputs") - else: - # both are decoder-only - input_ids_key = "input_ids" - attention_key = "attention_mask" - # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) @@ -4577,54 +4582,18 @@ def assisted_decoding( if this_peer_finished_flag.item() == 0.0: break - # Assistant: main logic start cur_len = input_ids.shape[-1] - # 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a - # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we - # need access to the assistant cache to secure strong speedups. - candidate_input_ids = input_ids - for _ in range(int(num_assistant_tokens)): - # 1.1 prepare assistant model inputs - assistant_inputs = assistant_model.prepare_inputs_for_generation( - candidate_input_ids, - **assistant_kwargs, - ) - - # 1.2. check if the input ids length is correct - has_past_key_values = assistant_inputs.get("past_key_values", None) is not None - if has_past_key_values and assistant_inputs[input_ids_key].shape[-1] not in (1, 2): - raise ValueError("The length of the input ids in assistant inputs should be 1 or 2") - - # 1.3. use the assistant model to obtain the next candidate logits - assistant_model_outputs = assistant_model(**assistant_inputs) - - # 1.4. greedily select the next candidate token - if len(logits_processor) > 0: - assistant_model_outputs.logits[:, -1, :] = logits_processor( - candidate_input_ids, assistant_model_outputs.logits[:, -1, :] - ) - new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) - candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) - - # 1.5. update assistant model inputs - if assistant_kwargs.get(attention_key, None) is not None: - mask = assistant_kwargs[attention_key] - assistant_kwargs[attention_key] = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1) - assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values - - # 1.6. stop assistant generation on EOS - if eos_token_id_tensor is not None: - last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1) - last_assistant_token_is_eos = ( - ~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() - ) - if last_assistant_token_is_eos: - break - else: - last_assistant_token_is_eos = False - + # 1. Fetch candidate sequences from a `CandidateGenerator` + candidate_input_ids = candidate_generator.get_candidates(input_ids) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] + last_assistant_token_is_eos = ( + ~candidate_input_ids[:, -1] + .tile(eos_token_id_tensor.shape[0], 1) + .ne(eos_token_id_tensor.unsqueeze(1)) + .prod(dim=0) + .bool() + ) # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, @@ -4687,20 +4656,10 @@ def assisted_decoding( # 5.3. Discard past key values relative to unused assistant tokens new_cache_size = new_cur_len - 1 outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) - assistant_kwargs["past_key_values"] = _crop_past_key_values( - assistant_model, assistant_kwargs["past_key_values"], new_cache_size - 1 - ) # the assistant does not have the token after the last match, hence the -1 - - # 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, - # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the - # cost of forecasting incorrect assistant tokens. - if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic": - if n_matches == int(num_assistant_tokens): - num_assistant_tokens += 2.0 - else: - num_assistant_tokens = max(1.0, num_assistant_tokens - 1.0) - # Assistant: main logic end + # 6. Update the candidate generation strategy if needed + candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) + if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need @@ -4749,12 +4708,6 @@ def assisted_decoding( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) - # Update assistant_kwargs for the assistant's next round of generations - assistant_kwargs = _prepare_attention_mask( - assistant_kwargs, new_cur_len, assistant_model.config.is_encoder_decoder - ) - assistant_kwargs = _prepare_token_type_ids(assistant_kwargs, new_cur_len) - # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( @@ -4802,54 +4755,6 @@ def assisted_decoding( return input_ids -def _crop_past_key_values(model, past_key_values, maximum_length): - """Crops the past key values up to a certain maximum length.""" - new_past = [] - if model.config.is_encoder_decoder: - for idx in range(len(past_key_values)): - new_past.append( - ( - past_key_values[idx][0][:, :, :maximum_length, :], - past_key_values[idx][1][:, :, :maximum_length, :], - past_key_values[idx][2], - past_key_values[idx][3], - ) - ) - past_key_values = tuple(new_past) - # bloom is special - elif "bloom" in model.__class__.__name__.lower() or ( - model.config.architectures is not None and "bloom" in model.config.architectures[0].lower() - ): - for idx in range(len(past_key_values)): - new_past.append( - ( - past_key_values[idx][0][:, :, :maximum_length], - past_key_values[idx][1][:, :maximum_length, :], - ) - ) - past_key_values = tuple(new_past) - # gptbigcode is too - elif "gptbigcode" in model.__class__.__name__.lower() or ( - model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower() - ): - if model.config.multi_query: - for idx in range(len(past_key_values)): - past_key_values[idx] = past_key_values[idx][:, :maximum_length, :] - else: - for idx in range(len(past_key_values)): - past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :] - else: - for idx in range(len(past_key_values)): - new_past.append( - ( - past_key_values[idx][0][:, :, :maximum_length, :], - past_key_values[idx][1][:, :, :maximum_length, :], - ) - ) - past_key_values = tuple(new_past) - return past_key_values - - def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): """ Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple @@ -4932,37 +4837,3 @@ def _ranking_fast( contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] _, selected_idx = contrastive_score.max(dim=-1) # [B] return selected_idx - - -def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]: - """Expands or crops the model's mask for decoding purposes, to the defined length""" - - mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask" - if mask_key not in model_kwargs: - return model_kwargs - - mask = model_kwargs[mask_key] - mask_length_diff = new_length - mask.shape[1] - - if mask_length_diff < 0: - model_kwargs[mask_key] = mask[:, :mask_length_diff] - elif mask_length_diff > 0: - model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) - return model_kwargs - - -def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: - """Expands or crops the model's token_type_ids for decoding purposes, to the defined length""" - if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: - return model_kwargs - - token_type_ids = model_kwargs["token_type_ids"] - final_token_type = token_type_ids[:, -1].unsqueeze(-1) - type_length_diff = new_length - token_type_ids.shape[1] - - if type_length_diff < 0: - token_type_ids = token_type_ids[:, :type_length_diff] - elif type_length_diff > 0: - token_type_copies = final_token_type.repeat(1, type_length_diff) - model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) - return model_kwargs