diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index bff557d7fc92f..d2b140e718501 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -52,7 +52,6 @@ steps:
- tests/worker
- tests/standalone_tests/lazy_torch_compile.py
commands:
- - pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git # Used by multimoda processing test
- python3 standalone_tests/lazy_torch_compile.py
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index d07cde3db5c6e..2edb610ddf959 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -767,16 +767,10 @@ See [this page](#generative-models) for more information on how to use generativ
E Pre-computed embeddings can be inputted for this modality.
+ Multiple items can be inputted per text prompt for this modality.
-````{note}
-To use `DeepSeek-VL2` series models, you need to install a fork version `deepseek_vl2` package:
-
-```shell
-pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git
+```{note}
+To use `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
```
-Besides, to run `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
-````
-
```{note}
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
```
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index cf3c5dd4e0a2c..43c44fa867e0a 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -393,7 +393,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
model_example_map = {
"aria": load_aria,
- "deepseek_vl2": load_deepseek_vl2,
+ "deepseek_vl_v2": load_deepseek_vl2,
"h2ovl_chat": load_h2onvl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index 5710303548c34..ca572cc39e538 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -190,7 +190,7 @@
dtype="bfloat16",
),
"deepseek_vl_v2": VLMTestInfo(
- models=["deepseek-ai/deepseek-vl2-tiny"],
+ models=["Isotr0py/deepseek-vl2-tiny"], # model repo using dynamic module
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501
max_model_len=4096,
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index 0a38779e0e4f0..1e3e7ea50b122 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -22,6 +22,8 @@ def _test_processing_correctness(
):
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
+ elif model_id == "deepseek-ai/deepseek-vl2-tiny":
+ hf_overrides = {"architectures": ["DeepseekVLV2ForCausalLM"]}
else:
hf_overrides = {}
@@ -139,6 +141,7 @@ def _test_processing_correctness(
("rhymes-ai/Aria", {"image": True}),
("Salesforce/blip2-opt-2.7b", {"image": False}),
("facebook/chameleon-7b", {"image": False}),
+ ("deepseek-ai/deepseek-vl2-tiny", {"image": True}),
("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}),
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py
index 4553695022169..4d3d1c329a2c0 100644
--- a/vllm/model_executor/models/deepseek_vl2.py
+++ b/vllm/model_executor/models/deepseek_vl2.py
@@ -1,7 +1,7 @@
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
import math
-from functools import cached_property, partial
+from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
@@ -9,7 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
-from transformers import AutoProcessor, BatchFeature, ProcessorMixin
+from transformers import BatchFeature
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
@@ -31,6 +31,8 @@
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
MlpProjectorConfig,
VisionEncoderConfig)
+from vllm.transformers_utils.processors.deepseek_vl2 import (
+ DeepseekVLV2Processor)
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
@@ -129,25 +131,8 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(DeepseekVLV2Config)
- def get_hf_processor(self) -> ProcessorMixin:
- # TODO(Isotr0py): we should get rid of dependency on deepseek_vl2
- # in the future, because it's flasky and lack of maintenance.
- try:
- from deepseek_vl2.models.processing_deepseek_vl_v2 import (
- DeepseekVLV2Processor, select_best_resolution)
- AutoProcessor.register("DeepseekVLV2Processor",
- DeepseekVLV2Processor)
- except ModuleNotFoundError as exc:
- raise ModuleNotFoundError(
- "You need to `pip install "
- "git+https://github.com/deepseek-ai/DeepSeek-VL2.git` "
- "to use this model") from exc
-
- processor = self.ctx.get_hf_processor(DeepseekVLV2Processor)
- processor.select_best_resolution = partial(
- select_best_resolution,
- candidate_resolutions=processor.candidate_resolutions)
- return processor
+ def get_hf_processor(self) -> DeepseekVLV2Processor:
+ return self.ctx.get_hf_processor(DeepseekVLV2Processor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
@@ -224,31 +209,21 @@ def _call_hf_processor(
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
- outputs = self.info.ctx.call_hf_processor(
+ processed_outputs = self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(prompt=prompt, **mm_data),
mm_kwargs,
)
-
- # Deepseek-vl2 processor don't return BatchFeature,
- # we need to manually create it
- processed_outputs = dict(input_ids=outputs["input_ids"])
- processed_outputs = BatchFeature(data=dict(processed_outputs),
- tensor_type="pt")
-
- # Remove batch dimension from processor outputs,
- # because we will try batch to create NestedTensors
target_dtype = self.info.ctx.model_config.dtype
- pixel_values = outputs["images"].to(target_dtype).squeeze(0)
- images_spatial_crop = outputs["images_spatial_crop"].squeeze(0)
+ pixel_values = processed_outputs.pop("pixel_values").to(
+ target_dtype)
+ # split pixel values into patches corresponding to each image
+ images_spatial_crop = processed_outputs["images_spatial_crop"]
patches_per_image = [
x.prod().item() + 1 for x in images_spatial_crop
]
-
- # Rename `images` -> `pixel_values` to avoid confusion
- processed_outputs["pixel_values"] = list(
- pixel_values.split(patches_per_image))
- processed_outputs["images_spatial_crop"] = images_spatial_crop
+ pixel_values = pixel_values.split(patches_per_image)
+ processed_outputs["pixel_values"] = pixel_values
else:
tokenizer = self.info.get_tokenizer()
processed_outputs = tokenizer(prompt,
diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py
new file mode 100644
index 0000000000000..9c71b8cada32e
--- /dev/null
+++ b/vllm/transformers_utils/processors/__init__.py
@@ -0,0 +1,4 @@
+from vllm.transformers_utils.processors.deepseek_vl2 import (
+ DeepseekVLV2Processor)
+
+__all__ = ["DeepseekVLV2Processor"]
diff --git a/vllm/transformers_utils/processors/deepseek_vl2.py b/vllm/transformers_utils/processors/deepseek_vl2.py
new file mode 100644
index 0000000000000..27cdf6bc22d0e
--- /dev/null
+++ b/vllm/transformers_utils/processors/deepseek_vl2.py
@@ -0,0 +1,361 @@
+# yapf: disable
+# ruff: noqa: E501
+# coding=utf-8
+# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/ff23960c5cf9e6874b44be38af930cfb0ccbb620/deepseek_vl2/models/processing_deepseek_vl_v2.py
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+import math
+from typing import List, Tuple
+
+import torch
+import torchvision.transforms as T
+from PIL import Image, ImageOps
+from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast
+from transformers.processing_utils import ProcessorMixin
+
+
+class ImageTransform:
+
+ def __init__(self,
+ mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
+ std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
+ normalize: bool = True):
+ self.mean = mean
+ self.std = std
+ self.normalize = normalize
+
+ transform_pipelines = [T.ToTensor()]
+
+ if normalize:
+ transform_pipelines.append(T.Normalize(mean, std))
+
+ self.transform = T.Compose(transform_pipelines)
+
+ def __call__(self, pil_img: Image.Image):
+ x = self.transform(pil_img)
+ return x
+
+
+class DeepseekVLV2Processor(ProcessorMixin):
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
+ attributes = ["tokenizer"]
+
+ def __init__(
+ self,
+ tokenizer: LlamaTokenizerFast,
+ candidate_resolutions: Tuple[Tuple[int, int]],
+ patch_size: int,
+ downsample_ratio: int,
+ image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
+ image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
+ normalize: bool = True,
+ image_token: str = "",
+ pad_token: str = "<|▁pad▁|>",
+ add_special_token: bool = False,
+ sft_format: str = "deepseek",
+ mask_prompt: bool = True,
+ ignore_id: int = -100,
+ **kwargs,
+ ):
+
+ self.candidate_resolutions = candidate_resolutions
+ self.image_size = candidate_resolutions[0][0]
+ self.patch_size = patch_size
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.normalize = normalize
+ self.downsample_ratio = downsample_ratio
+
+ self.image_transform = ImageTransform(mean=image_mean, std=image_std, normalize=normalize)
+ self.tokenizer = tokenizer
+ self.tokenizer.padding_side = 'left' # must set this,padding side with make a difference in batch inference
+
+ # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
+ if tokenizer.pad_token is None:
+ self.tokenizer.add_special_tokens({'pad_token': pad_token})
+
+ # add image token
+ image_token_id = self.tokenizer.vocab.get(image_token)
+ if image_token_id is None:
+ special_tokens = [image_token]
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
+ self.tokenizer.add_special_tokens(special_tokens_dict)
+ self.image_token_id = self.tokenizer.vocab.get(image_token)
+
+ # add five special tokens for grounding-related tasks
+ # <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
+ special_tokens = ['<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>']
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
+ self.tokenizer.add_special_tokens(special_tokens_dict)
+
+ # add special tokens for SFT data
+ special_tokens = ["<|User|>", "<|Assistant|>"]
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
+ self.tokenizer.add_special_tokens(special_tokens_dict)
+
+ self.image_token = image_token
+ self.pad_token = pad_token
+ self.add_special_token = add_special_token
+ self.sft_format = sft_format
+ self.mask_prompt = mask_prompt
+ self.ignore_id = ignore_id
+
+ super().__init__(
+ tokenizer,
+ **kwargs,
+ )
+
+ def select_best_resolution(self, image_size):
+ # used for cropping
+ original_width, original_height = image_size
+ best_fit = None
+ max_effective_resolution = 0
+ min_wasted_resolution = float("inf")
+
+ for width, height in self.candidate_resolutions:
+ scale = min(width / original_width, height / original_height)
+ downscaled_width, downscaled_height = int(
+ original_width * scale), int(original_height * scale)
+ effective_resolution = min(downscaled_width * downscaled_height,
+ original_width * original_height)
+ wasted_resolution = (width * height) - effective_resolution
+
+ if effective_resolution > max_effective_resolution or (
+ effective_resolution == max_effective_resolution
+ and wasted_resolution < min_wasted_resolution):
+ max_effective_resolution = effective_resolution
+ min_wasted_resolution = wasted_resolution
+ best_fit = (width, height)
+
+ return best_fit
+
+ @property
+ def bos_id(self):
+ return self.tokenizer.bos_token_id
+
+ @property
+ def eos_id(self):
+ return self.tokenizer.eos_token_id
+
+ @property
+ def pad_id(self):
+ return self.tokenizer.pad_token_id
+
+ def encode(self, text: str, bos: bool = True, eos: bool = False):
+ t = self.tokenizer.encode(text, add_special_tokens=False)
+
+ if bos:
+ t = [self.bos_id] + t
+ if eos:
+ t = t + [self.eos_id]
+
+ return t
+
+ def decode(self, t: List[int], **kwargs) -> str:
+ return self.tokenizer.decode(t, **kwargs)
+
+ def process_one(
+ self,
+ prompt: str,
+ images: List[Image.Image],
+ inference_mode: bool = True,
+ **kwargs,
+ ):
+ """
+
+ Args:
+ prompt (str): the formatted prompt;
+ conversations (List[Dict]): conversations with a list of messages;
+ images (List[ImageType]): the list of images;
+ inference_mode (bool): if True, then remove the last eos token;
+ system_prompt (str): the system prompt;
+ **kwargs:
+
+ Returns:
+ outputs (BaseProcessorOutput): the output of the processor,
+ - input_ids (torch.LongTensor): [N + image tokens]
+ - target_ids (torch.LongTensor): [N + image tokens]
+ - pixel_values (torch.FloatTensor): [n_patches, 3, H, W]
+ - image_id (int): the id of the image token
+ - num_image_tokens (List[int]): the number of image tokens
+ """
+
+ assert (prompt is not None and images is not None
+ ), "prompt and images must be used at the same time."
+
+ sft_format = prompt
+ tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens = self.tokenize_with_images(
+ sft_format, images, bos=True, eos=True, cropping=len(images) <= 2)
+ masked_tokenized_str = []
+ for token_index in tokenized_str:
+ if token_index != self.image_token_id:
+ masked_tokenized_str.append(token_index)
+ else:
+ masked_tokenized_str.append(self.ignore_id)
+
+ assert len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str), \
+ (f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
+ f"imags_seq_mask's length {len(images_seq_mask)}, are not equal")
+
+ input_ids = torch.LongTensor(tokenized_str)
+ target_ids = torch.LongTensor(masked_tokenized_str)
+ images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
+
+ # set input_ids < 0 | input_ids == self.image_token_id as ignore_id
+ target_ids[(input_ids < 0) |
+ (input_ids == self.image_token_id)] = self.ignore_id
+ input_ids[input_ids < 0] = self.pad_id
+
+ if inference_mode:
+ # 去掉结尾的eos token
+ assert input_ids[-1] == self.eos_id
+ input_ids = input_ids[:-1]
+ target_ids = target_ids[:-1]
+ images_seq_mask = images_seq_mask[:-1]
+
+ if len(images_list) == 0:
+ pixel_values = torch.zeros((1, 3, self.image_size, self.image_size))
+ images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
+ else:
+ pixel_values = torch.stack(images_list, dim=0)
+ images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
+
+ input_ids = input_ids.unsqueeze(0)
+
+ prepare = BatchFeature(
+ data=dict(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ images_seq_mask=images_seq_mask,
+ images_spatial_crop=images_spatial_crop,
+ num_image_tokens=num_image_tokens,
+ ),
+ tensor_type="pt",
+ )
+ return prepare
+
+ def __call__(
+ self,
+ *,
+ prompt: str,
+ images: List[Image.Image],
+ inference_mode: bool = True,
+ **kwargs,
+ ):
+ """
+
+ Args:
+ prompt (str): the formatted prompt;
+ images (List[ImageType]): the list of images;
+ inference_mode (bool): if True, then remove the last eos token;
+ **kwargs:
+
+ Returns:
+ outputs (BaseProcessorOutput): the output of the processor,
+ - input_ids (torch.LongTensor): [N + image tokens]
+ - images (torch.FloatTensor): [n_images, 3, H, W]
+ - image_id (int): the id of the image token
+ - num_image_tokens (List[int]): the number of image tokens
+ """
+
+ prepare = self.process_one(
+ prompt=prompt,
+ images=images,
+ inference_mode=inference_mode,
+ )
+
+ return prepare
+
+ def tokenize_with_images(
+ self,
+ conversation: str,
+ images: List[Image.Image],
+ bos: bool = True,
+ eos: bool = True,
+ cropping: bool = True,
+ ):
+ """Tokenize text with tags."""
+ assert conversation.count(self.image_token) == len(images)
+ text_splits = conversation.split(self.image_token)
+ images_list, images_seq_mask, images_spatial_crop = [], [], []
+ num_image_tokens = []
+ tokenized_str = []
+ for text_sep, image in zip(text_splits, images):
+ """encode text_sep"""
+ tokenized_sep = self.encode(text_sep, bos=False, eos=False)
+ tokenized_str += tokenized_sep
+ images_seq_mask += [False] * len(tokenized_sep)
+
+ """select best resolution for anyres"""
+ if cropping:
+ best_width, best_height = self.select_best_resolution(image.size)
+ else:
+ best_width, best_height = self.image_size, self.image_size
+
+ """process the global view"""
+ global_view = ImageOps.pad(image, (self.image_size, self.image_size),
+ color=tuple(int(x * 255) for x in self.image_transform.mean))
+ images_list.append(self.image_transform(global_view))
+
+ """process the local views"""
+ local_view = ImageOps.pad(image, (best_width, best_height),
+ color=tuple(int(x * 255) for x in self.image_transform.mean))
+ for i in range(0, best_height, self.image_size):
+ for j in range(0, best_width, self.image_size):
+ images_list.append(
+ self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size))))
+
+ """record height / width crop num"""
+ num_width_tiles, num_height_tiles = best_width // self.image_size, best_height // self.image_size
+ images_spatial_crop.append([num_width_tiles, num_height_tiles])
+
+ """add image tokens"""
+ h = w = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
+ # global views tokens h * (w + 1), 1 is for line separator
+ tokenized_image = [self.image_token_id] * h * (w + 1)
+ # add a separator between global and local views
+ tokenized_image += [self.image_token_id]
+ # local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
+ tokenized_image += [self.image_token_id] * (num_height_tiles * h) * (num_width_tiles * w + 1)
+
+ tokenized_str += tokenized_image
+ images_seq_mask += [True] * len(tokenized_image)
+ num_image_tokens.append(len(tokenized_image))
+
+ """process the last text split"""
+ tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
+ tokenized_str += tokenized_sep
+ images_seq_mask += [False] * len(tokenized_sep)
+
+ """add the bos and eos tokens"""
+ if bos:
+ tokenized_str = [self.bos_id] + tokenized_str
+ images_seq_mask = [False] + images_seq_mask
+ if eos:
+ tokenized_str = tokenized_str + [self.eos_id]
+ images_seq_mask = images_seq_mask + [False]
+
+ assert len(tokenized_str) == len(
+ images_seq_mask), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
+
+ return tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens
+
+
+AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor)