From fd279b8d29d8f4ada8396f9fa8ef4b9250306c84 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 14 Feb 2025 20:20:46 +0800 Subject: [PATCH] [VLM] Keep track of whether prompt replacements have been applied (#13215) --- vllm/model_executor/models/glm4v.py | 8 + vllm/model_executor/models/llava.py | 3 +- vllm/model_executor/models/llava_onevision.py | 57 ++++- vllm/model_executor/models/minicpmo.py | 90 +++---- vllm/model_executor/models/minicpmv.py | 221 ++++++++---------- vllm/model_executor/models/qwen2_audio.py | 10 - vllm/model_executor/models/qwen2_vl.py | 100 +++----- vllm/model_executor/models/qwen_vl.py | 13 +- vllm/multimodal/parse.py | 58 ++++- vllm/multimodal/processing.py | 142 ++++++----- 10 files changed, 373 insertions(+), 329 deletions(-) diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 67f19841f4aa7..450421302a190 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -484,6 +484,14 @@ def get_dummy_processor_inputs( class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): + def _hf_processor_applies_repl( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> bool: + return False + def _get_mm_fields_config( self, hf_inputs: BatchFeature, diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index b1fee3eeb542f..dcd90474e9364 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -294,7 +294,7 @@ def _call_hf_processor( pixel_values = processed_outputs.get("pixel_values") if pixel_values is not None: # Before/after https://github.com/huggingface/transformers/pull/35122 - if Version(TRANSFORMERS_VERSION) <= Version("4.48.2"): + if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"): images = mm_data["images"] assert isinstance(images, list) @@ -819,7 +819,6 @@ def get_replacement_mantis(item_idx: int): prompt_ids, mm_item_counts, ) - self._validate_mm_placeholders(mm_placeholders, mm_item_counts) mm_placeholder_ranges = { diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 2889426283f84..084d4d51ad236 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -299,36 +299,69 @@ def _call_hf_processor( mm_kwargs=mm_kwargs, ) + # LLaVA-OneVision processor doesn't support multiple videos + # with different sizes when converting back to tensors + # So, we process each component separately + # NOTE: No prompt replacement is applied in this case processor = self.info.get_hf_processor() + image_token = processor.image_token video_token = processor.video_token - # LLaVA-OneVision processor doesn't support multiple videos - # with different sizes when converting back to tensors - text_image_outputs = super()._call_hf_processor( + text_outputs = super()._call_hf_processor( prompt=prompt, - mm_data=mm_data, + mm_data={}, mm_kwargs=mm_kwargs, ) + images = mm_data.pop("images", []) + assert isinstance(images, list) + if images: + processor_outputs = super()._call_hf_processor( + prompt=image_token * len(images), + mm_data={"images": images}, + mm_kwargs=mm_kwargs, + ) + image_outputs = { + k: v + for k, v in processor_outputs.items() + if k in ("pixel_values", "image_sizes") + } + else: + image_outputs = {} + pixel_values_videos = [] for video in videos: - item_processor_data = dict(prompt=video_token, videos=video) - item_outputs = super()._call_hf_processor( - prompt=prompt, - mm_data=item_processor_data, + prompt=video_token, + mm_data={"videos": video}, mm_kwargs=mm_kwargs, ) - pixel_values_videos.append( - item_outputs.pop("pixel_values_videos")[0]) + pixel_values_videos.append(item_outputs["pixel_values_videos"][0]) + + video_outputs = {"pixel_values_videos": pixel_values_videos} combined_outputs = dict( - **text_image_outputs, - pixel_values_videos=pixel_values_videos, + text_outputs, + **image_outputs, + **video_outputs, ) return BatchFeature(combined_outputs) + def _hf_processor_applies_repl( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> bool: + base_result = super()._hf_processor_applies_repl( + prompt_text=prompt_text, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return base_result and mm_items.get_count("video", strict=False) == 0 + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index ab697fb8cc645..473881f955465 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -27,8 +27,8 @@ Tuple, TypedDict, Union) import torch -import torch.types from torch import nn +from transformers import BatchFeature from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.whisper.modeling_whisper import ( ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder) @@ -37,23 +37,21 @@ from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import MultiModalFieldConfig -from vllm.multimodal.parse import (ModalityData, ModalityDataItems, - MultiModalDataItems, MultiModalDataParser, - VideoItem) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - PromptReplacement) +from vllm.multimodal.parse import (AudioItem, DictEmbeddingItems, ModalityData, + ModalityDataItems, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import PromptReplacement from vllm.multimodal.profiling import ProcessorInputs from vllm.sequence import IntermediateTensors from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, - MiniCPMVEmbeddingItems, MiniCPMVMultiModalDataParser, - MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo) + MiniCPMVMultiModalDataParser, + MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo, + _minicpmv_field_config) from .utils import AutoWeightsLoader, maybe_prefix CPU_DEVICE = torch.device("cpu") -MiniCPMOEmbeddingItems = MiniCPMVEmbeddingItems - class MiniCPMOAudioFeatureInputs(TypedDict): type: Literal["audio_features"] @@ -103,28 +101,49 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict): MiniCPMOAudioEmbeddingInputs] -class MiniCPMOAudioEmbeddingItems(MiniCPMOEmbeddingItems): +def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): + audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0)) + + return dict( + **_minicpmv_field_config(hf_inputs), + audio_features=MultiModalFieldConfig.flat_from_sizes( + "audio", audio_num_slices), + audio_feature_lens=MultiModalFieldConfig.flat_from_sizes( + "audio", audio_num_slices), + audio_num_slices=MultiModalFieldConfig.batched("audio"), + audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"), + audio_embeds=MultiModalFieldConfig.flat_from_sizes( + "audio", audio_num_slices), + ) + - def __init__(self, data: Dict) -> None: - super().__init__(data, "audio") - audio_embeds = self.data.get("audio_embeds", None) - if audio_embeds is None: - raise ValueError("Incorrect type of video_embeds", - "Got type: None") - self.data["audio_embeds"] = audio_embeds +class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems): - def get(self, index: int) -> object: - return self.data["audio_embeds"][index] + def __init__( + self, + data: Mapping[str, torch.Tensor], + fields_config: Mapping[str, MultiModalFieldConfig], + ) -> None: + super().__init__( + data, + modality="image", + fields_config=fields_config, + required_fields={"audio_embeds"}, + ) class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): def _parse_audio_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], + data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): - return MiniCPMOAudioEmbeddingItems(data) + return MiniCPMOAudioEmbeddingItems( + data, + fields_config=_minicpmo_field_config(data), + ) + return super()._parse_audio_data(data) @@ -167,6 +186,10 @@ def get_max_audio_tokens_per_chunk(self) -> int: def get_max_audio_chunks_with_most_features(self) -> int: return 30 + def get_max_audio_tokens(self) -> int: + return self.get_max_audio_tokens_per_chunk( + ) * self.get_max_audio_chunks_with_most_features() + def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: sampling_rate = self.get_default_audio_sampling_rate() # exclude @@ -194,7 +217,8 @@ def get_num_frames_with_most_features(self, seq_len: int) -> int: return num_frames -class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder): +class MiniCPMODummyInputsBuilder( + MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]): def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, @@ -222,8 +246,7 @@ def get_dummy_processor_inputs( class MiniCPMOMultiModalProcessor( - MiniCPMVMultiModalProcessor, - BaseMultiModalProcessor[MiniCPMOProcessingInfo]): + MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: return MiniCPMOMultiModalDataParser( @@ -369,21 +392,10 @@ def get_replacement_minicpmv(item_idx: int, modality: str): def _get_mm_fields_config( self, - hf_inputs, + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0)) - - return dict( - **super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs), - audio_features=MultiModalFieldConfig.flat_from_sizes( - "audio", audio_num_slices), - audio_feature_lens=MultiModalFieldConfig.flat_from_sizes( - "audio", audio_num_slices), - audio_num_slices=MultiModalFieldConfig.batched("audio"), - audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"), - audio_embeds=MultiModalFieldConfig.flat_from_sizes( - "audio", audio_num_slices)) + return _minicpmo_field_config(hf_inputs) class MultiModalProjector(nn.Module): @@ -406,7 +418,7 @@ def forward(self, audio_features: torch.Tensor) -> torch.Tensor: class MiniCPMWhisperEncoderLayer(nn.Module): - def __init__(self, config: WhisperConfig, layer_idx: int = None): + def __init__(self, config: WhisperConfig, layer_idx: int): super().__init__() self.embed_dim = config.d_model self.self_attn = WHISPER_ATTENTION_CLASSES[ diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 58a4448d436aa..77ac9eb467be6 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -35,6 +35,7 @@ from PIL import Image from torch import nn from transformers import BatchFeature, PretrainedConfig +from typing_extensions import TypeVar from vllm.attention import AttentionMetadata from vllm.config import VllmConfig @@ -51,9 +52,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, PlaceholderRange) -from vllm.multimodal.parse import (ImageItem, ImageSize, ModalityData, - ModalityDataItems, MultiModalDataItems, - MultiModalDataParser, VideoItem) +from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageSize, + ModalityData, ModalityDataItems, + MultiModalDataItems, MultiModalDataParser, + VideoItem) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs @@ -115,93 +117,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict): MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs] - -class MiniCPMVEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], - dict[str, torch.Tensor]]): - - def __init__(self, data: Dict, modality: str) -> None: - super().__init__(data, modality) - - def get_processor_data(self) -> Mapping[str, object]: - return self.data - - def get_passthrough_data(self) -> Mapping[str, object]: - return {} - - def get_count(self) -> int: - return len(self.data[f"{self.modality}_embeds"]) - - def get(self, index: int) -> Dict[str, torch.Tensor]: - out = {} - for k, v in self.data.items(): - out[k] = v[index] - return out - - -class MiniCPMVImageEmbeddingItems(MiniCPMVEmbeddingItems): - - def __init__(self, data: Dict) -> None: - super().__init__(data, "image") - image_embeds = self.data.get("image_embeds", None) - image_sizes = self.data.get("image_sizes", None) - if image_embeds is None: - raise ValueError("In correct type of image_embeds", - "Got type: None") - if not isinstance(image_embeds[0], torch.Tensor): - raise ValueError("In correct type of image_embeds", - f"Got type: {type(image_embeds[0])}") - if image_sizes is None: - raise ValueError( - "In correct type of image_sizes", "Got type: None." - "If you're using `image_size_list`, " - "please rename it to `image_sizes`") - if len(image_embeds[0].shape) == 2: - image_embeds = [image_embeds] - image_sizes = [image_sizes] - self.data["image_embeds"] = image_embeds - self.data["image_sizes"] = image_sizes - - def get_image_size(self, index: int) -> ImageSize: - image_size = self.data["image_sizes"][index] - return ImageSize(width=image_size[0], height=image_size[1]) - - -class MiniCPMVVideoEmbeddingItems(MiniCPMVEmbeddingItems): - - def __init__(self, data: Dict) -> None: - super().__init__(data, "video") - video_embeds = self.data.get("video_embeds", None) - image_sizes = self.data.get("image_sizes", None) - num_frames = self.data.get("num_frames", None) - if video_embeds is None: - raise ValueError("In correct type of video_embeds", - "Got type: None") - if not isinstance(video_embeds[0], torch.Tensor): - raise ValueError("In correct type of video_embeds", - f"Got type: {type(video_embeds[0])}") - if image_sizes is None: - raise ValueError( - "In correct type of image_sizes", "Got type: None." - "If you're using `image_size_list`, " - "please rename it to `image_sizes`") - if num_frames is None: - raise ValueError("In correct type of numframes", "Got type: None") - if len(video_embeds[0].shape) == 2: - video_embeds = [video_embeds] - image_sizes = [image_sizes] - num_frames = [num_frames] - self.data["video_embeds"] = video_embeds - self.data["image_sizes"] = image_sizes - self.data["num_frames"] = num_frames - - def get_frame_size(self, index: int) -> ImageSize: - frame_size = self.data["image_sizes"][index] - return ImageSize(width=frame_size[0], height=frame_size[1]) - - def get_num_frames(self, index: int) -> int: - return self.data["num_frames"][index] - - DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) @@ -311,6 +226,71 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: return tuple(int(x) for x in version_str.split(".")) +def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): + image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0)) + video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0)) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_num_slices), + image_sizes=MultiModalFieldConfig.batched("image"), + tgt_sizes=MultiModalFieldConfig.flat_from_sizes( + "image", image_num_slices), + image_num_slices=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_num_slices), + video_pixel_values=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_slices), + video_image_sizes=MultiModalFieldConfig.batched("video"), + video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_slices), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_slices), + video_num_slices=MultiModalFieldConfig.batched("video"), + ) + + +class MiniCPMVImageEmbeddingItems(DictEmbeddingItems): + + def __init__( + self, + data: Mapping[str, torch.Tensor], + fields_config: Mapping[str, MultiModalFieldConfig], + ) -> None: + super().__init__( + data, + modality="image", + fields_config=fields_config, + required_fields={"image_embeds", "image_sizes"}, + ) + + def get_image_size(self, index: int) -> ImageSize: + image_size = self.get(index)["image_sizes"].tolist() + return ImageSize(width=image_size[0], height=image_size[1]) + + +class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems): + + def __init__( + self, + data: Mapping[str, torch.Tensor], + fields_config: Mapping[str, MultiModalFieldConfig], + ) -> None: + super().__init__( + data, + modality="video", + fields_config=fields_config, + required_fields={"video_embeds", "video_image_sizes"}, + ) + + def get_frame_size(self, index: int) -> ImageSize: + frame_size = self.get(index)["video_image_sizes"].tolist() + return ImageSize(width=frame_size[0], height=frame_size[1]) + + def get_num_frames(self, index: int) -> int: + return len(self.get(index)["video_image_sizes"]) + + class MiniCPMVMultiModalDataParser(MultiModalDataParser): def _parse_image_data( @@ -318,7 +298,11 @@ def _parse_image_data( data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): - return MiniCPMVImageEmbeddingItems(data) + return MiniCPMVImageEmbeddingItems( + data, + fields_config=_minicpmv_field_config(data), + ) + return super()._parse_image_data(data) def _parse_video_data( @@ -326,7 +310,11 @@ def _parse_video_data( data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): - return MiniCPMVVideoEmbeddingItems(data) + return MiniCPMVVideoEmbeddingItems( + data, + fields_config=_minicpmv_field_config(data), + ) + return super()._parse_video_data(data) @@ -392,10 +380,6 @@ def get_max_video_tokens(self, seq_len: int) -> int: return self.get_max_video_frame_tokens( ) * self.get_num_frames_with_most_features(seq_len) - def get_max_audio_tokens(self) -> int: - return self.get_max_audio_tokens_per_chunk( - ) * self.get_max_audio_chunks_with_most_features() - def get_slice_query_num(self) -> int: hf_config = self.get_hf_config() query_num = getattr(hf_config, "query_num", 64) @@ -476,8 +460,12 @@ def get_default_image_sizes(self, num_slices: int) -> ImageSize: return ImageSize(width=image_size, height=image_size * num_slices) -class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo] - ): +_I = TypeVar("_I", + bound=MiniCPMVProcessingInfo, + default=MiniCPMVProcessingInfo) + + +class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): def get_dummy_processor_inputs( self, @@ -514,8 +502,7 @@ def get_dummy_processor_inputs( mm_data=mm_data) -class MiniCPMVMultiModalProcessor( - BaseMultiModalProcessor[MiniCPMVProcessingInfo]): +class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): def _get_data_parser(self) -> MultiModalDataParser: return MiniCPMVMultiModalDataParser() @@ -675,7 +662,7 @@ def get_num_slices_by_modality(self, inputs: Dict[str, object], self.info.get_video_max_slice_num() ) * inputs[modality]["num_frames"][index] else: - raise ValueError(f"UnExpected modality: {modality}") + raise ValueError(f"Unexpected modality: {modality}") def check_mm_inputs(self, inputs: Dict[str, object], matches: List[str]) -> None: @@ -700,7 +687,7 @@ def get_prompt_texts_by_modality(self, inputs: Dict[str, object], inputs["video"]["video_image_sizes"][index], inputs["video"]["num_frames"][index]) else: - raise ValueError(f"UnExpected modality: {modality}") + raise ValueError(f"Unexpected modality: {modality}") def call_base_hf_processor( self, @@ -742,6 +729,14 @@ def _call_hf_processor( } } + def _hf_processor_applies_repl( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> bool: + return False + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], @@ -770,28 +765,10 @@ def get_replacement_minicpmv(item_idx: int, modality: str): def _get_mm_fields_config( self, - hf_inputs, + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0)) - video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0)) - - return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_num_slices), - image_sizes=MultiModalFieldConfig.batched("image"), - tgt_sizes=MultiModalFieldConfig.flat_from_sizes( - "image", image_num_slices), - image_num_slices=MultiModalFieldConfig.batched("image"), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_num_slices), - video_pixel_values=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_slices), - video_image_sizes=MultiModalFieldConfig.batched("video"), - video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_slices), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_slices), - video_num_slices=MultiModalFieldConfig.batched("video")) + return _minicpmv_field_config(hf_inputs) def apply( self, diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index f09529ca4bd1f..cf79544e60e87 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -243,16 +243,6 @@ def get_replacement_qwen2_audio(item_idx: int): ) ] - def _always_apply_prompt_replacements(self) -> bool: - # Qwen2-Audio processor will start inserting placeholder tokens - # in an upcoming release: - # https://github.com/huggingface/transformers/pull/35534 - # NOTE: `_find_placeholders_by_modality` may incorrectly think that HF - # has already performed processing for multi-audio input when the input - # audios are short (the corresponding placeholders may take up fewer - # tokens than the number of audio items) - return not hasattr(self.info.get_hf_processor(), "audio_token") - @MULTIMODAL_REGISTRY.register_processor( Qwen2AudioMultiModalProcessor, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 961f53cef1379..ce927fbbf1232 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -58,8 +58,9 @@ from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalFieldConfig, MultiModalKwargs, VideoItem) -from vllm.multimodal.parse import (ImageSize, ModalityDataItems, - MultiModalDataItems, MultiModalDataParser) +from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, + ModalityDataItems, MultiModalDataItems, + MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs @@ -657,49 +658,25 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -class Qwen2VLEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], - dict[str, torch.Tensor]]): - - def __init__(self, data: dict, modality: str) -> None: - super().__init__(data, modality) - - grid_thw = data[f"{modality}_grid_thw"] - slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist() - self._slices = [ - slice(slice_idxs[i], slice_idxs[i + 1]) - for i in range(len(grid_thw)) - ] - - def get_count(self) -> int: - return len(self.data[f"{self.modality}_grid_thw"]) - - def get(self, index: int) -> dict[str, torch.Tensor]: - out = {} - for k, v in self.data.items(): - if v != f"{self.modality}_grid_thw": - v = v[self._slices[index]] - - out[k] = v - - return out - - def get_processor_data(self) -> Mapping[str, object]: - return {} - - def get_passthrough_data(self) -> Mapping[str, object]: - return self.data - - -class Qwen2VLImageEmbeddingItems(Qwen2VLEmbeddingItems): - - def __init__(self, data: dict) -> None: - super().__init__(data, "image") - - -class Qwen2VLVideoEmbeddingItems(Qwen2VLEmbeddingItems): - - def __init__(self, data: dict) -> None: - super().__init__(data, "video") +def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) class Qwen2VLMultiModalDataParser(MultiModalDataParser): @@ -709,7 +686,12 @@ def _parse_image_data( data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): - return Qwen2VLEmbeddingItems(data, modality="image") + return DictEmbeddingItems( + data, + modality="image", + fields_config=_qwen2vl_field_config(data), + required_fields={"image_embeds", "image_grid_thw"}, + ) return super()._parse_image_data(data) @@ -718,7 +700,12 @@ def _parse_video_data( data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): - return Qwen2VLEmbeddingItems(data, modality="video") + return DictEmbeddingItems( + data, + modality="video", + fields_config=_qwen2vl_field_config(data), + required_fields={"video_embeds", "video_grid_thw"}, + ) return super()._parse_video_data(data) @@ -999,24 +986,7 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) - image_grid_sizes = image_grid_thw.prod(-1) - - video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) - video_grid_sizes = video_grid_thw.prod(-1) - - return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_grid_thw=MultiModalFieldConfig.batched("video"), - ) + return _qwen2vl_field_config(hf_inputs) @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 5316eb7e002bc..0f4f5072fb2b4 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -520,10 +520,7 @@ def get_tokenizer(self) -> PreTrainedTokenizer: return _get_tokenizer_without_image_pad(tokenizer) def get_hf_processor(self) -> QwenVLProcessor: - tokenizer = self.ctx.tokenizer - assert isinstance(tokenizer, PreTrainedTokenizer) - - return QwenVLProcessor(self.get_hf_config(), tokenizer) + return QwenVLProcessor(self.get_hf_config(), self.get_tokenizer()) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} @@ -605,6 +602,14 @@ def _call_hf_processor( mm_kwargs=mm_kwargs, ) + def _hf_processor_applies_repl( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> bool: + return False + def _get_mm_fields_config( self, hf_inputs: BatchFeature, diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 063f458b2c4d9..fb07c5c6a25d6 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -9,13 +9,15 @@ import numpy as np import torch from PIL.Image import Image +from transformers import BatchFeature from typing_extensions import TypeAlias, TypeGuard, assert_never from vllm.utils import is_list_of from .audio import resample_audio from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, - ImageItem, ModalityData, MultiModalDataDict, VideoItem) + ImageItem, ModalityData, MultiModalDataDict, + MultiModalFieldConfig, MultiModalKwargs, VideoItem) _T = TypeVar("_T") _I = TypeVar("_I") @@ -111,6 +113,60 @@ def get_feature_size(self, item_idx: int) -> int: return len(self.get(item_idx)) +class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], + Mapping[str, torch.Tensor]]): + """ + Base class for data items that are expressed as a dictionary of tensors. + + Usually, the dictionary keys correspond to the outputs of HF processor. + """ + + def __init__( + self, + data: Mapping[str, torch.Tensor], + modality: str, + fields_config: Mapping[str, MultiModalFieldConfig], + required_fields: set[str], + ) -> None: + super().__init__(data, modality) + + missing_required_fields = required_fields - fields_config.keys() + if missing_required_fields: + fields = set(fields_config.keys()) + msg = f"{required_fields=} should be a subset of {fields=}" + raise ValueError(msg) + + missing_required_data_keys = required_fields - data.keys() + if missing_required_data_keys: + data_keys = set(data.keys()) + msg = (f"The data should contain the fields: {required_fields}, " + f"but only found the following keys: {data_keys}") + raise ValueError(msg) + + self.fields_config = fields_config + self.required_fields = required_fields + + self._kwargs = MultiModalKwargs.from_hf_inputs( + BatchFeature(dict(data)), + fields_config, + ) + + def get_count(self) -> int: + return self._kwargs.get_item_count(self.modality) + + def get(self, index: int) -> Mapping[str, torch.Tensor]: + return { + k: v.data + for k, v in self._kwargs.get_item(self.modality, index).items() + } + + def get_processor_data(self) -> Mapping[str, object]: + return {} + + def get_passthrough_data(self) -> Mapping[str, object]: + return self.data + + class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): def __init__(self, data: Sequence[HfAudioItem]) -> None: diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 74479f5ffad50..fcd02fbd5203c 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -23,7 +23,8 @@ from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem, PlaceholderRange) -from .parse import MultiModalDataItems, MultiModalDataParser +from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, + MultiModalDataParser) if TYPE_CHECKING: from .profiling import BaseDummyInputsBuilder @@ -830,15 +831,34 @@ def _call_hf_processor( mm_kwargs, ) + def _hf_processor_applies_repl( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> bool: + """ + Return whether the HF processor applies prompt replacements. + + For most HF processors, this should be :code:`True` when multi-modal + data items are passed, but :code:`False` when multi-modal embeddings + are passed. + """ + return not any( + isinstance(items, (EmbeddingItems, DictEmbeddingItems)) + for items in mm_items.values()) + def _apply_hf_processor_text_mm( self, prompt_text: str, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - ) -> tuple[list[int], MultiModalKwargs]: + ) -> tuple[list[int], MultiModalKwargs, bool]: """ Apply the HF processor on the prompt text and multi-modal data together. + + In addition, return whether prompt replacements have been applied. """ processor_data, passthrough_data = self._get_hf_mm_data(mm_items) @@ -856,7 +876,13 @@ def _apply_hf_processor_text_mm( self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), ) - return prompt_ids, mm_kwargs + is_repl_applied = self._hf_processor_applies_repl( + prompt_text=prompt_text, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return prompt_ids, mm_kwargs, is_repl_applied def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]: """ @@ -866,7 +892,7 @@ def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]: correspond to each other, we create dummy multi-modal items to go along with the text. """ - prompt_ids, _ = self._apply_hf_processor_text_mm( + prompt_ids, _, _ = self._apply_hf_processor_text_mm( prompt_text=prompt_text, mm_items=MultiModalDataItems({}), hf_processor_mm_kwargs={}, @@ -908,7 +934,7 @@ def _apply_hf_processor_mm_only( mm_counts, ) - _, mm_kwargs = self._apply_hf_processor_text_mm( + _, mm_kwargs, _ = self._apply_hf_processor_text_mm( prompt_text=dummy_inputs.prompt_text, mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, @@ -923,13 +949,17 @@ def _apply_hf_processor_main( hf_processor_mm_kwargs: Mapping[str, object], *, enable_hf_prompt_replacement: bool, - ) -> tuple[list[int], MultiModalKwargs]: + ) -> tuple[list[int], MultiModalKwargs, bool]: """ Apply the HF processor on the prompt text and multi-modal data. + In addition, return whether prompt replacements have been applied + (for most HF processors, this should be :code:`True`). + Note: - If :code:`enable_hf_prompt_replacement=False`, the prompt should - correspond to the multi-modal items. + If :code:`enable_hf_prompt_replacement=False`, we use HF processor + to perform prompt replacement if available; HF processor requires + that the prompt corresponds to multi-modal items. """ if isinstance(prompt, str): if enable_hf_prompt_replacement: @@ -943,19 +973,19 @@ def _apply_hf_processor_main( else: prompt_ids = self._apply_hf_processor_tokens_only(prompt) - mm_missing_kwargs = self._apply_hf_processor_mm_only( + mm_kwargs = self._apply_hf_processor_mm_only( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, ) - return prompt_ids, mm_missing_kwargs + return prompt_ids, mm_kwargs, False def _cached_apply_hf_processor( self, prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - ) -> tuple[list[int], MultiModalKwargs]: + ) -> tuple[list[int], MultiModalKwargs, bool]: """ Apply the HF processor on the full prompt text, caching the results and reusing cached results. @@ -992,8 +1022,13 @@ def _cached_apply_hf_processor( mm_missing_data_items = self._to_mm_items(mm_missing_data) # NOTE: `prompt` does not correspond to `mm_missing_data_items`, - # so we need to pass `enable_hf_prompt_replacement=False` - prompt_ids, mm_missing_kwargs = self._apply_hf_processor_main( + # so we can't apply prompt replacements until the new multimodal + # items are combined with the cached multimodal items + ( + prompt_ids, + mm_missing_kwargs, + is_repl_applied, + ) = self._apply_hf_processor_main( prompt=prompt, mm_items=mm_missing_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, @@ -1036,7 +1071,7 @@ def _cached_apply_hf_processor( mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) - return prompt_ids, mm_kwargs + return prompt_ids, mm_kwargs, is_repl_applied def _bind_and_group_repls( self, @@ -1047,18 +1082,6 @@ def _bind_and_group_repls( it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls) return dict(full_groupby_modality(it)) - def _always_apply_prompt_replacements(self) -> bool: - """ - A flag which can be overridden so that - :meth:`_apply_prompt_replacements` is always called even if we - detect that HF has performed processing via - :meth:`_find_placeholders_by_modality`. - - This is useful in cases where :meth:`_find_placeholders_by_modality` - cannot be reliably used to detect whether HF has performed processing. - """ - return False - def _apply_prompt_replacements( self, token_ids: list[int], @@ -1155,29 +1178,21 @@ def _validate_mm_placeholders( self, mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], mm_item_counts: Mapping[str, int], - *, - allow_missing: bool = False, - ) -> Mapping[str, int]: - missing_repl_counts = dict[str, int]() - + ) -> None: for modality, item_count in mm_item_counts.items(): placeholders = mm_placeholders.get(modality, []) - if len(placeholders) != item_count and not allow_missing: + if len(placeholders) != item_count: raise RuntimeError( f"Expected there to be {item_count} prompt replacements " - f"corresponding to {item_count} {modality} items, but only " - f"found {len(placeholders)} prompt replacements! Either " - "the prompt text has missing/incorrect tokens for " + f"corresponding to {item_count} {modality} items, but " + f"instead found {len(placeholders)} prompt replacements! " + "Either the prompt text has missing/incorrect tokens for " "multi-modal inputs, or there is a problem with your " "implementation of merged multi-modal processor for this " "model (usually arising from an inconsistency between " "`_call_hf_processor` and `_get_prompt_replacements`).") - missing_repl_counts[modality] = item_count - len(placeholders) - - return missing_repl_counts - def apply( self, prompt: Union[str, list[int]], @@ -1217,7 +1232,11 @@ def apply( else: mm_hashes = None - prompt_ids, mm_kwargs = self._cached_apply_hf_processor( + ( + prompt_ids, + mm_kwargs, + is_repl_applied, + ) = self._cached_apply_hf_processor( prompt, mm_items, hf_processor_mm_kwargs, @@ -1233,52 +1252,27 @@ def apply( mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) - hf_mm_placeholders = self._find_mm_placeholders( - mm_prompt_repls, - prompt_ids, - mm_item_counts, - ) - - if self._always_apply_prompt_replacements(): - mm_missing_repl_counts = mm_item_counts - mm_missing_repls = dict(mm_prompt_repls) - else: - mm_missing_repl_counts = self._validate_mm_placeholders( - hf_mm_placeholders, + if is_repl_applied: + mm_placeholders = self._find_mm_placeholders( + mm_prompt_repls, + prompt_ids, mm_item_counts, - allow_missing=True, ) + self._validate_mm_placeholders(mm_placeholders, mm_item_counts) - mm_missing_repls = dict[str, list[BoundPromptReplacement]]() - for modality, missing_repl_count in mm_missing_repl_counts.items(): - if missing_repl_count == 0: - mm_missing_repls[modality] = [] - elif missing_repl_count == mm_item_counts.get(modality, 0): - mm_missing_repls[modality] = mm_prompt_repls[modality] - else: - raise ValueError("Partial prompt replacement within " - f"{modality=} is not supported") - - # If HF processor already inserts placeholder tokens, - # there is no need for us to insert them - if all(len(repls) == 0 for repls in mm_missing_repls.values()): tokenizer = self.info.get_tokenizer() prompt = decode_tokens(tokenizer, prompt_ids) - mm_placeholders = hf_mm_placeholders else: ( prompt_ids, prompt, - missing_mm_placeholders, + mm_placeholders, ) = self._apply_prompt_replacements( prompt_ids, - mm_missing_repls, - mm_missing_repl_counts, + mm_prompt_repls, + mm_item_counts, ) - - mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders} - - self._validate_mm_placeholders(mm_placeholders, mm_item_counts) + self._validate_mm_placeholders(mm_placeholders, mm_item_counts) mm_placeholder_ranges = { modality: [item.to_range() for item in placeholders]