Skip to content

Commit

Permalink
Blip dynamic input resolution (#30722)
Browse files Browse the repository at this point in the history
* blip with interpolated pos encoding

* feat: Add interpolate_pos_encoding option to other models from `BLIP` family.

* include check for textual generated content in tests
  • Loading branch information
zafstojano authored May 13, 2024
1 parent a4e530e commit f63d822
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 20 deletions.
81 changes: 73 additions & 8 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
""" PyTorch BLIP model."""

import math
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
Expand Down Expand Up @@ -231,15 +232,51 @@ def __init__(self, config: BlipVisionConfig):

self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1

if num_patches == num_positions and height == width:
return self.position_embedding

class_pos_embed = self.position_embedding[:, 0, :]
patch_pos_embed = self.position_embedding[:, 1:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
if interpolate_pos_encoding:
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
else:
position_embedding = self.position_embedding
embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
return embeddings


Expand Down Expand Up @@ -509,6 +546,8 @@ def _init_weights(self, module):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
"""

BLIP_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -545,6 +584,8 @@ def _init_weights(self, module):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
"""


Expand Down Expand Up @@ -657,6 +698,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand All @@ -671,7 +713,7 @@ def forward(
if pixel_values is None:
raise ValueError("You have to specify pixel_values")

hidden_states = self.embeddings(pixel_values)
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
Expand Down Expand Up @@ -779,6 +821,7 @@ def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> torch.FloatTensor:
r"""
Returns:
Expand All @@ -804,7 +847,11 @@ def get_image_features(
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

vision_outputs = self.vision_model(pixel_values=pixel_values, return_dict=return_dict)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

pooled_output = vision_outputs[1] # pooled_output
image_features = self.visual_projection(pooled_output)
Expand All @@ -818,6 +865,7 @@ def get_multimodal_features(
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> torch.FloatTensor:
r"""
Returns:
Expand Down Expand Up @@ -846,6 +894,7 @@ def get_multimodal_features(
output_attentions=True,
output_hidden_states=True,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

image_embeds = vision_outputs[0]
Expand Down Expand Up @@ -876,6 +925,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BlipOutput]:
r"""
Returns:
Expand Down Expand Up @@ -913,6 +963,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

text_outputs = self.text_model(
Expand Down Expand Up @@ -999,6 +1050,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BlipForConditionalGenerationModelOutput]:
r"""
Returns:
Expand Down Expand Up @@ -1033,6 +1085,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

image_embeds = vision_outputs[0]
Expand Down Expand Up @@ -1065,6 +1118,7 @@ def generate(
pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
r"""
Expand Down Expand Up @@ -1100,7 +1154,10 @@ def generate(
"""

batch_size = pixel_values.shape[0]
vision_outputs = self.vision_model(pixel_values=pixel_values)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)

image_embeds = vision_outputs[0]

Expand Down Expand Up @@ -1174,6 +1231,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BlipTextVisionModelOutput]:
r"""
Returns:
Expand Down Expand Up @@ -1227,6 +1285,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

image_embeds = vision_outputs[0]
Expand Down Expand Up @@ -1279,6 +1338,7 @@ def generate(
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
r"""
Expand Down Expand Up @@ -1316,7 +1376,10 @@ def generate(
2
```
"""
vision_outputs = self.vision_model(pixel_values=pixel_values)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)

image_embeds = vision_outputs[0]

Expand Down Expand Up @@ -1408,6 +1471,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BlipTextVisionModelOutput]:
r"""
Returns:
Expand Down Expand Up @@ -1441,6 +1505,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

image_embeds = vision_outputs[0]
Expand Down
Loading

0 comments on commit f63d822

Please sign in to comment.