diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c807617a2c10d..c41903f84910d 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -280,7 +280,7 @@ Multimodal Language Models - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - * - :code:`Qwen2VLForConditionalGeneration` - - Qwen2-VL (see note) + - Qwen2-VL - Image\ :sup:`+` / Video\ :sup:`+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - @@ -297,15 +297,6 @@ Multimodal Language Models For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 -.. note:: - For :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. - This can be installed by running the following command: - - .. code-block:: bash - - pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830 - ----- If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. Otherwise, please refer to :ref:`Adding a New Model ` and :ref:`Enabling Multimodal Inputs ` diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 5f365bbc30670..d4853fd790098 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -25,6 +25,7 @@ import torch from torch import nn +from transformers import GraniteConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig @@ -48,7 +49,6 @@ default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.granite import GraniteConfig from vllm.utils import is_hip from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 889ebc6c2e1ff..f895e693b7107 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -31,12 +31,9 @@ import torch.nn.functional as F from einops import rearrange, repeat from PIL import Image -from transformers import Qwen2VLConfig from transformers.image_utils import (get_image_size, infer_channel_dimension_format, to_numpy_array) -from transformers.models.qwen2_vl.configuration_qwen2_vl import ( - Qwen2VLVisionConfig) from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( make_batched_images, make_batched_videos, smart_resize) @@ -66,6 +63,8 @@ from vllm.multimodal.image import cached_get_image_processor from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceData +from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig, + Qwen2VLVisionConfig) from vllm.transformers_utils.processor import get_processor from vllm.utils import is_cpu diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 3871c0cb8b819..0f20e8d0c8213 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -20,10 +20,10 @@ # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, ExaoneConfig, - GraniteConfig, InternVLChatConfig, - JAISConfig, MedusaConfig, - MllamaConfig, MLPSpeculatorConfig, - MPTConfig, NemotronConfig, + InternVLChatConfig, JAISConfig, + MedusaConfig, MllamaConfig, + MLPSpeculatorConfig, MPTConfig, + NemotronConfig, Qwen2VLConfig, RWConfig, SolarConfig, UltravoxConfig) # yapf: enable @@ -57,9 +57,7 @@ "nemotron": NemotronConfig, "solar": SolarConfig, "ultravox": UltravoxConfig, - # Granite can be removed from here once we have upgraded to - # transformers 4.45+ - "granite": GraniteConfig, + "qwen2_vl": Qwen2VLConfig, **_CONFIG_REGISTRY_OVERRIDE_HF } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index d5b13adb58a0b..462cd964325d2 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -6,7 +6,6 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig -from vllm.transformers_utils.configs.granite import GraniteConfig from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig @@ -14,6 +13,8 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig +from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig, + Qwen2VLVisionConfig) from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig @@ -32,7 +33,6 @@ "NemotronConfig", "SolarConfig", "UltravoxConfig", - # Granite can be removed from here once we have upgraded to - # transformers 4.45+ - "GraniteConfig", + "Qwen2VLConfig", + "Qwen2VLVisionConfig", ] diff --git a/vllm/transformers_utils/configs/granite.py b/vllm/transformers_utils/configs/granite.py deleted file mode 100644 index c12838be5d385..0000000000000 --- a/vllm/transformers_utils/configs/granite.py +++ /dev/null @@ -1,199 +0,0 @@ -# coding=utf-8 -# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Granite model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_rope_utils import rope_config_validation -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -class GraniteConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of - a [`GraniteModel`]. It is used to instantiate an Granite - model according to the specified arguments, defining the model architecture. - Instantiating a configuration with the defaults will yield a similar - configuration to that of the Granite-3B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to - control the model outputs. Read the documentation from [`PretrainedConfig`] - for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the Granite model. Defines the number of - different tokens that can be represented by the `inputs_ids` - passed when calling [`GraniteModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the - Transformer decoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to - implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi - Head Attention (MHA), if `num_key_value_heads=1` the model will use - Multi Query Attention (MQA) otherwise GQA is used. When converting - a multi-head checkpoint to a GQA checkpoint, each group key and - value head should be constructed by meanpooling all the original - heads within that group. For more details checkout - [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not - specified, will default to `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the - decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for - initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values - attentions (not used by all models). Only relevant if - `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE - embeddings. Currently supports two scaling strategies: linear and - dynamic. Their scaling factor must be a float greater than 1. The - expected format is - `{"type": strategy name, "factor": scaling factor}`. - When using this flag, don't update `max_position_embeddings` to - the expected new maximum. See the following thread for more - information on how these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. - This is an experimental feature, subject to breaking API changes - in future versions. - attention_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output - projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - mlp_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in up_proj, down_proj and gate_proj layers - in the MLP layers. - embedding_multiplier (`float`, *optional*, defaults to 1.0): - embedding multiplier - logits_scaling (`float`, *optional*, defaults to 1.0): - divisor for output logits - residual_multiplier (`float`, *optional*, defaults to 1.0): - residual multiplier - attention_multiplier (`float`, *optional*, defaults to 1.0): - attention multiplier - - ```python - >>> from transformers import GraniteModel, GraniteConfig - - >>> # Initializing a Granite granite-3b style configuration - >>> configuration = GraniteConfig() - - >>> # Initializing a model from the granite-7b style configuration - >>> model = GraniteModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "granite" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=False, - rope_theta=10000.0, - rope_scaling=None, - attention_bias=False, - attention_dropout=0.0, - mlp_bias=False, - embedding_multiplier=1.0, - logits_scaling=1.0, - residual_multiplier=1.0, - attention_multiplier=1.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.mlp_bias = mlp_bias - - self.embedding_multiplier = embedding_multiplier - self.logits_scaling = logits_scaling - self.residual_multiplier = residual_multiplier - self.attention_multiplier = attention_multiplier - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - rope_config_validation(self) diff --git a/vllm/transformers_utils/configs/qwen2vl.py b/vllm/transformers_utils/configs/qwen2vl.py new file mode 100644 index 0000000000000..92dd962790bc8 --- /dev/null +++ b/vllm/transformers_utils/configs/qwen2vl.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2VL model configuration""" + +import os +from typing import Union + +from transformers import PretrainedConfig + + +class Qwen2VLVisionConfig(PretrainedConfig): + model_type = "qwen2_vl" + + def __init__( + self, + depth=32, + embed_dim=1280, + hidden_size=3584, + hidden_act="quick_gelu", + mlp_ratio=4, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, + os.PathLike], + **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + + if config_dict.get("model_type") == "qwen2_vl": + config_dict = config_dict["vision_config"] + + return cls.from_dict(config_dict, **kwargs) + + +class Qwen2VLConfig(PretrainedConfig): + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = Qwen2VLVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Qwen2VLVisionConfig() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # NOTE: the following section from original transformers config + # for Qwen2-VL is commented out to address rope config loading issue + # + # if self.rope_scaling is not None and "type" in self.rope_scaling: + # if self.rope_scaling["type"] == "mrope": + # self.rope_scaling["type"] = "default" + # self.rope_scaling["rope_type"] = self.rope_scaling["type"] + # rope_config_validation(self) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)