Skip to content

Commit

Permalink
[Bugfix] Fix input processor for InternVL2 model (vllm-project#7164)
Browse files Browse the repository at this point in the history
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
  • Loading branch information
2 people authored and kylesayrs committed Aug 17, 2024
1 parent 66574f0 commit 92148ff
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 34 deletions.
23 changes: 19 additions & 4 deletions tests/models/test_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from huggingface_hub import snapshot_download
from PIL.Image import Image
from transformers import AutoConfig

from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END,
IMG_START,
Expand All @@ -26,10 +27,15 @@

# we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
models = [
snapshot_download("OpenGVLab/InternVL2-1B"),
snapshot_download("OpenGVLab/InternVL2-2B"),
# snapshot_download("OpenGVLab/InternVL2-4B"), # broken
snapshot_download("OpenGVLab/InternVL2-1B",
allow_patterns=DOWNLOAD_PATTERN),
snapshot_download("OpenGVLab/InternVL2-2B",
allow_patterns=DOWNLOAD_PATTERN),
# Broken due to outdated implementation of Phi-3
# See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3
# snapshot_download("OpenGVLab/InternVL2-4B"),
]


Expand All @@ -41,8 +47,17 @@ def __init__(self, hf_runner: HfRunner):
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype

self.config = AutoConfig.from_pretrained(hf_runner.model_name)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size

def __call__(self, text: str, images: Image, **kwargs):
pixel_values = image_to_pixel_values(images).to(self.dtype)
pixel_values = image_to_pixel_values(images, self.image_size,
self.min_num, self.max_num,
self.use_thumbnail).to(self.dtype)
num_patches_list = [pixel_values.shape[0]]
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
Expand Down
84 changes: 54 additions & 30 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

MAX_IMAGE_FEATURE_SIZE_WIDTH = 3000
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500


class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
Expand Down Expand Up @@ -84,11 +81,9 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
return best_ratio


def calculate_num_blocks(orig_width: int,
orig_height: int,
min_num=1,
max_num=6,
image_size=448):
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
max_num: int,
image_size: int) -> Tuple[int, int, int]:
aspect_ratio = orig_width / orig_height

# calculate the existing image aspect ratio
Expand All @@ -110,11 +105,9 @@ def calculate_num_blocks(orig_width: int,


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def dynamic_preprocess(image,
min_num=1,
max_num=6,
image_size=448,
use_thumbnail=False):
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
image_size: int,
use_thumbnail: int) -> List[Image.Image]:
orig_width, orig_height = image.size

blocks, target_width, target_height = calculate_num_blocks(
Expand All @@ -138,12 +131,14 @@ def dynamic_preprocess(image,


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def image_to_pixel_values(image: Image.Image, input_size=448, max_num=6):
def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
max_num: int, use_thumbnail: bool) -> torch.Tensor:
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image,
min_num=min_num,
max_num=max_num,
image_size=input_size,
use_thumbnail=True,
max_num=max_num)
use_thumbnail=use_thumbnail)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
Expand All @@ -159,12 +154,18 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
def get_max_internvl_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig)
vision_config = hf_config.vision_config

use_thumbnail = hf_config.use_thumbnail
max_dynamic_patch = hf_config.max_dynamic_patch
if use_thumbnail:
max_dynamic_patch += 1
downsample_ratio = hf_config.downsample_ratio

image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)
return num_patches * 7
return num_patches * max_dynamic_patch


def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
Expand All @@ -176,30 +177,35 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
hf_config = ctx.get_hf_config(PretrainedConfig)
vision_config = hf_config.vision_config

image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)

image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
num_blocks, _, _ = calculate_num_blocks(width, height)
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
max_num, image_size)
# add thumbnail image if num_blocks > 1
if hf_config.use_thumbnail and num_blocks > 1:
num_blocks += 1
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)

tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)

prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
image_prompt = IMG_START + IMG_CONTEXT * (num_blocks +
1) * num_patches + IMG_END
image_prompt = IMG_START + IMG_CONTEXT * num_blocks * num_patches + IMG_END
new_prompt = prompt.replace('<image>', image_prompt, 1)
new_prompt_token_ids = tokenizer.encode(new_prompt)

Expand All @@ -209,8 +215,19 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):


def input_mapper_for_internvl(ctx: InputContext, data: object):
hf_config = ctx.get_hf_config(PretrainedConfig)

use_thumbnail = hf_config.use_thumbnail
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
image_size = hf_config.vision_config.image_size

if isinstance(data, Image.Image):
data = image_to_pixel_values(data)
data = image_to_pixel_values(data,
image_size,
min_num,
max_num,
use_thumbnail=use_thumbnail)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
Expand Down Expand Up @@ -240,10 +257,17 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
add_special_tokens=False)[0],
image_feature_size_override=image_feature_size,
)

image_size = vision_config.image_size
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
max_image_width = max_num * image_size
max_image_height = min_num * image_size

mm_data = dummy_image_for_clip(
vision_config,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
image_width_override=max_image_width,
image_height_override=max_image_height,
)

return seq_data, mm_data
Expand Down

0 comments on commit 92148ff

Please sign in to comment.