Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return image hidden states #33426

Merged
merged 10 commits into from
Sep 13, 2024
11 changes: 5 additions & 6 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,17 @@ class LlavaCausalLMOutputWithPast(ModelOutput):

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.

image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None


class LlavaMultiModalProjector(nn.Module):
Expand Down Expand Up @@ -560,6 +558,7 @@ def forward(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)

def prepare_inputs_for_generation(
Expand Down
11 changes: 5 additions & 6 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,17 @@ class LlavaNextCausalLMOutputWithPast(ModelOutput):

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.

image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None


# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
Expand Down Expand Up @@ -931,6 +929,7 @@ def forward(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)

def prepare_inputs_for_generation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,19 +177,21 @@ class LlavaNextVideoCausalLMOutputWithPast(ModelOutput):

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.

image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
video_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
video_hidden_states: Optional[torch.FloatTensor] = None


class LlavaNextVideoPooler(nn.Module):
Expand Down Expand Up @@ -1015,6 +1017,8 @@ def forward(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values is not None else None,
)

def prepare_inputs_for_generation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,20 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.

image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
video_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None


# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaOnevision
Expand Down Expand Up @@ -690,6 +691,8 @@ def forward(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
)

def prepare_inputs_for_generation(
Expand Down
11 changes: 5 additions & 6 deletions src/transformers/models/paligemma/modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,17 @@ class PaliGemmaCausalLMOutputWithPast(ModelOutput):

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.

image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None


class PaliGemmaMultiModalProjector(nn.Module):
Expand Down Expand Up @@ -497,6 +495,7 @@ def forward(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)

def prepare_inputs_for_generation(
Expand Down
16 changes: 10 additions & 6 deletions src/transformers/models/video_llava/modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,21 @@ class VideoLlavaCausalLMOutputWithPast(ModelOutput):

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.

image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
video_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
video_hidden_states: Optional[torch.FloatTensor] = None


# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->VideoLlava
Expand Down Expand Up @@ -672,6 +674,8 @@ def forward(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values_images is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
)

def prepare_inputs_for_generation(
Expand Down
11 changes: 5 additions & 6 deletions src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,17 @@ class VipLlavaCausalLMOutputWithPast(ModelOutput):

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.

image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None


class VipLlavaMultiModalProjector(nn.Module):
Expand Down Expand Up @@ -554,6 +552,7 @@ def forward(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)

def prepare_inputs_for_generation(
Expand Down
Loading