Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perceiver interpolate position embedding #30979

Merged
merged 8 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 73 additions & 10 deletions src/transformers/models/perceiver/modeling_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,13 +699,24 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


@add_start_docstrings(
"""The Perceiver: a scalable, fully attentional architecture.""",
"""The Perceiver: a scalable, fully attentional architecture.

<Tip>

Note that it's possible to fine-tune Perceiver on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.

</Tip>
""",
PERCEIVER_MODEL_START_DOCSTRING,
)
class PerceiverModel(PerceiverPreTrainedModel):
Expand Down Expand Up @@ -754,6 +765,7 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, PerceiverModelOutput]:
r"""
Expand Down Expand Up @@ -857,7 +869,9 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if self.input_preprocessor is not None:
inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(inputs)
inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(
inputs, interpolate_pos_encoding=interpolate_pos_encoding
)
Comment on lines 871 to +874
Copy link
Contributor Author

@g1y5x3 g1y5x3 May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts since self.input_processor here could also be either TextPreprocessor or MultimodalPreprosessor, it would complain got unexpected keyword argument if it was not added to the forward() of those two corresponding methods. Alternatively, maybe we could add another condition here to call self.input_processor with just two arguments unless it is ImagePreprocessor?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that's OK, I just wanted to make sure that this was deliberate

else:
modality_sizes = None
inputs_without_pos = None
Expand Down Expand Up @@ -1247,6 +1261,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
) -> Union[Tuple, PerceiverClassifierOutput]:
Expand Down Expand Up @@ -1295,6 +1310,7 @@ def forward(
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
logits = outputs.logits if return_dict else outputs[0]
Expand Down Expand Up @@ -2749,9 +2765,31 @@ def num_dimensions(self) -> int:
def output_size(self, *args, **kwargs) -> int:
return self._num_channels

def forward(self, batch_size: int) -> torch.Tensor:
def interpolate_pos_encoding(self, position_embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
num_positions = position_embeddings.shape[0]
new_height = new_width = math.sqrt(num_positions)
position_embeddings = position_embeddings.reshape(
1, int(new_height), int(new_width), self._num_channels
).permute(0, 3, 1, 2)
position_embeddings = nn.functional.interpolate(
position_embeddings,
scale_factor=(height / new_height, width / new_width),
mode="bicubic",
align_corners=False,
)
position_embeddings = position_embeddings.reshape(1, self._num_channels, -1).permute(0, 2, 1).squeeze(0)
return position_embeddings

def forward(
self, batch_size: int, interpolate_pos_encoding: bool = False, input_size: torch.Size = None
) -> torch.Tensor:
position_embeddings = self.position_embeddings

if interpolate_pos_encoding:
height, width = input_size
height, width = height + 0.1, width + 0.1
position_embeddings = self.interpolate_pos_encoding(position_embeddings, height, width)

if batch_size is not None:
position_embeddings = position_embeddings.expand(batch_size, -1, -1)
return position_embeddings
Expand Down Expand Up @@ -2859,7 +2897,13 @@ def __init__(self, config: PerceiverConfig) -> None:
def num_channels(self) -> int:
return self.config.d_model

def forward(self, inputs: torch.LongTensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
def forward(
self,
inputs: torch.LongTensor,
pos: Optional[torch.Tensor] = None,
network_input_is_1d: bool = True,
interpolate_pos_encoding: bool = False,
):
embeddings_without_pos = self.embeddings(inputs)

seq_length = inputs.shape[1]
Expand Down Expand Up @@ -3139,14 +3183,17 @@ def num_channels(self) -> int:

return inp_dim + pos_dim

def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool = True):
def _build_network_inputs(
self, inputs: torch.Tensor, network_input_is_1d: bool = True, interpolate_pos_encoding: bool = False
):
"""
Construct the final input, including position encoding.

This method expects the inputs to always have channels as last dimension.

"""
batch_size = inputs.shape[0]
input_size = inputs.shape[1:3]
index_dims = inputs.shape[1:-1]
indices = np.prod(index_dims)

Expand All @@ -3156,7 +3203,7 @@ def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool

# Construct the position encoding.
if self.position_encoding_type == "trainable":
pos_enc = self.position_embeddings(batch_size)
pos_enc = self.position_embeddings(batch_size, interpolate_pos_encoding, input_size)
elif self.position_encoding_type == "fourier":
pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)

Expand All @@ -3174,7 +3221,13 @@ def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool
inputs_with_pos = inputs + pos_enc
return inputs_with_pos, inputs

def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
def forward(
self,
inputs: torch.Tensor,
pos: Optional[torch.Tensor] = None,
network_input_is_1d: bool = True,
interpolate_pos_encoding: bool = False,
):
if self.prep_type == "conv":
# Convnet image featurization.
# Downsamples spatially by a factor of 4
Expand Down Expand Up @@ -3218,7 +3271,7 @@ def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, netw
else:
raise ValueError("Unsupported data format for conv1x1.")

inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d)
inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d, interpolate_pos_encoding)
modality_sizes = None # Size for each modality, only needed for multimodal

return inputs, modality_sizes, inputs_without_pos
Expand Down Expand Up @@ -3338,7 +3391,13 @@ def _build_network_inputs(self, inputs):

return inputs_with_pos, inputs

def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
def forward(
self,
inputs: torch.Tensor,
pos: Optional[torch.Tensor] = None,
network_input_is_1d: bool = True,
interpolate_pos_encoding: bool = False,
):
inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch])

inputs, inputs_without_pos = self._build_network_inputs(inputs)
Expand Down Expand Up @@ -3391,7 +3450,11 @@ def num_channels(self) -> int:
return common_channel_size

def forward(
self, inputs: Mapping[str, torch.Tensor], pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True
self,
inputs: Mapping[str, torch.Tensor],
pos: Optional[torch.Tensor] = None,
network_input_is_1d: bool = True,
interpolate_pos_encoding: bool = False,
) -> PreprocessorOutputType:
padded = {}
modality_sizes = {}
Expand Down
20 changes: 20 additions & 0 deletions tests/models/perceiver/test_modeling_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,3 +1031,23 @@ def test_inference_optical_flow(self):
)

self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))

@slow
def test_inference_interpolate_pos_encoding(self):
image_processor = PerceiverImageProcessor(size={"height": 384, "width": 384})
model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned")
model.to(torch_device)

# prepare inputs
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").pixel_values.to(torch_device)
input_mask = None

# forward pass
with torch.no_grad():
outputs = model(inputs=inputs, attention_mask=input_mask, interpolate_pos_encoding=True)
logits = outputs.logits

# verify logits
expected_shape = torch.Size((1, model.config.num_labels))
self.assertEqual(logits.shape, expected_shape)