Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] PRIME algorithm #362

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions recipe/prime/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
69 changes: 69 additions & 0 deletions recipe/prime/config/prime_trainer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# the prime config will override default ppo_trainer.yaml

hydra:
searchpath:
- file://verl/trainer/config

defaults:
- ppo_trainer
- _self_

data:
filter_accuracy: True
accuracy_lower_bound: 0.2
accuracy_upper_bound: 0.8
oversample_factor: 4.0 # Sample more responses than the batch size. prompts satisfying the filter will be prioritized.

actor_rollout_ref:
hybrid_engine: True
model:
use_remove_padding: True
rollout:
# number of responses (i.e. num sample times)
n: 4
actor:
entropy_coeff: 0.001

reward_model:
enable: True
strategy: fsdp
model:
use_remove_padding: True
tokenizer_path: ${actor_rollout_ref.model.path}
enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing}
ref_type: freeze
fsdp_config:
min_num_params: 0
param_offload: ${actor_rollout_ref.actor.fsdp_config.param_offload}
# grad_offload: ${actor_rollout_ref.actor.fsdp_config.grad_offload}
optimizer_offload: ${actor_rollout_ref.actor.fsdp_config.optimizer_offload}
update: before # ``before`` for double-forward, ``after`` for single-forward
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0.
min_lr_ratio: null
warmup_style: constant
total_training_steps: -1 # must be overridden by program
weight_decay: 0.
grad_clip: 10.0
beta_train: 0.05
loss_type: ce # currently only supports ce loss
prime_granularity: token
prime_norm: batch_norm # batch_norm or none. if set to none, the normalizer is beta_train
mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
reward_manager: prime

algorithm:
adv_estimator: rloo
# now supports rloo. it treats different source of reward separately.
kl_ctrl:
type: fixed
kl_coef: 0.000
reward_gt_coef: 5
reward_dpo_coef: 5

trainer:
project_name: prime
experiment_name: examples
val_before_train: False
balance_batch: False
130 changes: 130 additions & 0 deletions recipe/prime/main_prime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from .prime_ray_trainer import RayPRIMETrainer

import ray
import hydra


@hydra.main(config_path='config', config_name='prime_trainer', version_base=None)
def main(config):
run_prime(config)


def run_prime(config, compute_score=None):
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}})

ray.get(main_task.remote(config, compute_score))


@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
def main_task(config, compute_score=None):
from verl.utils.fs import copy_local_path_from_hdfs
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)

# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)

# instantiate tokenizer
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)

# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup

elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup

else:
raise NotImplementedError

from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role

role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.RefPolicy: ray.remote(ActorRolloutRefWorker)
}

global_pool_id = 'global_pool'
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.RefPolicy: global_pool_id,
}
if config.reward_model.enable:
from .prime_fsdp_workers import PRIMERewardModelWorker
role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker)
mapping[Role.RewardModel] = global_pool_id

reward_manager_name = config.reward_model.get("reward_manager", "naive")
if reward_manager_name == 'naive':
from verl.workers.reward_manager import NaiveRewardManager
reward_manager_cls = NaiveRewardManager
elif reward_manager_name == 'prime':
from verl.workers.reward_manager import PrimeRewardManager
reward_manager_cls = PrimeRewardManager
else:
raise NotImplementedError
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)

# Note that we always use function-based RM for validation
val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)

resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

trainer = RayPRIMETrainer(config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn)
trainer.init_workers()
trainer.fit()


if __name__ == '__main__':
main()
107 changes: 107 additions & 0 deletions recipe/prime/prime_core_algos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import verl
import verl.utils.torch_functional as verl_F


def compute_rloo_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor, n_samples, config):
# calculate rloo reward on different reward sources, and sum again

def masked_rloo(reward_tensor_original, mask_tensor):
reward_tensor = reward_tensor_original.clone()
reward_tensor[~mask_tensor] = 0
for start_pos in range(0, reward_tensor.shape[0], n_samples):
cur_rewards_mean = torch.cat([
reward_tensor[pos:pos + 1][mask_tensor[pos:pos + 1]].mean(dim=0, keepdim=True)
for pos in range(start_pos, start_pos + n_samples)
],
dim=0)
cur_rewards_sum = cur_rewards_mean.sum()
cur_reward_baseline = cur_rewards_sum / (n_samples - 1)
reward_tensor[start_pos:start_pos + n_samples][
mask_tensor[start_pos:start_pos + n_samples]] = \
reward_tensor[start_pos:start_pos + n_samples][
mask_tensor[start_pos:start_pos + n_samples]] * (
n_samples / (n_samples - 1)) - cur_reward_baseline

return reward_tensor

reward_tensors = []

with torch.no_grad():

if 'rm_scores' in data.batch and config.algorithm.reward_dpo_coef != 0.:
reward_tensor = data.batch['rm_scores']
reward_mask = eos_mask.bool()

reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef)

if 'acc' in data.batch and config.algorithm.reward_gt_coef != 0.:
reward_tensor = torch.zeros_like(eos_mask, dtype=torch.float32)
reward_mask = torch.zeros_like(eos_mask, dtype=torch.bool)

prompt_ids = data.batch['prompts']
prompt_length = prompt_ids.shape[-1]
valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1)

reward_mask[
torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
valid_response_length - 1] = True
reward_tensor[
torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
valid_response_length - 1] = data.batch['acc']

reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef)

final_reward_tensor = sum(reward_tensors)

returns = (final_reward_tensor * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])

advantages = returns.clone()
advantages = verl_F.masked_whiten(advantages, eos_mask)

return advantages, returns


def compute_ce_dpo_loss_rm(token_level_scores, acc, eos_mask, beta):
cur_scores = ((token_level_scores * eos_mask).sum(dim=1) * beta).sigmoid()
cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc)
return cur_dpo_loss


def compute_dpo_accuracy(token_level_scores, acc, eos_mask, n_samples):
dpo_acc = []
for start_id in range(0, token_level_scores.shape[0], n_samples):
cur_scores = (token_level_scores[start_id:start_id + n_samples] *
eos_mask[start_id:start_id + n_samples]).sum(dim=1)

def get_upper_triangle(tensor_x):
diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0)
upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1)
return diff_matrix[upper_tri_indices]

cur_acc_diff = get_upper_triangle(acc[start_id:start_id + n_samples]) # in range [-1,1]
cur_score_diff = get_upper_triangle(cur_scores) # in R
cur_score_prediction = (cur_score_diff > 0).float() # in [0,1]
if cur_acc_diff.abs().sum() == 0:
cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5
else:
cur_acc = (((cur_score_diff > 0) == (cur_acc_diff > 0)).float() *
cur_acc_diff.abs()).sum() / cur_acc_diff.abs().sum()

dpo_acc.append(cur_acc.unsqueeze(0))

return torch.cat(dpo_acc, dim=0).mean()
Loading