Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blip dynamic input resolution #30722

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 73 additions & 8 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
""" PyTorch BLIP model."""

import math
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
Expand Down Expand Up @@ -234,15 +235,51 @@ def __init__(self, config: BlipVisionConfig):

self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1

if num_patches == num_positions and height == width:
return self.position_embedding

class_pos_embed = self.position_embedding[:, 0, :]
patch_pos_embed = self.position_embedding[:, 1:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
if interpolate_pos_encoding:
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
else:
position_embedding = self.position_embedding
embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
return embeddings


Expand Down Expand Up @@ -512,6 +549,8 @@ def _init_weights(self, module):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
"""

BLIP_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -548,6 +587,8 @@ def _init_weights(self, module):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
"""


Expand Down Expand Up @@ -660,6 +701,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand All @@ -674,7 +716,7 @@ def forward(
if pixel_values is None:
raise ValueError("You have to specify pixel_values")

hidden_states = self.embeddings(pixel_values)
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
Expand Down Expand Up @@ -782,6 +824,7 @@ def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> torch.FloatTensor:
r"""
Returns:
Expand All @@ -807,7 +850,11 @@ def get_image_features(
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

vision_outputs = self.vision_model(pixel_values=pixel_values, return_dict=return_dict)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

pooled_output = vision_outputs[1] # pooled_output
image_features = self.visual_projection(pooled_output)
Expand All @@ -821,6 +868,7 @@ def get_multimodal_features(
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> torch.FloatTensor:
r"""
Returns:
Expand Down Expand Up @@ -849,6 +897,7 @@ def get_multimodal_features(
output_attentions=True,
output_hidden_states=True,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

image_embeds = vision_outputs[0]
Expand Down Expand Up @@ -879,6 +928,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BlipOutput]:
r"""
Returns:
Expand Down Expand Up @@ -916,6 +966,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

text_outputs = self.text_model(
Expand Down Expand Up @@ -1002,6 +1053,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BlipForConditionalGenerationModelOutput]:
r"""
Returns:
Expand Down Expand Up @@ -1036,6 +1088,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

image_embeds = vision_outputs[0]
Expand Down Expand Up @@ -1068,6 +1121,7 @@ def generate(
pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
r"""
Expand Down Expand Up @@ -1103,7 +1157,10 @@ def generate(
"""

batch_size = pixel_values.shape[0]
vision_outputs = self.vision_model(pixel_values=pixel_values)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)

image_embeds = vision_outputs[0]

Expand Down Expand Up @@ -1177,6 +1234,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BlipTextVisionModelOutput]:
r"""
Returns:
Expand Down Expand Up @@ -1230,6 +1288,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

image_embeds = vision_outputs[0]
Expand Down Expand Up @@ -1282,6 +1341,7 @@ def generate(
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
r"""
Expand Down Expand Up @@ -1319,7 +1379,10 @@ def generate(
2
```
"""
vision_outputs = self.vision_model(pixel_values=pixel_values)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)

image_embeds = vision_outputs[0]

Expand Down Expand Up @@ -1411,6 +1474,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BlipTextVisionModelOutput]:
r"""
Returns:
Expand Down Expand Up @@ -1444,6 +1508,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

image_embeds = vision_outputs[0]
Expand Down
Loading
Loading