Skip to content

Commit

Permalink
fix pos_embed
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Dec 16, 2024
1 parent f064009 commit 2c56a48
Showing 1 changed file with 6 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,19 @@ def __init__(self, config: ViTPoseBackboneConfig) -> None:

self.patch_embeddings = ViTPoseBackbonePatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
position_embeddings = torch.zeros(1, num_patches + 1, config.hidden_size)
# Pre-compute the modified position embeddings
self.position_embeddings = nn.Parameter(
position_embeddings[:, 1:] + position_embeddings[:, :1]
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config

def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
embeddings = self.patch_embeddings(pixel_values)

# add positional encoding to each token
embeddings = embeddings + self.position_embeddings[:, 1:] + self.position_embeddings[:, :1]
embeddings = embeddings + self.position_embeddings

embeddings = self.dropout(embeddings)

Expand Down

0 comments on commit 2c56a48

Please sign in to comment.