From 365801feddaf5c4448704a1f55269dd992f5a4b1 Mon Sep 17 00:00:00 2001
From: Cyrus Leung <tlleungac@connect.ust.hk>
Date: Wed, 1 Jan 2025 14:15:21 +0800
Subject: [PATCH] [VLM] Add max-count checking in data parser for single image
 models (#11661)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
---
 docs/source/models/supported_models.md  |  2 +-
 tests/multimodal/test_processing.py     |  3 ++-
 vllm/model_executor/models/blip2.py     |  4 ++++
 vllm/model_executor/models/chameleon.py |  4 ++++
 vllm/model_executor/models/fuyu.py      | 18 +++++++++-------
 vllm/multimodal/parse.py                | 28 +++++++++++++++++++++++--
 6 files changed, 48 insertions(+), 11 deletions(-)

diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index f74c201bdff6b..7682ed104b8c5 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -566,7 +566,7 @@ See [this page](#generative-models) for more information on how to use generativ
   - [V1](gh-issue:8779)
 * - `AriaForConditionalGeneration`
   - Aria
-  - T + I
+  - T + I<sup>+</sup>
   - `rhymes-ai/Aria`
   -
   - ✅︎
diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py
index 81278cde264ff..1850ca46ccc8f 100644
--- a/tests/multimodal/test_processing.py
+++ b/tests/multimodal/test_processing.py
@@ -622,10 +622,11 @@ def _test_processing_cache_correctness(
 
 
 # yapf: disable
+# True if the model supports multiple data items of the modality per request
 @pytest.mark.parametrize(("model_id", "modalities"), [
     ("rhymes-ai/Aria", {"image": True}),
     ("Salesforce/blip2-opt-2.7b", {"image": False}),
-    ("facebook/chameleon-7b", {"image": True}),
+    ("facebook/chameleon-7b", {"image": False}),
     ("adept/fuyu-8b", {"image": False}),
     ("llava-hf/llava-1.5-7b-hf", {"image": True}),
     ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py
index bf70f5d904f5b..50680fadc4aa3 100644
--- a/vllm/model_executor/models/blip2.py
+++ b/vllm/model_executor/models/blip2.py
@@ -18,6 +18,7 @@
 from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                     MultiModalInputsV2, MultiModalKwargs,
                                     NestedTensors, PlaceholderRange)
+from vllm.multimodal.parse import MultiModalDataParser
 from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                         MultiModalDataItems, ProcessorInputs,
                                         PromptReplacement)
@@ -404,6 +405,9 @@ def get_max_blip2_image_tokens(ctx: InputContext):
 
 class Blip2MultiModalProcessor(BaseMultiModalProcessor):
 
+    def _get_data_parser(self) -> MultiModalDataParser:
+        return MultiModalDataParser(max_mm_counts={"image": 1})
+
     def _get_hf_processor(self) -> Blip2Processor:
         return self.ctx.get_hf_processor(Blip2Processor)
 
diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py
index 85fca23b05746..c731934e792fc 100644
--- a/vllm/model_executor/models/chameleon.py
+++ b/vllm/model_executor/models/chameleon.py
@@ -31,6 +31,7 @@
 from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                     MultiModalInputsV2, MultiModalKwargs,
                                     NestedTensors, PlaceholderRange)
+from vllm.multimodal.parse import MultiModalDataParser
 from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                         MultiModalDataItems, ProcessorInputs,
                                         PromptReplacement)
@@ -60,6 +61,9 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
 
 class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
 
+    def _get_data_parser(self) -> MultiModalDataParser:
+        return MultiModalDataParser(max_mm_counts={"image": 1})
+
     def _get_hf_processor(self) -> ChameleonProcessor:
         return self.ctx.get_hf_processor(ChameleonProcessor)
 
diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py
index 8c14866f20b92..0a48fa3fe11c0 100644
--- a/vllm/model_executor/models/fuyu.py
+++ b/vllm/model_executor/models/fuyu.py
@@ -34,7 +34,7 @@
 from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                     MultiModalInputsV2, MultiModalKwargs,
                                     NestedTensors, PlaceholderRange)
