diff --git a/xfuser/model_executor/layers/embeddings.py b/xfuser/model_executor/layers/embeddings.py index 843e667..ca4d58b 100644 --- a/xfuser/model_executor/layers/embeddings.py +++ b/xfuser/model_executor/layers/embeddings.py @@ -193,6 +193,6 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): ] pos_embedding = torch.cat(pos_embed_list, dim=1) - embeds[:, self.max_text_seq_length :] += pos_embedding + embeds[:, text_embeds.shape[1] :] += pos_embedding return embeds