diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 65322e236ca0..538a9aaf7fcb 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1651,54 +1651,23 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and - # generate the first token for each sequence. Later use the generated Input ids for continuation. - if past_key_values is not None: - if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 - inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] - elif ( - inputs_embeds is not None # Exception 1 - or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 - ): - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - if inputs_embeds is not None and input_ids.shape[1] == 0: - position_ids = position_ids[:, -inputs_embeds.shape[1] :] - else: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + model_inputs = super().prepare_inputs_for_generation( + input_ids, + pixel_values=pixel_values, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + **kwargs, + ) - if cache_position[0] == 0: + if cache_position[0] != 0: # If we're in cached decoding stage, pixel values should be `None` because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model - model_inputs["pixel_values"] = pixel_values - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - } - ) + model_inputs["pixel_values"] = None + return model_inputs diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 7d31b8d3d323..b3331cf1293a 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1967,5 +1967,36 @@ def forward( return outputs + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + use_cache=use_cache, + **kwargs, + ) + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + + return model_inputs + __all__ = ["Emu3ForConditionalGeneration", "Emu3ForCausalLM", "Emu3TextModel", "Emu3PreTrainedModel", "Emu3VQVAE"] diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index d645a88baf38..cdb4ee5d6fa9 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -1275,6 +1275,37 @@ def forward( return outputs + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + use_cache=use_cache, + **kwargs, + ) + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + + return model_inputs + __all__ = [ "Emu3ForConditionalGeneration", diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index d19c48b76293..3f6d31aab388 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -345,36 +345,20 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1:] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - if image_patches_indices is not None: - model_inputs["image_patches_indices"] = image_patches_indices - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "image_patches_indices": image_patches_indices if past_key_values is None else None, - "image_patches": image_patches if past_key_values is None else None, - } + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + image_patches=image_patches, + image_patches_indices=image_patches_indices, + **kwargs, ) + + if past_key_values is not None: + model_inputs["image_patches_indices"] = None + model_inputs["image_patches"] = None + return model_inputs @staticmethod diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 8c6f1f059bfc..7a476a13103c 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1667,63 +1667,33 @@ def prepare_inputs_for_generation( ): # Overwritten -- custom processing based on `config.use_resampler` - model_inputs = {} + images_kwargs = {} if image_hidden_states is not None: if self.config.use_resampler: - model_inputs["perceiver_embeddings"] = image_hidden_states + images_kwargs["perceiver_embeddings"] = image_hidden_states else: - model_inputs["image_encoder_embeddings"] = image_hidden_states + images_kwargs["image_encoder_embeddings"] = image_hidden_states else: - model_inputs["pixel_values"] = pixel_values - - # If we have cache: let's slice `input_ids` or `input embeds` through `cache_position`, to keep only the unprocessed tokens - if past_key_values is not None: - if inputs_embeds is not None: - if input_ids.shape[1] == 0: - inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] - else: - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: - input_ids = input_ids[:, cache_position] - if image_attention_mask is not None: - image_attention_mask = image_attention_mask[:, -input_ids.shape[1] :] + images_kwargs["pixel_values"] = pixel_values + images_kwargs["interpolate_pos_encoding"] = kwargs.pop("interpolate_pos_encoding", False) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - - # If past_key_values are present then slice the postion ids for only only the unprocessed tokens. - if past_key_values: - if inputs_embeds is not None and input_ids.shape[1] == 0: - position_ids = position_ids[:, -inputs_embeds.shape[1] :] - else: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: - model_inputs.update({"inputs_embeds": inputs_embeds, "input_ids": None}) - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs.update( - {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - ) - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": use_cache, - "cache_position": cache_position, - "position_ids": position_ids, - "attention_mask": attention_mask, - "image_attention_mask": image_attention_mask, - "interpolate_pos_encoding": kwargs.get("interpolate_pos_encoding", False), - } + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + image_attention_mask=image_attention_mask, + **images_kwargs, + **kwargs, ) + if image_attention_mask is not None and inputs_embeds is None: + seq_length = model_inputs["input_ids"].shape[1] + model_inputs["image_attention_mask"] = image_attention_mask[:, -seq_length:] + return model_inputs def _update_model_kwargs_for_generation( diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index eb676c295a4f..2c44f998574d 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1226,6 +1226,10 @@ def forward(self, image_hidden_states, attention_mask): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ @@ -1334,6 +1338,7 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Idefics2BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1443,6 +1448,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + cache_position=cache_position, return_dict=return_dict, ) @@ -1527,6 +1533,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, Idefics2CausalLMOutputWithPast]: r""" @@ -1603,6 +1610,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + cache_position=cache_position, return_dict=return_dict, ) @@ -1659,49 +1667,28 @@ def prepare_inputs_for_generation( # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take # precedence is moved to the model, we can remove this fn) - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - if past_key_values is not None: - if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: - input_ids = input_ids[:, cache_position] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + logits_to_keep=logits_to_keep, + **kwargs, + ) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - # but IDEFICS requires noth ids and embeds to be present + # but IDEFICS requires both ids and embeds to be present if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - - if logits_to_keep is not None: - model_inputs["logits_to_keep"] = logits_to_keep + model_inputs["input_ids"] = input_ids if image_hidden_states is not None: - pixel_values = None - pixel_attention_mask = None - else: - pixel_values = pixel_values - pixel_attention_mask = pixel_attention_mask - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "pixel_attention_mask": pixel_attention_mask, - "image_hidden_states": image_hidden_states, - } - ) + model_inputs["pixel_values"] = None + model_inputs["pixel_attention_mask"] = None + return model_inputs def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 31d97ffa70ec..251e11067fff 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -812,6 +812,10 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ @@ -928,6 +932,7 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Idefics3BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1024,6 +1029,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + cache_position=cache_position, return_dict=return_dict, ) @@ -1110,7 +1116,9 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, Idefics3CausalLMOutputWithPast]: r""" Args: @@ -1119,6 +1127,13 @@ def forward( config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: Example: @@ -1193,11 +1208,14 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + cache_position=cache_position, return_dict=return_dict, ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1248,49 +1266,28 @@ def prepare_inputs_for_generation( # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take # precedence is moved to the model, we can remove this fn) - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - if past_key_values is not None: - if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: - input_ids = input_ids[:, cache_position] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + logits_to_keep=logits_to_keep, + **kwargs, + ) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - # but IDEFICS requires noth ids and embeds to be present + # but IDEFICS requires both ids and embeds to be present if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - - if logits_to_keep is not None: - model_inputs["logits_to_keep"] = logits_to_keep + model_inputs["input_ids"] = input_ids if image_hidden_states is not None: - pixel_values = None - pixel_attention_mask = None - else: - pixel_values = pixel_values - pixel_attention_mask = pixel_attention_mask - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "pixel_attention_mask": pixel_attention_mask, - "image_hidden_states": image_hidden_states, - } - ) + model_inputs["pixel_values"] = None + model_inputs["pixel_attention_mask"] = None + return model_inputs # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration._update_model_kwargs_for_generation diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 55277cd5a193..13c0273b1724 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -1699,31 +1699,18 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - position_ids = None - if cache_position is None: - past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + # Kosmos2 has offset for position ids, so we need to create them correctly + position_ids = create_position_ids_from_input_ids( + input_ids, + padding_idx=self.config.pad_token_id, + past_key_values_length=0, + ) if past_key_values is not None: - position_ids = create_position_ids_from_input_ids( - input_ids, - padding_idx=self.config.pad_token_id, - past_key_values_length=0, - ) - - if input_ids.shape[1] != cache_position.shape[0]: - input_ids = input_ids[:, cache_position] - position_ids = position_ids[:, -input_ids.shape[1] :] - image_embeds = None image_embeds_position_mask = None + # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation) elif image_embeds_position_mask is not None: - # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation) batch_size, seq_len = input_ids.size() mask_len = image_embeds_position_mask.size()[-1] image_embeds_position_mask = torch.cat( @@ -1734,15 +1721,19 @@ def prepare_inputs_for_generation( dim=1, ) - return { - "input_ids": input_ids, - "image_embeds": image_embeds, - "image_embeds_position_mask": image_embeds_position_mask, - "past_key_values": past_key_values, - "attention_mask": attention_mask, - "position_ids": position_ids, - "use_cache": use_cache, - } + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + image_embeds=image_embeds, + image_embeds_position_mask=image_embeds_position_mask, + use_cache=use_cache, + position_ids=position_ids, + cache_position=cache_position, + **model_kwargs, + ) + + return model_inputs @staticmethod # Copied from transformers.models.umt5.modeling_umt5.UMT5ForConditionalGeneration._reorder_cache diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index c21264a39804..e392634e5972 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -45,7 +45,6 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, - is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1922,68 +1921,29 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and - # generate the first token for each sequence. Later use the generated Input ids for continuation. - if past_key_values is not None: - if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 - inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] - elif ( - inputs_embeds is not None # Exception 1 - or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 - ): - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - - if cache_position[0] != 0: - pixel_values = None - pixel_values_videos = None - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = inputs_embeds.shape - device = inputs_embeds.device - else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + # Qwen2-5-VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None - attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - config=self.config, - past_key_values=past_key_values, - ) + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "pixel_values_videos": pixel_values_videos, - "image_grid_thw": image_grid_thw, - "video_grid_thw": video_grid_thw, - "cache_position": cache_position, - "second_per_grid_ts": second_per_grid_ts, - } - ) return model_inputs def _get_image_nums_and_video_nums( diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 7dd8a91a2028..2d8695b5a407 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -44,13 +44,12 @@ from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor from ...activations import ACT2FN -from ...cache_utils import StaticCache from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, VideoInput from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...utils import is_flash_attn_2_available, is_torchdynamo_compiling, logging +from ...utils import is_flash_attn_2_available, logging if is_flash_attn_2_available(): @@ -788,68 +787,29 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and - # generate the first token for each sequence. Later use the generated Input ids for continuation. - if past_key_values is not None: - if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 - inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] - elif ( - inputs_embeds is not None # Exception 1 - or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 - ): - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - - if cache_position[0] != 0: - pixel_values = None - pixel_values_videos = None + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + # Qwen2-5-VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = inputs_embeds.shape - device = inputs_embeds.device - else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device - - attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - config=self.config, - past_key_values=past_key_values, - ) + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "pixel_values_videos": pixel_values_videos, - "image_grid_thw": image_grid_thw, - "video_grid_thw": video_grid_thw, - "cache_position": cache_position, - "second_per_grid_ts": second_per_grid_ts, - } - ) return model_inputs diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 7a22d75b6ae6..fd494f04782f 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -41,7 +41,6 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, - is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1791,67 +1790,28 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and - # generate the first token for each sequence. Later use the generated Input ids for continuation. - if past_key_values is not None: - if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 - inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] - elif ( - inputs_embeds is not None # Exception 1 - or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 - ): - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - - if cache_position[0] != 0: - pixel_values = None - pixel_values_videos = None - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + **kwargs, + ) - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = inputs_embeds.shape - device = inputs_embeds.device - else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + # Qwen2-VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None - attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, - device=device, - cache_position=cache_position, - batch_size=batch_size, - config=self.config, - past_key_values=past_key_values, - ) + if model_inputs["cache_position"][0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "pixel_values_videos": pixel_values_videos, - "image_grid_thw": image_grid_thw, - "video_grid_thw": video_grid_thw, - "cache_position": cache_position, - } - ) return model_inputs def _get_image_nums_and_video_nums( diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 53c0093c5c35..c34aaaa62c99 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -733,6 +733,10 @@ def forward(self, image_hidden_states): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ @@ -1083,7 +1087,9 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, ) -> Union[Tuple, SmolVLMCausalLMOutputWithPast]: r""" Args: @@ -1151,11 +1157,14 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + cache_position=cache_position, return_dict=return_dict, ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1205,49 +1214,28 @@ def prepare_inputs_for_generation( # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take # precedence is moved to the model, we can remove this fn) - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - if past_key_values is not None: - if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: - input_ids = input_ids[:, cache_position] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + logits_to_keep=logits_to_keep, + **kwargs, + ) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - # but IDEFICS requires noth ids and embeds to be present + # but IDEFICS requires both ids and embeds to be present if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - - if logits_to_keep is not None: - model_inputs["logits_to_keep"] = logits_to_keep + model_inputs["input_ids"] = input_ids if image_hidden_states is not None: - pixel_values = None - pixel_attention_mask = None - else: - pixel_values = pixel_values - pixel_attention_mask = pixel_attention_mask - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "pixel_attention_mask": pixel_attention_mask, - "image_hidden_states": image_hidden_states, - } - ) + model_inputs["pixel_values"] = None + model_inputs["pixel_attention_mask"] = None + return model_inputs def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):