-from vllm.multimodal.parse import ImageProcessorItems
+from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataParser
 from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                         MultiModalDataItems, ProcessorInputs,
                                         PromptReplacement)
@@ -54,7 +54,7 @@
 
 class FuyuImagePatchInputs(TypedDict):
     type: Literal["image_patches"]
-    data: torch.Tensor
+    flat_data: torch.Tensor
     """
     Shape: 
     `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
@@ -63,7 +63,7 @@ class FuyuImagePatchInputs(TypedDict):
     patches_per_image: List[int]
     """
     List of number of total patches for each image in the batch.
-    This is used to restore the first two dimensions of `data`.
+    This is used to restore the first two dimensions of `flat_data`.
     """
 
 
@@ -102,6 +102,9 @@ def get_max_fuyu_image_tokens(ctx: InputContext):
 
 class FuyuMultiModalProcessor(BaseMultiModalProcessor):
 
+    def _get_data_parser(self) -> MultiModalDataParser:
+        return MultiModalDataParser(max_mm_counts={"image": 1})
+
     def _get_hf_processor(self) -> FuyuProcessor:
         return self.ctx.get_hf_processor(FuyuProcessor)
 
@@ -304,7 +307,7 @@ def _parse_and_validate_image_input(
 
             return FuyuImagePatchInputs(
                 type="image_patches",
-                data=self._validate_pixel_values(
+                flat_data=self._validate_pixel_values(
                     flatten_bn(image_patches_flat, concat=True)),
                 patches_per_image=[x.size(0) for x in image_patches_flat],
             )
@@ -313,12 +316,13 @@ def _parse_and_validate_image_input(
 
     def _process_image_input(
             self, image_input: FuyuImagePatchInputs) -> NestedTensors:
-        image_patches = image_input["data"]
+        image_patches_flat = image_input["flat_data"]
         patches_per_image = image_input["patches_per_image"]
 
         assert self.vision_embed_tokens is not None
-        vision_embeddings, _ = self.vision_embed_tokens(image_patches)
-        return vision_embeddings.split(patches_per_image, dim=0)
+        vision_embeddings_flat, _ = self.vision_embed_tokens(
+            image_patches_flat)
+        return vision_embeddings_flat.split(patches_per_image, dim=0)
 
     def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
         image_input = self._parse_and_validate_image_input(**kwargs)
diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py
index 17a795247372e..da111e999ebb8 100644
--- a/vllm/multimodal/parse.py
+++ b/vllm/multimodal/parse.py
@@ -220,11 +220,24 @@ def get_items(
 class MultiModalDataParser:
     """
     Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`.
+
+    Args:
+        max_mm_counts (Mapping[str, int]): The maximum allowed number of items
+            belonging to each modality. This effectively sets a hard limit over
+            `--limit-mm-per-prompt`.
+        target_sr (float, optional): Enables automatic resampling of audio
+            items to the model's expected sampling rate.
     """
 
-    def __init__(self, *, target_sr: Optional[float] = None) -> None:
+    def __init__(
+        self,
+        *,
+        max_mm_counts: Mapping[str, int] = {},
+        target_sr: Optional[float] = None,
+    ) -> None:
         super().__init__()
 
+        self.max_mm_counts = max_mm_counts
         self.target_sr = target_sr
 
     def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]:
@@ -332,6 +345,7 @@ def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
 
     def parse_mm_data(self,
                       mm_data: MultiModalDataDict) -> MultiModalDataItems:
+        max_mm_counts = self.max_mm_counts
         subparsers = self._get_subparsers()
 
         mm_items = MultiModalDataItems()
@@ -339,6 +353,16 @@ def parse_mm_data(self,
             if k not in subparsers:
                 raise ValueError(f"Unsupported modality: {k}")
 
-            mm_items[k] = subparsers[k](v)
+            modality_items = subparsers[k](v)
+
+            if k in max_mm_counts:
+                max_count = max_mm_counts[k]
+                if len(modality_items) > max_count:
+                    raise ValueError(
+                        f"This model supports at most {max_count} {k} items "
+                        f"per prompt, but {len(modality_items)} {k} items "
+                        "were given or set as its limit_mm_per_prompt.")
+
+            mm_items[k] = modality_items
 
         return mm_items