Skip to content

Commit

Permalink
Make rotary_pos_emb optional & fix type
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Jan 25, 2025
1 parent ef6022e commit 64c2827
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def forward(
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
Expand Down Expand Up @@ -356,7 +356,7 @@ def forward(
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
Expand Down Expand Up @@ -394,7 +394,7 @@ def forward(
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
Expand Down Expand Up @@ -445,7 +445,11 @@ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)

def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, position_embeddings: Optional[torch.Tensor] = None
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
Expand Down Expand Up @@ -1034,9 +1038,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
)
else:
hidden_states = blk(
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=None, position_embeddings=position_embeddings
)
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)

return self.merger(hidden_states)

Expand Down

0 comments on commit 64c2827

Please sign in to comment.