Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand inputs in processors for VLMs #30962

Merged
merged 44 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
050657f
let it be
zucchini-nlp May 20, 2024
a67087e
draft
zucchini-nlp May 22, 2024
1e2b873
should not have changed
zucchini-nlp May 22, 2024
70145d4
add warnings
zucchini-nlp May 29, 2024
16a6787
Merge remote-tracking branch 'upstream/main' into vlm_processors
zucchini-nlp May 29, 2024
8472035
fix & add tests
zucchini-nlp May 29, 2024
13af9e8
fix tests
zucchini-nlp May 29, 2024
41d086f
ipnuts embeds cannot be passed with pixels
zucchini-nlp May 29, 2024
bf59ed6
more updates
zucchini-nlp Jun 7, 2024
020e7ed
paligemma ready!
zucchini-nlp Jun 10, 2024
3e0455c
minor typos
zucchini-nlp Jun 10, 2024
674f16e
update blip-2
zucchini-nlp Jun 10, 2024
42ae646
fix tests & raise error
zucchini-nlp Jun 10, 2024
b5259f2
Merge branch 'main' into vlm_processors
zucchini-nlp Jun 10, 2024
a6c50de
docstring
zucchini-nlp Jun 10, 2024
4766e2e
add blip2 test
zucchini-nlp Jun 10, 2024
d46df90
Merge branch 'main' into vlm_processors
zucchini-nlp Jun 10, 2024
f74297b
tmp
zucchini-nlp Jun 17, 2024
5fc8565
add image seq length to config
zucchini-nlp Jun 18, 2024
1b4674a
update docstring
zucchini-nlp Jun 18, 2024
c3c130b
Merge branch 'main' into vlm_processors
zucchini-nlp Jun 18, 2024
8438875
delete
zucchini-nlp Jun 18, 2024
bf9e637
fix tests
zucchini-nlp Jun 18, 2024
db1fa4f
fix blip
zucchini-nlp Jun 18, 2024
246b06a
fix paligemma
zucchini-nlp Jun 21, 2024
222bf9a
merge `main`
zucchini-nlp Jul 18, 2024
5486215
out-of-place scatter
zucchini-nlp Jul 18, 2024
78c4484
add llava-next-video
zucchini-nlp Jul 18, 2024
d60624e
Update src/transformers/models/blip_2/modeling_blip_2.py
zucchini-nlp Aug 5, 2024
1973b39
remove tmp
zucchini-nlp Aug 5, 2024
a6e380f
merge `main`
zucchini-nlp Aug 5, 2024
8e88d8b
codestyle
zucchini-nlp Aug 5, 2024
689eed9
nits
zucchini-nlp Aug 6, 2024
28e8054
more nits
zucchini-nlp Aug 6, 2024
637e514
remove overriding in tests
zucchini-nlp Aug 6, 2024
be939d8
comprehension when merging video
zucchini-nlp Aug 6, 2024
232eb7c
fix-copies
zucchini-nlp Aug 6, 2024
385a617
revert changes for embeds test
zucchini-nlp Aug 6, 2024
4831a7e
fix tests after making comprehension
zucchini-nlp Aug 6, 2024
85fbff9
Update src/transformers/models/blip_2/processing_blip_2.py
zucchini-nlp Aug 8, 2024
119178f
Update src/transformers/models/blip_2/processing_blip_2.py
zucchini-nlp Aug 8, 2024
2451911
more updates
zucchini-nlp Aug 8, 2024
414031e
fix tests
zucchini-nlp Aug 8, 2024
8cfad20
Merge remote-tracking branch 'upstream/main' into vlm_processors
zucchini-nlp Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 43 additions & 181 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from ... import PreTrainedModel
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...modeling_outputs import ModelOutput
from ...utils import (
add_start_docstrings,
Expand Down Expand Up @@ -274,84 +273,6 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
self.vocab_size = model_embeds.num_embeddings
return model_embeds

def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)

# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)

# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]

# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)

if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)

final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]

final_embedding[batch_indices, indices_to_mask] = 0

if labels is None:
final_labels = None

return final_embedding, final_attention_mask, final_labels, position_ids

@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Expand All @@ -369,6 +290,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -406,6 +328,7 @@ def forward(
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
Expand All @@ -415,63 +338,28 @@ def forward(
else self.config.vision_feature_select_strategy
)

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)

if inputs_embeds is None:
# 1. Extra the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)

# 2. Merge text and images
if pixel_values is not None and input_ids.shape[1] != 1:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]

if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
)

image_features = self.multi_modal_projector(selected_image_feature)
inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)

# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]

# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]

extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)

# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
if pixel_values is not None:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")

image_features = self.multi_modal_projector(selected_image_feature)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = image_features.flatten()

outputs = self.language_model(
attention_mask=attention_mask,
Expand All @@ -482,6 +370,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)

logits = outputs[0]
Expand Down Expand Up @@ -515,56 +404,29 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
attention_mask=None,
cache_position=None,
**kwargs,
):
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
cache_length = past_length = past_key_values[0][0].shape[2]

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
elif self.config.image_token_index in input_ids:
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
**kwargs,
)

# If we're in cached decoding stage, pixel values is None because input ids do not contain special image token anymore
# Otherwise we need pixel values passed by the user
if past_key_values is None:
model_inputs["pixel_values"] = pixel_values

return model_inputs

def _reorder_cache(self, *args, **kwargs):
Expand Down
27 changes: 25 additions & 2 deletions src/transformers/models/llava/processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class LlavaProcessor(ProcessorMixin):

def __init__(self, image_processor=None, tokenizer=None):
super().__init__(image_processor, tokenizer)
self.image_size = self.image_processor.size["shortest_edge"]
self.patch_size = 14 # self.image_processor.path_size
self.image_token = "<image>"
self.vision_feature_select_strategy = "default" # self.image_processor.vision_feature_select_strategy

def __call__(
self,
Expand Down Expand Up @@ -105,10 +109,29 @@ def __call__(
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
else:
pixel_values = None

if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

# Replace the image token with the expanded image token sequence
num_image_tokens = (self.image_size // self.patch_size) ** 2 + 1
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1

prompt_strings = []
for sample in text:
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)

text_inputs = self.tokenizer(
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
prompt_strings,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)

return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})

# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
Expand Down
Loading
Loading