From d8998048773beaa48de0a65298a147c70fdfa470 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sun, 12 Jan 2025 16:17:24 +0800 Subject: [PATCH] [Model] Initialize support for Deepseek-VL2 models (#11578) Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung --- .buildkite/test-pipeline.yaml | 1 + docs/source/models/supported_models.md | 20 +- examples/offline_inference/vision_language.py | 18 + .../vision_language_multi_image.py | 23 + .../vision_language/test_models.py | 27 + .../vision_language/vlm_utils/model_utils.py | 36 + tests/models/registry.py | 2 + tests/models/test_initialization.py | 3 + vllm/entrypoints/chat_utils.py | 4 +- vllm/model_executor/models/deepseek_v2.py | 18 +- vllm/model_executor/models/deepseek_v3.py | 20 +- vllm/model_executor/models/deepseek_vl2.py | 662 ++++++++++++++++++ vllm/model_executor/models/minicpmv.py | 2 +- vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/config.py | 6 +- vllm/transformers_utils/configs/__init__.py | 2 + .../configs/deepseek_vl2.py | 214 ++++++ 17 files changed, 1050 insertions(+), 9 deletions(-) create mode 100644 vllm/model_executor/models/deepseek_vl2.py create mode 100644 vllm/transformers_utils/configs/deepseek_vl2.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index cf82210f96ee3..393912881bca3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -52,6 +52,7 @@ steps: - tests/worker - tests/standalone_tests/lazy_torch_compile.py commands: + - pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git # Used by multimoda processing test - python3 standalone_tests/lazy_torch_compile.py - pytest -v -s mq_llm_engine # MQLLMEngine - pytest -v -s async_engine # AsyncLLMEngine diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 5c96dfdad25f7..642ef3c9655b8 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -610,6 +610,13 @@ See [this page](#generative-models) for more information on how to use generativ - - ✅︎ - ✅︎ +* - `DeepseekVLV2ForCausalLM` + - DeepSeek-VL2 + - T + I+ + - `deepseek-ai/deepseek-vl2-tiny`(WIP), `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. (see note) + - + - ✅︎ + - ✅︎ * - `FuyuForCausalLM` - Fuyu - T + I @@ -755,8 +762,19 @@ See [this page](#generative-models) for more information on how to use generativ E Pre-computed embeddings can be inputted for this modality. + Multiple items can be inputted per text prompt for this modality. +````{note} +The `deepseek-ai/deepseek-vl2-tiny` is not supported yet. + +To use `DeepSeek-VL2` series models, you need to install a fork version `deepseek_vl2` package: +```shell +pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git +``` + +Besides, to run `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM. +```` + ```{note} -To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. +To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. ``` ```{note} diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index b51bfae455267..ad32b9fe242e9 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -66,6 +66,23 @@ def run_chameleon(question: str, modality: str): return llm, prompt, stop_token_ids +# Deepseek-VL2 +def run_deepseek_vl2(question: str, modality: str): + assert modality == "image" + + model_name = "deepseek-ai/deepseek-vl2-small" + + llm = LLM(model=model_name, + max_model_len=4096, + max_num_seqs=2, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}) + + prompt = f"<|User|>: \n{question}\n\n<|Assistant|>:" + stop_token_ids = None + return llm, prompt, stop_token_ids + + # Fuyu def run_fuyu(question: str, modality: str): assert modality == "image" @@ -498,6 +515,7 @@ def run_qwen2_vl(question: str, modality: str): "aria": run_aria, "blip-2": run_blip2, "chameleon": run_chameleon, + "deepseek_vl_v2": run_deepseek_vl2, "fuyu": run_fuyu, "glm4v": run_glm4v, "h2ovl_chat": run_h2ovl, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index cf2e90a325c6a..c6cf3f30c31cb 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -54,6 +54,28 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData: ) +def load_deepseek_vl2(question: str, image_urls: List[str]): + model_name = "deepseek-ai/deepseek-vl2-small" + + llm = LLM(model=model_name, + max_model_len=4096, + max_num_seqs=2, + hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}, + limit_mm_per_prompt={"image": len(image_urls)}) + + placeholder = "".join(f"image_{i}:\n" + for i, _ in enumerate(image_urls, start=1)) + prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:" + + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=None, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData: model_name = "h2oai/h2ovl-mississippi-2b" @@ -372,6 +394,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: model_example_map = { "aria": load_aria, + "deepseek_vl2": load_deepseek_vl2, "h2ovl_chat": load_h2onvl, "idefics3": load_idefics3, "internvl_chat": load_internvl, diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 146685738a1d0..7620ed1107e8f 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -188,6 +188,33 @@ max_tokens=8, dtype="bfloat16", ), + "deepseek_vl_v2": VLMTestInfo( + models=["deepseek-ai/deepseek-vl2-small"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + dtype="bfloat16", + prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + single_image_prompts=IMAGE_ASSETS.prompts({ + "stop_sign": "\nWhat's the color of the stop sign and car?", + "cherry_blossom": "\nWhat's the color of the tower?", + }), + multi_image_prompt="image_1:\nimage_2:\nDescribe the two images shortly.", # noqa: E501 + vllm_runner_kwargs={"hf_overrides": {"architectures": ["DeepseekVLV2ForCausalLM"]}}, # noqa: E501 + image_size_factors=[(0.10, 0.15)], + patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner, + postprocess_inputs=model_utils.cast_dtype_post_processor("images"), + hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output, + stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"], # noqa: E501 + num_logprobs=5, + marks=[ + pytest.mark.skipif( + not is_flash_attn_2_available(), + reason="Model needs flash-attn for numeric convergence.", + ), + large_gpu_mark(min_gb=48), + ], + ), "fuyu": VLMTestInfo( models=["adept/fuyu-8b"], test_type=VLMTestType.IMAGE, diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 6c7a753af787e..1ca85c7bb2056 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -183,6 +183,14 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, ####### Post-processors for HF outputs +def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput, + model: str) -> RunnerOutput: + output_ids, output_str, out_logprobs = hf_output + if output_str.endswith("<|end▁of▁sentence|>"): + output_str = output_str.split("<|end▁of▁sentence|>")[0] + return output_ids, output_str, out_logprobs + + def minicpmv_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output @@ -261,6 +269,34 @@ def qwen_prompt_path_encoder( ####### Model-specific HuggingFace runner patchers +def deepseekvl2_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for GLM4.""" + hf_processor = hf_model.processor + + def processor(*args, text="", images=None, **kwargs): + if isinstance(images, Image): + images = [images] + # inputs is a custom class instead of dict or BatchFeature + inputs = hf_processor( + *args, + prompt=text, + images=images, + **kwargs, + ) + inputs = { + k: inputs[k] + for k in inputs.keys() # noqa + if k not in ("seq_lens", "sft_format") + } + inputs = BatchEncoding(data=inputs, tensor_type="pt") + return inputs + + hf_model.processor = processor + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.language.model.embed_tokens + return hf_model + + def glm_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for GLM4.""" hf_processor = hf_model.processor diff --git a/tests/models/registry.py b/tests/models/registry.py index f5aaa8eb071f9..d079725b2f78d 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -179,6 +179,8 @@ class _HfExamplesInfo: trust_remote_code=True), "ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b", is_available_online=False), + # TODO(Isotr0py): Use deepseek-vl2-tiny for test after it's supported + "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-small"), # noqa: E501 "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"), "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 7a564c1f4a1d0..daece7c93c0ef 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -26,6 +26,9 @@ def test_can_initialize(model_arch): # Avoid OOM def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: + if hf_config.model_type == "deepseek_vl_v2": + hf_config.update({"architectures": ["DeepseekVLV2ForCausalLM"]}) + if hasattr(hf_config, "text_config"): text_config: PretrainedConfig = hf_config.text_config else: diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 923c7459f6948..beedf5d16ab86 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -403,8 +403,8 @@ def _placeholder_str(self, modality: ModalityStr, if model_type.startswith("llava"): return self._cached_token_str(self._tokenizer, hf_config.image_token_index) - if model_type in ("chameleon", "internvl_chat", "NVLM_D", - "h2ovl_chat"): + if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat", + "NVLM_D", "h2ovl_chat"): return "" if model_type == "mllama": return "<|image|>" diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 4cf4e6c358bf2..9132040545863 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -243,7 +243,11 @@ def __init__( bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj") - rope_scaling["rope_type"] = 'deepseek_yarn' + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.use_normal_rope = False + else: + self.use_normal_rope = True self.rotary_emb = get_rope(qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, @@ -298,7 +302,18 @@ def forward( self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = latent_cache[:, :, self.kv_lora_rank:] + + if self.use_normal_rope: + seq_len = positions.size(0) + ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape + q_pe = q_pe.reshape(seq_len, -1) + k_pe = k_pe.reshape(seq_len, -1) + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + + if self.use_normal_rope: + q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape) + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., :self.qk_nope_head_dim] = k_nope @@ -355,6 +370,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.self_attn", ) + if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): diff --git a/vllm/model_executor/models/deepseek_v3.py b/vllm/model_executor/models/deepseek_v3.py index d4710622681b5..ca79b14c55fea 100644 --- a/vllm/model_executor/models/deepseek_v3.py +++ b/vllm/model_executor/models/deepseek_v3.py @@ -251,7 +251,11 @@ def __init__( bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj") - rope_scaling["rope_type"] = 'deepseek_yarn' + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.use_normal_rope = False + else: + self.use_normal_rope = True self.rotary_emb = get_rope(qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, @@ -306,7 +310,18 @@ def forward( self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = latent_cache[:, :, self.kv_lora_rank:] + + if self.use_normal_rope: + seq_len = positions.size(0) + ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape + q_pe = q_pe.reshape(seq_len, -1) + k_pe = k_pe.reshape(seq_len, -1) + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + + if self.use_normal_rope: + q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape) + q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., :self.qk_nope_head_dim] = k_nope @@ -583,7 +598,8 @@ def load_weights(self, weights: Iterable[Tuple[str, continue # TODO(simon): support nextn predict layers - if self.config.num_nextn_predict_layers > 0: + if hasattr(self.config, "num_nextn_predict_layers" + ) and self.config.num_nextn_predict_layers > 0: assert self.config.num_nextn_predict_layers == 1 layer_idx = self.config.num_hidden_layers if name.startswith(f"model.layers.{layer_idx}"): diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py new file mode 100644 index 0000000000000..99fa941c055d2 --- /dev/null +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -0,0 +1,662 @@ +# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py +"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" +import math +from functools import cached_property, partial +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, + TypedDict, Union) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from transformers import AutoProcessor, BatchFeature, ProcessorMixin + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, + MlpProjectorConfig, + VisionEncoderConfig) +from vllm.utils import is_list_of + +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + +logger = init_logger(__name__) + +# The image token id may be various +_IMAGE_TOKEN = "" + + +class DeepseekVL2ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + """ + images_spatial_crop: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + """ + + +class DeepseekVL2VImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +DeepseekVL2ImageInputs = Union[DeepseekVL2ImagePixelInputs, + DeepseekVL2VImageEmbeddingInputs] + + +class MlpProjector(nn.Module): + + def __init__(self, cfg: MlpProjectorConfig): + + super().__init__() + + self.cfg = cfg + assert not cfg.token_pooling, ( + "Token pooling is not supported currently.") + + if cfg.projector_type == "downsample_mlp_gelu": + mlp_depth = cfg.depth + mlp_ratio = cfg.mlp_ratio + modules = [ + nn.Linear( + cfg.input_dim * cfg.downsample_ratio * + cfg.downsample_ratio, cfg.n_embed * mlp_ratio) + ] + for _ in range(1, mlp_depth - 1): + modules.append(nn.GELU()) + modules.append( + nn.Linear(cfg.n_embed * mlp_ratio, + cfg.n_embed * mlp_ratio)) + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) + modules = nn.Sequential(*modules) + + else: + raise NotImplementedError( + f"Unsupported projector type: {cfg.projector_type}") + + self.layers = modules + + def forward(self, x): + bs, hw, input_dim = x.shape + h = w = int((hw)**0.5) + """compute padding""" + if h % self.cfg.downsample_ratio: + pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio + else: + pad = 0 + x = x.reshape(bs, h, w, input_dim) + if pad > 0: + x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) + """4 to 1 concat""" + x = x.permute(0, 3, 1, 2) # B, C, H, W + x = F.unfold(x, + kernel_size=self.cfg.downsample_ratio, + stride=self.cfg.downsample_ratio, + padding=0) # B, C*4, HW // 4 + x = x.permute(0, 2, 1) + + return self.layers(x) + + +class DeepseekVL2ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(DeepseekVLV2Config) + + def get_hf_processor(self) -> ProcessorMixin: + # TODO(Isotr0py): we should get rid of dependency on deepseek_vl2 + # in the future, because it's flasky and lack of maintenance. + try: + from deepseek_vl2.models.processing_deepseek_vl_v2 import ( + DeepseekVLV2Processor, select_best_resolution) + AutoProcessor.register("DeepseekVLV2Processor", + DeepseekVLV2Processor) + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "You need to `pip install " + "git+https://github.com/deepseek-ai/DeepSeek-VL2.git` " + "to use this model") from exc + + processor = self.ctx.get_hf_processor(DeepseekVLV2Processor) + processor.select_best_resolution = partial( + select_best_resolution, + candidate_resolutions=processor.candidate_resolutions) + return processor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_num_image_tokens(self, *, image_width: int, + image_height: int) -> int: + hf_processor = self.get_hf_processor() + image_size = hf_processor.image_size + patch_size = hf_processor.patch_size + downsample_ratio = hf_processor.downsample_ratio + + best_width, best_height = hf_processor.select_best_resolution( + (image_width, image_height)) + + num_width_tiles, num_height_tiles = (best_width // image_size, + best_height // image_size) + h = w = math.ceil((image_size // patch_size) / downsample_ratio) + + global_views_tokens = h * (w + 1) + local_views_tokens = (num_height_tiles * h) * (num_width_tiles * w + 1) + return global_views_tokens + local_views_tokens + 1 + + def get_image_size_with_most_features(self) -> ImageSize: + hf_config = self.get_hf_config() + candidate_resolutions = hf_config.candidate_resolutions + height, width = max(candidate_resolutions, + key=lambda x: self.get_num_image_tokens( + image_width=x[1], image_height=x[0])) + return ImageSize(width=width, height=height) + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + max_image_size = self.get_image_size_with_most_features() + max_image_tokens = self.get_num_image_tokens( + image_height=max_image_size.height, + image_width=max_image_size.width) + + return {"image": max_image_tokens} + + +class DeepseekVL2DummyInputsBuilder( + BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + hf_processor = self.info.get_hf_processor() + image_token: str = hf_processor.image_token + + max_image_size = self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=max_image_size.width, + height=max_image_size.height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +class DeepseekVL2MultiModalProcessor( + BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + outputs = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(prompt=prompt, **mm_data), + mm_kwargs, + ) + + # Deepseek-vl2 processor don't return BatchFeature, + # we need to manually create it + processed_outputs = dict(input_ids=outputs["input_ids"]) + processed_outputs = BatchFeature(data=dict(processed_outputs), + tensor_type="pt") + + # Remove batch dimension from processor outputs, + # because we will try batch to create NestedTensors + target_dtype = self.info.ctx.model_config.dtype + pixel_values = outputs["images"].to(target_dtype).squeeze(0) + images_spatial_crop = outputs["images_spatial_crop"].squeeze(0) + patches_per_image = [ + x.prod().item() + 1 for x in images_spatial_crop + ] + + # Rename `images` -> `pixel_values` to avoid confusion + processed_outputs["pixel_values"] = list( + pixel_values.split(patches_per_image)) + processed_outputs["images_spatial_crop"] = images_spatial_crop + else: + tokenizer = self.info.get_tokenizer() + processed_outputs = tokenizer(prompt, + add_special_tokens=True, + return_tensors="pt") + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + images_spatial_crop=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_processor = self.info.get_hf_processor() + image_token_id: int = hf_processor.image_token_id + + def get_replacement_deepseek_vl2(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement_deepseek_vl2, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + DeepseekVL2MultiModalProcessor, + info=DeepseekVL2ProcessingInfo, + dummy_inputs=DeepseekVL2DummyInputsBuilder) +class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "language.": "language_model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: DeepseekVLV2Config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_config = config.vision_config + self.projector_config = config.projector_config + self.text_config = config.text_config + + model_config = vllm_config.model_config + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + tokenizer_mode=model_config.tokenizer_mode, + tokenizer_revision=model_config.tokenizer_revision, + trust_remote_code=model_config.trust_remote_code, + ) + self.image_token_id = tokenizer.vocab.get(_IMAGE_TOKEN) + + self.vision = self._init_vision_module(self.vision_config, + quant_config, + maybe_prefix(prefix, "vision")) + + self.projector = MlpProjector(self.projector_config) + self.tile_tag = config.tile_tag + self.global_view_pos = config.global_view_pos + + # special token for image token sequence format + embed_std = 1 / torch.sqrt( + torch.tensor(self.projector_config.n_embed, dtype=torch.float32)) + if self.tile_tag == "2D": + # <|view_separator|>, <|\n|> + self.image_newline = nn.Parameter( + torch.randn(self.projector_config.n_embed) * embed_std) + # This is a typo in original implementation + self.view_seperator = nn.Parameter( + torch.randn(self.projector_config.n_embed) * embed_std) + else: + raise ValueError( + f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=self.text_config, + prefix=maybe_prefix(prefix, "language"), + architectures=["DeepseekV3ForCausalLM"] + if self.text_config.topk_method == "noaux_tc" else + ["DeepseekV2ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _init_vision_module( + self, + vision_config: VisionEncoderConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ) -> nn.Module: + # TODO: refactor vision model through timm wrapper from transformers + try: + import timm + except ImportError: + raise ImportError("Please install timm") from ImportError + + with set_default_torch_dtype(torch.float16): + model = timm.create_model( + "vit_so400m_patch14_siglip_384.webli", + pretrained=False, + num_classes=0, + dynamic_img_size=True, + dynamic_img_pad=True, + ) + + model = model.to(dtype=torch.get_default_dtype()) + return model + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _validate_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = self.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _validate_images_spatial_crop( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + expected_dims = 2 + + def _validate_shape(d: torch.Tensor): + actual_dims = d.size(-1) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + f"The expected shape of image sizes per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[DeepseekVL2ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + images_spatial_crop = kwargs.pop("images_spatial_crop", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(images_spatial_crop, (torch.Tensor, list)): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(images_spatial_crop)}") + + return DeepseekVL2ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(flatten_bn(pixel_values)), + images_spatial_crop=self._validate_images_spatial_crop( + flatten_bn(images_spatial_crop, concat=True))) + + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return DeepseekVL2VImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + raise AssertionError("This line should be unreachable.") + + def _pixel_values_to_embedding( + self, + pixel_values: NestedTensors, + images_spatial_crop: torch.Tensor, + ) -> NestedTensors: + # Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width] + total_tiles = [x for x in pixel_values] + + # [batch_all_tiles, 3, height, width] + total_tiles = torch.cat(total_tiles, dim=0) + + # [batch_all_tiles, vit_seq_len, c] + images_feature = self.vision.forward_features(total_tiles) + + # [batch_all_tiles, hw, D] + images_embeds = self.projector(images_feature) + + _, hw, n_dim = images_embeds.shape + h = w = int(hw**0.5) + + # 根据self.tile_tag & self.global_view_pos填充image token sequence + tile_index = 0 + vision_embeddings = [] + for jdx in range(images_spatial_crop.size(0)): + # extra global & local features + num_width_tiles, num_height_tiles = images_spatial_crop[jdx] + if num_width_tiles == 0 or num_height_tiles == 0: + break + num_tiles_in_image = num_width_tiles * num_height_tiles + + # [hw, D] + global_features = images_embeds[tile_index] + + # [num_height_tiles * num_width_tiles, hw, D] + local_features = images_embeds[tile_index + 1:tile_index + 1 + + num_tiles_in_image] + tile_index += num_tiles_in_image + 1 + + # format global and local features + # ----------------- global view add newline ----------------- + # [hw, D] -> [h, w, D] + global_features = global_features.view(h, w, n_dim) + + # [D] -> [h, 1, D] + new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h) + + # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D] + global_features = torch.cat([global_features, new_lines_in_global], + dim=1) + + # [h, w + 1, D] -> [h * (w + 1), D] + global_features = global_features.view(-1, n_dim) + + # ----------------- local view add newline ----------------- + # [num_height_tiles * num_width_tiles, h * w, D] -> + # [num_height_tiles * h, num_width_tiles * w, D] + local_features = rearrange(local_features, + "(th tw) (h w) d -> (th h) (tw w) d", + th=num_height_tiles, + tw=num_width_tiles, + h=h, + w=w) + + # [D] -> [num_height_tiles * h, 1, D] + new_lines_in_local = repeat(self.image_newline, + "d -> (th h) 1 d", + th=num_height_tiles, + h=h) + + # [num_height_tiles * h, num_width_tiles * w + 1, D] + local_features = torch.cat([local_features, new_lines_in_local], + dim=1) + + # [num_height_tiles * h, num_width_tiles * w + 1, D] + # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D] + local_features = local_features.view(-1, n_dim) + + # merge global and local tiles + if self.global_view_pos == "head": + global_local_features = torch.cat([ + global_features, + self.view_seperator[None, :], + local_features, + ]) + else: + global_local_features = torch.cat([ + local_features, + self.view_seperator[None, :], + global_features, + ]) + + vision_embeddings.append(global_local_features) + return vision_embeddings + + def _process_image_input( + self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor: + if image_input["type"] == "image_embeds": + image_data = image_input["data"] + if is_list_of(image_data, torch.Tensor): + # it's already a list of tensors + return image_data + if len(image_data.shape) == 3: + # 3D tensor + return list(torch.unbind(image_data, dim=0)) + raise ValueError( + "We expect batched 2D tensors;" + "this can be either a list of 2D tensors or a single 3D tensor." + ) + + pixel_values = image_input["data"] + images_spatial_crop = image_input["images_spatial_crop"] + + return self._pixel_values_to_embedding( + pixel_values=pixel_values, images_spatial_crop=images_spatial_crop) + + def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.image_token_id) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object): + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(weights, + mapper=self.hf_to_vllm_mapper) + return autoloaded_weights diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 8f36437d47d9e..ff7dab89e4da8 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -657,7 +657,7 @@ def init_vision_module( quant_config: Optional[QuantizationConfig], prefix: str = "", ) -> nn.Module: - # TODO: refactor this vision model + # TODO: refactor vision model through timm wrapper from transformers try: import timm except ImportError: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 62840b8c1bcda..a7286a9203f67 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -149,6 +149,7 @@ "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), + "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 58417980e7b47..c97acffa1a719 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -23,8 +23,9 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, - DbrxConfig, EAGLEConfig, - ExaoneConfig, H2OVLChatConfig, + DbrxConfig, DeepseekVLV2Config, + EAGLEConfig, ExaoneConfig, + H2OVLChatConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, @@ -54,6 +55,7 @@ "chatglm": ChatGLMConfig, "cohere2": Cohere2Config, "dbrx": DbrxConfig, + "deepseek_vl_v2": DeepseekVLV2Config, "mpt": MPTConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index a41a35c88b3a1..f065c56124605 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,6 +1,7 @@ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.cohere2 import Cohere2Config from vllm.transformers_utils.configs.dbrx import DbrxConfig +from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config from vllm.transformers_utils.configs.eagle import EAGLEConfig from vllm.transformers_utils.configs.exaone import ExaoneConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and @@ -25,6 +26,7 @@ "ChatGLMConfig", "Cohere2Config", "DbrxConfig", + "DeepseekVLV2Config", "MPTConfig", "RWConfig", "H2OVLChatConfig", diff --git a/vllm/transformers_utils/configs/deepseek_vl2.py b/vllm/transformers_utils/configs/deepseek_vl2.py new file mode 100644 index 0000000000000..681528c3c0116 --- /dev/null +++ b/vllm/transformers_utils/configs/deepseek_vl2.py @@ -0,0 +1,214 @@ +# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py#L115-L268 +from typing import Tuple + +from transformers.configuration_utils import PretrainedConfig + + +class VisionEncoderConfig(PretrainedConfig): + model_type: str = "vision" + + model_name: str = "vit_so400m_patch14_siglip_384.webli" + image_size: int = 384 + patch_size: int = 16 + width: int = 1024 + layers: int = 24 + heads: int = 16 + mlp_ratio: int = 4 + global_pool: str = "map" + ignore_head: bool = True + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + weight_init: str = "skip" + deterministic: bool = False + num_recomputing_layers: int = 0 + + def __init__(self, + model_name: str = "vit_so400m_patch14_siglip_384.webli", + image_size: int = 384, + patch_size: int = 16, + width: int = 1024, + layers: int = 24, + heads: int = 16, + mlp_ratio: int = 4, + global_pool: str = "map", + ignore_head: bool = True, + class_token: bool = False, + num_classes: int = 0, + use_checkpoint: bool = False, + **kwargs): + self.model_name = model_name + self.image_size = image_size + self.patch_size = patch_size + self.width = width + self.layers = layers + self.heads = heads + self.mlp_ratio = mlp_ratio + self.global_pool = global_pool + self.ignore_head = ignore_head + self.class_token = class_token + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + + super().__init__(**kwargs) + + +class MlpProjectorConfig(PretrainedConfig): + model_type = "mlp_projector" + projector_type: str = "downsample_mlp_gelu" + input_dim: int = 1152 + n_embed: int = 2048 + depth: int = 2 + mlp_ratio: int = 1 + downsample_ratio: int = 2 + token_pooling: bool = False + + def __init__(self, + projector_type: str = "downsample_mlp_gelu", + input_dim: int = 1152, + n_embed: int = 2048, + depth: int = 2, + mlp_ratio: int = 1, + downsample_ratio: int = 2, + **kwargs): + self.projector_type = projector_type + self.input_dim = input_dim + self.n_embed = n_embed + self.depth = depth + self.mlp_ratio = mlp_ratio + self.downsample_ratio = downsample_ratio + + super().__init__(**kwargs) + + +class DeepseekV2Config(PretrainedConfig): + + model_type = "deepseek_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=None, + n_routed_experts=None, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method='gready', + n_group=None, + topk_group=None, + num_experts_per_tok=None, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func='softmax', + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + use_mla=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = float(rms_norm_eps) + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.use_mla = use_mla + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class DeepseekVLV2Config(PretrainedConfig): + model_type = "deepseek_vl_v2" + vision_config: VisionEncoderConfig + projector_config: MlpProjectorConfig + + tile_tag: str = "2D" + global_view_pos: str = "head" + candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384), ) + + def __init__(self, + tile_tag: str = "tile_tag", + global_view_pos: str = "head", + candidate_resolutions: Tuple[Tuple[int, + int]] = ((384, 384), ), + **kwargs): + super().__init__(**kwargs) + + vision_config = kwargs.get("vision_config", {}) + self.vision_config = VisionEncoderConfig(**vision_config) + + projector_config = kwargs.get("projector_config", {}) + self.projector_config = MlpProjectorConfig(**projector_config) + + language_config = kwargs.get("language_config", {}) + self.text_config = DeepseekV2Config(**language_config) + + self.tile_tag = tile_tag + self.global_view_pos = global_view_pos + self.candidate_resolutions = candidate_resolutions + self.vocab_size = self.text_config.vocab_size