diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index bc9c3323426a..b05aab08d817 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -308,7 +308,7 @@ def forward( hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) @@ -357,7 +357,7 @@ def forward( hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) @@ -395,7 +395,7 @@ def forward( hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) @@ -446,7 +446,11 @@ def __init__(self, config, attn_implementation: str = "sdpa") -> None: self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) def forward( - self, hidden_states, cu_seqlens, rotary_pos_emb, position_embeddings: Optional[torch.Tensor] = None + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), @@ -1035,9 +1039,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. blk.__call__, hidden_states, cu_seqlens, None, position_embeddings ) else: - hidden_states = blk( - hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=None, position_embeddings=position_embeddings - ) + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) return self.merger(hidden_states)