Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Remove transformers attention porting in VITs #10414

Merged
merged 8 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 36 additions & 30 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig
from transformers.models.blip.modeling_blip import BlipAttention

from vllm.attention.selector import _Backend
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
Expand All @@ -21,11 +22,7 @@
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
from .utils import get_vit_attn_backend


def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
Expand Down Expand Up @@ -168,7 +165,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class BlipParallelAttention(nn.Module):
class BlipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
Expand Down Expand Up @@ -208,6 +205,12 @@ def __init__(
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"BLIP does not support {self.attn_backend} backend now.")

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
Expand All @@ -231,11 +234,26 @@ def forward(
self.num_heads_per_partition,
self.head_dim)

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)

out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out)

Expand Down Expand Up @@ -285,18 +303,11 @@ def __init__(
super().__init__()

# fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self.self_attn = BlipAttention(config)
self.self_attn = BlipAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config,
Expand Down Expand Up @@ -374,11 +385,6 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()

tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0

self.config = config

self.embeddings = BlipVisionEmbeddings(config)
Expand Down Expand Up @@ -422,7 +428,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.encoder.layers)
Expand Down
65 changes: 36 additions & 29 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPSdpaAttention

from vllm.attention.selector import _Backend
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
Expand All @@ -23,11 +24,7 @@
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
from .utils import get_vit_attn_backend


def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
Expand Down Expand Up @@ -197,7 +194,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class CLIPParallelAttention(nn.Module):
class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
Expand Down Expand Up @@ -237,6 +234,12 @@ def __init__(
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"CLIP does not support {self.attn_backend} backend now.")

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
Expand All @@ -261,11 +264,26 @@ def forward(
self.num_heads_per_partition,
self.head_dim)

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)

out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out)

Expand Down Expand Up @@ -311,17 +329,11 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()

num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = CLIPParallelAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
self.self_attn = CLIPSdpaAttention(config)
self.self_attn = CLIPAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config,
Expand Down Expand Up @@ -461,11 +473,6 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()

tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0

self.vision_model = CLIPVisionTransformer(
config=config,
quant_config=quant_config,
Expand All @@ -490,7 +497,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.vision_model.encoder.layers)
Expand Down
32 changes: 22 additions & 10 deletions vllm/model_executor/models/intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn.functional as F
from transformers import PretrainedConfig

from vllm.attention.selector import _Backend
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
Expand All @@ -24,11 +25,7 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
from .utils import get_vit_attn_backend

NORM2FN = {
'rms_norm': RMSNorm,
Expand Down Expand Up @@ -186,6 +183,11 @@ def __init__(
prefix=f"{prefix}.proj",
)

self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"InternViT does not support {self.attn_backend} backend now.")

def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
Expand All @@ -211,11 +213,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)

x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
x = x.view(B, N, -1)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

x, _ = self.proj(x)
return x
out = xops.memory_efficient_attention_forward(q,
k,
v,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(1, 2)

out = out.view(B, N, -1)
out, _ = self.proj(out)
return out


class InternSdpaAttention(nn.Module):
Expand Down Expand Up @@ -362,7 +374,7 @@ def _init_attn(
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads

if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
if (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config,
quant_config=quant_config,
num_dummy_heads=num_dummy_heads,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(
)

# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend()
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
}:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def __init__(
prefix=f"{prefix}.proj")

# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend()
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
}:
Expand Down
Loading