Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIRE sampling added. #58

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions .github/workflows/e2e_digit_completion_fire.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: e2e_digit_completion_fire

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/e2e_digit_completion_fire.yml
pull_request:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/e2e_digit_completion_fire.yml
- "tests/e2e/*.sh"

# Declare permissions just read content.
permissions:
contents: read

jobs:
e2e_digit_completion:
runs-on: [self-hosted, l20-0]
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER: 1
container:
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install hf_transfer
pip3 install -e .[test]
- name: Running digit completon e2e training tests on 8 L20 GPUs
run: |
ray stop --force
bash tests/e2e/run_ray_trainer_fire_sampling.sh
40 changes: 40 additions & 0 deletions tests/e2e/run_ray_trainer_fire_sampling.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env bash

set -e -x

OUTPUT_FILE="/tmp/output_ray_trainer.txt"

export PATH=$PATH:~/.local/bin

rm -rf $OUTPUT_FILE
python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \
data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \
data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \
data.train_batch_size=800 \
data.val_batch_size=200 \
data.max_prompt_length=16 \
data.max_response_length=32 \
data.return_raw_input_ids=True \
actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \
actor_rollout_ref.model.external_lib=tests.e2e.envs.digit_completion \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=200 \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.optim.lr=1e-4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=200 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=200 \
actor_rollout_ref.rollout.name=hf \
actor_rollout_ref.rollout.use_fire_sampling=True \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
critic.ppo_micro_batch_size_per_gpu=200 \
critic.model.path=tests/e2e/arithmetic_sequence/model \
critic.optim.lr=1e-3 \
algorithm.kl_ctrl.kl_coef=0.005 \
trainer.total_epochs=200 \
trainer.experiment_name=arithmetic_sequences \
trainer.logger=['console'] \
trainer.n_gpus_per_node=1 \
trainer.test_freq=1 \
trainer.save_freq=110 | tee $OUTPUT_FILE;

python3 tests/e2e/check_results.py --output_file=$OUTPUT_FILE
rm -rf $OUTPUT_FILE
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ actor_rollout_ref:
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
use_fire_sampling: False # https://arxiv.org/abs/2410.21236
prompt_length: ${data.max_prompt_length} # not use for opensource
response_length: ${data.max_response_length}
# for vllm rollout
Expand Down
6 changes: 5 additions & 1 deletion verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,11 @@ def _build_rollout(self):
rollout_sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?
elif self.config.rollout.name == 'vllm':
from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode
if self.config.rollout.use_fire_sampling:
from verl.workers.rollout.vllm_rollout import FIREvLLMRollout as vLLMRollout
from verl.workers.rollout.vllm_rollout import vllm_mode
else:
from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode
from verl.workers.sharding_manager import FSDPVLLMShardingManager
log_gpu_memory_usage('Before building vllm rollout', logger=None)
local_path = copy_local_path_from_hdfs(self.config.model.path)
Expand Down
1 change: 1 addition & 0 deletions verl/workers/rollout/vllm_rollout/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def get_version(pkg):
if package_version <= '0.6.3':
vllm_mode = 'customized'
from .vllm_rollout import vLLMRollout
from .fire_vllm_rollout import FIREvLLMRollout
else:
vllm_mode = 'spmd'
from .vllm_rollout_spmd import vLLMRollout
214 changes: 214 additions & 0 deletions verl/workers/rollout/vllm_rollout/fire_vllm_rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The vllm_rollout that can be applied in different backend
When working with FSDP:
- Use DTensor weight loader (recommended) or HF weight loader
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
When working with Megatron:
- Use Megatron weight loader
- During training, only the current pp stage holds the parameters
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
- Bind the parameters to the inference engine
- Do inference in tp. pp is treated as additional dp
- After inference, all the parameters that doesn't belong to this pp rank is freed.
"""
from typing import List
from contextlib import contextmanager
from omegaconf import DictConfig
import torch
import torch.distributed
from tensordict import TensorDict
from torch import nn

from verl import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
from verl.workers.rollout.base import BaseRollout
from verl.workers.rollout.vllm_rollout.vllm_rollout import vLLMRollout
from verl.third_party.vllm import LLM, vllm_version
from verl.third_party.vllm import parallel_state as vllm_ps
from vllm import SamplingParams

# TODO
# 1. support pp in vllm
# 2. passing tokenizer is not necessary? no encoding/decoding is happending here
# 3. simplify init logics


# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:
# remove the left padding in the prompt token_id
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
token_ids = prompt_token_ids[non_pad_index:].tolist()
return token_ids


class FIREvLLMRollout(vLLMRollout):

def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, **kwargs):
"""A vLLM rollout. It requires the module is supported by the vllm.

