Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Qwen2.5_VL fix from Qwen Team #2

Merged
merged 5 commits into from
Feb 3, 2025

Conversation

wulipc
Copy link

@wulipc wulipc commented Feb 3, 2025

This update comes from the [Qwen2.5-VL](https://github.com/QwenLM/Qwen2.5-VL) team, and the update details are as follows:

本次更新内容:

  • 修复 Qwen2_5_VLPatchMerger 中 mlp 激活函数,从 nn.SiLU() 改为 nn.GELU():正确的应该是nn.GELU(),详见: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L146
  • 调整 MRotaryEmbedding 中 cos and sin cache 大小为 max_position_embeddings 的 4 倍,以适配长时间的视频请求; 其中 4 为经验值;
  • 视频支持:添加 fps 参数,请求时通过 mm_processor_kwargs 传递;
  • 添加 qwen2_5_vloffline_inference 代码;
  • 删除无效的 packed_modules_mapping 配置;
  • 添加 Qwen2.5-VL 图像和视频 demo 代码;
  • 格式化代码及其他可读性优化;

Update Contents:

  • Fixed the activation function in Qwen2_5_VLPatchMerger from nn.SiLU() to nn.GELU(): It should correctly be nn.GELU(). For details, see: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L146
  • Adjusted the size of the cos and sin cache in MRotaryEmbedding to be 4 times the max_position_embeddings, adapting it for long-duration video requests; The scale 4 is an empirical value.
  • Video Support: Added an fps parameter, which can be passed during requests via mm_processor_kwargs.
  • Added offline_inference code for qwen2_5_vl.
  • Removed invalid packed_modules_mapping configuration.
  • Added demo code for Qwen2.5-VL image and video.
  • Code formatting and other readability optimizations.

Demo code / 示例代码:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import argparse
from pprint import pprint
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor
from vllm import LLM, SamplingParams


def prepare_inputs_for_vllm(messages, processor):
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs, video_kwargs = process_vision_info(messages, return_video_kwargs=True)

    # fps will be returned in video_kwargs
    print(f"video_kwargs: {video_kwargs}")

    mm_data = {}
    if image_inputs is not None:
        mm_data['image'] = image_inputs
    if video_inputs is not None:
        mm_data['video'] = video_inputs

    return {
        'prompt': text,
        'multi_modal_data': mm_data,
        'mm_processor_kwargs': video_kwargs  # 注意: fps 参数传递位置
    }


if __name__ == '__main__':
    # TODO: set your local path of video to load
    video_url = "file:///your/absolute/path/to/local/video.mp4"
    video_message = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": [
                {"type": "text", "text": "Could you go into detail about the content of this video?"},
                {"type": "video", "video": video_url, "total_pixels":  20480 * 28 * 28, "min_pixels":  16 * 28 * 28},
            ]
        },
    ]

    image_message = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-VL/qwen2.5vl_logo.png"},
                {"type": "text", "text": "Describe this image."},
            ],
        }
    ]

    # TODO: set your path of weight to load
    checkpoint_path = "/your/weight/path/to/load"
    processor = AutoProcessor.from_pretrained(checkpoint_path)
    inputs = [prepare_inputs_for_vllm(message, processor) for message in [image_message, video_message]]

    llm = LLM(
        model=checkpoint_path, trust_remote_code=True, gpu_memory_utilization=0.90, enforce_eager=False,
        tensor_parallel_size=torch.cuda.device_count(),
        limit_mm_per_prompt={'image': 10, 'video': 10},
        seed=0
    )

    sampling_params = SamplingParams(
        temperature=0.1, top_p=0.001, repetition_penalty=1.05, max_tokens=2048,
        top_k=-1,
        stop_token_ids=[],
    )
    
    for i, input_ in enumerate(inputs):
        print()
        print('=' * 40)
        print(f"Inputs[{i}]: {input_['prompt']=!r}")
    print('\n' + '>' * 40)

    outputs = llm.generate(inputs, sampling_params=sampling_params)
    for i, output in enumerate(outputs):
        generated_text = output.outputs[0].text
        print()
        print('=' * 40)
        print(f"Generated text: {generated_text!r}")

Thanks

We thank the vllm team for the outstanding work. If you have any questions, please feel free to contact us.

@wulipc wulipc requested a review from ywang96 as a code owner February 3, 2025 03:41
Copy link

github-actions bot commented Feb 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@ywang96
Copy link
Owner

ywang96 commented Feb 3, 2025

Much appreciated the fix! Going to take a look now, FYI @yixqiao

Copy link
Owner

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM! I left a few questions so please take a look!

Comment on lines -863 to +867
second_per_grid_ts: Optional[List[float]] = None,
video_second_per_grid_ts: Optional[List[float]] = None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it okay if we keep it as second_per_grid_ts? I think generally speaking it's better if we use the same names as those of the output of the Processor class unless there's a strong reason for us to use a different one, what do you think?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The term second_per_grid_ts is a video-related parameter, so using video_second_per_grid_ts would be more appropriate from this perspective. However, if you need to maintain consistency with the transformer, that is also acceptable.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I'll merge this PR and rename it afterwards!

Comment on lines -873 to +878
tokens_per_second = getattr(hf_config.vision_config,
"tokens_per_second", None)
video_tokens_per_second = getattr(hf_config.vision_config,
"tokens_per_second", 1.0)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Comment on lines +366 to +382
class Qwen2RMSNorm(nn.Module):

def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
Copy link
Owner

@ywang96 ywang96 Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just so I understand - are we using torch native rmsnorm because you observed precision issue of our RMSNorm kernel?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I noticed that the original code uses vLLM's RMSNorm, our native implementation, Qwen2RMSNorm, has already been tested on multiple datasets. The metrics in our report are also implemented with Qwen2RMSNorm. Considering precision issues, I have continued to use the previous implementation.

Comment on lines -983 to +1156
second_per_grid_ts=MultiModalFieldConfig.flat(
"video", video_slices),
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

@ywang96 ywang96 merged commit e765e1e into ywang96:qwen2_5_vl Feb 3, 2025
2 checks passed
@ywang96 ywang96 mentioned this pull request Feb 4, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants