Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Expose Phi3v num_crops as a mm_processor_kwarg #8658

Merged
merged 33 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
550378b
Allow for processor kwarg overrides
alex-jw-brooks Sep 16, 2024
190606f
Pass processor through to partial
alex-jw-brooks Sep 17, 2024
b1ca041
Add default & processor kwarg override tests
alex-jw-brooks Sep 17, 2024
195e31c
Don't allow ctx or inputs as kwargs
alex-jw-brooks Sep 17, 2024
1472d04
Add kwarg override for processor to dummy data factories
alex-jw-brooks Sep 17, 2024
f10601f
Add kwarg override forr processor to max token calc
alex-jw-brooks Sep 19, 2024
429097a
Move kwarg only override func to utils
alex-jw-brooks Sep 19, 2024
159cfc2
Force processor kwargs to be keyword-only
alex-jw-brooks Sep 19, 2024
af91930
Pass unfiltered processor kwargs to default mapper
alex-jw-brooks Sep 19, 2024
9adad10
Add hack for mapper preprocessor kwargs
alex-jw-brooks Sep 19, 2024
9f7aed8
Simplify dummy data processor kwarg & add tests
alex-jw-brooks Sep 19, 2024
ff59e44
Add tests for max multimodal token kwarg overrides
alex-jw-brooks Sep 19, 2024
6b26454
Format registry
alex-jw-brooks Sep 20, 2024
0e2d53d
Fix default mapper comparison
alex-jw-brooks Sep 20, 2024
5a3341b
Move kwarg filtering into hf processor getter
alex-jw-brooks Sep 20, 2024
3e1fe54
Enable processor_kwargs in video processor
alex-jw-brooks Sep 20, 2024
feccfd7
Add tests for mapper processor_kwargs
alex-jw-brooks Sep 20, 2024
3ada64d
Update mapper not on multimodal processor kwargs
alex-jw-brooks Sep 20, 2024
58dcc63
processor kwarg test cleanup
alex-jw-brooks Sep 20, 2024
1cee215
Move context builder to test utils
alex-jw-brooks Sep 19, 2024
d5f9efa
Use common context builder in processor kwarg tests
alex-jw-brooks Sep 20, 2024
b5d434b
Update vllm/entrypoints/llm.py
alex-jw-brooks Sep 22, 2024
a096301
Update vllm/inputs/registry.py
alex-jw-brooks Sep 22, 2024
79962e0
Update vllm/inputs/registry.py
alex-jw-brooks Sep 22, 2024
2cb1f72
Update vllm/inputs/registry.py
alex-jw-brooks Sep 22, 2024
37eb532
Update vllm/inputs/registry.py
alex-jw-brooks Sep 22, 2024
a4c7c3d
Update vllm/inputs/registry.py
alex-jw-brooks Sep 22, 2024
36dd2cb
Fix formatting
alex-jw-brooks Sep 22, 2024
f95c86f
Rename processor kwargs to mm processor kwargs
alex-jw-brooks Sep 22, 2024
632dac1
Expose phi3v num crops processor override
alex-jw-brooks Sep 19, 2024
9eca61a
Merge branch 'main' into phi3v_num_crops
DarkLight1337 Sep 23, 2024
a3ab6cb
Merge branch 'main' into phi3v_num_crops
DarkLight1337 Sep 23, 2024
4a9ccae
Update phi3v examples with num crops overrides
alex-jw-brooks Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def run_phi3v(question, modality):
model="microsoft/Phi-3-vision-128k-instruct",
trust_remote_code=True,
max_num_seqs=5,
processor_kwargs={"num_crops": 16},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you link to the HF repo explaining how to use num_crops?

Please update the multi-image input example as well.

)
stop_token_ids = None
return llm, prompt, stop_token_ids
Expand Down
186 changes: 181 additions & 5 deletions tests/models/decoder_only/vision_language/test_phi3v.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import os
import re
from typing import List, Optional, Tuple, Type
from typing import Callable, List, Optional, Tuple, Type

import pytest
from transformers import AutoTokenizer
import torch
from transformers import AutoImageProcessor, AutoTokenizer

from vllm.inputs import InputContext, LLMInputs
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu, is_hip

from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ...utils import check_logprobs_close
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ...utils import build_model_context, check_logprobs_close

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
Expand Down Expand Up @@ -71,7 +76,7 @@ def run_test(
All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
Expand Down Expand Up @@ -230,3 +235,174 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
mm_limit=2,
tensor_parallel_size=1,
)


### Fast tests for correctness in processor_kwarg override handling


# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def input_processor_for_phi3v():
from vllm.model_executor.models.phi3v import input_processor_for_phi3v
return input_processor_for_phi3v


@pytest.fixture()
def dummy_data_for_phi3v():
from vllm.model_executor.models.phi3v import dummy_data_for_phi3v
return dummy_data_for_phi3v


