From 2c56a4806e30bc9b5753b142fa04b913306c54ff Mon Sep 17 00:00:00 2001 From: sangbumchoi Date: Mon, 16 Dec 2024 08:56:13 +0000 Subject: [PATCH] fix pos_embed --- .../models/vitpose_backbone/modeling_vitpose_backbone.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py index c445f2ee21f7..26e80cad64cc 100644 --- a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +++ b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py @@ -89,7 +89,11 @@ def __init__(self, config: ViTPoseBackboneConfig) -> None: self.patch_embeddings = ViTPoseBackbonePatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + position_embeddings = torch.zeros(1, num_patches + 1, config.hidden_size) + # Pre-compute the modified position embeddings + self.position_embeddings = nn.Parameter( + position_embeddings[:, 1:] + position_embeddings[:, :1] + ) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config @@ -97,7 +101,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: embeddings = self.patch_embeddings(pixel_values) # add positional encoding to each token - embeddings = embeddings + self.position_embeddings[:, 1:] + self.position_embeddings[:, :1] + embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings)