Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Video-LLaVa: handle any number of frames #31221

Merged
merged 1 commit into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/transformers/models/video_llava/modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _merge_input_ids_with_visual_features(
num_images, num_image_patches, embed_dim = visual_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
special_vision_token = self.config.video_token_index if num_frames == 8 else self.config.image_token_index
special_vision_token = self.config.video_token_index if num_frames > 1 else self.config.image_token_index

# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == special_vision_token
Expand Down Expand Up @@ -375,14 +375,13 @@ def _get_vision_features(
# videos do not need to select features and it's always "full" (as it is done in the orig implementation)
if pixel_values_videos is not None:
batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape
if num_frames != 8:
raise ValueError(f"Video pixel values should have exactly `8` frames but foung `{num_frames}`")

pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)
video_outputs = self.video_tower(pixel_values, output_hidden_states=True)
video_outputs = video_outputs.hidden_states[vision_feature_layer].squeeze(1)
else:
video_outputs = None
num_frames = 0

if pixel_values_images is not None:
image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True)
Expand All @@ -397,7 +396,7 @@ def _get_vision_features(
else:
image_outputs = None

return image_outputs, video_outputs
return image_outputs, video_outputs, num_frames

@add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -513,7 +512,7 @@ def forward(

# 2. Merge text and images
if (pixel_values_images is not None or pixel_values_videos is not None) and input_ids.shape[1] != 1:
image_outputs, video_outputs = self._get_vision_features(
image_outputs, video_outputs, num_frames = self._get_vision_features(
pixel_values_images=pixel_values_images,
pixel_values_videos=pixel_values_videos,
vision_feature_layer=vision_feature_layer,
Expand Down Expand Up @@ -546,7 +545,7 @@ def forward(
input_ids,
attention_mask,
labels,
num_frames=8,
num_frames=num_frames,
)
else:
# In case input_ids.shape[1] == 1 & past_key_values != None, we are in the case of
Expand Down
3 changes: 3 additions & 0 deletions tests/models/video_llava/test_modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,9 @@ def test_video_llava_index_error_bug(self):
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
)
video_file = np.load(video_file)

# let's expand it for 16 frames, to check model can handle any number of frames
video_file = video_file.repeat(2, 0)
inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)

# Make sure that `generate` works
Expand Down
Loading