diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 3ecb495b2a1d..bc9c3323426a 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -214,16 +214,20 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim return q_embed, k_embed -def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - output = (tensor * cos) + (rotate_half(tensor) * sin) - output = output.to(orig_dtype) - return output +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q = q.float() + k = k.float() + cos = cos.unsqueeze(-2) + sin = sin.unsqueeze(-2) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed class VisionRotaryEmbedding(nn.Module): @@ -300,12 +304,27 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.proj = nn.Linear(dim, dim) def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) attention_mask = torch.full( [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype @@ -334,12 +353,27 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.proj = nn.Linear(dim, dim) def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( @@ -357,12 +391,27 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.proj = nn.Linear(dim, dim) def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) for i in range(1, len(cu_seqlens)): @@ -396,9 +445,14 @@ def __init__(self, config, attn_implementation: str = "sdpa") -> None: ) self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + def forward( + self, hidden_states, cu_seqlens, rotary_pos_emb, position_embeddings: Optional[torch.Tensor] = None + ) -> torch.Tensor: hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -956,11 +1010,14 @@ def rot_pos_emb(self, grid_thw): max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.float(), sin.float() def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + position_embeddings = self.rot_pos_emb(grid_thw) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, @@ -975,10 +1032,12 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. for blk in self.blocks: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb + blk.__call__, hidden_states, cu_seqlens, None, position_embeddings ) else: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + hidden_states = blk( + hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=None, position_embeddings=position_embeddings + ) return self.merger(hidden_states)