From 040a12f73d5cc9c26f8495179f9fc72b49b88a6a Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 21 Oct 2024 14:20:07 -0400 Subject: [PATCH] [Model][Bugfix] Fix batching with multi-image in PixtralHF (#9518) Signed-off-by: Tyler Michael Smith --- vllm/model_executor/models/llava.py | 60 +++++++++++++++++++++------ vllm/model_executor/models/pixtral.py | 11 ++--- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index a83b7d05df7aa..a666dcba290f2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -287,6 +287,34 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: return data + def _validate_image_sizes(self, images: List[torch.Tensor], + sizes: List[torch.Tensor]) -> List[torch.Tensor]: + if not isinstance(sizes, list): + sizes = [sizes] + + total_images = sum(size.numel() // 2 for size in sizes) + if total_images != len(images): + raise ValueError("Mismatch in number of images. " + f"Expected {total_images}, got {len(images)}") + img_idx = 0 + for size in sizes: + # Flatten the size tensor to a list of (height, width) pairs + size = size.view(-1, 2).tolist() + for expected_h, expected_w in size: + if img_idx >= len(images): + raise ValueError("Ran out of images before sizes. " + f"{img_idx} >= {len(images)}") + img = images[img_idx] + if img.shape[-2:] != (expected_h, expected_w): + raise ValueError( + "Image size mismatch. Expected " + f"{(expected_h, expected_w)}, got {img.shape[-2:]}") + if img.shape[-3] != 3: + raise ValueError("Image channel mismatch. Expected 3, " + f"got {img.shape[-3]}") + img_idx += 1 + return images + def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -305,20 +333,28 @@ def _parse_and_validate_image_input( # so we need to produce a list of tensors if image_sizes is not None: images = pixel_values - if isinstance(images, torch.Tensor): - # if passed as batch take all images - NN, N, B, C, W, H = images.shape - images = images.reshape(NN * N * B, C, W, H) - images = [images[i] for i in range(images.size(0))] - elif isinstance(images, list): - # if passed as list flatten lists of tensors - while isinstance(images, list) and len(images) == 1: - images = images[0] - - # TODO: Add validation based on image_sizes + + def flatten_to_3d_tensors(item): + if isinstance(item, torch.Tensor): + if item.dim() >= 3: + return [t for t in item.view(-1, *item.shape[-3:])] + else: + raise ValueError( + f"Unexpected tensor dimension: {item.dim()}") + elif isinstance(item, list): + return [ + t for subitem in item + for t in flatten_to_3d_tensors(subitem) + ] + else: + raise ValueError(f"Unexpected type: {type(item)}") + + # Restructure the batched images into a list of lists of images + images = flatten_to_3d_tensors(pixel_values) + return LlavaImagePixelInputs( type="pixel_values", - data=images, + data=self._validate_image_sizes(images, image_sizes), ) return LlavaImagePixelInputs( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 13c5149a63919..f33871c0d5acc 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -907,17 +907,18 @@ def forward( ) -> torch.Tensor: """ Args: - pixel_values: tensor of token features for - all tokens of all images of shape (N_toks, D) + pixel_values: Each image to be processed will be a separate tensor + in pixel_values. This means it will be a list of tensors + because multiple requests batched can have multiple images, + each with their own shape potentially + Returns: image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently patch_embeds_list = [ - self.patch_conv( - img.reshape(-1, img.shape[-3], img.shape[-2], - img.shape[-1]).to(self.dtype)) + self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values ]