diff --git a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py index c445f2ee21f7..26e80cad64cc 100644 --- a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +++ b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py @@ -89,7 +89,11 @@ def __init__(self, config: ViTPoseBackboneConfig) -> None: self.patch_embeddings = ViTPoseBackbonePatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + position_embeddings = torch.zeros(1, num_patches + 1, config.hidden_size) + # Pre-compute the modified position embeddings + self.position_embeddings = nn.Parameter( + position_embeddings[:, 1:] + position_embeddings[:, :1] + ) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config @@ -97,7 +101,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: embeddings = self.patch_embeddings(pixel_values) # add positional encoding to each token - embeddings = embeddings + self.position_embeddings[:, 1:] + self.position_embeddings[:, :1] + embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings)