diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 5a474043078db..ca894819f2c26 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -538,7 +538,7 @@ Text Generation - ✅︎ * - :code:`Qwen2VLForConditionalGeneration` - Qwen2-VL - - T + I\ :sup:`E+` + V\ :sup:`+` + - T + I\ :sup:`E+` + V\ :sup:`E+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - ✅︎ - ✅︎ diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/test_qwen2_vl.py new file mode 100644 index 0000000000000..718c675b86fb4 --- /dev/null +++ b/tests/models/decoder_only/vision_language/test_qwen2_vl.py @@ -0,0 +1,428 @@ +from typing import Any, List, Optional, Tuple, Type, TypedDict, Union + +import numpy.typing as npt +import pytest +import torch +from PIL import Image + +from vllm.entrypoints.llm import LLM +from vllm.multimodal.utils import (rescale_image_size, rescale_video_size, + sample_frames_from_video) + +from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput, + PromptVideoInput, VllmRunner) +from ...utils import check_logprobs_close + +models = ["Qwen/Qwen2-VL-2B-Instruct"] +target_dtype = "half" + +IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>" +VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>" + + +def qwen2_vl_chat_template(*query): + return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501 + + +IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + qwen2_vl_chat_template( + IMAGE_PLACEHOLDER, + "What is the biggest text's content in this image?", + ), + "cherry_blossom": + qwen2_vl_chat_template( + IMAGE_PLACEHOLDER, + "What is the season shown in this image? ", + "Reply with a short sentence (no more than 20 words)", + ), +}) + +VIDEO_PROMPTS = VIDEO_ASSETS.prompts({ + "sample_demo_1": + qwen2_vl_chat_template( + VIDEO_PLACEHOLDER, + "Describe this video with a short sentence ", + "(no more than 20 words)", + ), +}) + +MULTIIMAGE_PROMPT = qwen2_vl_chat_template( + IMAGE_PLACEHOLDER, + IMAGE_PLACEHOLDER, + "Describe these two images separately. ", + "For each image, reply with a short sentence ", + "(no more than 10 words).", +) + + +class Qwen2VLPromptImageEmbeddingInput(TypedDict): + image_embeds: torch.Tensor + image_grid_thw: torch.Tensor + + +class Qwen2VLPromptVideoEmbeddingInput(TypedDict): + video_embeds: torch.Tensor + video_grid_thw: torch.Tensor + + +def batch_make_image_embeddings( + image_batches: List[Union[Image.Image, List[Image.Image]]], processor, + llm: LLM) -> List[Qwen2VLPromptImageEmbeddingInput]: + """batched image embeddings for Qwen2-VL + + This will infer all images' embeddings in a single batch, + and split the result according to input batches. + + image_batches: + - Single-image batches: `List[Image.Image]` + - Multiple-image batches: `List[List[Image.Image]]]` + + returns: `List[Qwen2VLPromptImageEmbeddingInput]` + """ + + image_batches_: List[Any] = image_batches[:] + + # convert single-image batches to multiple-image batches + for idx in range(len(image_batches_)): + if not isinstance(image_batches_[idx], list): + image_batches_[idx] = [image_batches_[idx]] + + assert isinstance(image_batches_[idx], list) + + # append all images into a list (as a batch) + images: List[Image.Image] = [] + for image_batch in image_batches_: + images += image_batch + + # image to pixel values + image_processor = processor.image_processor + + preprocess_result = image_processor \ + .preprocess(images=images, return_tensors="pt") \ + .data + pixel_values = preprocess_result["pixel_values"] + image_grid_thw = preprocess_result["image_grid_thw"] + + # pixel values to embeddinds & grid_thws + with torch.no_grad(): + visual = llm.llm_engine.model_executor.driver_worker. \ + model_runner.model.visual + + pixel_values_on_device = pixel_values.to(visual.device, + dtype=visual.dtype) + image_grid_thw_on_device = image_grid_thw.to(visual.device, + dtype=torch.int64) + image_embeds = visual(pixel_values_on_device, + grid_thw=image_grid_thw_on_device) + + # split into original batches + result: List[Qwen2VLPromptImageEmbeddingInput] = [] + image_counter = 0 + embed_counter = 0 + for image_batch in image_batches_: + cur_batch_image_count = len(image_batch) + merge_size = image_processor.merge_size + cur_batch_embed_len = sum([ + grid_thw.prod() // merge_size // merge_size + for grid_thw in image_grid_thw[image_counter:image_counter + + cur_batch_image_count] + ]) + + result.append({ + "image_embeds": + image_embeds[embed_counter:embed_counter + cur_batch_embed_len], + "image_grid_thw": + image_grid_thw[image_counter:image_counter + + cur_batch_image_count], + }) + + embed_counter += cur_batch_embed_len + image_counter += cur_batch_image_count + + # ensure we don't lost any images or embeddings + assert embed_counter == image_embeds.size(0) + assert image_counter == image_grid_thw.size(0) + assert len(image_batches) == len(result) + + return result + + +def batch_make_video_embeddings( + video_batches: PromptVideoInput, processor, + llm: LLM) -> List[Qwen2VLPromptVideoEmbeddingInput]: + """batched video embeddings for Qwen2-VL + + A NDArray represents a single video's all frames. + + This will infer all videos' embeddings in a single batch, + and split the result according to input batches. + + video_batches: + - Single-video batches: `List[NDArray]` + - Multiple-video batches: `List[List[NDArray]]` + """ + + video_batches_: List[Any] = video_batches[:] + + for idx in range(len(video_batches_)): + if not isinstance(video_batches_[idx], list): + single_video_batch: List[npt.NDArray] = [video_batches_[idx]] + video_batches_[idx] = single_video_batch + + assert isinstance(video_batches_[idx], list) + + # append all videos into a list (as a batch) + videos: List[npt.NDArray] = [] + for video_batch in video_batches_: + videos += video_batch + + # video to pixel values + image_processor = processor.image_processor + + preprocess_result = image_processor \ + .preprocess(images=None, videos=videos, return_tensors="pt") \ + .data + pixel_values = preprocess_result["pixel_values_videos"] + video_grid_thw = preprocess_result["video_grid_thw"] + + # pixel values to embeddinds & grid_thws + with torch.no_grad(): + visual = llm.llm_engine.model_executor.driver_worker.\ + model_runner.model.visual + + pixel_values_on_device = pixel_values.to(visual.device, + dtype=visual.dtype) + video_grid_thw_on_device = video_grid_thw.to(visual.device, + dtype=torch.int64) + video_embeds = visual(pixel_values_on_device, + grid_thw=video_grid_thw_on_device) + + # split into original batches + result: List[Qwen2VLPromptVideoEmbeddingInput] = [] + video_counter = 0 + embed_counter = 0 + for video_batch in video_batches_: + cur_batch_video_count = len(video_batch) + merge_size = image_processor.merge_size + cur_batch_embed_len = sum([ + grid_thw.prod() // merge_size // merge_size + for grid_thw in video_grid_thw[video_counter:video_counter + + cur_batch_video_count] + ]) + + result.append({ + "video_embeds": + video_embeds[embed_counter:embed_counter + cur_batch_embed_len], + "video_grid_thw": + video_grid_thw[video_counter:video_counter + + cur_batch_video_count], + }) + + embed_counter += cur_batch_embed_len + video_counter += cur_batch_video_count + + # ensure we don't lost any videos or embeddings + assert embed_counter == video_embeds.size(0) + assert video_counter == video_grid_thw.size(0) + assert len(video_batches) == len(result) + + return result + + +def run_test( + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + mm_limit: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between + original image/video input and image/video embeddings input. + """ + from transformers import AutoProcessor # noqa: F401 + + processor = AutoProcessor.from_pretrained(model) + + # NOTE: + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + task="generate", + max_model_len=4000, + max_num_seqs=3, + dtype=dtype, + limit_mm_per_prompt={ + "image": mm_limit, + "video": mm_limit + }, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend + ) as vllm_model: + + outputs_per_case_for_original_input = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images or None, + videos=videos or None) + for prompts, images, videos in inputs + ] + + outputs_per_case_for_embeddings_input = [ + vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=batch_make_image_embeddings( + images, processor, vllm_model.model) if images else None, + videos=batch_make_video_embeddings( + videos, processor, vllm_model.model) if videos else None) + for prompts, images, videos in inputs + ] + + for outputs_for_original_input, \ + outputs_for_embeddings_input \ + in zip(outputs_per_case_for_original_input, + outputs_per_case_for_embeddings_input): + check_logprobs_close( + outputs_0_lst=outputs_for_original_input, + outputs_1_lst=outputs_for_embeddings_input, + name_0="original_input", + name_1="embeddings_input", + ) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # Single-scale + [0.5], + # Single-scale, batched + [0.5, 0.5], + # Multi-scale + [0.25, 0.5, 0.5], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model, + size_factors, dtype: str, + max_tokens: int, + num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_case: List[Tuple[ + List[str], PromptImageInput, PromptVideoInput]] = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + [], + ) for image, prompt in zip(images, IMAGE_PROMPTS)] + + run_test( + vllm_runner, + inputs_per_case, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + ) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + [], + # Single-scale + [0.5], + # Single-scale, batched + [0.5, 0.5], + # Multi-scale + [0.25, 0.5, 0.5], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets, + model, size_factors, + dtype: str, max_tokens: int, + num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_case: List[Tuple[List[str], PromptImageInput, + PromptVideoInput]] = [( + [MULTIIMAGE_PROMPT for _ in size_factors], + [[ + rescale_image_size(image, factor) + for image in images + ] for factor in size_factors], + [], + )] + + run_test( + vllm_runner, + inputs_per_case, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=2, + tensor_parallel_size=1, + ) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # Single-scale + [0.5], + # Single-scale, batched + [0.5, 0.5], + # Multi-scale + [0.25, 0.25, 0.5], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model, + size_factors, dtype: str, + max_tokens: int, + num_logprobs: int) -> None: + num_frames = 4 + sampled_vids = [ + sample_frames_from_video(asset.np_ndarrays, num_frames) + for asset in video_assets + ] + + inputs_per_case: List[Tuple[ + List[str], PromptImageInput, PromptVideoInput]] = [( + [prompt for _ in size_factors], + [], + [rescale_video_size(video, factor) for factor in size_factors], + ) for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)] + + run_test( + vllm_runner, + inputs_per_case, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 13109758767df..1b162e7df8578 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -79,7 +79,7 @@ class Qwen2VLImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: torch.Tensor + pixel_values: torch.Tensor """Shape: `(num_patches, num_channels * patch_size * patch_size)` """ @@ -92,9 +92,22 @@ class Qwen2VLImagePixelInputs(TypedDict): class Qwen2VLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. + image_embeds: torch.Tensor + """Supported types: + - List[`torch.Tensor`]: A list of tensors holding all images' features. + Each tensor holds an image's features. + - `torch.Tensor`: A tensor holding all images' features + (concatenation of all images' feature tensors). + + Tensor shape: `(num_image_features, hidden_size)` + - `num_image_features` varies based on + the number and resolution of the images. + - `hidden_size` must match the hidden size of language model backbone. + """ + + image_grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. """ @@ -102,7 +115,8 @@ class Qwen2VLImageEmbeddingInputs(TypedDict): Qwen2VLImageEmbeddingInputs] -class Qwen2VLVideoInputs(TypedDict): +class Qwen2VLVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] pixel_values_videos: torch.Tensor """Shape: `(num_patches, @@ -116,6 +130,30 @@ class Qwen2VLVideoInputs(TypedDict): """ +class Qwen2VLVideoEmbeddingInputs(TypedDict): + type: Literal["video_embeds"] + video_embeds: torch.Tensor + """Supported types: + - List[`torch.Tensor`]: A list of tensors holding all videos' features. + Each tensor holds an video's features. + - `torch.Tensor`: A tensor holding all videos' features + (concatenation of all videos' feature tensors). + + Tensor shape: `(num_image_features, hidden_size)` + - `num_image_features` varies based on + the number and resolution of the videos. + - `hidden_size` must match the hidden size of language model backbone. + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs, + Qwen2VLVideoEmbeddingInputs] + # === Vision Encoder === # @@ -585,6 +623,12 @@ def mm_input_mapper_for_qwen2_vl( "image_embeds": data.get("image_embeds"), "image_grid_thw": data.get("image_grid_thw"), }) + if data_type_key == "video" and isinstance(data, dict): + return MultiModalKwargs({ + "video_embeds": data.get("video_embeds"), + "video_grid_thw": data.get("video_grid_thw"), + }) + model_config = ctx.model_config # Handle mm processor kwargs; we pass these at creation time # because preprocess() in transformers doesn't expose them @@ -890,16 +934,33 @@ def input_processor_for_qwen2_vl( idx for idx, token in enumerate(prompt_token_ids) if token == hf_config.image_token_id ] - image_cnt = len(image_indices) - embed_dim = image_inputs.get('image_embeds').size(0) - assert embed_dim % image_cnt == 0 - num_pad_tokens = embed_dim // image_cnt + + # ensure all image tokens have grid_thw + assert \ + len(image_indices) == image_inputs["image_grid_thw"].size(0), \ + "image token num does not match image_grid_thw.shape" + + image_counter = 0 + pad_token_counter = 0 for idx, token in enumerate(prompt_token_ids): if idx in image_indices: + grid_thw = image_inputs["image_grid_thw"][image_counter] + grid_t, grid_h, grid_w = grid_thw + num_pad_tokens = (grid_t * grid_h * grid_w // + image_processor.merge_size // + image_processor.merge_size) prompt_token_ids_with_image.extend([token] * num_pad_tokens) + image_counter += 1 + pad_token_counter += num_pad_tokens else: prompt_token_ids_with_image.append(token) + + # ensure all embeddings are used + assert \ + pad_token_counter == image_inputs["image_embeds"].size(0), \ + "image_embeds.shape does not match image_grid_thw" + prompt_token_ids = prompt_token_ids_with_image else: prompt_token_ids = _expand_pad_tokens(image_inputs, @@ -912,14 +973,49 @@ def input_processor_for_qwen2_vl( max_pixels=max_pixels) if video_inputs is not None: - prompt_token_ids = _expand_pad_tokens(video_inputs, - hf_config.video_token_id, - make_batched_videos, - "video", - image_processor, - prompt_token_ids, - min_pixels=min_pixels, - max_pixels=max_pixels) + if isinstance(video_inputs, dict): + prompt_token_ids_with_video = [] + video_indices = [ + idx for idx, token in enumerate(prompt_token_ids) + if token == hf_config.video_token_id + ] + + # ensure all video tokens have grid_thw + assert \ + len(video_indices) == video_inputs["video_grid_thw"].size(0), \ + "video token num does not match video_grid_thw.shape" + + video_counter = 0 + pad_token_counter = 0 + for idx, token in enumerate(prompt_token_ids): + if idx in video_indices: + grid_thw = video_inputs["video_grid_thw"][video_counter] + grid_t, grid_h, grid_w = grid_thw + num_pad_tokens = (grid_t * grid_h * grid_w // + image_processor.merge_size // + image_processor.merge_size) + prompt_token_ids_with_video.extend([token] * + num_pad_tokens) + video_counter += 1 + pad_token_counter += num_pad_tokens + else: + prompt_token_ids_with_video.append(token) + + # ensure all embeddings are used + assert \ + pad_token_counter == video_inputs["video_embeds"].size(0), \ + "video_embeds.shape does not match video_grid_thw" + + prompt_token_ids = prompt_token_ids_with_video + else: + prompt_token_ids = _expand_pad_tokens(video_inputs, + hf_config.video_token_id, + make_batched_videos, + "video", + image_processor, + prompt_token_ids, + min_pixels=min_pixels, + max_pixels=max_pixels) prompt = inputs.get("prompt") if prompt is None: @@ -1051,49 +1147,71 @@ def _parse_and_validate_image_input( f"Got type: {type(pixel_values)}") return Qwen2VLImagePixelInputs(type="pixel_values", - data=pixel_values, + pixel_values=pixel_values, image_grid_thw=image_grid_thw) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") return Qwen2VLImageEmbeddingInputs(type="image_embeds", - data=image_embeds) + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) def _parse_and_validate_video_input( self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) - if pixel_values_videos is None: + if pixel_values_videos is None and video_embeds is None: return None - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - - return Qwen2VLVideoInputs( - pixel_values_videos=pixel_values_videos, - video_grid_thw=video_grid_thw, - ) + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, torch.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Qwen2VLVideoEmbeddingInputs(type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw) def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: if image_input["type"] == "image_embeds": - return image_input["data"].type(self.visual.dtype) + return image_input["image_embeds"].type(self.visual.dtype) - pixel_values = image_input["data"].type(self.visual.dtype) + pixel_values = image_input["pixel_values"].type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) return image_embeds def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: + if video_input["type"] == "video_embeds": + return video_input["video_embeds"].type(self.visual.dtype) + pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) video_embeds = self.visual(pixel_values_videos,