Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] SiglipVisionModel ported from transformers #6942

Merged
merged 24 commits into from
Aug 5, 2024
Merged
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9222552
feat: initial siglip implementation
ChristopherCho Jul 30, 2024
5e09410
fix: typo fixed
ChristopherCho Jul 30, 2024
8af6456
fix: change paligemma to use ported siglip
ChristopherCho Jul 30, 2024
c00edeb
fix: style fixed
ChristopherCho Jul 30, 2024
db99a08
feat: modify paligemma to fully utilize siglip
ChristopherCho Jul 30, 2024
3e3b032
feat: sync model methods for paligemma
ChristopherCho Jul 30, 2024
f04da2b
fix: style fix
ChristopherCho Jul 30, 2024
b3ccec5
fix: sync with transformers siglip
ChristopherCho Jul 31, 2024
106e193
fix: style fix
ChristopherCho Jul 31, 2024
5b9242f
fix: faulty weight loading logic for vision model
ChristopherCho Jul 31, 2024
3dc8ea0
feat: add various attention mechanisms
ChristopherCho Jul 31, 2024
5afa010
fix: style update
ChristopherCho Jul 31, 2024
7fdb13d
fix: remove unnecessary comments
ChristopherCho Jul 31, 2024
c47e54a
fix: remove unrequired docstring
ChristopherCho Aug 5, 2024
cac1933
fix: remove unrequired docstring
ChristopherCho Aug 5, 2024
2d1aeec
fix: detach vllm attention
ChristopherCho Aug 5, 2024
bb570c3
fix: remove vllm attention
ChristopherCho Aug 5, 2024
dee55d0
fix: revert vision tower weight loading
ChristopherCho Aug 5, 2024
bffc385
fix: use basic SiglipAttention for now
ChristopherCho Aug 5, 2024
681b36d
fix: remove unnecessary weight loading logic
ChristopherCho Aug 5, 2024
fb4972d
cleanup
ywang96 Aug 5, 2024
d15a299
typing
ywang96 Aug 5, 2024
9ef79b9
update
ywang96 Aug 5, 2024
54faf5d
Merge remote-tracking branch 'upstream/main' into siglip-support
ywang96 Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: sync with transformers siglip
  • Loading branch information
ChristopherCho committed Jul 31, 2024
commit b3ccec5dbb860cc75843bb7ebf00a1f0873bb13d
64 changes: 22 additions & 42 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from PIL import Image
from torch import nn
from transformers import SiglipConfig, SiglipVisionConfig
from vllm_flash_attn import flash_attn_varlen_func
from transformers.modeling_flash_attention_utils import (
_upad_input,
pad_input,
)
from vllm_flash_attn import flash_attn_func

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, ModelConfig
Expand Down Expand Up @@ -365,6 +369,8 @@ def _basic_attention_forward(self, q, k, v, batch_size, q_len, *args,

return attn_output

# Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
# and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
**kwargs):
"""Implements the multihead softmax attention.
Expand All @@ -377,43 +383,18 @@ def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)

seqlen_k = k.shape[1]

# goes for cuda device
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=q.device,
)

# during training q,k,v always have same seqlen
assert seqlen_k == q_len

cu_seqlens_k = cu_seqlens_q
dropout_p = self.dropout if self.training else 0.0

output = flash_attn_varlen_func(

attn_output = flash_attn_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
q_len,
seqlen_k,
dropout_p,
softmax_scale=None,
dropout_p=0.0,
causal=False,
)

attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()

output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
output = output.reshape(batch_size, q_len, self.embed_dim).contiguous()

return output

return attn_output

class SiglipMLP(nn.Module):

Expand Down Expand Up @@ -534,19 +515,19 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> Tuple:
last_hidden_state = inputs_embeds
hidden_states = (last_hidden_state, )
hidden_states = inputs_embeds
for encoder_layer in self.layers:
encoder_states = (hidden_states, )
layer_outputs = encoder_layer(
hidden_states=last_hidden_state,
hidden_states=hidden_states,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
hidden_states = layer_outputs[0]

encoder_states = encoder_states + (hidden_states, )

last_hidden_state = layer_outputs[0]
hidden_states = hidden_states + (last_hidden_state, )

return (last_hidden_state, hidden_states)
return (hidden_states, encoder_states)


class SiglipVisionTransformer(nn.Module):
Expand Down Expand Up @@ -616,8 +597,7 @@ def __init__(
super().__init__()

self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.head_size = config.hidden_size // config.num_attention_heads
self.scaling = self.head_size**-0.5
# TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
self.attention = torch.nn.MultiheadAttention(
config.hidden_size, config.num_attention_heads, batch_first=True)
self.layernorm = nn.LayerNorm(config.hidden_size,
Expand Down Expand Up @@ -662,7 +642,7 @@ def forward(
pixel_values,
kv_caches: List[torch.Tensor] = None,
attn_metadata: AttentionMetadata = None,
interpolate_pos_encoding: Optional[bool] = False, # added by eric
interpolate_pos_encoding: Optional[bool] = False,
) -> Tuple:
return self.vision_model(
pixel_values=pixel_values,
Expand Down
Loading