diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index cf82210f96ee3..393912881bca3 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -52,6 +52,7 @@ steps:
- tests/worker
- tests/standalone_tests/lazy_torch_compile.py
commands:
+ - pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git # Used by multimoda processing test
- python3 standalone_tests/lazy_torch_compile.py
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 5c96dfdad25f7..642ef3c9655b8 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -610,6 +610,13 @@ See [this page](#generative-models) for more information on how to use generativ
-
- ✅︎
- ✅︎
+* - `DeepseekVLV2ForCausalLM`
+ - DeepSeek-VL2
+ - T + I+
+ - `deepseek-ai/deepseek-vl2-tiny`(WIP), `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. (see note)
+ -
+ - ✅︎
+ - ✅︎
* - `FuyuForCausalLM`
- Fuyu
- T + I
@@ -755,8 +762,19 @@ See [this page](#generative-models) for more information on how to use generativ
E Pre-computed embeddings can be inputted for this modality.
+ Multiple items can be inputted per text prompt for this modality.
+````{note}
+The `deepseek-ai/deepseek-vl2-tiny` is not supported yet.
+
+To use `DeepSeek-VL2` series models, you need to install a fork version `deepseek_vl2` package:
+```shell
+pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git
+```
+
+Besides, to run `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
+````
+
```{note}
-To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
+To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
```
```{note}
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index b51bfae455267..ad32b9fe242e9 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -66,6 +66,23 @@ def run_chameleon(question: str, modality: str):
return llm, prompt, stop_token_ids
+# Deepseek-VL2
+def run_deepseek_vl2(question: str, modality: str):
+ assert modality == "image"
+
+ model_name = "deepseek-ai/deepseek-vl2-small"
+
+ llm = LLM(model=model_name,
+ max_model_len=4096,
+ max_num_seqs=2,
+ disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
+ hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]})
+
+ prompt = f"<|User|>: \n{question}\n\n<|Assistant|>:"
+ stop_token_ids = None
+ return llm, prompt, stop_token_ids
+
+
# Fuyu
def run_fuyu(question: str, modality: str):
assert modality == "image"
@@ -498,6 +515,7 @@ def run_qwen2_vl(question: str, modality: str):
"aria": run_aria,
"blip-2": run_blip2,
"chameleon": run_chameleon,
+ "deepseek_vl_v2": run_deepseek_vl2,
"fuyu": run_fuyu,
"glm4v": run_glm4v,
"h2ovl_chat": run_h2ovl,
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index cf2e90a325c6a..c6cf3f30c31cb 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -54,6 +54,28 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData:
)
+def load_deepseek_vl2(question: str, image_urls: List[str]):
+ model_name = "deepseek-ai/deepseek-vl2-small"
+
+ llm = LLM(model=model_name,
+ max_model_len=4096,
+ max_num_seqs=2,
+ hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
+ limit_mm_per_prompt={"image": len(image_urls)})
+
+ placeholder = "".join(f"image_{i}:\n"
+ for i, _ in enumerate(image_urls, start=1))
+ prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:"
+
+ return ModelRequestData(
+ llm=llm,
+ prompt=prompt,
+ stop_token_ids=None,
+ image_data=[fetch_image(url) for url in image_urls],
+ chat_template=None,
+ )
+
+
def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "h2oai/h2ovl-mississippi-2b"
@@ -372,6 +394,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
model_example_map = {
"aria": load_aria,
+ "deepseek_vl2": load_deepseek_vl2,
"h2ovl_chat": load_h2onvl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index 146685738a1d0..7620ed1107e8f 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -188,6 +188,33 @@
max_tokens=8,
dtype="bfloat16",
),
+ "deepseek_vl_v2": VLMTestInfo(
+ models=["deepseek-ai/deepseek-vl2-small"],
+ test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
+ dtype="bfloat16",
+ prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501
+ max_model_len=4096,
+ max_num_seqs=2,
+ single_image_prompts=IMAGE_ASSETS.prompts({
+ "stop_sign": "\nWhat's the color of the stop sign and car?",
+ "cherry_blossom": "\nWhat's the color of the tower?",
+ }),
+ multi_image_prompt="image_1:\nimage_2:\nDescribe the two images shortly.", # noqa: E501
+ vllm_runner_kwargs={"hf_overrides": {"architectures": ["DeepseekVLV2ForCausalLM"]}}, # noqa: E501
+ image_size_factors=[(0.10, 0.15)],
+ patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner,
+ postprocess_inputs=model_utils.cast_dtype_post_processor("images"),
+ hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output,
+ stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"], # noqa: E501
+ num_logprobs=5,
+ marks=[
+ pytest.mark.skipif(
+ not is_flash_attn_2_available(),
+ reason="Model needs flash-attn for numeric convergence.",
+ ),
+ large_gpu_mark(min_gb=48),
+ ],
+ ),
"fuyu": VLMTestInfo(
models=["adept/fuyu-8b"],
test_type=VLMTestType.IMAGE,
diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
index 6c7a753af787e..1ca85c7bb2056 100644
--- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
+++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
@@ -183,6 +183,14 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput,
####### Post-processors for HF outputs
+def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput,
+ model: str) -> RunnerOutput:
+ output_ids, output_str, out_logprobs = hf_output
+ if output_str.endswith("<|end▁of▁sentence|>"):
+ output_str = output_str.split("<|end▁of▁sentence|>")[0]
+ return output_ids, output_str, out_logprobs
+
+
def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
@@ -261,6 +269,34 @@ def qwen_prompt_path_encoder(
####### Model-specific HuggingFace runner patchers
+def deepseekvl2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
+ """Patches and returns an instance of the HfRunner to use for GLM4."""
+ hf_processor = hf_model.processor
+
+ def processor(*args, text="", images=None, **kwargs):
+ if isinstance(images, Image):
+ images = [images]
+ # inputs is a custom class instead of dict or BatchFeature
+ inputs = hf_processor(
+ *args,
+ prompt=text,
+ images=images,
+ **kwargs,
+ )
+ inputs = {
+ k: inputs[k]
+ for k in inputs.keys() # noqa
+ if k not in ("seq_lens", "sft_format")
+ }
+ inputs = BatchEncoding(data=inputs, tensor_type="pt")
+ return inputs
+
+ hf_model.processor = processor
+ hf_model.model.get_output_embeddings = lambda: \
+ hf_model.model.language.model.embed_tokens
+ return hf_model
+
+
def glm_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4."""
hf_processor = hf_model.processor
diff --git a/tests/models/registry.py b/tests/models/registry.py
index f5aaa8eb071f9..d079725b2f78d 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -179,6 +179,8 @@ class _HfExamplesInfo:
trust_remote_code=True),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b",
is_available_online=False),
+ # TODO(Isotr0py): Use deepseek-vl2-tiny for test after it's supported
+ "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-small"), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"),
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py
index 7a564c1f4a1d0..daece7c93c0ef 100644
--- a/tests/models/test_initialization.py
+++ b/tests/models/test_initialization.py
@@ -26,6 +26,9 @@ def test_can_initialize(model_arch):
# Avoid OOM
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
+ if hf_config.model_type == "deepseek_vl_v2":
+ hf_config.update({"architectures": ["DeepseekVLV2ForCausalLM"]})
+
if hasattr(hf_config, "text_config"):
text_config: PretrainedConfig = hf_config.text_config
else:
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 923c7459f6948..beedf5d16ab86 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -403,8 +403,8 @@ def _placeholder_str(self, modality: ModalityStr,
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
- if model_type in ("chameleon", "internvl_chat", "NVLM_D",
- "h2ovl_chat"):
+ if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
+ "NVLM_D", "h2ovl_chat"):
return ""
if model_type == "mllama":
return "<|image|>"
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index 4cf4e6c358bf2..9132040545863 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -243,7 +243,11 @@ def __init__(
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
- rope_scaling["rope_type"] = 'deepseek_yarn'
+ if rope_scaling:
+ rope_scaling["rope_type"] = 'deepseek_yarn'
+ self.use_normal_rope = False
+ else:
+ self.use_normal_rope = True
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
@@ -298,7 +302,18 @@ def forward(
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]
+
+ if self.use_normal_rope:
+ seq_len = positions.size(0)
+ ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
+ q_pe = q_pe.reshape(seq_len, -1)
+ k_pe = k_pe.reshape(seq_len, -1)
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
+
+ if self.use_normal_rope:
+ q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
+
q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
@@ -355,6 +370,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
+
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
diff --git a/vllm/model_executor/models/deepseek_v3.py b/vllm/model_executor/models/deepseek_v3.py
index d4710622681b5..ca79b14c55fea 100644
--- a/vllm/model_executor/models/deepseek_v3.py
+++ b/vllm/model_executor/models/deepseek_v3.py
@@ -251,7 +251,11 @@ def __init__(
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
- rope_scaling["rope_type"] = 'deepseek_yarn'
+ if rope_scaling:
+ rope_scaling["rope_type"] = 'deepseek_yarn'
+ self.use_normal_rope = False
+ else:
+ self.use_normal_rope = True
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
@@ -306,7 +310,18 @@ def forward(
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]
+
+ if self.use_normal_rope:
+ seq_len = positions.size(0)
+ ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
+ q_pe = q_pe.reshape(seq_len, -1)
+ k_pe = k_pe.reshape(seq_len, -1)
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
+
+ if self.use_normal_rope:
+ q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
+
q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
@@ -583,7 +598,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
continue
# TODO(simon): support nextn predict layers
- if self.config.num_nextn_predict_layers > 0:
+ if hasattr(self.config, "num_nextn_predict_layers"
+ ) and self.config.num_nextn_predict_layers > 0:
assert self.config.num_nextn_predict_layers == 1
layer_idx = self.config.num_hidden_layers
if name.startswith(f"model.layers.{layer_idx}"):
diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py
new file mode 100644
index 0000000000000..99fa941c055d2
--- /dev/null
+++ b/vllm/model_executor/models/deepseek_vl2.py
@@ -0,0 +1,662 @@
+# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
+"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
+import math
+from functools import cached_property, partial
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
+ TypedDict, Union)
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from transformers import AutoProcessor, BatchFeature, ProcessorMixin
+
+from vllm.attention import AttentionMetadata
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.model_executor import SamplingMetadata
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.model_loader.utils import set_default_torch_dtype
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
+ NestedTensors)
+from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
+ ImageSize, MultiModalDataItems)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
+from vllm.multimodal.utils import cached_get_tokenizer
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
+ MlpProjectorConfig,
+ VisionEncoderConfig)
+from vllm.utils import is_list_of
+
+from .interfaces import SupportsMultiModal, SupportsPP
+from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
+ init_vllm_registered_model, maybe_prefix,
+ merge_multimodal_embeddings)
+
+logger = init_logger(__name__)
+
+# The image token id may be various
+_IMAGE_TOKEN = ""
+
+
+class DeepseekVL2ImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ data: Union[torch.Tensor, List[torch.Tensor]]
+ """
+ Shape: `(batch_size * num_images, num_channels, height, width)`
+ """
+ images_spatial_crop: torch.Tensor
+ """
+ Shape: `(batch_size * num_images, 2)`
+ """
+
+
+class DeepseekVL2VImageEmbeddingInputs(TypedDict):
+ type: Literal["image_embeds"]
+ data: Union[torch.Tensor, List[torch.Tensor]]
+ """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
+
+ `hidden_size` must match the hidden size of language model backbone.
+ """
+
+
+DeepseekVL2ImageInputs = Union[DeepseekVL2ImagePixelInputs,
+ DeepseekVL2VImageEmbeddingInputs]
+
+
+class MlpProjector(nn.Module):
+
+ def __init__(self, cfg: MlpProjectorConfig):
+
+ super().__init__()
+
+ self.cfg = cfg
+ assert not cfg.token_pooling, (
+ "Token pooling is not supported currently.")
+
+ if cfg.projector_type == "downsample_mlp_gelu":
+ mlp_depth = cfg.depth
+ mlp_ratio = cfg.mlp_ratio
+ modules = [
+ nn.Linear(
+ cfg.input_dim * cfg.downsample_ratio *
+ cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
+ ]
+ for _ in range(1, mlp_depth - 1):
+ modules.append(nn.GELU())
+ modules.append(
+ nn.Linear(cfg.n_embed * mlp_ratio,
+ cfg.n_embed * mlp_ratio))
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
+ modules = nn.Sequential(*modules)
+
+ else:
+ raise NotImplementedError(
+ f"Unsupported projector type: {cfg.projector_type}")
+
+ self.layers = modules
+
+ def forward(self, x):
+ bs, hw, input_dim = x.shape
+ h = w = int((hw)**0.5)
+ """compute padding"""
+ if h % self.cfg.downsample_ratio:
+ pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
+ else:
+ pad = 0
+ x = x.reshape(bs, h, w, input_dim)
+ if pad > 0:
+ x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
+ """4 to 1 concat"""
+ x = x.permute(0, 3, 1, 2) # B, C, H, W
+ x = F.unfold(x,
+ kernel_size=self.cfg.downsample_ratio,
+ stride=self.cfg.downsample_ratio,
+ padding=0) # B, C*4, HW // 4
+ x = x.permute(0, 2, 1)
+
+ return self.layers(x)
+
+
+class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(DeepseekVLV2Config)
+
+ def get_hf_processor(self) -> ProcessorMixin:
+ # TODO(Isotr0py): we should get rid of dependency on deepseek_vl2
+ # in the future, because it's flasky and lack of maintenance.
+ try:
+ from deepseek_vl2.models.processing_deepseek_vl_v2 import (
+ DeepseekVLV2Processor, select_best_resolution)
+ AutoProcessor.register("DeepseekVLV2Processor",
+ DeepseekVLV2Processor)
+ except ModuleNotFoundError as exc:
+ raise ModuleNotFoundError(
+ "You need to `pip install "
+ "git+https://github.com/deepseek-ai/DeepSeek-VL2.git` "
+ "to use this model") from exc
+
+ processor = self.ctx.get_hf_processor(DeepseekVLV2Processor)
+ processor.select_best_resolution = partial(
+ select_best_resolution,
+ candidate_resolutions=processor.candidate_resolutions)
+ return processor
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": None}
+
+ def get_num_image_tokens(self, *, image_width: int,
+ image_height: int) -> int:
+ hf_processor = self.get_hf_processor()
+ image_size = hf_processor.image_size
+ patch_size = hf_processor.patch_size
+ downsample_ratio = hf_processor.downsample_ratio
+
+ best_width, best_height = hf_processor.select_best_resolution(
+ (image_width, image_height))
+
+ num_width_tiles, num_height_tiles = (best_width // image_size,
+ best_height // image_size)
+ h = w = math.ceil((image_size // patch_size) / downsample_ratio)
+
+ global_views_tokens = h * (w + 1)
+ local_views_tokens = (num_height_tiles * h) * (num_width_tiles * w + 1)
+ return global_views_tokens + local_views_tokens + 1
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ hf_config = self.get_hf_config()
+ candidate_resolutions = hf_config.candidate_resolutions
+ height, width = max(candidate_resolutions,
+ key=lambda x: self.get_num_image_tokens(
+ image_width=x[1], image_height=x[0]))
+ return ImageSize(width=width, height=height)
+
+ def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
+ max_image_size = self.get_image_size_with_most_features()
+ max_image_tokens = self.get_num_image_tokens(
+ image_height=max_image_size.height,
+ image_width=max_image_size.width)
+
+ return {"image": max_image_tokens}
+
+
+class DeepseekVL2DummyInputsBuilder(
+ BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]):
+
+ def get_dummy_processor_inputs(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ num_images = mm_counts.get("image", 0)
+ hf_processor = self.info.get_hf_processor()
+ image_token: str = hf_processor.image_token
+
+ max_image_size = self.info.get_image_size_with_most_features()
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=max_image_size.width,
+ height=max_image_size.height,
+ num_images=num_images)
+ }
+
+ return ProcessorInputs(
+ prompt_text=image_token * num_images,
+ mm_data=mm_data,
+ )
+
+
+class DeepseekVL2MultiModalProcessor(
+ BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]):
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ if mm_data:
+ outputs = self.info.ctx.call_hf_processor(
+ self.info.get_hf_processor(**mm_kwargs),
+ dict(prompt=prompt, **mm_data),
+ mm_kwargs,
+ )
+
+ # Deepseek-vl2 processor don't return BatchFeature,
+ # we need to manually create it
+ processed_outputs = dict(input_ids=outputs["input_ids"])
+ processed_outputs = BatchFeature(data=dict(processed_outputs),
+ tensor_type="pt")
+
+ # Remove batch dimension from processor outputs,
+ # because we will try batch to create NestedTensors
+ target_dtype = self.info.ctx.model_config.dtype
+ pixel_values = outputs["images"].to(target_dtype).squeeze(0)
+ images_spatial_crop = outputs["images_spatial_crop"].squeeze(0)
+ patches_per_image = [
+ x.prod().item() + 1 for x in images_spatial_crop
+ ]
+
+ # Rename `images` -> `pixel_values` to avoid confusion
+ processed_outputs["pixel_values"] = list(
+ pixel_values.split(patches_per_image))
+ processed_outputs["images_spatial_crop"] = images_spatial_crop
+ else:
+ tokenizer = self.info.get_tokenizer()
+ processed_outputs = tokenizer(prompt,
+ add_special_tokens=True,
+ return_tensors="pt")
+
+ return processed_outputs
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(
+ pixel_values=MultiModalFieldConfig.batched("image"),
+ images_spatial_crop=MultiModalFieldConfig.batched("image"),
+ image_embeds=MultiModalFieldConfig.batched("image"),
+ )
+
+ def _get_prompt_replacements(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptReplacement]:
+ hf_processor = self.info.get_hf_processor()
+ image_token_id: int = hf_processor.image_token_id
+
+ def get_replacement_deepseek_vl2(item_idx: int):
+ images = mm_items.get_items(
+ "image", (ImageEmbeddingItems, ImageProcessorItems))
+
+ if isinstance(images, ImageEmbeddingItems):
+ num_image_tokens = images.get_feature_size(item_idx)
+ else:
+ image_size = images.get_image_size(item_idx)
+
+ num_image_tokens = self.info.get_num_image_tokens(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ )
+ return [image_token_id] * num_image_tokens
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=[image_token_id],
+ replacement=get_replacement_deepseek_vl2,
+ )
+ ]
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ DeepseekVL2MultiModalProcessor,
+ info=DeepseekVL2ProcessingInfo,
+ dummy_inputs=DeepseekVL2DummyInputsBuilder)
+class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
+
+ hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
+ "language.": "language_model.",
+ })
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config: DeepseekVLV2Config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+
+ self.vision_config = config.vision_config
+ self.projector_config = config.projector_config
+ self.text_config = config.text_config
+
+ model_config = vllm_config.model_config
+ tokenizer = cached_get_tokenizer(
+ model_config.tokenizer,
+ tokenizer_mode=model_config.tokenizer_mode,
+ tokenizer_revision=model_config.tokenizer_revision,
+ trust_remote_code=model_config.trust_remote_code,
+ )
+ self.image_token_id = tokenizer.vocab.get(_IMAGE_TOKEN)
+
+ self.vision = self._init_vision_module(self.vision_config,
+ quant_config,
+ maybe_prefix(prefix, "vision"))
+
+ self.projector = MlpProjector(self.projector_config)
+ self.tile_tag = config.tile_tag
+ self.global_view_pos = config.global_view_pos
+
+ # special token for image token sequence format
+ embed_std = 1 / torch.sqrt(
+ torch.tensor(self.projector_config.n_embed, dtype=torch.float32))
+ if self.tile_tag == "2D":
+ # <|view_separator|>, <|\n|>
+ self.image_newline = nn.Parameter(
+ torch.randn(self.projector_config.n_embed) * embed_std)
+ # This is a typo in original implementation
+ self.view_seperator = nn.Parameter(
+ torch.randn(self.projector_config.n_embed) * embed_std)
+ else:
+ raise ValueError(
+ f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
+ )
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=self.text_config,
+ prefix=maybe_prefix(prefix, "language"),
+ architectures=["DeepseekV3ForCausalLM"]
+ if self.text_config.topk_method == "noaux_tc" else
+ ["DeepseekV2ForCausalLM"],
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ def _init_vision_module(
+ self,
+ vision_config: VisionEncoderConfig,
+ quant_config: Optional[QuantizationConfig],
+ prefix: str = "",
+ ) -> nn.Module:
+ # TODO: refactor vision model through timm wrapper from transformers
+ try:
+ import timm
+ except ImportError:
+ raise ImportError("Please install timm") from ImportError
+
+ with set_default_torch_dtype(torch.float16):
+ model = timm.create_model(
+ "vit_so400m_patch14_siglip_384.webli",
+ pretrained=False,
+ num_classes=0,
+ dynamic_img_size=True,
+ dynamic_img_pad=True,
+ )
+
+ model = model.to(dtype=torch.get_default_dtype())
+ return model
+
+ @cached_property
+ def sampler(self):
+ if hasattr(self.language_model, "sampler"):
+ return self.language_model.sampler
+
+ return get_sampler()
+
+ def _validate_pixel_values(
+ self, data: Union[torch.Tensor, List[torch.Tensor]]
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
+
+ h = w = self.vision_config.image_size
+ expected_dims = (3, h, w)
+
+ def _validate_shape(d: torch.Tensor):
+ actual_dims = tuple(d.shape[1:])
+
+ if actual_dims != expected_dims:
+ expected_expr = ("num_patches", *map(str, expected_dims))
+ raise ValueError(
+ "The expected shape of pixel values per image per batch "
+ f"is {expected_expr}. You supplied {tuple(d.shape)}.")
+
+ for d in data:
+ _validate_shape(d)
+
+ return data
+
+ def _validate_images_spatial_crop(
+ self, data: Union[torch.Tensor, List[torch.Tensor]]
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
+ expected_dims = 2
+
+ def _validate_shape(d: torch.Tensor):
+ actual_dims = d.size(-1)
+
+ if actual_dims != expected_dims:
+ expected_expr = str(expected_dims)
+ raise ValueError(
+ f"The expected shape of image sizes per image per batch "
+ f"is {expected_expr}. You supplied {tuple(d.shape)}.")
+
+ for d in data:
+ _validate_shape(d)
+
+ return data
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[DeepseekVL2ImageInputs]:
+ pixel_values = kwargs.pop("pixel_values", None)
+ images_spatial_crop = kwargs.pop("images_spatial_crop", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+
+ if pixel_values is None and image_embeds is None:
+ return None
+
+ if pixel_values is not None:
+ if not isinstance(pixel_values, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of pixel values. "
+ f"Got type: {type(pixel_values)}")
+
+ if not isinstance(images_spatial_crop, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of image sizes. "
+ f"Got type: {type(images_spatial_crop)}")
+
+ return DeepseekVL2ImagePixelInputs(
+ type="pixel_values",
+ data=self._validate_pixel_values(flatten_bn(pixel_values)),
+ images_spatial_crop=self._validate_images_spatial_crop(
+ flatten_bn(images_spatial_crop, concat=True)))
+
+ if image_embeds is not None:
+ if not isinstance(image_embeds, torch.Tensor):
+ raise ValueError("Incorrect type of image embeddings. "
+ f"Got type: {type(image_embeds)}")
+
+ return DeepseekVL2VImageEmbeddingInputs(
+ type="image_embeds",
+ data=flatten_bn(image_embeds),
+ )
+
+ raise AssertionError("This line should be unreachable.")
+
+ def _pixel_values_to_embedding(
+ self,
+ pixel_values: NestedTensors,
+ images_spatial_crop: torch.Tensor,
+ ) -> NestedTensors:
+ # Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width]
+ total_tiles = [x for x in pixel_values]
+
+ # [batch_all_tiles, 3, height, width]
+ total_tiles = torch.cat(total_tiles, dim=0)
+
+ # [batch_all_tiles, vit_seq_len, c]
+ images_feature = self.vision.forward_features(total_tiles)
+
+ # [batch_all_tiles, hw, D]
+ images_embeds = self.projector(images_feature)
+
+ _, hw, n_dim = images_embeds.shape
+ h = w = int(hw**0.5)
+
+ # 根据self.tile_tag & self.global_view_pos填充image token sequence
+ tile_index = 0
+ vision_embeddings = []
+ for jdx in range(images_spatial_crop.size(0)):
+ # extra global & local features
+ num_width_tiles, num_height_tiles = images_spatial_crop[jdx]
+ if num_width_tiles == 0 or num_height_tiles == 0:
+ break
+ num_tiles_in_image = num_width_tiles * num_height_tiles
+
+ # [hw, D]
+ global_features = images_embeds[tile_index]
+
+ # [num_height_tiles * num_width_tiles, hw, D]
+ local_features = images_embeds[tile_index + 1:tile_index + 1 +
+ num_tiles_in_image]
+ tile_index += num_tiles_in_image + 1
+
+ # format global and local features
+ # ----------------- global view add newline -----------------
+ # [hw, D] -> [h, w, D]
+ global_features = global_features.view(h, w, n_dim)
+
+ # [D] -> [h, 1, D]
+ new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
+
+ # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
+ global_features = torch.cat([global_features, new_lines_in_global],
+ dim=1)
+
+ # [h, w + 1, D] -> [h * (w + 1), D]
+ global_features = global_features.view(-1, n_dim)
+
+ # ----------------- local view add newline -----------------
+ # [num_height_tiles * num_width_tiles, h * w, D] ->
+ # [num_height_tiles * h, num_width_tiles * w, D]
+ local_features = rearrange(local_features,
+ "(th tw) (h w) d -> (th h) (tw w) d",
+ th=num_height_tiles,
+ tw=num_width_tiles,
+ h=h,
+ w=w)
+
+ # [D] -> [num_height_tiles * h, 1, D]
+ new_lines_in_local = repeat(self.image_newline,
+ "d -> (th h) 1 d",
+ th=num_height_tiles,
+ h=h)
+
+ # [num_height_tiles * h, num_width_tiles * w + 1, D]
+ local_features = torch.cat([local_features, new_lines_in_local],
+ dim=1)
+
+ # [num_height_tiles * h, num_width_tiles * w + 1, D]
+ # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
+ local_features = local_features.view(-1, n_dim)
+
+ # merge global and local tiles
+ if self.global_view_pos == "head":
+ global_local_features = torch.cat([
+ global_features,
+ self.view_seperator[None, :],
+ local_features,
+ ])
+ else:
+ global_local_features = torch.cat([
+ local_features,
+ self.view_seperator[None, :],
+ global_features,
+ ])
+
+ vision_embeddings.append(global_local_features)
+ return vision_embeddings
+
+ def _process_image_input(
+ self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor:
+ if image_input["type"] == "image_embeds":
+ image_data = image_input["data"]
+ if is_list_of(image_data, torch.Tensor):
+ # it's already a list of tensors
+ return image_data
+ if len(image_data.shape) == 3:
+ # 3D tensor
+ return list(torch.unbind(image_data, dim=0))
+ raise ValueError(
+ "We expect batched 2D tensors;"
+ "this can be either a list of 2D tensors or a single 3D tensor."
+ )
+
+ pixel_values = image_input["data"]
+ images_spatial_crop = image_input["images_spatial_crop"]
+
+ return self._pixel_values_to_embedding(
+ pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
+
+ def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ if image_input is None:
+ return None
+ vision_embeddings = self._process_image_input(image_input)
+ return vision_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[NestedTensors] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+ if multimodal_embeddings is not None:
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, multimodal_embeddings,
+ self.image_token_id)
+ return inputs_embeds
+
+ def forward(self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ kv_caches: List[torch.Tensor],
+ attn_metadata: AttentionMetadata,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object):
+
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
+ # condition is for v0 compatibility
+ elif inputs_embeds is None:
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
+ inputs_embeds = self.get_input_embeddings(input_ids,
+ vision_embeddings)
+ input_ids = None
+
+ hidden_states = self.language_model(input_ids,
+ positions,
+ kv_caches,
+ attn_metadata,
+ intermediate_tensors,
+ inputs_embeds=inputs_embeds)
+
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ return self.language_model.sample(logits, sampling_metadata)
+
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> Set[str]:
+
+ loader = AutoWeightsLoader(self)
+ autoloaded_weights = loader.load_weights(weights,
+ mapper=self.hf_to_vllm_mapper)
+ return autoloaded_weights
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index 8f36437d47d9e..ff7dab89e4da8 100644
--- a/vllm/model_executor/models/minicpmv.py
+++ b/vllm/model_executor/models/minicpmv.py
@@ -657,7 +657,7 @@ def init_vision_module(
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
- # TODO: refactor this vision model
+ # TODO: refactor vision model through timm wrapper from transformers
try:
import timm
except ImportError:
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 62840b8c1bcda..a7286a9203f67 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -149,6 +149,7 @@
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
+ "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index 58417980e7b47..c97acffa1a719 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -23,8 +23,9 @@
# yapf conflicts with isort for this block
# yapf: disable
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
- DbrxConfig, EAGLEConfig,
- ExaoneConfig, H2OVLChatConfig,
+ DbrxConfig, DeepseekVLV2Config,
+ EAGLEConfig, ExaoneConfig,
+ H2OVLChatConfig,
InternVLChatConfig, JAISConfig,
MedusaConfig, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
@@ -54,6 +55,7 @@
"chatglm": ChatGLMConfig,
"cohere2": Cohere2Config,
"dbrx": DbrxConfig,
+ "deepseek_vl_v2": DeepseekVLV2Config,
"mpt": MPTConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index a41a35c88b3a1..f065c56124605 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -1,6 +1,7 @@
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
from vllm.transformers_utils.configs.dbrx import DbrxConfig
+from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
from vllm.transformers_utils.configs.eagle import EAGLEConfig
from vllm.transformers_utils.configs.exaone import ExaoneConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
@@ -25,6 +26,7 @@
"ChatGLMConfig",
"Cohere2Config",
"DbrxConfig",
+ "DeepseekVLV2Config",
"MPTConfig",
"RWConfig",
"H2OVLChatConfig",
diff --git a/vllm/transformers_utils/configs/deepseek_vl2.py b/vllm/transformers_utils/configs/deepseek_vl2.py
new file mode 100644
index 0000000000000..681528c3c0116
--- /dev/null
+++ b/vllm/transformers_utils/configs/deepseek_vl2.py
@@ -0,0 +1,214 @@
+# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py#L115-L268
+from typing import Tuple
+
+from transformers.configuration_utils import PretrainedConfig
+
+
+class VisionEncoderConfig(PretrainedConfig):
+ model_type: str = "vision"
+
+ model_name: str = "vit_so400m_patch14_siglip_384.webli"
+ image_size: int = 384
+ patch_size: int = 16
+ width: int = 1024
+ layers: int = 24
+ heads: int = 16
+ mlp_ratio: int = 4
+ global_pool: str = "map"
+ ignore_head: bool = True
+ class_token: bool = False
+ num_classes: int = 0
+ use_checkpoint: bool = False
+ weight_init: str = "skip"
+ deterministic: bool = False
+ num_recomputing_layers: int = 0
+
+ def __init__(self,
+ model_name: str = "vit_so400m_patch14_siglip_384.webli",
+ image_size: int = 384,
+ patch_size: int = 16,
+ width: int = 1024,
+ layers: int = 24,
+ heads: int = 16,
+ mlp_ratio: int = 4,
+ global_pool: str = "map",
+ ignore_head: bool = True,
+ class_token: bool = False,
+ num_classes: int = 0,
+ use_checkpoint: bool = False,
+ **kwargs):
+ self.model_name = model_name
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.width = width
+ self.layers = layers
+ self.heads = heads
+ self.mlp_ratio = mlp_ratio
+ self.global_pool = global_pool
+ self.ignore_head = ignore_head
+ self.class_token = class_token
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+
+ super().__init__(**kwargs)
+
+
+class MlpProjectorConfig(PretrainedConfig):
+ model_type = "mlp_projector"
+ projector_type: str = "downsample_mlp_gelu"
+ input_dim: int = 1152
+ n_embed: int = 2048
+ depth: int = 2
+ mlp_ratio: int = 1
+ downsample_ratio: int = 2
+ token_pooling: bool = False
+
+ def __init__(self,
+ projector_type: str = "downsample_mlp_gelu",
+ input_dim: int = 1152,
+ n_embed: int = 2048,
+ depth: int = 2,
+ mlp_ratio: int = 1,
+ downsample_ratio: int = 2,
+ **kwargs):
+ self.projector_type = projector_type
+ self.input_dim = input_dim
+ self.n_embed = n_embed
+ self.depth = depth
+ self.mlp_ratio = mlp_ratio
+ self.downsample_ratio = downsample_ratio
+
+ super().__init__(**kwargs)
+
+
+class DeepseekV2Config(PretrainedConfig):
+
+ model_type = "deepseek_v2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=102400,
+ hidden_size=4096,
+ intermediate_size=11008,
+ moe_intermediate_size=1407,
+ num_hidden_layers=30,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ n_shared_experts=None,
+ n_routed_experts=None,
+ ep_size=1,
+ routed_scaling_factor=1.0,
+ kv_lora_rank=512,
+ q_lora_rank=1536,
+ qk_rope_head_dim=64,
+ v_head_dim=128,
+ qk_nope_head_dim=128,
+ topk_method='gready',
+ n_group=None,
+ topk_group=None,
+ num_experts_per_tok=None,
+ moe_layer_freq=1,
+ first_k_dense_replace=0,
+ norm_topk_prob=False,
+ scoring_func='softmax',
+ aux_loss_alpha=0.001,
+ seq_aux=True,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=100000,
+ eos_token_id=100001,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ use_mla=True,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.moe_intermediate_size = moe_intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.n_shared_experts = n_shared_experts
+ self.n_routed_experts = n_routed_experts
+ self.ep_size = ep_size
+ self.routed_scaling_factor = routed_scaling_factor
+ self.kv_lora_rank = kv_lora_rank
+ self.q_lora_rank = q_lora_rank
+ self.qk_rope_head_dim = qk_rope_head_dim
+ self.v_head_dim = v_head_dim
+ self.qk_nope_head_dim = qk_nope_head_dim
+ self.topk_method = topk_method
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.num_experts_per_tok = num_experts_per_tok
+ self.moe_layer_freq = moe_layer_freq
+ self.first_k_dense_replace = first_k_dense_replace
+ self.norm_topk_prob = norm_topk_prob
+ self.scoring_func = scoring_func
+ self.aux_loss_alpha = aux_loss_alpha
+ self.seq_aux = seq_aux
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = float(rms_norm_eps)
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.use_mla = use_mla
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class DeepseekVLV2Config(PretrainedConfig):
+ model_type = "deepseek_vl_v2"
+ vision_config: VisionEncoderConfig
+ projector_config: MlpProjectorConfig
+
+ tile_tag: str = "2D"
+ global_view_pos: str = "head"
+ candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384), )
+
+ def __init__(self,
+ tile_tag: str = "tile_tag",
+ global_view_pos: str = "head",
+ candidate_resolutions: Tuple[Tuple[int,
+ int]] = ((384, 384), ),
+ **kwargs):
+ super().__init__(**kwargs)
+
+ vision_config = kwargs.get("vision_config", {})
+ self.vision_config = VisionEncoderConfig(**vision_config)
+
+ projector_config = kwargs.get("projector_config", {})
+ self.projector_config = MlpProjectorConfig(**projector_config)
+
+ language_config = kwargs.get("language_config", {})
+ self.text_config = DeepseekV2Config(**language_config)
+
+ self.tile_tag = tile_tag
+ self.global_view_pos = global_view_pos
+ self.candidate_resolutions = candidate_resolutions
+ self.vocab_size = self.text_config.vocab_size