Skip to content

Commit

Permalink
Optimize Qwen2VL vision model by precomputing cos/sin embeds before ViT
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Jan 22, 2025
1 parent a7738f5 commit 2e912f3
Showing 1 changed file with 32 additions and 25 deletions.
57 changes: 32 additions & 25 deletions src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,16 +240,20 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
return q_embed, k_embed


def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q = q.float()
k = k.float()
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed


class VisionRotaryEmbedding(nn.Module):
Expand Down Expand Up @@ -326,12 +330,12 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
self.proj = nn.Linear(dim, dim)

def forward(
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, position_embeddings: torch.Tensor
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)

attention_mask = torch.full(
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
Expand Down Expand Up @@ -360,12 +364,12 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
self.proj = nn.Linear(dim, dim)

def forward(
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, position_embeddings: torch.Tensor
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
Expand All @@ -383,12 +387,12 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
self.proj = nn.Linear(dim, dim)

def forward(
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, position_embeddings: torch.Tensor
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)

attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
Expand Down Expand Up @@ -422,9 +426,9 @@ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
)
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)

def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
def forward(self, hidden_states, cu_seqlens, position_embeddings) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
self.norm1(hidden_states), cu_seqlens=cu_seqlens, position_embeddings=position_embeddings
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
Expand Down Expand Up @@ -988,11 +992,14 @@ def rot_pos_emb(self, grid_thw):
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.float(), sin.float()

def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
position_embeddings = self.rot_pos_emb(grid_thw)

cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
Expand All @@ -1007,10 +1014,10 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
for blk in self.blocks:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb
blk.__call__, hidden_states, cu_seqlens, position_embeddings
)
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)

return self.merger(hidden_states)

Expand Down

0 comments on commit 2e912f3

Please sign in to comment.