Args:
module: module here follows huggingface APIs
config: DictConfig
tokenizer: the task/model tokenizer
model_hf_config: the huggingface config to initiallize the generating model in vllm
**kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
"""
super().__init__(actor_module, config, tokenizer, model_hf_config, **kwargs)

self.use_fire_sampling = config.get('use_fire_sampling', False)
if self.use_fire_sampling:
kwargs_0 = kwargs.copy()
kwargs_0['temperature'] = 30
kwargs_0['max_tokens'] = 1
if 'top_k' not in kwargs_0 or kwargs_0['top_k'] <= 0:
kwargs_0['top_k'] = 16
kwargs['max_tokens'] -= 1
self.sampling_params_0 = SamplingParams(**kwargs_0)

@contextmanager
def update_sampling_params(self, **kwargs):
# update sampling params
old_sampling_params_args = {}
if kwargs:
for key, value in kwargs.items():
if hasattr(self.sampling_params, key):
old_value = getattr(self.sampling_params, key)
old_sampling_params_args[key] = old_value
setattr(self.sampling_params, key, value)

if self.use_fire_sampling:
old_sampling_params_args_0 = {}
if kwargs:
for key, value in kwargs.items():
if hasattr(self.sampling_params_0, key):
old_value = getattr(self.sampling_params_0, key)
old_sampling_params_args_0[key] = old_value
setattr(self.sampling_params_0, key, value)
yield
# roll back to previous sampling params
# if len(old_sampling_params_args):
for key, value in old_sampling_params_args.items():
setattr(self.sampling_params, key, value)

if self.use_fire_sampling:
for key, value in old_sampling_params_args_0.items():
setattr(self.sampling_params_0, key, value)

@torch.no_grad()
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
# rebuild vllm cache engine
if self.config.free_cache_engine:
self.inference_engine.init_cache_engine()

idx = prompts.batch['input_ids'] # (bs, prompt_length)
# left-padded attention_mask
attention_mask = prompts.batch['attention_mask']
position_ids = prompts.batch['position_ids']

# used to construct attention_mask
eos_token_id = prompts.meta_info['eos_token_id']

batch_size = idx.size(0)

idx_list = []
# parse idx from torch.Tensor to List[List[str]]
for i in range(batch_size):
idx_list.append(_pre_process_inputs(self.pad_token_id, idx[i]))

do_sample = prompts.meta_info.get('do_sample', True)
if not do_sample:
kwargs = {
'best_of': 1,
'top_p': 1.0,
'top_k': -1,
'min_p': 0.0,
'temperature': 0,
'n': 1 # if greedy, only 1 response
}

if not self.use_fire_sampling:
# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
output = self.inference_engine.generate(
prompts=None, # because we have already convert it to prompt token id
sampling_params=self.sampling_params,
prompt_token_ids=idx_list,
use_tqdm=False)

response = output[0].to(idx.device) # (bs, response_length)
log_probs = output[1].to(idx.device) # (bs, response_length)
else:
with self.update_sampling_params(**kwargs):
output_0 = self.inference_engine.generate(
prompts=None, # because we have already convert it to prompt token id
sampling_params=self.sampling_params_0,
prompt_token_ids=idx_list,
use_tqdm=False)
new_idx_list = []
for i in range(batch_size):
new_idx_list.append(idx_list[i] + output_0[0][i].tolist())
output = self.inference_engine.generate(
prompts=None, # because we have already convert it to prompt token id
sampling_params=self.sampling_params,
prompt_token_ids=new_idx_list,
use_tqdm=False)

response = torch.cat([output_0[0], output[0]], dim=1).to(idx.device) # (bs, response_length)
log_probs = torch.cat([output_0[1], output[1]], dim=1).to(idx.device) # (bs, response_length)


if response.shape[1] < self.config.response_length:
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)

if self.config.n > 1 and do_sample:
idx = idx.repeat_interleave(self.config.n, dim=0)
attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0)
position_ids = position_ids.repeat_interleave(self.config.n, dim=0)
batch_size = batch_size * self.config.n
seq = torch.cat([idx, response], dim=-1)

response_length = response.size(1)
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)

# TODO(sgm): fix position_ids on right_pad
# prompt: left pad + response: right pad
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[:, -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)

# all the tp ranks should contain the same data here. data in all ranks are valid
batch = TensorDict(
{
'prompts': idx,
'responses': response,
'input_ids': seq, # here input_ids become the whole sentences
# 'old_log_probs': log_probs, # we will recompute old log prob with actor
'attention_mask': attention_mask,
'position_ids': position_ids
},
batch_size=batch_size)

# free vllm cache engine
if self.config.free_cache_engine:
self.inference_engine.free_cache_engine()

return DataProto(batch=batch)
2 changes: 1 addition & 1 deletion verl/workers/rollout/vllm_rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,4 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
if self.config.free_cache_engine:
self.inference_engine.free_cache_engine()

return DataProto(batch=batch)
return DataProto(batch=batch)
Loading