From 5c0c132341974540d1e94c94adcbbc3320d4063a Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Wed, 8 May 2024 16:13:57 +0300 Subject: [PATCH 1/3] blip with interpolated pos encoding --- src/transformers/models/blip/modeling_blip.py | 77 +++++++++++++++++-- tests/models/blip/test_modeling_blip.py | 11 +++ 2 files changed, 80 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index bd61a1cbd781..d2155dcaedb9 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -14,6 +14,7 @@ # limitations under the License. """ PyTorch BLIP model.""" +import math import warnings from dataclasses import dataclass from typing import Any, Optional, Tuple, Union @@ -234,15 +235,51 @@ def __init__(self, config: BlipVisionConfig): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 + + if num_patches == num_positions and height == width: + return self.position_embedding + + class_pos_embed = self.position_embedding[:, 0, :] + patch_pos_embed = self.position_embedding[:, 1:, :] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) + if interpolate_pos_encoding: + position_embedding = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embedding = self.position_embedding + embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype) return embeddings @@ -660,6 +697,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -674,7 +712,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -782,6 +820,7 @@ def get_image_features( self, pixel_values: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> torch.FloatTensor: r""" Returns: @@ -807,7 +846,11 @@ def get_image_features( ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_outputs = self.vision_model(pixel_values=pixel_values, return_dict=return_dict) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) pooled_output = vision_outputs[1] # pooled_output image_features = self.visual_projection(pooled_output) @@ -821,6 +864,7 @@ def get_multimodal_features( pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> torch.FloatTensor: r""" Returns: @@ -849,6 +893,7 @@ def get_multimodal_features( output_attentions=True, output_hidden_states=True, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -879,6 +924,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BlipOutput]: r""" Returns: @@ -916,6 +962,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) text_outputs = self.text_model( @@ -1002,6 +1049,7 @@ def forward( output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BlipForConditionalGenerationModelOutput]: r""" Returns: @@ -1036,6 +1084,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -1068,6 +1117,7 @@ def generate( pixel_values: torch.FloatTensor, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: r""" @@ -1103,7 +1153,10 @@ def generate( """ batch_size = pixel_values.shape[0] - vision_outputs = self.vision_model(pixel_values=pixel_values) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) image_embeds = vision_outputs[0] @@ -1177,6 +1230,7 @@ def forward( output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BlipTextVisionModelOutput]: r""" Returns: @@ -1230,6 +1284,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -1282,6 +1337,7 @@ def generate( input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: r""" @@ -1319,7 +1375,10 @@ def generate( 2 ``` """ - vision_outputs = self.vision_model(pixel_values=pixel_values) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) image_embeds = vision_outputs[0] @@ -1411,6 +1470,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BlipTextVisionModelOutput]: r""" Returns: @@ -1444,6 +1504,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 4caba63a3104..f6e4f6bd8902 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -1381,6 +1381,17 @@ def test_inference_image_captioning_fp16(self): [30522, 1037, 3861, 1997, 1037, 2450, 1998, 2014, 3899, 2006, 1996, 3509, 102], ) + def test_inference_interpolate_pos_encoding(self): + model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(torch_device) + processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + processor.image_processor.size = {"height": 500, "width": 500} + + image = prepare_img() + inputs = processor(images=image, return_tensors="pt").to(torch_device) + + predictions = model.generate(**inputs, interpolate_pos_encoding=True) + self.assertEqual(predictions[0].tolist(), [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 1037, 3899, 102]) + def test_inference_vqa(self): model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(torch_device) processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") From 4a2d62787f4fd9b49507a7fc4600178aba9e7adf Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Wed, 8 May 2024 23:56:08 +0300 Subject: [PATCH 2/3] feat: Add interpolate_pos_encoding option to other models from `BLIP` family. --- src/transformers/models/blip/modeling_blip.py | 4 ++ .../models/blip_2/modeling_blip_2.py | 66 +++++++++++++++++-- .../instructblip/modeling_instructblip.py | 62 +++++++++++++++-- tests/models/blip_2/test_modeling_blip_2.py | 13 ++++ .../test_modeling_instructblip.py | 18 +++++ 5 files changed, 151 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index d2155dcaedb9..c858048c3876 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -549,6 +549,8 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. """ BLIP_INPUTS_DOCSTRING = r""" @@ -585,6 +587,8 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. """ diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index edd0d9a6d761..ad2437171357 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -104,15 +104,51 @@ def __init__(self, config: Blip2VisionConfig): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 + + if num_patches == num_positions and height == width: + return self.position_embedding + + class_pos_embed = self.position_embedding[:, 0, :] + patch_pos_embed = self.position_embedding[:, 1:, :] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) + if interpolate_pos_encoding: + position_embedding = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embedding = self.position_embedding + embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype) return embeddings @@ -324,6 +360,8 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. """ BLIP_2_TEXT_INPUTS_DOCSTRING = r""" @@ -405,6 +443,8 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. """ @@ -519,6 +559,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -533,7 +574,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1300,6 +1341,7 @@ def get_image_features( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ): r""" Returns: @@ -1333,6 +1375,7 @@ def get_image_features( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) return vision_outputs @@ -1344,6 +1387,7 @@ def get_qformer_features( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ): r""" Returns: @@ -1377,6 +1421,7 @@ def get_qformer_features( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -1409,6 +1454,7 @@ def forward( output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: r""" Returns: @@ -1444,6 +1490,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -1626,6 +1673,7 @@ def forward( output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: r""" Returns: @@ -1698,6 +1746,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -1782,6 +1831,7 @@ def generate( pixel_values: torch.FloatTensor, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: """ @@ -1803,7 +1853,11 @@ def generate( self._preprocess_accelerate() batch_size = pixel_values.shape[0] - image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state + image_embeds = self.vision_model( + pixel_values, + return_dict=True, + interpolate_pos_encoding=interpolate_pos_encoding, + ).last_hidden_state image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 52f8fa610a94..45b60bec2c6e 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -105,15 +105,51 @@ def __init__(self, config: InstructBlipVisionConfig): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 + + if num_patches == num_positions and height == width: + return self.position_embedding + + class_pos_embed = self.position_embedding[:, 0, :] + patch_pos_embed = self.position_embedding[:, 1:, :] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) + if interpolate_pos_encoding: + position_embedding = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embedding = self.position_embedding + embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype) return embeddings @@ -331,6 +367,8 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. """ INSTRUCTBLIP_INPUTS_DOCSTRING = r""" @@ -394,6 +432,8 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. """ @@ -508,6 +548,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -522,7 +563,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1330,6 +1371,7 @@ def forward( output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1382,6 +1424,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -1476,6 +1519,7 @@ def generate( qformer_attention_mask: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: """ @@ -1492,6 +1536,8 @@ def generate( The sequence used as a prompt for the generation. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): Mask to avoid performing attention on padding token indices. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the positional encoding of the image embeddings. Returns: captions (list): A list of strings of length batch_size * num_captions. @@ -1501,7 +1547,11 @@ def generate( self._preprocess_accelerate() batch_size = pixel_values.shape[0] - image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state + image_embeds = self.vision_model( + pixel_values, + return_dict=True, + interpolate_pos_encoding=interpolate_pos_encoding, + ).last_hidden_state image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 927f5341272f..ce789f6ac99d 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -882,6 +882,19 @@ def test_inference_opt(self): ) self.assertEqual(generated_text, "it's not a city, it's a beach") + def test_inference_interpolate_pos_encoding(self): + processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + model = Blip2ForConditionalGeneration.from_pretrained( + "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16 + ).to(torch_device) + processor.image_processor.size = {"height": 500, "width": 500} + + image = prepare_img() + inputs = processor(images=image, return_tensors="pt").to(torch_device) + + predictions = model.generate(**inputs, interpolate_pos_encoding=True) + self.assertEqual(predictions[0].tolist(), [2, 102, 693, 8, 2335, 15, 5, 4105, 50118]) + def test_inference_opt_batched_beam_search(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") model = Blip2ForConditionalGeneration.from_pretrained( diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index dcb8040bfcf9..febee81ff7a9 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -612,3 +612,21 @@ def test_inference_flant5_xl(self): generated_text, "The image depicts a man ironing clothes on the back of a yellow van in the middle of a busy city street. The man is wearing a yellow shirt with a bright yellow tie, and he is using an ironing board to complete his task. The image is unusual due to the fact that it shows a man ironing clothes on the back of a van in the middle of a busy city street. It is possible that the man is trying to save money by doing his laundry on the back of the van, but it is also possible that he is trying to save time by doing his laundry on the back of the van in the middle of a busy city street. Regardless of the reason for the man's actions, it is clear that he is trying to save time by doing his laundry on the back of the van in the middle of a busy city street.", ) + + def test_inference_interpolate_pos_encoding(self): + processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl") + model = InstructBlipForConditionalGeneration.from_pretrained( + "Salesforce/instructblip-flan-t5-xl", + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ).to(torch_device) + processor.image_processor.size = {"height": 500, "width": 500} + + image = prepare_img() + prompt = "What's in the image?" + inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device) + + predictions = model.generate(**inputs, interpolate_pos_encoding=True) + self.assertEqual( + predictions[0].tolist(), [0, 37, 1023, 753, 3, 9, 2335, 3823, 30, 8, 2608, 28, 3, 9, 1782, 5, 1] + ) From 6e211839e05c6a5a0fc152156f9bb7d701d36951 Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Sat, 11 May 2024 11:05:28 +0300 Subject: [PATCH 3/3] include check for textual generated content in tests --- tests/models/blip/test_modeling_blip.py | 3 +++ tests/models/blip_2/test_modeling_blip_2.py | 3 +++ tests/models/instructblip/test_modeling_instructblip.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index f6e4f6bd8902..89404342f0b0 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -1390,7 +1390,10 @@ def test_inference_interpolate_pos_encoding(self): inputs = processor(images=image, return_tensors="pt").to(torch_device) predictions = model.generate(**inputs, interpolate_pos_encoding=True) + generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() + self.assertEqual(predictions[0].tolist(), [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 1037, 3899, 102]) + self.assertEqual(generated_text, "a woman sitting on the beach with a dog") def test_inference_vqa(self): model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(torch_device) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index ce789f6ac99d..d2f3b2b719f2 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -893,7 +893,10 @@ def test_inference_interpolate_pos_encoding(self): inputs = processor(images=image, return_tensors="pt").to(torch_device) predictions = model.generate(**inputs, interpolate_pos_encoding=True) + generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() + self.assertEqual(predictions[0].tolist(), [2, 102, 693, 8, 2335, 15, 5, 4105, 50118]) + self.assertEqual(generated_text, "a woman and dog on the beach") def test_inference_opt_batched_beam_search(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index febee81ff7a9..86aea876fa50 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -627,6 +627,9 @@ def test_inference_interpolate_pos_encoding(self): inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device) predictions = model.generate(**inputs, interpolate_pos_encoding=True) + generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() + self.assertEqual( predictions[0].tolist(), [0, 37, 1023, 753, 3, 9, 2335, 3823, 30, 8, 2608, 28, 3, 9, 1782, 5, 1] ) + self.assertEqual(generated_text, "The image features a woman sitting on the beach with a dog.")