Skip to content

Commit

Permalink
Perceiver interpolate position embedding (#30979)
Browse files Browse the repository at this point in the history
* add test that currently fails

* test passed

* all perceiver passed

* fixup, style, quality, repo-consistency, all passed

* Apply suggestions from code review: default to False + compute sqrt once only

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix a minor bracket

* replace dim with self._num_channels

* add arguments to the rest preprocessors

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
g1y5x3 and amyeroberts authored May 24, 2024
1 parent 5855afd commit 42d8dd8
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 10 deletions.
83 changes: 73 additions & 10 deletions src/transformers/models/perceiver/modeling_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,13 +699,24 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


@add_start_docstrings(
"""The Perceiver: a scalable, fully attentional architecture.""",
"""The Perceiver: a scalable, fully attentional architecture.
<Tip>
Note that it's possible to fine-tune Perceiver on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""",
PERCEIVER_MODEL_START_DOCSTRING,
)
class PerceiverModel(PerceiverPreTrainedModel):
Expand Down Expand Up @@ -754,6 +765,7 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, PerceiverModelOutput]:
r"""
Expand Down Expand Up @@ -857,7 +869,9 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if self.input_preprocessor is not None:
inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(inputs)
inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(
inputs, interpolate_pos_encoding=interpolate_pos_encoding
)
else:
modality_sizes = None
inputs_without_pos = None
Expand Down Expand Up @@ -1247,6 +1261,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
) -> Union[Tuple, PerceiverClassifierOutput]:
Expand Down Expand Up @@ -1295,6 +1310,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
logits = outputs.logits if return_dict else outputs[0]
Expand Down Expand Up @@ -2749,9 +2765,31 @@ def num_dimensions(self) -> int:
def output_size(self, *args, **kwargs) -> int:
return self._num_channels

def forward(self, batch_size: int) -> torch.Tensor:
def interpolate_pos_encoding(self, position_embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
num_positions = position_embeddings.shape[0]
new_height = new_width = math.sqrt(num_positions)
position_embeddings = position_embeddings.reshape(
1, int(new_height), int(new_width), self._num_channels
).permute(0, 3, 1, 2)
position_embeddings = nn.functional.interpolate(
position_embeddings,
scale_factor=(height / new_height, width / new_width),
mode="bicubic",
align_corners=False,
)
position_embeddings = position_embeddings.reshape(1, self._num_channels, -1).permute(0, 2, 1).squeeze(0)
return position_embeddings

def forward(
self, batch_size: int, interpolate_pos_encoding: bool = False, input_size: torch.Size = None
) -> torch.Tensor:
position_embeddings = self.position_embeddings

if interpolate_pos_encoding:
height, width = input_size
height, width = height + 0.1, width + 0.1
position_embeddings = self.interpolate_pos_encoding(position_embeddings, height, width)

if batch_size is not None:
position_embeddings = position_embeddings.expand(batch_size, -1, -1)
return position_embeddings
Expand Down Expand Up @@ -2859,7 +2897,13 @@ def __init__(self, config: PerceiverConfig) -> None:
def num_channels(self) -> int:
return self.config.d_model

def forward(self, inputs: torch.LongTensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
def forward(
self,
inputs: torch.LongTensor,
pos: Optional[torch.Tensor] = None,
network_input_is_1d: bool = True,
interpolate_pos_encoding: bool = False,
):
embeddings_without_pos = self.embeddings(inputs)

seq_length = inputs.shape[1]
Expand Down Expand Up @@ -3139,14 +3183,17 @@ def num_channels(self) -> int:

return inp_dim + pos_dim

def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool = True):
def _build_network_inputs(
self, inputs: torch.Tensor, network_input_is_1d: bool = True, interpolate_pos_encoding: bool = False
):
"""
Construct the final input, including position encoding.
This method expects the inputs to always have channels as last dimension.
"""
batch_size = inputs.shape[0]
input_size = inputs.shape[1:3]
index_dims = inputs.shape[1:-1]
indices = np.prod(index_dims)

Expand All @@ -3156,7 +3203,7 @@ def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool

# Construct the position encoding.
if self.position_encoding_type == "trainable":
pos_enc = self.position_embeddings(batch_size)
pos_enc = self.position_embeddings(batch_size, interpolate_pos_encoding, input_size)
elif self.position_encoding_type == "fourier":
pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)

Expand All @@ -3174,7 +3221,13 @@ def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool
inputs_with_pos = inputs + pos_enc
return inputs_with_pos, inputs

def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
def forward(
self,
inputs: torch.Tensor,
pos: Optional[torch.Tensor] = None,
network_input_is_1d: bool = True,
interpolate_pos_encoding: bool = False,
):
if self.prep_type == "conv":
# Convnet image featurization.
# Downsamples spatially by a factor of 4
Expand Down Expand Up @@ -3218,7 +3271,7 @@ def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, netw
else:
raise ValueError("Unsupported data format for conv1x1.")

inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d)
inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d, interpolate_pos_encoding)
modality_sizes = None # Size for each modality, only needed for multimodal

return inputs, modality_sizes, inputs_without_pos
Expand Down Expand Up @@ -3338,7 +3391,13 @@ def _build_network_inputs(self, inputs):

return inputs_with_pos, inputs

def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
def forward(
self,
inputs: torch.Tensor,
pos: Optional[torch.Tensor] = None,
network_input_is_1d: bool = True,
interpolate_pos_encoding: bool = False,
):
inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch])

inputs, inputs_without_pos = self._build_network_inputs(inputs)
Expand Down Expand Up @@ -3391,7 +3450,11 @@ def num_channels(self) -> int:
return common_channel_size

def forward(
self, inputs: Mapping[str, torch.Tensor], pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True
self,
inputs: Mapping[str, torch.Tensor],
pos: Optional[torch.Tensor] = None,
network_input_is_1d: bool = True,
interpolate_pos_encoding: bool = False,
) -> PreprocessorOutputType:
padded = {}
modality_sizes = {}
Expand Down
20 changes: 20 additions & 0 deletions tests/models/perceiver/test_modeling_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,3 +1031,23 @@ def test_inference_optical_flow(self):
)

self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))

@slow
def test_inference_interpolate_pos_encoding(self):
image_processor = PerceiverImageProcessor(size={"height": 384, "width": 384})
model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned")
model.to(torch_device)

# prepare inputs
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").pixel_values.to(torch_device)
input_mask = None

# forward pass
with torch.no_grad():
outputs = model(inputs=inputs, attention_mask=input_mask, interpolate_pos_encoding=True)
logits = outputs.logits

# verify logits
expected_shape = torch.Size((1, model.config.num_labels))
self.assertEqual(logits.shape, expected_shape)

0 comments on commit 42d8dd8

Please sign in to comment.