Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Initial support for VLMs, add Qwen2.5VL GRPO example #386

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

hiyouga
Copy link
Contributor

@hiyouga hiyouga commented Feb 26, 2025

What does this PR do?

This PR migrates the feature of RL on VLMs in our implementation in EasyR1 fork back to veRL. We have validated this feature using Qwen2.5-VL 7B model on 8*H100 GPUs. The configuration and data processing script are provided along this PR for easy reproducing.

How to reproduce?

  1. Download and preprocess the dataset
python3 examples/data_preprocess/geo3k.py --local_dir ~/data/geo3k
  1. Start GRPO training
bash examples/grpo_trainer/run_qwen2_5_vl-7b.sh

Dependencies

Major Changes

New dataflow for multimodal RL

In this PR, we introduce two new concepts in the dataflow, multi_modal_data and multi_modal_inputs. The former means the multi-modal features required by the rollout worker (such as vLLM), while the latter means the multi-modal features required by the actor/critic worker (such as an HF model). They are different because the rollout and actor workers have their own data format requirements.

Taking Qwen2-VL + huggingface + vLLM as an example, the data structure should be:

  • multi_modal_data: {"image": [PIL.Image, PIL.Image, ...]}
  • multi_modal_inputs: {"pixel_values": torch.Tensor, "image_grid_thw": torch.Tensor}

Both of them are converted to numpy objects and placed in the non-tensor batch in DataProto.

This design can be extended to other modalities/VLMs easily due to the agnostic of models.

Other changes

  • Data

    • Support pre-processing the Geometry3k dataset.
    • Support config.data.image_key, which should be a list of Pillow images.
  • Actor/Ref/Critic

    • Support multi_modal_inputs.
    • Process position ids to adapt to the m-rope .
  • Rollout

    • Update dtensor weight loader to adapt to the Qwen2-VL architecture in vLLM 0.7+.
    • Support multi_modal_data.
    • Use raw_prompt_ids as the vLLM inputs to avoid unpadding the input ids.
  • Reward Manager

    • Add mathruler for more accurate math scores on the Geometry 3k dataset
  • Models

    • Support calculating the position ids for the m-rope in Qwen2-VL.
    • Support removing padding in flash attention2 for m-rope (transformers itself does not support it).
  • Sharding Manager

    • Support all-gathering the non-tensor batch.
  • FSDP Workers / Checkpoint Merger

    • Support AutoModelForVision2Seq at model initialization.

Note: The Ulysses parallelism is not completed yet. We will support it in the next update.

Performance

We provide the estimated MFU of the language model part for H100 GPUs. These values are lower than the actual ones because we did not compute the FLOPs of the vision tower part.

  • remove_padding=False: MFU ~7%
  • remove_padding=True: MFU ~20%

The training and test reward score curves are presented as follows.

image

Who can review?

@vermouth1992 @PeterSH6

@CLAassistant
Copy link

CLAassistant commented Feb 26, 2025

CLA assistant check
All committers have signed the CLA.

@khazic
Copy link

khazic commented Feb 26, 2025

good job

@vermouth1992
Copy link
Collaborator

Item TODO: I guess it's better to provide a documentation on how to extend to other VLMs like ulysses/rmpad, etc,. so that other contributors can add their interested VLMs.

@hiyouga
Copy link
Contributor Author

hiyouga commented Feb 26, 2025

@vermouth1992 Sure, but it requires non-trivial effort to build a unified guideline for all the VLMs, so we are considering improving the document next PR.

@hiyouga hiyouga force-pushed the vlm branch 5 times, most recently from 5b87997 to 6d08978 Compare February 26, 2025 12:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants