From 5e92b92237732baf02c0561381f4a5cb3fe621e1 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 18 Nov 2024 14:46:29 +0800 Subject: [PATCH 1/8] modify clip Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/clip.py | 61 +++++++++++++++-------------- vllm/model_executor/models/utils.py | 3 ++ 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 184758f4a8a45..56ec1d0c19e39 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -5,10 +5,11 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from PIL import Image from transformers import CLIPVisionConfig -from transformers.models.clip.modeling_clip import CLIPSdpaAttention +from vllm.attention.selector import _Backend from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import DecoderOnlyInputs, token_inputs @@ -23,11 +24,7 @@ repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData -try: - from xformers import ops as xops - USE_XFORMERS_OPS = True -except ImportError: - USE_XFORMERS_OPS = False +from .utils import get_vit_attn_backend def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -197,7 +194,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings -class CLIPParallelAttention(nn.Module): +class CLIPAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( @@ -237,6 +234,9 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + # Detect attention implementation. + self.attn_backend = get_vit_attn_backend() + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -261,11 +261,25 @@ def forward( self.num_heads_per_partition, self.head_dim) - out = xops.memory_efficient_attention_forward(query_states, - key_states, - value_states, - p=self.dropout, - scale=self.scale) + if self.attn_backend in (_Backend.XFORMERS, _Backend.FLASH_ATTN): + from xformers import ops as xops + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + elif self.attn_backend == _Backend.TORCH_SDPA: + query_states, key_states, value_states = (x.transpose(1, 2) + for x in (query_states, + key_states, + value_states)) + out = F.scaled_dot_product_attention(query_states, + key_states, + value_states, + dropout_p=0.0) + out = out.transpose(1, 2) + out = out.view(bsz, tgt_len, -1) attn_output, _ = self.out_proj(out) @@ -311,17 +325,11 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - - num_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - if USE_XFORMERS_OPS and num_heads % tp_size == 0: - self.self_attn = CLIPParallelAttention( - config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - else: - self.self_attn = CLIPSdpaAttention(config) + self.self_attn = CLIPAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = CLIPMLP(config, @@ -461,11 +469,6 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - - tp_size = get_tensor_model_parallel_world_size() - num_heads = config.num_attention_heads - self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 - self.vision_model = CLIPVisionTransformer( config=config, quant_config=quant_config, @@ -490,7 +493,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - ] if self.shard_weight else [] + ] params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() layer_count = len(self.vision_model.encoder.layers) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 7a4fcce95603d..68f5383879942 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -588,6 +588,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: def get_vit_attn_backend() -> _Backend: + """ + Get the available attention backend for Vision Transformer. + """ selected_backend: Optional[_Backend] = get_global_forced_attn_backend() if selected_backend is None: backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND From 622c7aae0830c89dbc97c15614a5606b7658c835 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 18 Nov 2024 15:21:50 +0800 Subject: [PATCH 2/8] modify blip attention Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/blip.py | 62 +++++++++++++++--------------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 6db6462e97f3f..bc132ca6d7076 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -4,10 +4,11 @@ import torch import torch.nn as nn +import torch.nn.functional as F from PIL import Image from transformers import Blip2VisionConfig, BlipVisionConfig -from transformers.models.blip.modeling_blip import BlipAttention +from vllm.attention.selector import _Backend from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import DecoderOnlyInputs, token_inputs @@ -21,11 +22,7 @@ repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData -try: - from xformers import ops as xops - USE_XFORMERS_OPS = True -except ImportError: - USE_XFORMERS_OPS = False +from .utils import get_vit_attn_backend def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -168,7 +165,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings -class BlipParallelAttention(nn.Module): +class BlipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( @@ -208,6 +205,9 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + # Detect attention implementation. + self.attn_backend = get_vit_attn_backend() + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -231,11 +231,25 @@ def forward( self.num_heads_per_partition, self.head_dim) - out = xops.memory_efficient_attention_forward(query_states, - key_states, - value_states, - p=self.dropout, - scale=self.scale) + if self.attn_backend in (_Backend.XFORMERS, _Backend.FLASH_ATTN): + from xformers import ops as xops + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + elif self.attn_backend == _Backend.TORCH_SDPA: + query_states, key_states, value_states = (x.transpose(1, 2) + for x in (query_states, + key_states, + value_states)) + out = F.scaled_dot_product_attention(query_states, + key_states, + value_states, + dropout_p=0.0) + out = out.transpose(1, 2) + out = out.view(bsz, tgt_len, -1) attn_output, _ = self.projection(out) @@ -285,18 +299,11 @@ def __init__( super().__init__() # fallback to sdpa attention if tp unavailable - num_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - if USE_XFORMERS_OPS and num_heads % tp_size == 0: - self.self_attn = BlipParallelAttention( - config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - else: - # Blip doesn't have SDPA attention implemented in transformers - # use eager attention instead for cpu backend - self.self_attn = BlipAttention(config) + self.self_attn = BlipAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = BlipMLP(config, @@ -374,11 +381,6 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - - tp_size = get_tensor_model_parallel_world_size() - num_heads = config.num_attention_heads - self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 - self.config = config self.embeddings = BlipVisionEmbeddings(config) @@ -422,7 +424,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - ] if self.shard_weight else [] + ] params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() layer_count = len(self.encoder.layers) From 301d21cd85631e3d1bfebc0c0b1fac7d8efdc44b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 18 Nov 2024 15:29:53 +0800 Subject: [PATCH 3/8] modify siglip Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/siglip.py | 59 +++++++++++++++------------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index c9e09b879843a..1724a76eea8be 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -6,11 +6,12 @@ import numpy as np import torch +import torch.nn.functional as F from PIL import Image from torch import nn from transformers import SiglipVisionConfig -from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention +from vllm.attention.selector import _Backend from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import DecoderOnlyInputs, token_inputs @@ -27,11 +28,7 @@ repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData -try: - from xformers import ops as xops - USE_XFORMERS_OPS = True -except ImportError: - USE_XFORMERS_OPS = False +from .utils import get_vit_attn_backend def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -254,7 +251,7 @@ def forward(self, return embeddings -class SiglipParallelAttention(nn.Module): +class SiglipAttention(nn.Module): def __init__( self, @@ -293,6 +290,8 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + self.attn_backend = get_vit_attn_backend() + def forward( self, hidden_states: torch.Tensor, @@ -313,11 +312,25 @@ def forward( self.num_heads_per_partition, self.head_dim) - out = xops.memory_efficient_attention_forward(query_states, - key_states, - value_states, - p=self.dropout, - scale=self.scale) + if self.attn_backend in (_Backend.XFORMERS, _Backend.FLASH_ATTN): + from xformers import ops as xops + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + elif self.attn_backend == _Backend.TORCH_SDPA: + query_states, key_states, value_states = (x.transpose(1, 2) + for x in (query_states, + key_states, + value_states)) + out = F.scaled_dot_product_attention(query_states, + key_states, + value_states, + dropout_p=0.0) + out = out.transpose(1, 2) + out = out.view(batch_size, q_len, -1) attn_output, _ = self.out_proj(out) @@ -372,17 +385,11 @@ def __init__( self.embed_dim = config.hidden_size - num_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - if USE_XFORMERS_OPS and num_heads % tp_size == 0: - self.self_attn = SiglipParallelAttention( - config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - else: - self.self_attn = SiglipSdpaAttention(config) - + self.self_attn = SiglipAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -569,10 +576,6 @@ def __init__( ) -> None: super().__init__() - num_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 - self.vision_model = SiglipVisionTransformer( config, quant_config, @@ -601,7 +604,7 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - ] if self.shard_weight else [] + ] params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() layer_count = len(self.vision_model.encoder.layers) From 07cb82a5cc68336a8df0ab6bc93162eed610c348 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 18 Nov 2024 15:50:16 +0800 Subject: [PATCH 4/8] modify intern vit Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/blip.py | 3 ++- vllm/model_executor/models/clip.py | 3 ++- vllm/model_executor/models/intern_vit.py | 25 ++++++++++++++++-------- vllm/model_executor/models/siglip.py | 3 ++- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index bc132ca6d7076..fa90ad8f42f42 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -247,7 +247,8 @@ def forward( out = F.scaled_dot_product_attention(query_states, key_states, value_states, - dropout_p=0.0) + dropout_p=self.dropout, + scale=self.scale) out = out.transpose(1, 2) out = out.view(bsz, tgt_len, -1) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 56ec1d0c19e39..94e5bdce46cd3 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -277,7 +277,8 @@ def forward( out = F.scaled_dot_product_attention(query_states, key_states, value_states, - dropout_p=0.0) + dropout_p=self.dropout, + scale=self.scale) out = out.transpose(1, 2) out = out.view(bsz, tgt_len, -1) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index bd91a0806ae5c..f2bab381cbf10 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from transformers import PretrainedConfig +from vllm.attention.selector import _Backend from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -24,11 +25,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -try: - from xformers import ops as xops - USE_XFORMERS_OPS = True -except ImportError: - USE_XFORMERS_OPS = False +from .utils import get_vit_attn_backend NORM2FN = { 'rms_norm': RMSNorm, @@ -186,6 +183,8 @@ def __init__( prefix=f"{prefix}.proj", ) + self.attn_backend = get_vit_attn_backend() + def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) @@ -211,9 +210,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: k = k.view(B, N, self.num_heads_per_partition, self.head_dim) v = v.view(B, N, self.num_heads_per_partition, self.head_dim) - x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale) - x = x.view(B, N, -1) + if self.attn_backend in (_Backend.XFORMERS, _Backend.FLASH_ATTN): + from xformers import ops as xops + out = xops.memory_efficient_attention_forward(q, + k, + v, + scale=self.scale) + elif self.attn_backend == _Backend.TORCH_SDPA: + q, k, v = (x.transpose(1, 2) for x in (q, k, v)) + out = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + out = out.transpose(1, 2) + + x = x.view(B, N, -1) x, _ = self.proj(x) return x @@ -362,7 +371,7 @@ def _init_attn( tp_size = get_tensor_model_parallel_world_size() num_heads = config.num_attention_heads - if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0: + if (num_heads + num_dummy_heads) % tp_size == 0: return InternParallelAttention(config, quant_config=quant_config, num_dummy_heads=num_dummy_heads, diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 1724a76eea8be..0888f3c2d2c0f 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -328,7 +328,8 @@ def forward( out = F.scaled_dot_product_attention(query_states, key_states, value_states, - dropout_p=0.0) + dropout_p=self.dropout, + scale=self.scale) out = out.transpose(1, 2) out = out.view(batch_size, q_len, -1) From 9d34464dbcb3039796ecd367ee98c1b9b091ff09 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 18 Nov 2024 17:20:37 +0800 Subject: [PATCH 5/8] fix typo in intern vit Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/intern_vit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index f2bab381cbf10..e6c3e499fcd7a 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -222,8 +222,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out = F.scaled_dot_product_attention(q, k, v, scale=self.scale) out = out.transpose(1, 2) - x = x.view(B, N, -1) - x, _ = self.proj(x) + out = out.view(B, N, -1) + out, _ = self.proj(out) return x From 61d9d0759257777d439d115927dd1ecbb4427052 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 18 Nov 2024 17:23:54 +0800 Subject: [PATCH 6/8] fix typo in intern vit Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/intern_vit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index e6c3e499fcd7a..22421d124c3fb 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -224,7 +224,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out = out.view(B, N, -1) out, _ = self.proj(out) - return x + return out class InternSdpaAttention(nn.Module): From 94d14e6198352b37ace1520137baf029a5998d96 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 18 Nov 2024 19:12:12 +0800 Subject: [PATCH 7/8] use SDPA for ROCM Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 68f5383879942..fcea9225abb00 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -609,7 +609,8 @@ def get_vit_attn_backend() -> _Backend: "so we use xformers backend instead. You can run " "`pip install flash-attn` to use flash-attention backend.") selected_backend = _Backend.XFORMERS - elif current_platform.is_cpu(): + elif current_platform.is_cpu() or current_platform.is_rocm(): + # ROCM doesn't support xformers selected_backend = _Backend.TORCH_SDPA else: selected_backend = _Backend.XFORMERS From b62850bae0b1618736a2b530784dd4d7ff3a7b9e Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 18 Nov 2024 19:57:23 +0800 Subject: [PATCH 8/8] fix FA on ROCM Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/blip.py | 7 +++++-- vllm/model_executor/models/clip.py | 7 +++++-- vllm/model_executor/models/intern_vit.py | 7 +++++-- vllm/model_executor/models/molmo.py | 2 +- vllm/model_executor/models/qwen2_vl.py | 2 +- vllm/model_executor/models/siglip.py | 7 +++++-- vllm/model_executor/models/utils.py | 5 +++-- 7 files changed, 25 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index fa90ad8f42f42..6af59697160a0 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -206,7 +206,10 @@ def __init__( self.num_heads_per_partition = divide(self.num_heads, self.tp_size) # Detect attention implementation. - self.attn_backend = get_vit_attn_backend() + self.attn_backend = get_vit_attn_backend(support_fa=False) + if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: + raise RuntimeError( + f"BLIP does not support {self.attn_backend} backend now.") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, @@ -231,7 +234,7 @@ def forward( self.num_heads_per_partition, self.head_dim) - if self.attn_backend in (_Backend.XFORMERS, _Backend.FLASH_ATTN): + if self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward(query_states, diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 94e5bdce46cd3..7f638506f9fb2 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -235,7 +235,10 @@ def __init__( self.num_heads_per_partition = divide(self.num_heads, self.tp_size) # Detect attention implementation. - self.attn_backend = get_vit_attn_backend() + self.attn_backend = get_vit_attn_backend(support_fa=False) + if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: + raise RuntimeError( + f"CLIP does not support {self.attn_backend} backend now.") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, @@ -261,7 +264,7 @@ def forward( self.num_heads_per_partition, self.head_dim) - if self.attn_backend in (_Backend.XFORMERS, _Backend.FLASH_ATTN): + if self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward(query_states, diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 22421d124c3fb..c4346fcb3bd2a 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -183,7 +183,10 @@ def __init__( prefix=f"{prefix}.proj", ) - self.attn_backend = get_vit_attn_backend() + self.attn_backend = get_vit_attn_backend(support_fa=False) + if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: + raise RuntimeError( + f"InternViT does not support {self.attn_backend} backend now.") def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): if self.tp_size > 1: @@ -210,7 +213,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: k = k.view(B, N, self.num_heads_per_partition, self.head_dim) v = v.view(B, N, self.num_heads_per_partition, self.head_dim) - if self.attn_backend in (_Backend.XFORMERS, _Backend.FLASH_ATTN): + if self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward(q, diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 035a1e2ab7b02..a7c90a3f5031b 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -187,7 +187,7 @@ def __init__( ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend() + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS }: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index ef6b52db6e17d..a929b9323b245 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -260,7 +260,7 @@ def __init__( prefix=f"{prefix}.proj") # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend() + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS }: diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 0888f3c2d2c0f..c58ad99692900 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -290,7 +290,10 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn_backend = get_vit_attn_backend() + self.attn_backend = get_vit_attn_backend(support_fa=False) + if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: + raise RuntimeError( + f"SIGLIP does not support {self.attn_backend} backend now.") def forward( self, @@ -312,7 +315,7 @@ def forward( self.num_heads_per_partition, self.head_dim) - if self.attn_backend in (_Backend.XFORMERS, _Backend.FLASH_ATTN): + if self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward(query_states, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index fcea9225abb00..03226f42ee053 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -587,10 +587,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return llm(*args, **kwargs) -def get_vit_attn_backend() -> _Backend: +def get_vit_attn_backend(support_fa: bool = False) -> _Backend: """ Get the available attention backend for Vision Transformer. """ + # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn. selected_backend: Optional[_Backend] = get_global_forced_attn_backend() if selected_backend is None: backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND @@ -599,7 +600,7 @@ def get_vit_attn_backend() -> _Backend: if selected_backend is None: # For Volta and Turing GPUs, use xformers instead. device_available = current_platform.has_device_capability(80) - if device_available: + if device_available and support_fa: from transformers.utils import is_flash_attn_2_available if is_flash_attn_2_available(): selected_backend = _Backend.FLASH_ATTN