@pytest.fixture()
def get_max_phi3v_image_tokens():
from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens
return get_max_phi3v_image_tokens


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops", [4, 16, None])
def test_input_mapper_override(model: str, image_assets: _ImageAssets,
num_crops: Optional[int]):
"""Ensure that the [default] input mapper handles num_crops properly."""
# We pass the processor kwargs here since for this model, we fall back to
# the default mapper; this will fall back to the HF mapper and forward
# mm_processor_kwargs to it.
mm_processor_kwargs = {
"num_crops": num_crops
} if num_crops is not None else {}
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
)

hf_processor = AutoImageProcessor.from_pretrained(model,
trust_remote_code=True,
**mm_processor_kwargs)

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)

image = image_assets[0].pil_image
hf_result = hf_processor.preprocess(
image,
return_tensors="pt",
)

vllm_result = mm_registry.map_input(
ctx.model_config,
{"image": image},
)

assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"])
assert torch.all(
hf_result["num_img_tokens"] == vllm_result["num_img_tokens"])

# For pixel values, the second axis should be the num_crops + 1
# for the rescaled original image. The default value in VLLM falls
# back to the HF config, which is why we compare to the processor num_crops
assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"])
assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
(4, 781),
(16, 2653),
])
def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
num_crops: int, expected_max_tokens: int):
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
# NOTE: mm_processor_kwargs on the context in this test is unused, since
# this is testing the mapper directly. In practice, the processor kwargs
# are wrapped in a closure when calling the max tokens func. We explicitly
# do NOT use the mm_processor_kwargs in the model context here to ensure
# that the max image tokens implementation is referencing a mix of the
# kwargs to the function and the original mm_processor_kwargs in case
# values are somehow updated and end up in a bad state.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)

actual_max_tokens = get_max_phi3v_image_tokens(
InputContext(ctx.model_config),
num_crops=num_crops,
)

assert expected_max_tokens == actual_max_tokens


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [
(4, 781, 1),
(4, 781, 2),
(16, 2653, 1),
(16, 2653, 2),
])
def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
num_crops: int, toks_per_img: int, num_imgs: int):
"""Ensure dummy_data_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the dummy data func.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)

sequence_data, _, = dummy_data_for_phi3v(
ctx=ctx,
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
mm_counts={"image": num_imgs},
num_crops=num_crops,
)
# Ensure we have the right number of placeholders per num_crops size
img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID)
assert img_tok_count == toks_per_img * num_imgs


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [
(4, 757, 1),
(4, 757, 2),
(16, 1921, 1),
(16, 1921, 2),
])
def test_input_processor_override(input_processor_for_phi3v: Callable,
image_assets: _ImageAssets, model: str,
num_crops: int, expected_toks_per_img: int,
num_imgs: int):
"""Ensure input_processor_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model)
# Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
images = [image_assets[0].pil_image] * num_imgs

llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})

proc_llm_inputs = input_processor_for_phi3v(
ctx=ctx,
llm_inputs=llm_inputs,
num_crops=num_crops,
)

# Ensure we have the right number of placeholders per num_crops size
img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
assert img_tok_count == expected_toks_per_img * num_imgs
31 changes: 22 additions & 9 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):


# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int):
transposed = False
if width < height:
width, height = height, width
Expand Down Expand Up @@ -337,8 +337,10 @@ def get_phi3v_image_feature_size(
*,
input_height: int,
input_width: int,
num_crops: int,
) -> int:
num_crops = hf_config.get("num_crops", 16)
if num_crops is None:
num_crops = hf_config.get("num_crops", 16)
new_width, new_height = _calc_hd_transform_size(width=input_width,
height=input_height,
hd_num=num_crops)
Expand All @@ -347,20 +349,26 @@ def get_phi3v_image_feature_size(
+ (new_height // 336 + 1) * 12


def get_max_phi3v_image_tokens(ctx: InputContext):
def get_max_phi3v_image_tokens(ctx: InputContext,
*,
num_crops: Optional[int] = None):

return get_phi3v_image_feature_size(
ctx.get_hf_image_processor_config(),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
num_crops=num_crops,
)


def dummy_data_for_phi3v(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
def dummy_data_for_phi3v(ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
num_crops: Optional[int] = None):
num_images = mm_counts["image"]

image_feature_size = get_max_phi3v_image_tokens(ctx)
image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops)

seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
Expand Down Expand Up @@ -398,7 +406,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig,
return image_placeholder_token_ids


def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
def input_processor_for_phi3v(ctx: InputContext,
llm_inputs: LLMInputs,
*,
num_crops: Optional[int] = None):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
Expand All @@ -412,7 +423,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
image_feature_size = [
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h)
input_height=h,
num_crops=num_crops)
]
image_data = [image_data]
elif is_list_of(image_data, Image.Image):
Expand All @@ -422,7 +434,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
image_feature_size.append(
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h))
input_height=h,
num_crops=num_crops))
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
Expand Down
Loading