From d784af5dfe592cc0709ef005402a11deba4eecd4 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 2 Sep 2024 16:43:16 +0200 Subject: [PATCH 1/4] fix --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f62184c1904c..b5d5bf6bb57a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1926,7 +1926,7 @@ def train( # do_train is not a reliable argument, as it might not be set and .train() still called, so # the following is a workaround: - if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: + if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train and not self.is_model_parallel: self._move_model_to_device(self.model, args.device) if "model_path" in kwargs: From 17aa9f1a98e53206063815e38686e4c18c808b57 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 11 Sep 2024 09:09:16 +0200 Subject: [PATCH 2/4] return image hidden states --- src/transformers/models/llava/modeling_llava.py | 11 +++++------ .../models/llava_next/modeling_llava_next.py | 11 +++++------ .../modeling_llava_next_video.py | 16 ++++++++++------ .../llava_onevision/modeling_llava_onevision.py | 15 +++++++++------ .../models/paligemma/modeling_paligemma.py | 11 +++++------ .../models/video_llava/modeling_video_llava.py | 16 ++++++++++------ .../models/vipllava/modeling_vipllava.py | 11 +++++------ 7 files changed, 49 insertions(+), 42 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index ae53156d9ba2..af9a18a6f155 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -70,11 +70,9 @@ class LlavaCausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - - image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ loss: Optional[torch.FloatTensor] = None @@ -82,7 +80,7 @@ class LlavaCausalLMOutputWithPast(ModelOutput): past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None class LlavaMultiModalProjector(nn.Module): @@ -556,6 +554,7 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, ) def prepare_inputs_for_generation( diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 5fe029f13e73..7159801602f7 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -171,11 +171,9 @@ class LlavaNextCausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - - image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ loss: Optional[torch.FloatTensor] = None @@ -183,7 +181,7 @@ class LlavaNextCausalLMOutputWithPast(ModelOutput): past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None # Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext @@ -927,6 +925,7 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, ) def prepare_inputs_for_generation( diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 78e8c5a07723..63c093584753 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -178,11 +178,12 @@ class LlavaNextVideoCausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - - image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ loss: Optional[torch.FloatTensor] = None @@ -190,7 +191,8 @@ class LlavaNextVideoCausalLMOutputWithPast(ModelOutput): past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None class LlavaNextVideoPooler(nn.Module): @@ -1009,6 +1011,8 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + video_hidden_states=video_features if pixel_values is not None else None, ) def prepare_inputs_for_generation( diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 9496c3857027..3e9afa769dfe 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -172,11 +172,12 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - - image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ loss: Optional[torch.FloatTensor] = None @@ -184,7 +185,7 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput): past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None # Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaOnevision @@ -690,6 +691,8 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + video_hidden_states=video_features if pixel_values_videos is not None else None, ) def prepare_inputs_for_generation( diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 7c09177e27d7..f394029fb8b9 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -72,11 +72,9 @@ class PaliGemmaCausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - - image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. """ loss: Optional[torch.FloatTensor] = None @@ -84,7 +82,7 @@ class PaliGemmaCausalLMOutputWithPast(ModelOutput): past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None class PaliGemmaMultiModalProjector(nn.Module): @@ -497,6 +495,7 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, ) def prepare_inputs_for_generation( diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index b9263ad15cbf..f1d95d28a880 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -67,11 +67,12 @@ class VideoLlavaCausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - - image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`. + video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ loss: Optional[torch.FloatTensor] = None @@ -79,7 +80,8 @@ class VideoLlavaCausalLMOutputWithPast(ModelOutput): past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None # Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->VideoLlava @@ -664,6 +666,8 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values_images is not None else None, + video_hidden_states=video_features if pixel_values_videos is not None else None, ) def prepare_inputs_for_generation( diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index e036d6fb7667..af7dd2b8d23b 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -67,11 +67,9 @@ class VipLlavaCausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - - image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ loss: Optional[torch.FloatTensor] = None @@ -79,7 +77,7 @@ class VipLlavaCausalLMOutputWithPast(ModelOutput): past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None class VipLlavaMultiModalProjector(nn.Module): @@ -550,6 +548,7 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, ) def prepare_inputs_for_generation( From dc2b700202206ea157acb680dc71f09ef017310b Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 13 Sep 2024 09:25:03 +0200 Subject: [PATCH 3/4] fix copies --- src/transformers/models/llava/modeling_llava.py | 1 - src/transformers/models/llava_next/modeling_llava_next.py | 1 - .../models/llava_next_video/modeling_llava_next_video.py | 3 +-- .../models/llava_onevision/modeling_llava_onevision.py | 3 ++- src/transformers/models/video_llava/modeling_video_llava.py | 1 - src/transformers/models/vipllava/modeling_vipllava.py | 2 +- 6 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 8367da1089d2..9ad19ccee722 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -43,7 +43,6 @@ @dataclass -# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava class LlavaCausalLMOutputWithPast(ModelOutput): """ Base class for Llava causal language model (or autoregressive) outputs. diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 47f617474746..ebb4da3102da 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -144,7 +144,6 @@ def unpad_image(tensor, original_size): @dataclass -# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->LlavaNext class LlavaNextCausalLMOutputWithPast(ModelOutput): """ Base class for LlavaNext causal language model (or autoregressive) outputs. diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 1ecf88135310..589bf346ceeb 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -150,7 +150,6 @@ def unpad_image(tensor, original_size): @dataclass -# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->LlavaNextVideo class LlavaNextVideoCausalLMOutputWithPast(ModelOutput): """ Base class for LlavaNextVideo causal language model (or autoregressive) outputs. @@ -1018,7 +1017,7 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, - video_hidden_states=video_features if pixel_values is not None else None, + video_hidden_states=video_features if pixel_values_videos is not None else None, ) def prepare_inputs_for_generation( diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 3e9afa769dfe..697ea84fea50 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -145,7 +145,7 @@ def unpad_image(tensor, original_size): @dataclass -# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->LlavaOnevision +# Copied from transformers.models.llava_next_video.modeling_llava_next_video.LlavaNextVideoCausalLMOutputWithPast with LlavaNextVideo->LlavaOnevision class LlavaOnevisionCausalLMOutputWithPast(ModelOutput): """ Base class for LlavaOnevision causal language model (or autoregressive) outputs. @@ -186,6 +186,7 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput): hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None # Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaOnevision diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 9a3e81dd3c70..9ae80be65ae4 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -40,7 +40,6 @@ @dataclass -# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->VideoLlava class VideoLlavaCausalLMOutputWithPast(ModelOutput): """ Base class for VideoLlava causal language model (or autoregressive) outputs. diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 9a2efdebbfdf..53a321369719 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -40,7 +40,7 @@ @dataclass -# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->VipLlava +# Copied from transformers.models.llava.modeling_llava.LlavaCausalLMOutputWithPast with Llava->VipLlava class VipLlavaCausalLMOutputWithPast(ModelOutput): """ Base class for VipLlava causal language model (or autoregressive) outputs. From 1dbdeb594441ed60a554f9817f6c5a7cb739e26c Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 13 Sep 2024 09:49:39 +0200 Subject: [PATCH 4/4] fix test --- tests/models/video_llava/test_modeling_video_llava.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index a8b2229a02f5..1f8883453754 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -320,6 +320,10 @@ def recursive_check(batched_object, single_row_object, model_name, key): model_row_output = model(**single_row_input) for key in model_batched_output: + # we can't test videos as their output shapes are linked to number of frames + # and we don't have to as it is a CLIP model and can be tested from `ClipModelTester` class + if key == "video_hidden_states": + continue recursive_check(model_batched_output[key], model_row_output[key], model_name, key) # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs