Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix input processor for InternVL2 model #7164

Merged
merged 9 commits into from
Aug 7, 2024
70 changes: 45 additions & 25 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,9 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
return best_ratio


def calculate_num_blocks(orig_width: int,
orig_height: int,
min_num=1,
max_num=6,
image_size=448):
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
max_num: int,
image_size: int) -> Tuple[int, int, int]:
aspect_ratio = orig_width / orig_height

# calculate the existing image aspect ratio
Expand All @@ -110,11 +108,9 @@ def calculate_num_blocks(orig_width: int,


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def dynamic_preprocess(image,
min_num=1,
max_num=6,
image_size=448,
use_thumbnail=False):
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
image_size: int,
use_thumbnail: int) -> List[Image.Image]:
orig_width, orig_height = image.size

blocks, target_width, target_height = calculate_num_blocks(
Expand All @@ -138,12 +134,14 @@ def dynamic_preprocess(image,


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def image_to_pixel_values(image: Image.Image, input_size=448, max_num=6):
def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
max_num: int, use_thumbnail: bool) -> torch.Tensor:
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image,
min_num=min_num,
max_num=max_num,
image_size=input_size,
use_thumbnail=True,
max_num=max_num)
use_thumbnail=use_thumbnail)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
Expand All @@ -159,12 +157,18 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
def get_max_internvl_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig)
vision_config = hf_config.vision_config

use_thumbnail = hf_config.use_thumbnail
max_dynamic_patch = hf_config.max_dynamic_patch
if use_thumbnail:
max_dynamic_patch += 1
downsample_ratio = hf_config.downsample_ratio
Comment on lines +158 to +162
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this the root cause of the original bug?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because we only append thumbnail image when processed image patches is more than one:

if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images

Copy link
Collaborator Author

@Isotr0py Isotr0py Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooops, the root cause should be these lines (L196-L198):

        min_num = hf_config.min_dynamic_patch
        max_num = hf_config.max_dynamic_patch
        num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
                                                max_num, image_size)
        # add thumbnail image if num_blocks > 1
        if hf_config.use_thumbnail and num_blocks > 1:
            num_blocks += 1

The if use_thumbnail: commented above should be OK because we are calculating max image tokens for profiling there, which means len(processed_images) (equal to "max_dynamic_patch" in hf_config) should be larger than 1.


image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)
return num_patches * 7
return num_patches * max_dynamic_patch


def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
Expand All @@ -176,30 +180,35 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
hf_config = ctx.get_hf_config(PretrainedConfig)
vision_config = hf_config.vision_config

image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)

image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
num_blocks, _, _ = calculate_num_blocks(width, height)
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
max_num, image_size)
# add thumbnail image if num_blocks > 1
if hf_config.use_thumbnail and num_blocks > 1:
num_blocks += 1
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)

tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)

prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
image_prompt = IMG_START + IMG_CONTEXT * (num_blocks +
1) * num_patches + IMG_END
image_prompt = IMG_START + IMG_CONTEXT * num_blocks * num_patches + IMG_END
new_prompt = prompt.replace('<image>', image_prompt, 1)
new_prompt_token_ids = tokenizer.encode(new_prompt)

Expand All @@ -209,8 +218,19 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):


def input_mapper_for_internvl(ctx: InputContext, data: object):
hf_config = ctx.get_hf_config(PretrainedConfig)

use_thumbnail = hf_config.use_thumbnail
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
image_size = hf_config.vision_config.image_size

if isinstance(data, Image.Image):
data = image_to_pixel_values(data)
data = image_to_pixel_values(data,
image_size,
min_num,
max_num,
use_thumbnail=use_thumbnail)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
Expand Down
Loading