Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable dynamic resolution input for Beit #31053

Merged
merged 5 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 105 additions & 10 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,68 @@ def __init__(self, config: BeitConfig) -> None:
else:
self.mask_token = None
self.patch_embeddings = BeitPatchEmbeddings(config)
self.patch_size = config.patch_size
self.image_size = (
config.image_size
if isinstance(config.image_size, collections.abc.Iterable)
else (config.image_size, config.image_size)
)
num_patches = self.patch_embeddings.num_patches
if config.use_absolute_position_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
else:
self.position_embeddings = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
higher resolution images.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings

class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h = height // self.patch_size
w = width // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h, w = h + 0.1, w + 0.1

patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)

embeddings, (patch_height, patch_width) = self.patch_embeddings(
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
)
Expand All @@ -158,7 +212,10 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Bo

cls_tokens = self.cls_token.expand(batch_size, -1, -1)
if self.position_embeddings is not None:
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
if interpolate_pos_encoding:
cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width)
else:
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]

embeddings = torch.cat((cls_tokens, embeddings), dim=1)

Expand Down Expand Up @@ -191,7 +248,11 @@ def __init__(self, config):

self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self,
pixel_values: torch.Tensor,
position_embedding: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
Expand Down Expand Up @@ -251,6 +312,7 @@ def forward(
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

Expand All @@ -265,7 +327,9 @@ def forward(

# Add relative position bias if present.
if self.relative_position_bias is not None:
attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0)
attention_scores = attention_scores + self.relative_position_bias(
interpolate_pos_encoding, attention_scores.shape[2]
).unsqueeze(0)

# Add shared relative position bias if provided.
if relative_position_bias is not None:
Expand Down Expand Up @@ -342,8 +406,11 @@ def forward(
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias)
self_outputs = self.attention(
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
)

attention_output = self.output(self_outputs[0], hidden_states)

Expand Down Expand Up @@ -407,12 +474,14 @@ def forward(
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
head_mask,
output_attentions=output_attentions,
relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
Expand Down Expand Up @@ -471,12 +540,21 @@ def __init__(self, config: BeitConfig, window_size: tuple) -> None:

self.register_buffer("relative_position_index", relative_position_index, persistent=False)

def forward(self) -> torch.Tensor:
def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
) # Wh*Ww,Wh*Ww,nH

return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
if interpolate_pos_encoding:
relative_position_bias = nn.functional.interpolate(
relative_position_bias.unsqueeze(1),
size=(dim_size, dim_size),
mode="bilinear",
align_corners=False,
).squeeze(1)

return relative_position_bias


class BeitEncoder(nn.Module):
Expand Down Expand Up @@ -508,6 +586,7 @@ def forward(
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
interpolate_pos_encoding: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
Expand All @@ -528,9 +607,13 @@ def forward(
)
else:
relative_position_bias = (
self.relative_position_bias() if self.relative_position_bias is not None else None
self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1])
if self.relative_position_bias is not None
else None
)
layer_outputs = layer_module(
hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
)
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias)

hidden_states = layer_outputs[0]

Expand Down Expand Up @@ -607,6 +690,8 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -658,6 +743,7 @@ def forward(
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, BeitModelOutputWithPooling]:
r"""
Expand All @@ -680,14 +766,17 @@ def forward(
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos)
embedding_output, (patch_height, patch_width) = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)

encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
Expand Down Expand Up @@ -755,6 +844,7 @@ def forward(
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, MaskedLMOutput]:
r"""
Expand Down Expand Up @@ -800,6 +890,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down Expand Up @@ -858,6 +949,7 @@ def forward(
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
Expand All @@ -872,6 +964,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down Expand Up @@ -1215,6 +1308,7 @@ def forward(
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, SemanticSegmenterOutput]:
r"""
Expand Down Expand Up @@ -1252,6 +1346,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=True, # we need the intermediate hidden states
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down
Loading