From 050657f75119c92e6081b7f774a185a7d98e79cf Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 20 May 2024 12:54:09 +0200 Subject: [PATCH 01/37] let it be --- .../models/llava/processing_llava.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index ff010f74428a..61d06d3ad251 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -41,12 +41,13 @@ class LlavaProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - image_processor_class = "AutoImageProcessor" - tokenizer_class = "AutoTokenizer" + image_processor_class = "CLIPImageProcessor" + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") def __init__(self, image_processor=None, tokenizer=None): super().__init__(image_processor, tokenizer) + def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, @@ -105,12 +106,29 @@ def __call__( pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] else: pixel_values = None + + # Replace the image token with the expanded image token sequence + image_str = self.image_token.content + num_image_tokens = self._get_number_of_features() + prompt_strings = [] + for sample in text: + sample = sample.replace(image_str, image_str * num_image_tokens) + prompt_strings.append(sample) + text_inputs = self.tokenizer( text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length ) return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) + def _get_number_of_features(self) -> int: + image_size = self.config.vision_config.image_size + patch_size = self.config.vision_config.patch_size + + num_patches = (image_size // patch_size) ** 2 + num_features = num_patches + 1 + return num_features + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): """ From a67087e41cf95fb2f272e08d9185c1745a0479af Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 22 May 2024 13:36:45 +0200 Subject: [PATCH 02/37] draft --- .../models/llava/modeling_llava.py | 224 +++------------- .../models/llava/processing_llava.py | 37 +-- .../models/llava_next/modeling_llava_next.py | 208 +++++---------- .../llava_next/processing_llava_next.py | 70 ++++- .../video_llava/modeling_video_llava.py | 245 +++--------------- .../video_llava/processing_video_llava.py | 28 +- .../models/vipllava/modeling_vipllava.py | 211 +++------------ 7 files changed, 305 insertions(+), 718 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 0426776beed1..e0df904896b5 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -23,7 +23,6 @@ from ... import PreTrainedModel from ...activations import ACT2FN -from ...cache_utils import Cache from ...modeling_outputs import ModelOutput from ...utils import ( add_start_docstrings, @@ -274,84 +273,6 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m self.vocab_size = model_embeds.num_embeddings return model_embeds - def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): - num_images, num_image_patches, embed_dim = image_features.shape - batch_size, sequence_length = input_ids.shape - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) - # 1. Create a mask to know where special image tokens are - special_image_token_mask = input_ids == self.config.image_token_index - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - # Compute the maximum embed dimension - max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length - batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) - - # 2. Compute the positions where text should be written - # Calculate new positions for text tokens in merged image-text sequence. - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. - # `torch.cumsum` computes how each image token shifts subsequent text token positions. - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. - new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 - nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] - if left_padding: - new_token_positions += nb_image_pad[:, None] # offset for left padding - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] - - # 3. Create the full embedding, already padded to the maximum position - final_embedding = torch.zeros( - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - final_attention_mask = torch.zeros( - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device - ) - if labels is not None: - final_labels = torch.full( - (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device - ) - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually - # set the corresponding tensors into their correct target device. - target_device = inputs_embeds.device - batch_indices, non_image_indices, text_to_overwrite = ( - batch_indices.to(target_device), - non_image_indices.to(target_device), - text_to_overwrite.to(target_device), - ) - attention_mask = attention_mask.to(target_device) - - # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] - if labels is not None: - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) - image_to_overwrite = torch.full( - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device - ) - image_to_overwrite[batch_indices, text_to_overwrite] = False - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) - - if image_to_overwrite.sum() != image_features.shape[:-1].numel(): - raise ValueError( - f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" - f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." - ) - - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) - final_attention_mask |= image_to_overwrite - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - - # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. - batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) - indices_to_mask = new_token_positions[batch_indices, pad_indices] - - final_embedding[batch_indices, indices_to_mask] = 0 - - if labels is None: - final_labels = None - - return final_embedding, final_attention_mask, final_labels, position_ids - @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -369,6 +290,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: r""" Args: @@ -406,6 +328,7 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) @@ -415,63 +338,28 @@ def forward( else self.config.vision_feature_select_strategy ) + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + if inputs_embeds is None: - # 1. Extra the input embeddings inputs_embeds = self.get_input_embeddings()(input_ids) - # 2. Merge text and images - if pixel_values is not None and input_ids.shape[1] != 1: - image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) - # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. - selected_image_feature = image_outputs.hidden_states[vision_feature_layer] - - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - else: - raise ValueError( - f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" - ) - - image_features = self.multi_modal_projector(selected_image_feature) - inputs_embeds = inputs_embeds.to(image_features.dtype) - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels - ) - - # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of - # generation with cache - elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + if pixel_values is not None: + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") + + image_features = self.multi_modal_projector(selected_image_feature) + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = image_features.flatten() outputs = self.language_model( attention_mask=attention_mask, @@ -482,6 +370,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = outputs[0] @@ -515,56 +404,29 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + **kwargs, ): - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - else: - cache_length = past_length = past_key_values[0][0].shape[2] - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - elif self.config.image_token_index in input_ids: - input_ids = input_ids[:, input_ids.shape[1] - 1 :] - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "pixel_values": pixel_values, - } + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + **kwargs, ) + + # If we're in cached decoding stage, pixel values is None because input ids do not contain special image token anymore + # Otherwise we need pixel values passed by the user + if past_key_values is None: + model_inputs["pixel_values"] = pixel_values + return model_inputs def _reorder_cache(self, *args, **kwargs): diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 61d06d3ad251..57dd0b69479f 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -46,7 +46,10 @@ class LlavaProcessor(ProcessorMixin): def __init__(self, image_processor=None, tokenizer=None): super().__init__(image_processor, tokenizer) - + self.image_size = self.image_processor.size["shortest_edge"] + self.patch_size = 14 # self.image_processor.path_size + self.image_token = "" + self.vision_feature_select_strategy = "default" # self.image_processor.vision_feature_select_strategy def __call__( self, @@ -106,29 +109,31 @@ def __call__( pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] else: pixel_values = None - + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + # Replace the image token with the expanded image token sequence - image_str = self.image_token.content - num_image_tokens = self._get_number_of_features() + num_image_tokens = (self.image_size // self.patch_size) ** 2 + 1 + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + prompt_strings = [] for sample in text: - sample = sample.replace(image_str, image_str * num_image_tokens) + sample = sample.replace(self.image_token, self.image_token * num_image_tokens) prompt_strings.append(sample) - + text_inputs = self.tokenizer( - text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length + prompt_strings, + return_tensors=return_tensors, + padding=padding, + truncation=truncation, + max_length=max_length, ) - return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) - def _get_number_of_features(self) -> int: - image_size = self.config.vision_config.image_size - patch_size = self.config.vision_config.patch_size - - num_patches = (image_size // patch_size) ** 2 - num_features = num_patches + 1 - return num_features - # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index c052af3b3c8a..e5e0675e962a 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -25,7 +25,6 @@ from ... import PreTrainedModel from ...activations import ACT2FN -from ...cache_utils import Cache from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput from ...utils import ( @@ -129,6 +128,7 @@ def unpad_image(tensor, original_size): original_aspect_ratio = original_width / original_height current_aspect_ratio = current_width / current_height + print(original_height, original_width, current_height, current_width) if original_aspect_ratio > current_aspect_ratio: scale_factor = current_width / original_width @@ -700,6 +700,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]: r""" Args: @@ -737,6 +738,7 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) @@ -746,104 +748,52 @@ def forward( else self.config.vision_feature_select_strategy ) - if inputs_embeds is None: - # 1. Extract the input embeddings - # In case image_token_index is not in the embeddings (extra token but embedding don't have it) - for_inputs_embeds_ids = input_ids.clone() - for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0 - inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids) - - # 2. Merge text and images - if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0: - # ! infer image_num_patches from image_sizes - image_num_patches = [ - image_size_to_num_patches( - image_size=imsize, - grid_pinpoints=self.config.image_grid_pinpoints, - patch_size=self.config.vision_config.image_size, - ) - for imsize in image_sizes - ] - # figure out if pixel_values is concatenated or stacked - if pixel_values.dim() == 5: - # stacking when input is (batch_size, num_patches, num_channels, height, width) - _pixel_values_list = [ - pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches) - ] - pixel_values = torch.cat(_pixel_values_list, dim=0) - elif pixel_values.dim() != 4: - # otherwise has to be stacked from list of (num_patches, num_channels, height, width) - raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") - - image_features = self.vision_tower(pixel_values, output_hidden_states=True) - selected_image_feature = image_features.hidden_states[vision_feature_layer] - - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - - image_features = self.multi_modal_projector(selected_image_feature) - - image_features = torch.split(image_features, image_num_patches, dim=0) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - image_newline=self.image_newline, - ) - - inputs_embeds = inputs_embeds.to(image_features.dtype) - inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features( - image_features, - feature_lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - ) + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) - # pixel_values is not None but is empty ---> text only cases - elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0: - # there are no images - pass - - # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of - # generation with cache - elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None and pixel_values.size(0) > 0: + # ! infer image_num_patches from image_sizes + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.config.image_grid_pinpoints, + patch_size=self.config.vision_config.image_size, ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + for imsize in image_sizes + ] + # figure out if pixel_values is concatenated or stacked + if pixel_values.dim() == 5: + # stacking when input is (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [ + pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches) + ] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + # otherwise has to be stacked from list of (num_patches, num_channels, height, width) + raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") + + image_features = self.vision_tower(pixel_values, output_hidden_states=True) + selected_image_feature = image_features.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + image_features = self.multi_modal_projector(selected_image_feature) + image_features = torch.split(image_features, image_num_patches, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + image_newline=self.image_newline, + ) + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds[special_image_mask] = image_features.flatten() outputs = self.language_model( attention_mask=attention_mask, @@ -854,6 +804,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = outputs[0] @@ -894,57 +845,24 @@ def prepare_inputs_for_generation( pixel_values=None, image_sizes=None, attention_mask=None, + cache_position=None, **kwargs, ): - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - else: - cache_length = past_length = past_key_values[0][0].shape[2] - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - elif self.config.image_token_index in input_ids: - input_ids = input_ids[:, input_ids.shape[1] - 1 :] - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "image_sizes": image_sizes, - } + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + **kwargs, ) + + # If we're in cached decoding stage, pixel values is None because input ids do not contain special image token anymore + # Otherwise we need pixel values passed by the user + if past_key_values is None: + model_inputs["pixel_values"] = pixel_values + model_inputs["image_sizes"] = image_sizes + return model_inputs # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._reorder_cache diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index 8a4b76e9c68a..d60615fa0ed2 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -19,6 +19,7 @@ from typing import List, Optional, Union from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils import select_best_resolution from ...image_utils import ImageInput from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy @@ -46,6 +47,11 @@ class LlavaNextProcessor(ProcessorMixin): def __init__(self, image_processor=None, tokenizer=None): super().__init__(image_processor, tokenizer) + self.image_size = self.image_processor.size["shortest_edge"] + self.patch_size = 14 # self.image_processor.path_size + self.image_token = "" + self.vision_feature_select_strategy = "default" # self.image_processor.vision_feature_select_strategy + def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], @@ -108,12 +114,74 @@ def __call__( image_inputs = self.image_processor(images, do_pad=do_pad, return_tensors=return_tensors) else: image_inputs = {} + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + if not image_inputs: + prompt_strings = text + else: + image_sizes = image_inputs["image_sizes"] + prompt_strings = [] + for image_size, sample in zip(image_sizes, text): + # Replace the image token with the expanded image token sequence + height, width = image_size + num_image_tokens = self._get_number_of_features(height, width) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + + sample = sample.replace(self.image_token, self.image_token * num_image_tokens) + prompt_strings.append(sample) + text_inputs = self.tokenizer( - text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length + prompt_strings, + return_tensors=return_tensors, + padding=padding, + truncation=truncation, + max_length=max_length, ) return BatchFeature(data={**text_inputs, **image_inputs}) + def _get_number_of_features(self, height: int, width: int) -> int: + image_grid_pinpoints = self.image_processor.image_grid_pinpoints + image_size = self.image_size + patch_size = self.patch_size + + npatches = image_size // patch_size + + height_best_resolution, width_best_resolution = select_best_resolution([height, width], image_grid_pinpoints) + num_patch_height, num_patch_width = height_best_resolution // image_size, width_best_resolution // image_size + + unpadded_features, newline_features = self._get_unpadded_features( + height, width, npatches, num_patch_height, num_patch_width + ) + # The base patch covers the entire image (+1 for the CLS) + base_features = npatches**2 + 1 + num_image_tokens = unpadded_features + newline_features + base_features + return num_image_tokens + + def _get_unpadded_features(self, height, width, npatches, num_patch_height, num_patch_width): + current_width = npatches * num_patch_height + current_height = npatches * num_patch_width + + original_aspect_ratio = width / height + current_aspect_ratio = current_width / current_height + if original_aspect_ratio > current_aspect_ratio: + new_height = (height * current_width) // width + padding = (current_height - new_height) // 2 + current_height -= padding * 2 + else: + new_width = (width * current_height) // height + padding = (current_width - new_width) // 2 + current_width -= padding * 2 + + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 7fbd142fbe85..f45bc814fcf2 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -22,7 +22,6 @@ from ... import PreTrainedModel from ...activations import ACT2FN -from ...cache_utils import Cache from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput from ...utils import ( add_start_docstrings, @@ -279,87 +278,6 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m self.vocab_size = model_embeds.num_embeddings return model_embeds - def _merge_input_ids_with_visual_features( - self, visual_features, inputs_embeds, input_ids, attention_mask, labels, num_frames=1 - ): - num_images, num_image_patches, embed_dim = visual_features.shape - batch_size, sequence_length = input_ids.shape - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) - special_vision_token = self.config.video_token_index if num_frames == 8 else self.config.image_token_index - - # 1. Create a mask to know where special image tokens are - special_image_token_mask = input_ids == special_vision_token - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - # Compute the maximum embed dimension - max_seq_len = (num_special_image_tokens.max() * (num_image_patches * num_frames - 1)) + sequence_length - batch_indices, non_image_indices = torch.where(input_ids != special_vision_token) - - # 2. Compute the positions where text should be written - # Calculate new positions for text tokens in merged image-text sequence. - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. - # `torch.cumsum` computes how each image token shifts subsequent text token positions. - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. - new_token_positions = ( - torch.cumsum((special_image_token_mask * (num_image_patches * num_frames - 1) + 1), dim=-1) - 1 - ) - nb_image_pad = max_seq_len - 1 - new_token_positions[:, -1] - if left_padding: - new_token_positions += nb_image_pad[:, None] # offset for left padding - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] - - # 3. Create the full embedding, already padded to the maximum position - # expand input ids so that the second "merge" with videos does not fail - final_embedding = torch.zeros( - batch_size, max_seq_len, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - final_attention_mask = torch.zeros( - batch_size, max_seq_len, dtype=attention_mask.dtype, device=inputs_embeds.device - ) - final_input_ids = torch.full( - (batch_size, max_seq_len), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device - ) - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually - # set the corresponding tensors into their correct target device. - target_device = inputs_embeds.device - batch_indices, non_image_indices, text_to_overwrite = ( - batch_indices.to(target_device), - non_image_indices.to(target_device), - text_to_overwrite.to(target_device), - ) - attention_mask = attention_mask.to(target_device) - - # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] - final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices] - if labels is not None: - final_labels = torch.full( - (batch_size, max_seq_len), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device - ) - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - else: - final_labels = None - - # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling - image_to_overwrite = torch.full((batch_size, max_seq_len), True, dtype=torch.bool, device=inputs_embeds.device) - image_to_overwrite[batch_indices, text_to_overwrite] = False - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) - - if image_to_overwrite.sum() != visual_features.shape[:-1].numel(): - visual_type = "videos" if num_frames == 8 else "images" - num_images //= num_frames - raise ValueError( - f"The input provided to the model are wrong. The number of {visual_type} tokens is {torch.sum(special_image_token_mask)} while" - f" the number of {visual_type} given to the model is {num_images}. This prevents correct indexing and breaks batch generation." - ) - - final_embedding[image_to_overwrite] = visual_features.contiguous().reshape(-1, embed_dim).to(target_device) - final_attention_mask |= image_to_overwrite - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - - return final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids - def _get_vision_features( self, pixel_values_images: Optional[torch.FloatTensor] = None, @@ -415,6 +333,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, VideoLlavaCausalLMOutputWithPast]: r""" Args: @@ -495,6 +414,7 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) @@ -504,79 +424,35 @@ def forward( else self.config.vision_feature_select_strategy ) + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + if inputs_embeds is None: - # 1. Extra the input embeddings inputs_embeds = self.get_input_embeddings()(input_ids) - # 2. Merge text and images - if (pixel_values_images is not None or pixel_values_videos is not None) and input_ids.shape[1] != 1: - image_outputs, video_outputs = self._get_vision_features( - pixel_values_images=pixel_values_images, - pixel_values_videos=pixel_values_videos, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - ) + if pixel_values_images is not None or pixel_values_videos is not None: + image_outputs, video_outputs = self._get_vision_features( + pixel_values_images=pixel_values_images, + pixel_values_videos=pixel_values_videos, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) - # first add image embeds where possible, then expand again and add video embeds - if image_outputs is not None: - visual_features = self.multi_modal_projector(image_outputs) - ( - inputs_embeds, - attention_mask, - labels, - position_ids, - input_ids, - ) = self._merge_input_ids_with_visual_features( - visual_features, inputs_embeds, input_ids, attention_mask, labels - ) - if video_outputs is not None: - visual_features = self.multi_modal_projector(video_outputs) - ( - inputs_embeds, - attention_mask, - labels, - position_ids, - _, - ) = self._merge_input_ids_with_visual_features( - visual_features, - inputs_embeds, - input_ids, - attention_mask, - labels, - num_frames=8, - ) - else: - # In case input_ids.shape[1] == 1 & past_key_values != None, we are in the case of - # generation with cache - if past_key_values is not None and input_ids.shape[1] == 1: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + if image_outputs is not None: + image_features = self.multi_modal_projector(image_outputs) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + inputs_embeds[special_image_mask] = image_features.flatten() + if video_outputs is not None: + video_features = self.multi_modal_projector(video_outputs) + print(video_features.shape) + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + inputs_embeds[special_image_mask] = video_features.flatten() outputs = self.language_model( attention_mask=attention_mask, @@ -587,6 +463,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = outputs[0] @@ -627,60 +504,24 @@ def prepare_inputs_for_generation( pixel_values_images=None, pixel_values_videos=None, attention_mask=None, + cache_position=None, **kwargs, ): - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - else: - cache_length = past_length = past_key_values[0][0].shape[2] - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - else: - input_ids = input_ids[:, input_ids.shape[1] - 1 :] - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] - - pixel_values_videos = None - pixel_values_images = None - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "pixel_values_videos": pixel_values_videos, - "pixel_values_images": pixel_values_images, - } + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + **kwargs, ) + + # If we're in cached decoding stage, pixel values is None because input ids do not contain special image token anymore + # Otherwise we need pixel values passed by the user + if past_key_values is None: + model_inputs["pixel_values_images"] = pixel_values_images + model_inputs["pixel_values_videos"] = pixel_values_videos + return model_inputs def _reorder_cache(self, *args, **kwargs): diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index 0355d756ce27..51e2ad55a665 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -46,6 +46,11 @@ class VideoLlavaProcessor(ProcessorMixin): def __init__(self, image_processor=None, tokenizer=None): super().__init__(image_processor, tokenizer) + self.image_size = self.image_processor.size["shortest_edge"] + self.patch_size = 14 # self.image_processor.path_size + self.image_token = "" + self.video_token = "