Skip to content

Commit

Permalink
VLMs: even more clean-up (#36249)
Browse files Browse the repository at this point in the history
* squash

* style
  • Loading branch information
zucchini-nlp authored Feb 21, 2025
1 parent e18f233 commit 14552cb
Show file tree
Hide file tree
Showing 12 changed files with 277 additions and 449 deletions.
59 changes: 14 additions & 45 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,54 +1651,23 @@ def prepare_inputs_for_generation(
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if inputs_embeds is not None and input_ids.shape[1] == 0:
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
else:
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
model_inputs = super().prepare_inputs_for_generation(
input_ids,
pixel_values=pixel_values,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
**kwargs,
)

if cache_position[0] == 0:
if cache_position[0] != 0:
# If we're in cached decoding stage, pixel values should be `None` because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values

model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
model_inputs["pixel_values"] = None

return model_inputs


Expand Down
31 changes: 31 additions & 0 deletions src/transformers/models/emu3/modeling_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1967,5 +1967,36 @@ def forward(

return outputs

def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model

model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
pixel_values=pixel_values,
use_cache=use_cache,
**kwargs,
)

if cache_position[0] != 0:
model_inputs["pixel_values"] = None

return model_inputs


__all__ = ["Emu3ForConditionalGeneration", "Emu3ForCausalLM", "Emu3TextModel", "Emu3PreTrainedModel", "Emu3VQVAE"]
31 changes: 31 additions & 0 deletions src/transformers/models/emu3/modular_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,37 @@ def forward(

return outputs

def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model

model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
pixel_values=pixel_values,
use_cache=use_cache,
**kwargs,
)

if cache_position[0] != 0:
model_inputs["pixel_values"] = None

return model_inputs


__all__ = [
"Emu3ForConditionalGeneration",
Expand Down
42 changes: 13 additions & 29 deletions src/transformers/models/fuyu/modeling_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,36 +345,20 @@ def prepare_inputs_for_generation(
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model

if past_key_values is not None:
input_ids = input_ids[:, -1:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1:]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

if image_patches_indices is not None:
model_inputs["image_patches_indices"] = image_patches_indices

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"image_patches_indices": image_patches_indices if past_key_values is None else None,
"image_patches": image_patches if past_key_values is None else None,
}
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
image_patches=image_patches,
image_patches_indices=image_patches_indices,
**kwargs,
)

if past_key_values is not None:
model_inputs["image_patches_indices"] = None
model_inputs["image_patches"] = None

return model_inputs

@staticmethod
Expand Down
70 changes: 20 additions & 50 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,63 +1667,33 @@ def prepare_inputs_for_generation(
):
# Overwritten -- custom processing based on `config.use_resampler`

model_inputs = {}
images_kwargs = {}
if image_hidden_states is not None:
if self.config.use_resampler:
model_inputs["perceiver_embeddings"] = image_hidden_states
images_kwargs["perceiver_embeddings"] = image_hidden_states
else:
model_inputs["image_encoder_embeddings"] = image_hidden_states
images_kwargs["image_encoder_embeddings"] = image_hidden_states
else:
model_inputs["pixel_values"] = pixel_values

# If we have cache: let's slice `input_ids` or `input embeds` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if inputs_embeds is not None:
if input_ids.shape[1] == 0:
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
else:
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
if image_attention_mask is not None:
image_attention_mask = image_attention_mask[:, -input_ids.shape[1] :]
images_kwargs["pixel_values"] = pixel_values
images_kwargs["interpolate_pos_encoding"] = kwargs.pop("interpolate_pos_encoding", False)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)

# If past_key_values are present then slice the postion ids for only only the unprocessed tokens.
if past_key_values:
if inputs_embeds is not None and input_ids.shape[1] == 0:
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
else:
position_ids = position_ids[:, -input_ids.shape[1] :]

# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs.update({"inputs_embeds": inputs_embeds, "input_ids": None})
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs.update(
{"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
)

model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": use_cache,
"cache_position": cache_position,
"position_ids": position_ids,
"attention_mask": attention_mask,
"image_attention_mask": image_attention_mask,
"interpolate_pos_encoding": kwargs.get("interpolate_pos_encoding", False),
}
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
image_attention_mask=image_attention_mask,
**images_kwargs,
**kwargs,
)

if image_attention_mask is not None and inputs_embeds is None:
seq_length = model_inputs["input_ids"].shape[1]
model_inputs["image_attention_mask"] = image_attention_mask[:, -seq_length:]

return model_inputs

def _update_model_kwargs_for_generation(
Expand Down
63 changes: 25 additions & 38 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,10 @@ def forward(self, image_hidden_states, attention_mask):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""


Expand Down Expand Up @@ -1334,6 +1338,7 @@ def forward(
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Idefics2BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -1443,6 +1448,7 @@ def forward(
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=return_dict,
)

Expand Down Expand Up @@ -1527,6 +1533,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> Union[Tuple, Idefics2CausalLMOutputWithPast]:
r"""
Expand Down Expand Up @@ -1603,6 +1610,7 @@ def forward(
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=return_dict,
)

Expand Down Expand Up @@ -1659,49 +1667,28 @@ def prepare_inputs_for_generation(
# Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
# precedence is moved to the model, we can remove this fn)

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_hidden_states=image_hidden_states,
logits_to_keep=logits_to_keep,
**kwargs,
)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
# but IDEFICS requires noth ids and embeds to be present
# but IDEFICS requires both ids and embeds to be present
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if logits_to_keep is not None:
model_inputs["logits_to_keep"] = logits_to_keep
model_inputs["input_ids"] = input_ids

if image_hidden_states is not None:
pixel_values = None
pixel_attention_mask = None
else:
pixel_values = pixel_values
pixel_attention_mask = pixel_attention_mask
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"pixel_attention_mask": pixel_attention_mask,
"image_hidden_states": image_hidden_states,
}
)
model_inputs["pixel_values"] = None
model_inputs["pixel_attention_mask"] = None

return model_inputs

def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
Expand Down
Loading

0 comments on commit 14552cb

Please sign in to comment.