Skip to content

Commit

Permalink
Video-LLaVa: handle any number of frames (#31221)
Browse files Browse the repository at this point in the history
video-llava can handle more frames
  • Loading branch information
zucchini-nlp authored Jun 4, 2024
1 parent 36ade4a commit d64e4da
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/transformers/models/video_llava/modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _merge_input_ids_with_visual_features(
num_images, num_image_patches, embed_dim = visual_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
special_vision_token = self.config.video_token_index if num_frames == 8 else self.config.image_token_index
special_vision_token = self.config.video_token_index if num_frames > 1 else self.config.image_token_index

# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == special_vision_token
Expand Down Expand Up @@ -375,14 +375,13 @@ def _get_vision_features(
# videos do not need to select features and it's always "full" (as it is done in the orig implementation)
if pixel_values_videos is not None:
batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape
if num_frames != 8:
raise ValueError(f"Video pixel values should have exactly `8` frames but foung `{num_frames}`")

pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)
video_outputs = self.video_tower(pixel_values, output_hidden_states=True)
video_outputs = video_outputs.hidden_states[vision_feature_layer].squeeze(1)
else:
video_outputs = None
num_frames = 0

if pixel_values_images is not None:
image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True)
Expand All @@ -397,7 +396,7 @@ def _get_vision_features(
else:
image_outputs = None

return image_outputs, video_outputs
return image_outputs, video_outputs, num_frames

@add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -513,7 +512,7 @@ def forward(

# 2. Merge text and images
if (pixel_values_images is not None or pixel_values_videos is not None) and input_ids.shape[1] != 1:
image_outputs, video_outputs = self._get_vision_features(
image_outputs, video_outputs, num_frames = self._get_vision_features(
pixel_values_images=pixel_values_images,
pixel_values_videos=pixel_values_videos,
vision_feature_layer=vision_feature_layer,
Expand Down Expand Up @@ -546,7 +545,7 @@ def forward(
input_ids,
attention_mask,
labels,
num_frames=8,
num_frames=num_frames,
)
else:
# In case input_ids.shape[1] == 1 & past_key_values != None, we are in the case of
Expand Down
3 changes: 3 additions & 0 deletions tests/models/video_llava/test_modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,9 @@ def test_video_llava_index_error_bug(self):
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
)
video_file = np.load(video_file)

# let's expand it for 16 frames, to check model can handle any number of frames
video_file = video_file.repeat(2, 0)
inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)

# Make sure that `generate` works
Expand Down

0 comments on commit d64e4da

Please sign in to comment.