From ee91c1520450555730ee98fb6c2ba018e4c55dfa Mon Sep 17 00:00:00 2001 From: Zefan Wang Date: Mon, 24 Feb 2025 15:56:30 +0800 Subject: [PATCH 1/6] draft --- recipe/prime/__init__.py | 0 recipe/prime/config/prime_config.yaml | 0 recipe/prime/config/prime_trainer.yaml | 67 ++ recipe/prime/main_prime.py | 116 ++++ recipe/prime/prime_core_algos.py | 0 recipe/prime/prime_dp_rm.py | 0 recipe/prime/prime_fsdp_workers.py | 297 +++++++++ recipe/prime/prime_ray_trainer.py | 867 +++++++++++++++++++++++++ recipe/prime/run_prime_qwen.sh | 55 ++ scripts/format.sh | 2 +- 10 files changed, 1403 insertions(+), 1 deletion(-) create mode 100644 recipe/prime/__init__.py create mode 100644 recipe/prime/config/prime_config.yaml create mode 100644 recipe/prime/config/prime_trainer.yaml create mode 100644 recipe/prime/main_prime.py create mode 100644 recipe/prime/prime_core_algos.py create mode 100644 recipe/prime/prime_dp_rm.py create mode 100644 recipe/prime/prime_fsdp_workers.py create mode 100644 recipe/prime/prime_ray_trainer.py create mode 100644 recipe/prime/run_prime_qwen.sh diff --git a/recipe/prime/__init__.py b/recipe/prime/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/recipe/prime/config/prime_config.yaml b/recipe/prime/config/prime_config.yaml new file mode 100644 index 00000000..e69de29b diff --git a/recipe/prime/config/prime_trainer.yaml b/recipe/prime/config/prime_trainer.yaml new file mode 100644 index 00000000..468862c6 --- /dev/null +++ b/recipe/prime/config/prime_trainer.yaml @@ -0,0 +1,67 @@ +# the prime config will override default ppo_trainer.yaml + +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + filter_accuracy: True + accuracy_lower_bound: 0.2 + accuracy_upper_bound: 0.8 + oversample_factor: 4.0 # Sample more responses than the batch size. prompts satisfying the filter will be prioritized. + +actor_rollout_ref: + hybrid_engine: True + model: + use_remove_padding: True + rollout: + # number of responses (i.e. num sample times) + n: 4 + actor: + entropy_coeff: 0.001 + +reward_model: + enable: True + strategy: fsdp + model: + use_remove_padding: True + enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing} + ref_type: freeze + fsdp_config: + min_num_params: 0 + param_offload: ${actor_rollout_ref.actor.fsdp_config.param_offload} +# grad_offload: ${actor_rollout_ref.actor.fsdp_config.grad_offload} + optimizer_offload: ${actor_rollout_ref.actor.fsdp_config.optimizer_offload} + update: before # ``before`` for double-forward, ``after`` for single-forward + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. + min_lr_ratio: null + warmup_style: constant + total_training_steps: -1 # must be overridden by program + weight_decay: 0. + grad_clip: 1.0 + beta_train: 0.05 + loss_type: ce # currently only supports ce loss + prime_granularity: token + prime_norm: batch_norm # batch_norm or none. if set to none, the normalizer is beta_train + mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + reward_manager: prime + +algorithm: + adv_estimator: rloo + # now supports rloo. it treats different source of reward separately. + kl_ctrl: + type: fixed + kl_coef: 0.000 + reward_gt_coef: 5 + reward_dpo_coef: 5 + +trainer: + project_name: prime + experiment_name: examples + val_before_train: False \ No newline at end of file diff --git a/recipe/prime/main_prime.py b/recipe/prime/main_prime.py new file mode 100644 index 00000000..5821cc29 --- /dev/null +++ b/recipe/prime/main_prime.py @@ -0,0 +1,116 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" +from .prime_ray_trainer import RayPRIMETrainer + +import ray +import hydra + + +@hydra.main(config_path='config', config_name='prime_trainer', version_base=None) +def main(config): + run_prime(config) + + +def run_prime(config, compute_score=None): + if not ray.is_initialized(): + # this is for local ray cluster + ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}}) + + ray.get(main_task.remote(config, compute_score)) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +def main_task(config, compute_score=None): + from verl.utils.fs import copy_local_path_from_hdfs + # print initial config + from pprint import pprint + from omegaconf import OmegaConf + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + from verl.utils import hf_tokenizer + tokenizer = hf_tokenizer(local_path) + + # define worker classes + if config.actor_rollout_ref.actor.strategy == 'fsdp': + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from verl.single_controller.ray import RayWorkerGroup + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == 'megatron': + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.RefPolicy: ray.remote(ActorRolloutRefWorker) + } + + global_pool_id = 'global_pool' + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.RefPolicy: global_pool_id, + } + if config.reward_model.enable: + from .prime_fsdp_workers import PRIMERewardModelWorker + role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker) + mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker) + + reward_manager_name = config.reward_model.get("reward_manager", "naive") + if reward_manager_name == 'naive': + from verl.workers.reward_manager import NaiveRewardManager + reward_manager_cls = NaiveRewardManager + elif reward_manager_name == 'prime': + from verl.workers.reward_manager import PrimeRewardManager + reward_manager_cls = PrimeRewardManager + else: + raise NotImplementedError + reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score) + + # Note that we always use function-based RM for validation + val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score) + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayPRIMETrainer(config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn) + trainer.init_workers() + trainer.fit() + + +if __name__ == '__main__': + main() diff --git a/recipe/prime/prime_core_algos.py b/recipe/prime/prime_core_algos.py new file mode 100644 index 00000000..e69de29b diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py new file mode 100644 index 00000000..e69de29b diff --git a/recipe/prime/prime_fsdp_workers.py b/recipe/prime/prime_fsdp_workers.py new file mode 100644 index 00000000..18673927 --- /dev/null +++ b/recipe/prime/prime_fsdp_workers.py @@ -0,0 +1,297 @@ + +import logging +import os +import warnings + +import torch +import torch.distributed +from torch.distributed.device_mesh import init_device_mesh +import verl.utils.torch_functional as verl_F +from omegaconf import DictConfig, open_dict +from verl import DataProto +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import register, Dispatch +from verl.utils import hf_tokenizer +from verl.utils.debug import log_gpu_memory_usage +from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager +from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_model_to_cpu, load_fsdp_optimizer, \ + load_fsdp_model_to_gpu +from verl.utils.import_utils import import_external_libs +from verl.utils.model import compute_position_id_with_mask +from verl.utils.flops_counter import FlopsCounter +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager + +from codetiming import Timer +from verl.workers.fsdp_workers import create_device_mesh + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +class PRIMERewardModelWorker(Worker): + + def __init__(self, config): + super().__init__() + import torch.distributed + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="nccl") + self.config = config + + # build device mesh for Ulysses Sequence Parallel + world_size = torch.distributed.get_world_size() + from torch.distributed.device_mesh import init_device_mesh + + fsdp_size = self.config.model.fsdp_config.fsdp_size + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + + self.ulysses_device_mesh = None + self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + dp = world_size // self.ulysses_sequence_parallel_size + if self.ulysses_sequence_parallel_size > 1: + self.ulysses_device_mesh = init_device_mesh('cuda', + mesh_shape=(dp, self.ulysses_sequence_parallel_size), + mesh_dim_names=['dp', 'sp']) + + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + + # set FSDP offload params + self._is_offload_param = self.config.model.fsdp_config.param_offload + self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload + + # normalize config + self.config.ppo_mini_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size) + if self.config.ppo_micro_batch_size is not None: + self.config.ppo_micro_batch_size //= (torch.distributed.get_world_size() // + self.ulysses_sequence_parallel_size) + self.config.forward_micro_batch_size //= (torch.distributed.get_world_size() // + self.ulysses_sequence_parallel_size) + self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size + self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size + assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 + + def _build_critic_model_optimizer(self, config): + # the following line is necessary + from verl.utils.model import LambdaLayer, print_model_size, squeeze + from verl.utils.torch_dtypes import PrecisionType + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision + from torch import optim + + local_path = copy_local_path_from_hdfs(config.model.path) + # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info + # using random initialized model from any architecture. May not be the same as Actor. + + tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path) + self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False)) + + from omegaconf import OmegaConf + override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) + override_config_kwargs = { + 'bos_token_id': self.tokenizer.bos_token_id, + 'eos_token_id': self.tokenizer.eos_token_id, + 'pad_token_id': self.tokenizer.pad_token_id, + } + override_config_kwargs.update(override_config) + if self.rank == 0: + print(f'Critic overriding config {override_config_kwargs}') + + torch_dtype = self.config.model.fsdp_config.get('model_dtype', 'fp32') + torch_dtype = PrecisionType.to_dtype(torch_dtype) + + from transformers import AutoConfig, AutoModelForTokenClassification + from torch import nn + + trust_remote_code = False + critic_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) + critic_model_config.num_labels = 1 + + use_remove_padding = config.model.get('use_remove_padding', False) + if use_remove_padding: + from verl.models.registry import check_model_support_rmpad + check_model_support_rmpad(critic_model_config.model_type) + + if use_remove_padding and self.ulysses_sequence_parallel_size > 1: + from verl.models.transformers.monkey_patch import apply_monkey_patch + apply_monkey_patch(critic_model_config, verbose=True) + + init_context = get_init_weight_context_manager() + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + setattr(critic_model_config, 'classifier_dropout', 0.) + setattr(critic_model_config, 'hidden_dropout', '0') + critic_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=critic_model_config, + attn_implementation='flash_attention_2', + trust_remote_code=trust_remote_code) + + # some parameters may not in torch_dtype + critic_module.to(torch_dtype) + + if config.model.get('enable_gradient_checkpointing', False): + critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) + if self.rank == 0: + print_model_size(critic_module) + + self.critic_model_config = critic_model_config + + fsdp_config = self.config.model.fsdp_config + mixed_precision_config = fsdp_config.get('mixed_precision', None) + if mixed_precision_config is not None: + param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16')) + reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32')) + buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32')) + else: + param_dtype = torch.bfloat16 + reduce_dtype = torch.float32 + buffer_dtype = torch.float32 + + mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) + + auto_wrap_policy = get_fsdp_wrap_policy(module=critic_module, config=self.config.model.fsdp_config.wrap_policy) + + log_gpu_memory_usage('Before critic FSDP', logger=None) + + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) + + # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation + critic_module = FSDP(critic_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=False, + device_mesh=self.device_mesh, + cpu_offload=None) + + log_gpu_memory_usage('After critic FSDP', logger=None) + + critic_optimizer = optim.AdamW(critic_module.parameters(), + lr=config.optim.lr, + betas=config.optim.get('betas', (0.9, 0.999)), + weight_decay=config.optim.get('weight_decay', 1e-2)) + + total_steps = config.optim.get('total_training_steps', 0) + num_warmup_steps_ratio = config.optim.get('lr_warmup_steps_ratio', 0.) + num_warmup_steps = int(num_warmup_steps_ratio * total_steps) + + print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}') + + from verl.utils.torch_functional import get_constant_schedule_with_warmup + critic_lr_scheduler = get_constant_schedule_with_warmup(optimizer=critic_optimizer, + num_warmup_steps=num_warmup_steps) + + return critic_module, critic_optimizer, critic_lr_scheduler + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get('external_lib', None)) + + from verl.workers.critic import DataParallelPPOCritic + self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( + self.config) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.critic_optimizer) + + self.critic = DataParallelPPOCritic(config=self.config, + critic_module=self.critic_module, + critic_optimizer=self.critic_optimizer) + + self.flops_counter = FlopsCounter(self.critic_model_config) + self.checkpoint_manager = FSDPCheckpointManager(model=self.critic_module, + optimizer=self.critic_optimizer, + lr_scheduler=self.critic_lr_scheduler, + tokenizer=self.tokenizer) + + torch.cuda.empty_cache() + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + def compute_values(self, data: DataProto): + data = data.to('cuda') + + if self._is_offload_param: + load_fsdp_model_to_gpu(self.critic_module) + micro_batch_size = self.config.forward_micro_batch_size_per_gpu + data.meta_info['micro_batch_size'] = micro_batch_size + data.meta_info['max_token_len'] = self.config.forward_max_token_len_per_gpu + data.meta_info['use_dynamic_bsz'] = self.config.use_dynamic_bsz + # perform forward computation + with self.ulysses_sharding_manager: + data = self.ulysses_sharding_manager.preprocess_data(data=data) + values = self.critic.compute_values(data=data) + output = DataProto.from_dict(tensors={'values': values}) + output = self.ulysses_sharding_manager.postprocess_data(data=output) + + output = output.to('cpu') + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + return output + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + def update_critic(self, data: DataProto): + data = data.to('cuda') + if self._is_offload_param: + load_fsdp_model_to_gpu(self.critic_module) + if self._is_offload_optimizer: + load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()) + + # perform forward computation + with self.ulysses_sharding_manager: + data = self.ulysses_sharding_manager.preprocess_data(data=data) + + with Timer(name='update_critic', logger=None) as timer: + metrics = self.critic.update_critic(data=data) + delta_time = timer.last + + global_num_tokens = data.meta_info['global_token_num'] + estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) + metrics['mfu/critic'] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + + self.critic_lr_scheduler.step() + lr = self.critic_lr_scheduler.get_last_lr()[0] + metrics['critic/lr'] = lr + + output = DataProto(batch=None, meta_info={'metrics': metrics}) + output = self.ulysses_sharding_manager.postprocess_data(data=output) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.critic_optimizer) + torch.cuda.empty_cache() + output = output.to('cpu') + return output + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False): + import torch + if self._is_offload_param: + load_fsdp_model_to_gpu(self.critic_module) + + self.checkpoint_manager.save_checkpoint(local_path=local_path, + hdfs_path=hdfs_path, + global_step=global_step, + remove_previous_ckpt=remove_previous_ckpt) + + torch.distributed.barrier() + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def load_checkpoint(self, path, del_local_after_load=True): + import torch + if self._is_offload_param: + load_fsdp_model_to_gpu(self.critic_module) + + self.checkpoint_manager.load_checkpoint(path=path, del_local_after_load=del_local_after_load) + + torch.distributed.barrier() + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) \ No newline at end of file diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py new file mode 100644 index 00000000..bce5c0a8 --- /dev/null +++ b/recipe/prime/prime_ray_trainer.py @@ -0,0 +1,867 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FSDP PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import os +import uuid +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from pprint import pprint +from typing import Type, Dict +from copy import deepcopy + +import numpy as np +from codetiming import Timer +from omegaconf import OmegaConf, open_dict +from verl import DataProto +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.base import Worker +from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo import core_algos +from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path +from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn +from verl.trainer.ppo.ray_trainer import Role, WorkerType, ResourcePoolManager, apply_kl_penalty, reduce_metrics, _compute_response_info, _timer + +import torch +from verl.utils.torch_functional import masked_mean + + +def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1): + # prepare response group + # TODO: add other ways to estimate advantages + if adv_estimator == 'gae': + values = data.batch['values'] + responses = data.batch['responses'] + response_length = responses.size(-1) + attention_mask = data.batch['attention_mask'] + response_mask = attention_mask[:, -response_length:] + token_level_rewards = data.batch['token_level_rewards'] + advantages, returns = core_algos.compute_gae_advantage_return(token_level_rewards=token_level_rewards, + values=values, + eos_mask=response_mask, + gamma=gamma, + lam=lam) + data.batch['advantages'] = advantages + data.batch['returns'] = returns + elif adv_estimator == 'grpo': + token_level_rewards = data.batch['token_level_rewards'] + index = data.non_tensor_batch['uid'] + responses = data.batch['responses'] + response_length = responses.size(-1) + attention_mask = data.batch['attention_mask'] + response_mask = attention_mask[:, -response_length:] + advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards=token_level_rewards, + eos_mask=response_mask, + index=index) + data.batch['advantages'] = advantages + data.batch['returns'] = returns + elif adv_estimator == 'reinforce_plus_plus': + token_level_rewards = data.batch['token_level_rewards'] + responses = data.batch['responses'] + response_length = responses.size(-1) + attention_mask = data.batch['attention_mask'] + response_mask = attention_mask[:, -response_length:] + advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( + token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma) + data.batch['advantages'] = advantages + data.batch['returns'] = returns + elif adv_estimator == 'remax': + token_level_rewards = data.batch['token_level_rewards'] + index = data.non_tensor_batch['uid'] + responses = data.batch['responses'] + response_length = responses.size(-1) + attention_mask = data.batch['attention_mask'] + response_mask = attention_mask[:, -response_length:] + + reward_baselines = data.batch['reward_baselines'] + + advantages, returns = core_algos.compute_remax_outcome_advantage(token_level_rewards=token_level_rewards, + reward_baselines=reward_baselines, + eos_mask=response_mask) + + data.batch['advantages'] = advantages + data.batch['returns'] = returns + else: + raise NotImplementedError + return data + +def compute_data_metrics(batch, use_critic=True): + # TODO: add response length + sequence_score = batch.batch['token_level_scores'].sum(-1) + sequence_reward = batch.batch['token_level_rewards'].sum(-1) + + advantages = batch.batch['advantages'] + returns = batch.batch['returns'] + + max_response_length = batch.batch['responses'].shape[-1] + + prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool() + response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool() + + max_prompt_length = prompt_mask.size(-1) + + response_info = _compute_response_info(batch) + prompt_length = response_info['prompt_length'] + response_length = response_info['response_length'] + + valid_adv = torch.masked_select(advantages, response_mask) + valid_returns = torch.masked_select(returns, response_mask) + + if use_critic: + values = batch.batch['values'] + valid_values = torch.masked_select(values, response_mask) + return_diff_var = torch.var(valid_returns - valid_values) + return_var = torch.var(valid_returns) + + metrics = { + # score + 'critic/score/mean': + torch.mean(sequence_score).detach().item(), + 'critic/score/max': + torch.max(sequence_score).detach().item(), + 'critic/score/min': + torch.min(sequence_score).detach().item(), + # reward + 'critic/rewards/mean': + torch.mean(sequence_reward).detach().item(), + 'critic/rewards/max': + torch.max(sequence_reward).detach().item(), + 'critic/rewards/min': + torch.min(sequence_reward).detach().item(), + # adv + 'critic/advantages/mean': + torch.mean(valid_adv).detach().item(), + 'critic/advantages/max': + torch.max(valid_adv).detach().item(), + 'critic/advantages/min': + torch.min(valid_adv).detach().item(), + # returns + 'critic/returns/mean': + torch.mean(valid_returns).detach().item(), + 'critic/returns/max': + torch.max(valid_returns).detach().item(), + 'critic/returns/min': + torch.min(valid_returns).detach().item(), + **({ + # values + 'critic/values/mean': torch.mean(valid_values).detach().item(), + 'critic/values/max': torch.max(valid_values).detach().item(), + 'critic/values/min': torch.min(valid_values).detach().item(), + # vf explained var + 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } if use_critic else {}), + + # response length + 'response_length/mean': + torch.mean(response_length).detach().item(), + 'response_length/max': + torch.max(response_length).detach().item(), + 'response_length/min': + torch.min(response_length).detach().item(), + 'response_length/clip_ratio': + torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), + # prompt length + 'prompt_length/mean': + torch.mean(prompt_length).detach().item(), + 'prompt_length/max': + torch.max(prompt_length).detach().item(), + 'prompt_length/min': + torch.min(prompt_length).detach().item(), + 'prompt_length/clip_ratio': + torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + } + return metrics + + +def compute_timing_metrics(batch, timing_raw): + response_info = _compute_response_info(batch) + num_prompt_tokens = torch.sum(response_info['prompt_length']).item() + num_response_tokens = torch.sum(response_info['response_length']).item() + num_overall_tokens = num_prompt_tokens + num_response_tokens + + num_tokens_of_section = { + 'gen': num_response_tokens, + **{ + name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor'] + }, + } + + return { + **{ + f'timing_s/{name}': value for name, value in timing_raw.items() + }, + **{ + f'timing_per_token_ms/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys( + )) & set(timing_raw.keys()) + }, + } + + +class RayPRIMETrainer(object): + """ + Note that this trainer runs on the driver process on a single CPU/GPU node. + """ + + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__(self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + reward_fn=None, + val_reward_fn=None): + + # assert torch.cuda.is_available(), 'cuda must be available on driver' + + self.tokenizer = tokenizer + self.config = config + self.reward_fn = reward_fn + self.val_reward_fn = val_reward_fn + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert self.hybrid_engine, 'Currently, only support hybrid engine' + + if self.hybrid_engine: + assert Role.ActorRollout in role_worker_mapping, f'{role_worker_mapping.keys()=}' + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = Role.RefPolicy in role_worker_mapping + self.use_rm = Role.RewardModel in role_worker_mapping + self.ray_worker_group_cls = ray_worker_group_cls + + # define KL control + if self.use_reference_policy: + if config.algorithm.kl_ctrl.type == 'fixed': + self.kl_ctrl = core_algos.FixedKLController(kl_coef=config.algorithm.kl_ctrl.kl_coef) + elif config.algorithm.kl_ctrl.type == 'adaptive': + assert config.algorithm.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}' + self.kl_ctrl = core_algos.AdaptiveKLController(init_kl_coef=config.algorithm.kl_ctrl.kl_coef, + target_kl=config.algorithm.kl_ctrl.target_kl, + horizon=config.algorithm.kl_ctrl.horizon) + else: + raise NotImplementedError + else: + self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.) + + self.use_critic = False + + self._validate_config() + self._create_dataloader() + + def _validate_config(self): + config = self.config + # number of GPUs total + n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes + + # 1. Check total batch size for data correctness + real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n + assert real_train_batch_size % n_gpus == 0, \ + f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})." + + # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" + # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". + def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): + if mbs is None and mbs_per_gpu is None: + raise ValueError(f"[{name}] Please set at least one of '{name}.micro_batch_size' or " + f"'{name}.micro_batch_size_per_gpu'.") + + if mbs is not None and mbs_per_gpu is not None: + raise ValueError(f"[{name}] You have set both '{name}.micro_batch_size' AND " + f"'{name}.micro_batch_size_per_gpu'. Please remove '{name}.micro_batch_size' " + f"because only '*_micro_batch_size_per_gpu' is supported (the former is deprecated).") + + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu + check_mutually_exclusive(config.actor_rollout_ref.actor.ppo_micro_batch_size, + config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, + "actor_rollout_ref.actor") + + # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive(config.actor_rollout_ref.ref.log_prob_micro_batch_size, + config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.ref") + + # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive(config.actor_rollout_ref.rollout.log_prob_micro_batch_size, + config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.rollout") + + if self.use_critic and not config.critic.use_dynamic_bsz: + # Check for critic micro-batch size conflicts + check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, + "critic") + + # Check for reward model micro-batch size conflicts + if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: + check_mutually_exclusive(config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, + "reward_model") + + # Actor + # if NOT dynamic_bsz, we must ensure: + # ppo_mini_batch_size is divisible by ppo_micro_batch_size + # ppo_micro_batch_size * sequence_parallel_size >= n_gpus + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + sp_size = config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1) + if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: + assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0 + assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus + + # critic + if self.use_critic and not config.critic.use_dynamic_bsz: + sp_size = config.critic.get('ulysses_sequence_parallel_size', 1) + if config.critic.ppo_micro_batch_size is not None: + assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 + assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus + + # Check if use_remove_padding is enabled when using sequence parallelism for fsdp + if config.actor_rollout_ref.actor.strategy == 'fsdp': + if config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1) > 1 or \ + config.actor_rollout_ref.ref.get('ulysses_sequence_parallel_size', 1) > 1: + assert config.actor_rollout_ref.model.use_remove_padding, \ + "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." + + if self.use_critic and config.critic.strategy == 'fsdp': + if config.critic.get('ulysses_sequence_parallel_size', 1) > 1: + assert config.critic.model.use_remove_padding, \ + "When using sequence parallelism for critic, you must enable `use_remove_padding`." + + print("[validate_config] All configuration checks passed successfully!") + + def _create_dataloader(self): + from torch.utils.data import DataLoader, RandomSampler, SequentialSampler + # TODO: we have to make sure the batch size is divisible by the dp size + self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files, + tokenizer=self.tokenizer, + prompt_key=self.config.data.prompt_key, + max_prompt_length=self.config.data.max_prompt_length, + filter_prompts=True, + return_raw_chat=self.config.data.get('return_raw_chat', False), + truncation='error') + # use sampler for better ckpt resume + if self.config.data.shuffle: + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(self.config.data.get('seed', 1)) + sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator) + else: + sampler = SequentialSampler(data_source=self.train_dataset) + + self.train_dataloader = DataLoader(dataset=self.train_dataset, + batch_size=self.config.data.train_batch_size, + drop_last=True, + collate_fn=collate_fn, + sampler=sampler) + + self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files, + tokenizer=self.tokenizer, + prompt_key=self.config.data.prompt_key, + max_prompt_length=self.config.data.max_prompt_length, + filter_prompts=True, + return_raw_chat=self.config.data.get('return_raw_chat', False), + truncation='error') + self.val_dataloader = DataLoader(dataset=self.val_dataset, + batch_size=len(self.val_dataset), + shuffle=True, + drop_last=True, + collate_fn=collate_fn) + + assert len(self.train_dataloader) >= 1 + assert len(self.val_dataloader) >= 1 + + print(f'Size of train dataloader: {len(self.train_dataloader)}') + print(f'Size of val dataloader: {len(self.val_dataloader)}') + + # inject total_training_steps to actor/critic optim_config. This is hacky. + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f'Total training steps: {self.total_training_steps}') + + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + self.config.critic.optim.total_training_steps = total_training_steps + + def _maybe_log_val_generations_to_wandb(self, inputs, outputs, scores): + """Log a table of validation samples to wandb""" + + generations_to_log = self.config.trainer.val_generations_to_log_to_wandb + + if generations_to_log == 0: + return + + if generations_to_log > 0 and 'wandb' not in self.config.trainer.logger: + print( + 'WARNING: `val_generations_to_log_to_wandb` is set to a positive value, but no wandb logger is found. ') + return + + import wandb + import numpy as np + + # Create tuples of (input, output, score) and sort by input text + samples = list(zip(inputs, outputs, scores)) + samples.sort(key=lambda x: x[0]) # Sort by input text + + # Use fixed random seed for deterministic shuffling + rng = np.random.RandomState(42) + rng.shuffle(samples) + + # Take first N samples after shuffling + samples = samples[:generations_to_log] + + # Create column names for all samples + columns = ["step"] + sum([[f"input_{i+1}", f"output_{i+1}", f"score_{i+1}"] for i in range(len(samples))], []) + + if not hasattr(self, 'validation_table'): + # Initialize the table on first call + self.validation_table = wandb.Table(columns=columns) + + # Create a new table with same columns and existing data + # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737 + new_table = wandb.Table(columns=columns, data=self.validation_table.data) + + # Add new row with all data + row_data = [] + row_data.append(self.global_steps) + for sample in samples: + row_data.extend(sample) + + new_table.add_data(*row_data) + + # Update reference and log + wandb.log({"val/generations": new_table}, step=self.global_steps) + self.validation_table = new_table + + def _validate(self): + reward_tensor_lst = [] + data_source_lst = [] + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_scores = [] + + for test_data in self.val_dataloader: + test_batch = DataProto.from_single_dict(test_data) + + # we only do validation on rule-based rm + if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model': + return {} + + # Store original inputs + input_ids = test_batch.batch['input_ids'] + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + + test_gen_batch = test_batch.pop(['input_ids', 'attention_mask', 'position_ids']) + test_gen_batch.meta_info = { + 'eos_token_id': self.tokenizer.eos_token_id, + 'pad_token_id': self.tokenizer.pad_token_id, + 'recompute_log_prob': False, + 'do_sample': False, + 'validate': True, + } + + # pad to be divisible by dp_size + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + # unpad + test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + print('validation generation end') + + # Store generated outputs + output_ids = test_output_gen_batch.batch['responses'] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + sample_outputs.extend(output_texts) + + test_batch = test_batch.union(test_output_gen_batch) + + # evaluate using reward_function + reward_tensor = self.val_reward_fn(test_batch) + + # Store scores + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + reward_tensor_lst.append(reward_tensor) + data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0])) + + self._maybe_log_val_generations_to_wandb(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,) + data_sources = np.concatenate(data_source_lst, axis=0) + + # evaluate test_score based on data source + data_source_reward = {} + for i in range(reward_tensor.shape[0]): + data_source = data_sources[i] + if data_source not in data_source_reward: + data_source_reward[data_source] = [] + data_source_reward[data_source].append(reward_tensor[i].item()) + + metric_dict = {} + for data_source, rewards in data_source_reward.items(): + metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards) + + return metric_dict + + def init_workers(self): + """Init resource pool and worker group""" + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role='actor_rollout') + self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) + self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls + + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role='ref') + self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool]['rm'] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + self.wg_dicts = [] + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 + self.wg_dicts.append(wg_dict) + + if self.use_critic: + self.critic_wg = all_wg['critic'] + self.critic_wg.init_model() + + if self.use_reference_policy: + self.ref_policy_wg = all_wg['ref'] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = all_wg['rm'] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg['actor_rollout'] + self.actor_rollout_wg.init_model() + + def _save_checkpoint(self): + # path: given_path + `/global_step_{global_steps}` + `/actor` + local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, + f'global_step_{self.global_steps}') + actor_local_path = os.path.join(local_global_step_folder, 'actor') + + actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( + self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor') + self.actor_rollout_wg.save_checkpoint(actor_local_path, + actor_remote_path, + self.global_steps, + remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) + + if self.use_critic: + critic_local_path = os.path.join(local_global_step_folder, 'critic') + critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( + self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'critic') + self.critic_wg.save_checkpoint(critic_local_path, + critic_remote_path, + self.global_steps, + remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) + + # save dataloader + dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt') + import dill + torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill) + + # latest checkpointed iteration tracker (for atomic usage) + local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir, + 'latest_checkpointed_iteration.txt') + with open(local_latest_checkpointed_iteration, 'w') as f: + f.write(str(self.global_steps)) + + def _load_checkpoint(self): + if self.config.trainer.resume_mode == 'disable': + return 0 + + # load from hdfs + if self.config.trainer.default_hdfs_dir is not None: + NotImplementedError('load from hdfs is not implemented yet') + else: + checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + + # find global_step_folder + if self.config.trainer.resume_mode == 'auto': + if global_step_folder is None: + print('Training from scratch') + return 0 + else: + if not (self.config.trainer.resume_from_path and global_step_folder is not None): + assert isinstance(self.config.trainer.resume_mode, str), "resume ckpt must be str type" + assert 'global_step_' in self.config.trainer.resume_mode, "resume ckpt must specify the global_steps" + global_step_folder = self.config.trainer.resume_mode + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + print(f'Load from checkpoint folder: {global_step_folder}') + # set global step + self.global_steps = int(global_step_folder.split('global_step_')[-1]) + + print(f'Setting global step to {self.global_steps}') + print(f'Resuming from {global_step_folder}') + + actor_path = os.path.join(global_step_folder, 'actor') + critic_path = os.path.join(global_step_folder, 'critic') + # load actor + self.actor_rollout_wg.load_checkpoint(actor_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + # load critic + if self.use_critic: + self.critic_wg.load_checkpoint(critic_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + + # load dataloader, + # TODO: from remote not implemented yet + dataloader_local_path = os.path.join(global_step_folder, 'data.pt') + self.train_dataloader = torch.load(dataloader_local_path) + if isinstance(self.train_dataloader.dataset, RLHFDataset): + self.train_dataloader.dataset.resume_dataset_state() + + def _balance_batch(self, batch: DataProto, metrics, logging_prefix='global_seqlen'): + """Reorder the data on single controller such that each dp rank gets similar total tokens""" + attention_mask = batch.batch['attention_mask'] + batch_size = attention_mask.shape[0] + global_seqlen_lst = batch.batch['attention_mask'].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + world_size = self.actor_rollout_wg.world_size + global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst, + k_partitions=world_size, + equal_size=True) + # reorder based on index. The data will be automatically equally partitioned by dispatch function + global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) + batch.reorder(global_idx) + global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst, + partitions=global_partition_lst, + prefix=logging_prefix) + metrics.update(global_balance_stats) + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from verl.utils.tracking import Tracking + from omegaconf import OmegaConf + + logger = Tracking(project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True)) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True): + val_metrics = self._validate() + pprint(f'Initial validation metrics: {val_metrics}') + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get('val_only', False): + return + + # we start from step 1 + self.global_steps += 1 + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # pop those keys for generation + gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) + + with _timer('step', timing_raw): + # generate a batch + with _timer('gen', timing_raw): + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + + if self.config.algorithm.adv_estimator == 'remax': + with _timer('gen_max', timing_raw): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info['do_sample'] = False + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + + batch = batch.union(gen_baseline_output) + reward_baseline_tensor = self.reward_fn(batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + + batch.batch['reward_baselines'] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + + batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], + dtype=object) + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + # balance the number of valid tokens on each dp rank. + # Note that this breaks the order of data inside the batch. + # Please take care when you implement group based adv computation such as GRPO and rloo + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + + # recompute old_log_probs + with _timer('old_log_prob', timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with _timer('ref', timing_raw): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with _timer('values', timing_raw): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with _timer('adv', timing_raw): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm: + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + # we combine with rule-based rm + reward_tensor = self.reward_fn(batch) + batch.batch['token_level_scores'] = reward_tensor + + # compute rewards. apply_kl_penalty if available + if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False): + batch, kl_metrics = apply_kl_penalty(batch, + kl_ctrl=self.kl_ctrl, + kl_penalty=self.config.algorithm.kl_penalty) + metrics.update(kl_metrics) + else: + batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] + + # compute advantages, executed on the driver process + batch = compute_advantage(batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n) + + # update critic + if self.use_critic: + with _timer('update_critic', timing_raw): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with _timer('update_actor', timing_raw): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + metrics.update(actor_output_metrics) + + # validate + if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ + self.global_steps % self.config.trainer.test_freq == 0: + with _timer('testing', timing_raw): + val_metrics: dict = self._validate() + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and \ + self.global_steps % self.config.trainer.save_freq == 0: + with _timer('save_checkpoint', timing_raw): + self._save_checkpoint() + + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + self.global_steps += 1 + + if self.global_steps >= self.total_training_steps: + + # perform validation after training + if self.val_reward_fn is not None: + val_metrics = self._validate() + pprint(f'Final validation metrics: {val_metrics}') + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.save_freq > 0 and \ + (self.global_steps - 1) % self.config.trainer.save_freq != 0: + with _timer('save_checkpoint', timing_raw): + self._save_checkpoint() + return diff --git a/recipe/prime/run_prime_qwen.sh b/recipe/prime/run_prime_qwen.sh new file mode 100644 index 00000000..62d3fb1a --- /dev/null +++ b/recipe/prime/run_prime_qwen.sh @@ -0,0 +1,55 @@ +set -x + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +model_path=PRIME-RL/Eurus-2-7B-SFT + +python3 -m recipe.prime.main_prime \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=64 \ + data.val_batch_size=6312 \ + data.max_prompt_length=1024 \ + data.max_response_length=3072 \ + data.filter_accuracy=True \ + data.accuracy_lower_bound=0.2 \ + data.accuracy_upper_bound=0.8 \ + data.oversample_factor=4 \ + actor_rollout_ref.model.path=$model_path \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + algorithm.adv_estimator=rloo \ + reward_model.model.path=$model_path \ + reward_model.micro_batch_size=8 \ + reward_model.model.update=before \ + reward_model.model.beta_train=0.05 \ + reward_model.model.optim.lr=1e-6 \ + reward_model.model.optim.grad_clip=10.0 \ + reward_model.model.input_tokenizer=null \ + reward_model.mini_batch_size=64 \ + trainer.val_before_train=False \ + trainer.logger=['console','wandb'] \ + trainer.project_name='prime_example' \ + trainer.experiment_name='Eurus-2-7B-SFT' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 $@ diff --git a/scripts/format.sh b/scripts/format.sh index ed49d6f1..d8562caf 100644 --- a/scripts/format.sh +++ b/scripts/format.sh @@ -1,3 +1,3 @@ #!/bin/bash pip3 install --upgrade yapf -python3 -m yapf -ir -vv --style ./.style.yapf verl tests single_controller examples +python3 -m yapf -ir -vv --style ./.style.yapf verl tests single_controller examples recipe From 876ea1a020214c28f3cccbbd195523ee61d6505d Mon Sep 17 00:00:00 2001 From: Zefan Wang Date: Mon, 24 Feb 2025 18:43:47 +0800 Subject: [PATCH 2/6] trainer, alg, manager --- recipe/prime/__init__.py | 13 + recipe/prime/config/prime_config.yaml | 0 recipe/prime/main_prime.py | 28 +- recipe/prime/prime_core_algos.py | 76 ++++ recipe/prime/prime_dp_rm.py | 13 + recipe/prime/prime_fsdp_workers.py | 17 +- recipe/prime/prime_ray_trainer.py | 512 ++++---------------------- verl/workers/reward_manager/prime.py | 41 ++- 8 files changed, 243 insertions(+), 457 deletions(-) delete mode 100644 recipe/prime/config/prime_config.yaml diff --git a/recipe/prime/__init__.py b/recipe/prime/__init__.py index e69de29b..b1697c70 100644 --- a/recipe/prime/__init__.py +++ b/recipe/prime/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/recipe/prime/config/prime_config.yaml b/recipe/prime/config/prime_config.yaml deleted file mode 100644 index e69de29b..00000000 diff --git a/recipe/prime/main_prime.py b/recipe/prime/main_prime.py index 5821cc29..b15e3873 100644 --- a/recipe/prime/main_prime.py +++ b/recipe/prime/main_prime.py @@ -1,3 +1,17 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -83,7 +97,7 @@ def main_task(config, compute_score=None): if config.reward_model.enable: from .prime_fsdp_workers import PRIMERewardModelWorker role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker) - mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker) + mapping[Role.RewardModel] = global_pool_id reward_manager_name = config.reward_model.get("reward_manager", "naive") if reward_manager_name == 'naive': @@ -102,12 +116,12 @@ def main_task(config, compute_score=None): resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) trainer = RayPRIMETrainer(config=config, - tokenizer=tokenizer, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn) + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn) trainer.init_workers() trainer.fit() diff --git a/recipe/prime/prime_core_algos.py b/recipe/prime/prime_core_algos.py index e69de29b..19393e35 100644 --- a/recipe/prime/prime_core_algos.py +++ b/recipe/prime/prime_core_algos.py @@ -0,0 +1,76 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import verl +import verl.utils.torch_functional as verl_F + + +def compute_rloo_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor, n_samples, config): + # calculate rloo reward on different reward sources, and sum again + + def masked_rloo(reward_tensor_original, mask_tensor): + reward_tensor = reward_tensor_original.clone() + reward_tensor[~mask_tensor] = 0 + for start_pos in range(0, reward_tensor.shape[0], n_samples): + cur_rewards_mean = torch.cat([ + reward_tensor[pos:pos + 1][mask_tensor[pos:pos + 1]].mean(dim=0, keepdim=True) + for pos in range(start_pos, start_pos + n_samples) + ], + dim=0) + cur_rewards_sum = cur_rewards_mean.sum() + cur_reward_baseline = cur_rewards_sum / (n_samples - 1) + reward_tensor[start_pos:start_pos + n_samples][ + mask_tensor[start_pos:start_pos + n_samples]] = \ + reward_tensor[start_pos:start_pos + n_samples][ + mask_tensor[start_pos:start_pos + n_samples]] * ( + n_samples / (n_samples - 1)) - cur_reward_baseline + + return reward_tensor + + reward_tensors = [] + + with torch.no_grad(): + + if 'rm_scores' in data.batch and config.algorithm.dpo_coef != 0.: + reward_tensor = data.batch['rm_scores'] + reward_mask = eos_mask.bool() + + reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.dpo_coef) + + if 'acc' in data.batch and config.algorithm.gt_coef != 0.: + reward_tensor = torch.zeros_like(eos_mask, dtype=torch.float32) + reward_mask = torch.zeros_like(eos_mask, dtype=torch.bool) + + prompt_ids = data.batch['prompts'] + prompt_length = prompt_ids.shape[-1] + valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1) + + reward_mask[ + torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device), + valid_response_length - 1] = True + reward_tensor[ + torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device), + valid_response_length - 1] = data.batch['acc'] + + reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.gt_coef) + + final_reward_tensor = sum(reward_tensors) + + returns = (final_reward_tensor * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + + advantages = returns.clone() + advantages = verl_F.masked_whiten(advantages, eos_mask) + + return advantages, returns diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py index e69de29b..b1697c70 100644 --- a/recipe/prime/prime_dp_rm.py +++ b/recipe/prime/prime_dp_rm.py @@ -0,0 +1,13 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/recipe/prime/prime_fsdp_workers.py b/recipe/prime/prime_fsdp_workers.py index 18673927..0ebf5c84 100644 --- a/recipe/prime/prime_fsdp_workers.py +++ b/recipe/prime/prime_fsdp_workers.py @@ -1,3 +1,16 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import os @@ -28,6 +41,8 @@ logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) + + class PRIMERewardModelWorker(Worker): def __init__(self, config): @@ -294,4 +309,4 @@ def load_checkpoint(self, path, del_local_after_load=True): torch.distributed.barrier() if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) \ No newline at end of file + offload_fsdp_model_to_cpu(self.critic_module) diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py index bce5c0a8..2ae171b4 100644 --- a/recipe/prime/prime_ray_trainer.py +++ b/recipe/prime/prime_ray_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2024 PRIME team and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ """ import os +import statistics import uuid from contextlib import contextmanager from dataclasses import dataclass, field @@ -34,78 +35,33 @@ from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.ppo import core_algos +from . import prime_core_algos from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn from verl.trainer.ppo.ray_trainer import Role, WorkerType, ResourcePoolManager, apply_kl_penalty, reduce_metrics, _compute_response_info, _timer +from verl.trainer.ppo.ray_trainer import RayPPOTrainer import torch from verl.utils.torch_functional import masked_mean -def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1): - # prepare response group - # TODO: add other ways to estimate advantages - if adv_estimator == 'gae': - values = data.batch['values'] +def compute_advantage(data: DataProto, adv_estimator, config): + if adv_estimator == 'rloo': responses = data.batch['responses'] response_length = responses.size(-1) attention_mask = data.batch['attention_mask'] response_mask = attention_mask[:, -response_length:] - token_level_rewards = data.batch['token_level_rewards'] - advantages, returns = core_algos.compute_gae_advantage_return(token_level_rewards=token_level_rewards, - values=values, - eos_mask=response_mask, - gamma=gamma, - lam=lam) - data.batch['advantages'] = advantages - data.batch['returns'] = returns - elif adv_estimator == 'grpo': - token_level_rewards = data.batch['token_level_rewards'] - index = data.non_tensor_batch['uid'] - responses = data.batch['responses'] - response_length = responses.size(-1) - attention_mask = data.batch['attention_mask'] - response_mask = attention_mask[:, -response_length:] - advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards=token_level_rewards, - eos_mask=response_mask, - index=index) - data.batch['advantages'] = advantages - data.batch['returns'] = returns - elif adv_estimator == 'reinforce_plus_plus': - token_level_rewards = data.batch['token_level_rewards'] - responses = data.batch['responses'] - response_length = responses.size(-1) - attention_mask = data.batch['attention_mask'] - response_mask = attention_mask[:, -response_length:] - advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma) - data.batch['advantages'] = advantages - data.batch['returns'] = returns - elif adv_estimator == 'remax': - token_level_rewards = data.batch['token_level_rewards'] - index = data.non_tensor_batch['uid'] - responses = data.batch['responses'] - response_length = responses.size(-1) - attention_mask = data.batch['attention_mask'] - response_mask = attention_mask[:, -response_length:] - - reward_baselines = data.batch['reward_baselines'] - - advantages, returns = core_algos.compute_remax_outcome_advantage(token_level_rewards=token_level_rewards, - reward_baselines=reward_baselines, - eos_mask=response_mask) - + advantages, returns = prime_core_algos.compute_rloo_advantage_return(data, response_mask, + config.actor_rollout_ref.rollout.n, config) data.batch['advantages'] = advantages data.batch['returns'] = returns else: raise NotImplementedError return data + def compute_data_metrics(batch, use_critic=True): - # TODO: add response length - sequence_score = batch.batch['token_level_scores'].sum(-1) - sequence_reward = batch.batch['token_level_rewards'].sum(-1) advantages = batch.batch['advantages'] returns = batch.batch['returns'] @@ -131,20 +87,6 @@ def compute_data_metrics(batch, use_critic=True): return_var = torch.var(valid_returns) metrics = { - # score - 'critic/score/mean': - torch.mean(sequence_score).detach().item(), - 'critic/score/max': - torch.max(sequence_score).detach().item(), - 'critic/score/min': - torch.min(sequence_score).detach().item(), - # reward - 'critic/rewards/mean': - torch.mean(sequence_reward).detach().item(), - 'critic/rewards/max': - torch.max(sequence_reward).detach().item(), - 'critic/rewards/min': - torch.min(sequence_reward).detach().item(), # adv 'critic/advantages/mean': torch.mean(valid_adv).detach().item(), @@ -214,7 +156,7 @@ def compute_timing_metrics(batch, timing_raw): } -class RayPRIMETrainer(object): +class RayPRIMETrainer(RayPPOTrainer): """ Note that this trainer runs on the driver process on a single CPU/GPU node. """ @@ -232,36 +174,8 @@ def __init__(self, # assert torch.cuda.is_available(), 'cuda must be available on driver' - self.tokenizer = tokenizer - self.config = config - self.reward_fn = reward_fn - self.val_reward_fn = val_reward_fn - - self.hybrid_engine = config.actor_rollout_ref.hybrid_engine - assert self.hybrid_engine, 'Currently, only support hybrid engine' - - if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f'{role_worker_mapping.keys()=}' - - self.role_worker_mapping = role_worker_mapping - self.resource_pool_manager = resource_pool_manager - self.use_reference_policy = Role.RefPolicy in role_worker_mapping - self.use_rm = Role.RewardModel in role_worker_mapping - self.ray_worker_group_cls = ray_worker_group_cls - - # define KL control - if self.use_reference_policy: - if config.algorithm.kl_ctrl.type == 'fixed': - self.kl_ctrl = core_algos.FixedKLController(kl_coef=config.algorithm.kl_ctrl.kl_coef) - elif config.algorithm.kl_ctrl.type == 'adaptive': - assert config.algorithm.kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}' - self.kl_ctrl = core_algos.AdaptiveKLController(init_kl_coef=config.algorithm.kl_ctrl.kl_coef, - target_kl=config.algorithm.kl_ctrl.target_kl, - horizon=config.algorithm.kl_ctrl.horizon) - else: - raise NotImplementedError - else: - self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.) + super().__init__(config, tokenizer, role_worker_mapping, resource_pool_manager, ray_worker_group_cls, reward_fn, + val_reward_fn) self.use_critic = False @@ -269,83 +183,9 @@ def __init__(self, self._create_dataloader() def _validate_config(self): + super()._validate() + # TODO: Additional config checks can be added here config = self.config - # number of GPUs total - n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes - - # 1. Check total batch size for data correctness - real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n - assert real_train_batch_size % n_gpus == 0, \ - f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})." - - # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" - # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". - def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): - if mbs is None and mbs_per_gpu is None: - raise ValueError(f"[{name}] Please set at least one of '{name}.micro_batch_size' or " - f"'{name}.micro_batch_size_per_gpu'.") - - if mbs is not None and mbs_per_gpu is not None: - raise ValueError(f"[{name}] You have set both '{name}.micro_batch_size' AND " - f"'{name}.micro_batch_size_per_gpu'. Please remove '{name}.micro_batch_size' " - f"because only '*_micro_batch_size_per_gpu' is supported (the former is deprecated).") - - if not config.actor_rollout_ref.actor.use_dynamic_bsz: - # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu - check_mutually_exclusive(config.actor_rollout_ref.actor.ppo_micro_batch_size, - config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, - "actor_rollout_ref.actor") - - # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive(config.actor_rollout_ref.ref.log_prob_micro_batch_size, - config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.ref") - - # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive(config.actor_rollout_ref.rollout.log_prob_micro_batch_size, - config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.rollout") - - if self.use_critic and not config.critic.use_dynamic_bsz: - # Check for critic micro-batch size conflicts - check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, - "critic") - - # Check for reward model micro-batch size conflicts - if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: - check_mutually_exclusive(config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, - "reward_model") - - # Actor - # if NOT dynamic_bsz, we must ensure: - # ppo_mini_batch_size is divisible by ppo_micro_batch_size - # ppo_micro_batch_size * sequence_parallel_size >= n_gpus - if not config.actor_rollout_ref.actor.use_dynamic_bsz: - sp_size = config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1) - if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: - assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0 - assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus - - # critic - if self.use_critic and not config.critic.use_dynamic_bsz: - sp_size = config.critic.get('ulysses_sequence_parallel_size', 1) - if config.critic.ppo_micro_batch_size is not None: - assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 - assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus - - # Check if use_remove_padding is enabled when using sequence parallelism for fsdp - if config.actor_rollout_ref.actor.strategy == 'fsdp': - if config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1) > 1 or \ - config.actor_rollout_ref.ref.get('ulysses_sequence_parallel_size', 1) > 1: - assert config.actor_rollout_ref.model.use_remove_padding, \ - "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." - - if self.use_critic and config.critic.strategy == 'fsdp': - if config.critic.get('ulysses_sequence_parallel_size', 1) > 1: - assert config.critic.model.use_remove_padding, \ - "When using sequence parallelism for critic, you must enable `use_remove_padding`." - - print("[validate_config] All configuration checks passed successfully!") def _create_dataloader(self): from torch.utils.data import DataLoader, RandomSampler, SequentialSampler @@ -366,7 +206,8 @@ def _create_dataloader(self): sampler = SequentialSampler(data_source=self.train_dataset) self.train_dataloader = DataLoader(dataset=self.train_dataset, - batch_size=self.config.data.train_batch_size, + batch_size=int(self.config.data.train_batch_size * + self.config.data.oversample_factor), drop_last=True, collate_fn=collate_fn, sampler=sampler) @@ -404,196 +245,6 @@ def _create_dataloader(self): self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps self.config.critic.optim.total_training_steps = total_training_steps - def _maybe_log_val_generations_to_wandb(self, inputs, outputs, scores): - """Log a table of validation samples to wandb""" - - generations_to_log = self.config.trainer.val_generations_to_log_to_wandb - - if generations_to_log == 0: - return - - if generations_to_log > 0 and 'wandb' not in self.config.trainer.logger: - print( - 'WARNING: `val_generations_to_log_to_wandb` is set to a positive value, but no wandb logger is found. ') - return - - import wandb - import numpy as np - - # Create tuples of (input, output, score) and sort by input text - samples = list(zip(inputs, outputs, scores)) - samples.sort(key=lambda x: x[0]) # Sort by input text - - # Use fixed random seed for deterministic shuffling - rng = np.random.RandomState(42) - rng.shuffle(samples) - - # Take first N samples after shuffling - samples = samples[:generations_to_log] - - # Create column names for all samples - columns = ["step"] + sum([[f"input_{i+1}", f"output_{i+1}", f"score_{i+1}"] for i in range(len(samples))], []) - - if not hasattr(self, 'validation_table'): - # Initialize the table on first call - self.validation_table = wandb.Table(columns=columns) - - # Create a new table with same columns and existing data - # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737 - new_table = wandb.Table(columns=columns, data=self.validation_table.data) - - # Add new row with all data - row_data = [] - row_data.append(self.global_steps) - for sample in samples: - row_data.extend(sample) - - new_table.add_data(*row_data) - - # Update reference and log - wandb.log({"val/generations": new_table}, step=self.global_steps) - self.validation_table = new_table - - def _validate(self): - reward_tensor_lst = [] - data_source_lst = [] - - # Lists to collect samples for the table - sample_inputs = [] - sample_outputs = [] - sample_scores = [] - - for test_data in self.val_dataloader: - test_batch = DataProto.from_single_dict(test_data) - - # we only do validation on rule-based rm - if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model': - return {} - - # Store original inputs - input_ids = test_batch.batch['input_ids'] - input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] - sample_inputs.extend(input_texts) - - test_gen_batch = test_batch.pop(['input_ids', 'attention_mask', 'position_ids']) - test_gen_batch.meta_info = { - 'eos_token_id': self.tokenizer.eos_token_id, - 'pad_token_id': self.tokenizer.pad_token_id, - 'recompute_log_prob': False, - 'do_sample': False, - 'validate': True, - } - - # pad to be divisible by dp_size - test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) - test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) - # unpad - test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) - print('validation generation end') - - # Store generated outputs - output_ids = test_output_gen_batch.batch['responses'] - output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] - sample_outputs.extend(output_texts) - - test_batch = test_batch.union(test_output_gen_batch) - - # evaluate using reward_function - reward_tensor = self.val_reward_fn(test_batch) - - # Store scores - scores = reward_tensor.sum(-1).cpu().tolist() - sample_scores.extend(scores) - - reward_tensor_lst.append(reward_tensor) - data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0])) - - self._maybe_log_val_generations_to_wandb(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) - - reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,) - data_sources = np.concatenate(data_source_lst, axis=0) - - # evaluate test_score based on data source - data_source_reward = {} - for i in range(reward_tensor.shape[0]): - data_source = data_sources[i] - if data_source not in data_source_reward: - data_source_reward[data_source] = [] - data_source_reward[data_source].append(reward_tensor[i].item()) - - metric_dict = {} - for data_source, rewards in data_source_reward.items(): - metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards) - - return metric_dict - - def init_workers(self): - """Init resource pool and worker group""" - self.resource_pool_manager.create_resource_pool() - - self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} - - # create actor and rollout - if self.hybrid_engine: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config.actor_rollout_ref, - role='actor_rollout') - self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls - else: - raise NotImplementedError - - # create critic - if self.use_critic: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) - critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) - self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls - - # create reference policy if needed - if self.use_reference_policy: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], - config=self.config.actor_rollout_ref, - role='ref') - self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls - - # create a reward model if reward_fn is None - if self.use_rm: - # we create a RM here - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) - self.resource_pool_to_cls[resource_pool]['rm'] = rm_cls - - # initialize WorkerGroup - # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, - # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups. - # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. - all_wg = {} - self.wg_dicts = [] - for resource_pool, class_dict in self.resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 - self.wg_dicts.append(wg_dict) - - if self.use_critic: - self.critic_wg = all_wg['critic'] - self.critic_wg.init_model() - - if self.use_reference_policy: - self.ref_policy_wg = all_wg['ref'] - self.ref_policy_wg.init_model() - - if self.use_rm: - self.rm_wg = all_wg['rm'] - self.rm_wg.init_model() - - # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg['actor_rollout'] - self.actor_rollout_wg.init_model() - def _save_checkpoint(self): # path: given_path + `/global_step_{global_steps}` + `/actor` local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, @@ -607,14 +258,14 @@ def _save_checkpoint(self): self.global_steps, remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) - if self.use_critic: - critic_local_path = os.path.join(local_global_step_folder, 'critic') - critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( - self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'critic') - self.critic_wg.save_checkpoint(critic_local_path, - critic_remote_path, - self.global_steps, - remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) + if self.use_rm: + reward_local_path = os.path.join(local_global_step_folder, 'reward') + reward_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( + self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'reward') + self.rm_wg.save_checkpoint(reward_local_path, + reward_remote_path, + self.global_steps, + remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) # save dataloader dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt') @@ -662,13 +313,13 @@ def _load_checkpoint(self): print(f'Resuming from {global_step_folder}') actor_path = os.path.join(global_step_folder, 'actor') - critic_path = os.path.join(global_step_folder, 'critic') + reward_path = os.path.join(global_step_folder, 'reward') # load actor self.actor_rollout_wg.load_checkpoint(actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) # load critic - if self.use_critic: - self.critic_wg.load_checkpoint(critic_path, + if self.use_rm: + self.critic_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) # load dataloader, @@ -678,23 +329,6 @@ def _load_checkpoint(self): if isinstance(self.train_dataloader.dataset, RLHFDataset): self.train_dataloader.dataset.resume_dataset_state() - def _balance_batch(self, batch: DataProto, metrics, logging_prefix='global_seqlen'): - """Reorder the data on single controller such that each dp rank gets similar total tokens""" - attention_mask = batch.batch['attention_mask'] - batch_size = attention_mask.shape[0] - global_seqlen_lst = batch.batch['attention_mask'].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) - world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst, - k_partitions=world_size, - equal_size=True) - # reorder based on index. The data will be automatically equally partitioned by dispatch function - global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) - batch.reorder(global_idx) - global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst, - partitions=global_partition_lst, - prefix=logging_prefix) - metrics.update(global_balance_stats) - def fit(self): """ The training loop of PPO. @@ -766,11 +400,20 @@ def fit(self): # balance the number of valid tokens on each dp rank. # Note that this breaks the order of data inside the batch. # Please take care when you implement group based adv computation such as GRPO and rloo - self._balance_batch(batch, metrics=metrics) + # self._balance_batch(batch, metrics=metrics) # compute global_valid tokens batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + # verify + with _timer(name='verify', text="{name}: {seconds:.1f} seconds"): + scores = self.reward_fn.verify(batch) + metrics['acc'] = statistics.mean(scores) + + # filter the batch. 1/oversample_factor samples will be kept. If there is a filter, prompts passing it will be prioritized. + + batch = self.filter_and_downsample(scores, batch) + # recompute old_log_probs with _timer('old_log_prob', timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) @@ -782,55 +425,31 @@ def fit(self): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) - # compute values - if self.use_critic: - with _timer('values', timing_raw): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - with _timer('adv', timing_raw): - # compute scores. Support both model and function-based. - # We first compute the scores using reward model. Then, we call reward_fn to combine - # the results from reward model and rule-based results. + if self.use_rm: - # we first compute reward model score - reward_tensor = self.rm_wg.compute_rm_score(batch) + update_style = self.config.reward_model.model.update + if update_style == 'none': # only run forward + reward_tensor = self.rm_wg.compute_rm_score(batch) + elif update_style == 'after': # update and directly return the reward + reward_tensor = self.rm_wg.update_rm(batch) + elif update_style == 'before': # update reward model, and then run forward + reward_tensor = self.rm_wg.update_rm(batch) + reward_tensor = self.rm_wg.compute_rm_score(batch) + else: + raise NotImplementedError batch = batch.union(reward_tensor) - # we combine with rule-based rm - reward_tensor = self.reward_fn(batch) - batch.batch['token_level_scores'] = reward_tensor - - # compute rewards. apply_kl_penalty if available - if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False): - batch, kl_metrics = apply_kl_penalty(batch, - kl_ctrl=self.kl_ctrl, - kl_penalty=self.config.algorithm.kl_penalty) - metrics.update(kl_metrics) - else: - batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] - # compute advantages, executed on the driver process batch = compute_advantage(batch, adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n) - - # update critic - if self.use_critic: - with _timer('update_critic', timing_raw): - critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) - metrics.update(critic_output_metrics) - - # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: - # update actor - with _timer('update_actor', timing_raw): - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) - metrics.update(actor_output_metrics) + config=self.config) + + # update actor + with _timer('update_actor', timing_raw): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + metrics.update(actor_output_metrics) # validate if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ @@ -865,3 +484,24 @@ def fit(self): with _timer('save_checkpoint', timing_raw): self._save_checkpoint() return + + def filter_and_downsample(self, scores, batch: DataProto): + """ + downsample the batch according to oversample_factor + samples passing the filters will be prioritized + """ + n_samples = int(self.config.actor_rollout_ref.rollout.n) + reward_matrix = torch.tensor(scores).reshape(-1, n_samples) + + filter_mask = torch.zeros((reward_matrix.shape[0]), dtype=torch.bool) + + if self.config.data.filter_accuracy: + acc_tensor = torch.mean(reward_matrix, dim=-1) + filter_mask[(acc_tensor > self.config.data.accuracy_upper_bound) & + (acc_tensor < self.config.data.accuracy_lower_bound)] = False + + reorder_index = torch.argsort(filter_mask, descending=True) + reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples + 1).unsqueeze(0)).view(-1) + batch = batch.reorder(reorder_index[:int(len(batch) // self.config.data.oversample_factor)]) + + return batch diff --git a/verl/workers/reward_manager/prime.py b/verl/workers/reward_manager/prime.py index f2f8856f..14330c2b 100644 --- a/verl/workers/reward_manager/prime.py +++ b/verl/workers/reward_manager/prime.py @@ -84,23 +84,14 @@ def __init__(self, tokenizer, num_examine, compute_score=None) -> None: self.num_examine = num_examine # the number of batches of decoded responses to print to the console self.compute_score = compute_score or _default_compute_score - def __call__(self, data: DataProto): - """We will expand this function gradually based on the available datasets""" - - # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn - if 'rm_scores' in data.batch.keys(): - return data.batch['rm_scores'] - - reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) - - already_print_data_sources = {} - + def verify(self, data): + """ + verify the batch and save as ``acc`` tensor + """ # batched scoring prompt_ids = data.batch['prompts'] - prompt_length = prompt_ids.shape[-1] response_ids = data.batch['responses'] - valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(dim=-1) sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True) ground_truth = [data_item.non_tensor_batch['reward_model']['ground_truth'] for data_item in data] data_sources = data.non_tensor_batch['data_source'] @@ -119,6 +110,30 @@ def __call__(self, data: DataProto): except Exception as e: print(f"Unexpected error in batched reward computing. Setting all as 0.: {e}") scores = [0. for _ in range(len(sequences_str))] + data.batch['acc'] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device) + return scores + + def __call__(self, data: DataProto): + """We will expand this function gradually based on the available datasets""" + + # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn + if 'rm_scores' in data.batch.keys(): + return data.batch['rm_scores'] + + reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + + already_print_data_sources = {} + + # batched scoring + prompt_ids = data.batch['prompts'] + prompt_length = prompt_ids.shape[-1] + + response_ids = data.batch['responses'] + valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(dim=-1) + sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True) + data_sources = data.non_tensor_batch['data_source'] + + scores = self.verify(data) for i in range(len(data)): data_source = data_sources[i] From 5ccd3a431fdb82e1c4f39d81e1fe754dc729f3ce Mon Sep 17 00:00:00 2001 From: Zefan Wang Date: Tue, 25 Feb 2025 16:23:35 +0800 Subject: [PATCH 3/6] model worker --- recipe/prime/config/prime_trainer.yaml | 3 +- recipe/prime/prime_dp_rm.py | 198 ++++++++++++++++++++++++- recipe/prime/prime_fsdp_workers.py | 137 ++++++++--------- recipe/prime/prime_ray_trainer.py | 39 ++--- verl/trainer/config/ppo_trainer.yaml | 1 + verl/trainer/ppo/ray_trainer.py | 3 +- 6 files changed, 287 insertions(+), 94 deletions(-) diff --git a/recipe/prime/config/prime_trainer.yaml b/recipe/prime/config/prime_trainer.yaml index 468862c6..0394e985 100644 --- a/recipe/prime/config/prime_trainer.yaml +++ b/recipe/prime/config/prime_trainer.yaml @@ -64,4 +64,5 @@ algorithm: trainer: project_name: prime experiment_name: examples - val_before_train: False \ No newline at end of file + val_before_train: False + balance_batch: False \ No newline at end of file diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py index b1697c70..b63c81e5 100644 --- a/recipe/prime/prime_dp_rm.py +++ b/recipe/prime/prime_dp_rm.py @@ -10,4 +10,200 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. + +""" +Implement a multiprocess PPOCritic +""" +import itertools +from typing import Iterable + +import torch +import torch.distributed +from torch import nn, optim + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from verl import DataProto +from verl.trainer.ppo import core_algos +from verl.workers.critic import BasePPOCritic +from verl.utils.py_functional import append_to_dict +from verl.utils.torch_functional import masked_mean +from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad +from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx + +from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis + +__all__ = ['DataParallelPRIMERewardModel'] + + +class DataParallelPRIMERewardModel: + + def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer): + self.config=config + self.reward_module = reward_module + self.ref_module = ref_module + self.reward_optimizer = reward_optimizer + self.use_remove_padding = self.config.model.get('use_remove_padding', False) + print(f'Reward model use_remove_padding={self.use_remove_padding}') + + self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + + def _forward_micro_batch(self, micro_batch): + response_length = micro_batch['responses'].size(-1) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + input_ids = micro_batch['input_ids'] + batch, seqlen = input_ids.shape + attention_mask = micro_batch['attention_mask'] + position_ids = micro_batch['position_ids'] + + if self.use_remove_padding: + input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), + attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices).transpose(0, 1) + + # pad and slice the inputs if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \ + position_ids_rmpad, \ + sp_size=self.ulysses_sequence_parallel_size) + + # only pass input_ids and position_ids to enable flash_attn_varlen + output = self.critic_module(input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + use_cache=False) # prevent model thinks we are generating + values_rmpad = output.logits + values_rmpad = values_rmpad.squeeze(0) # (total_nnz) + + # gather output if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + values_rmpad = gather_outpus_and_unpad(values_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size) + + # pad it back + values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1) + values = values[:, -response_length - 1:-1] + else: + output = self.critic_module(input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False) # prevent model thinks we are generating + values = output.logits + values = values[:, -response_length - 1:-1].squeeze(-1) + return values + + def _optimizer_step(self): + assert self.config.grad_clip is not None + + if isinstance(self.critic_module, FSDP): + grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip) + self.critic_optimizer.step() + return grad_norm + + def compute_values(self, data: DataProto) -> torch.Tensor: + self.critic_module.eval() + micro_batch_size = data.meta_info['micro_batch_size'] + select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] + batch = data.select(batch_keys=select_keys).batch + use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] + + if use_dynamic_bsz: + # split using dynamic bsz + max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size + micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + else: + micro_batches = batch.split(micro_batch_size) + + values_lst = [] + for micro_batch in micro_batches: + with torch.no_grad(): + values = self._forward_micro_batch(micro_batch) + values_lst.append(values) + values = torch.concat(values_lst, dim=0) + responses = data.batch['responses'] + attention_mask = data.batch['attention_mask'] + response_length = responses.size(1) + values = values * attention_mask[:, -response_length - 1:-1] + + if use_dynamic_bsz: + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == values.size(0), f"{len(indices)} vs. {values.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + values = values[revert_indices] + + return values + + def update_critic(self, data: DataProto): + # make sure we are in training mode + self.critic_module.train() + metrics = {} + + select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns'] + batch = data.select(batch_keys=select_keys).batch + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + dataloader = batch.split(self.config.ppo_mini_batch_size) + + for batch_idx, data in enumerate(dataloader): + # split batch into micro_batches + mini_batch = data + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + else: + micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + + self.critic_optimizer.zero_grad() + + for data in micro_batches: + data = data.cuda() # critic device is cpu when using offload + input_ids = data['input_ids'] + responses = data['responses'] + attention_mask = data['attention_mask'] + position_ids = data['position_ids'] + values = data['values'] + returns = data['returns'] + response_length = responses.size(1) + + eos_mask = attention_mask[:, -response_length - 1:-1] + + vpreds = self._forward_micro_batch(data) + + # assert not torch.any(torch.isnan(vpreds)).item() + + vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds, + values=values, + returns=returns, + eos_mask=eos_mask, + cliprange_value=self.config.cliprange_value) + if self.config.use_dynamic_bsz: + # relative to the dynamic bsz + loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size) + else: + loss = vf_loss / self.gradient_accumulation + + loss.backward() + + data = { + 'critic/vf_loss': vf_loss.detach().item(), + 'critic/vf_clipfrac': vf_clipfrac.detach().item(), + 'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(), + } + + append_to_dict(metrics, data) + + grad_norm = self._optimizer_step() + data = {'critic/grad_norm': grad_norm.detach().item()} + append_to_dict(metrics, data) + self.critic_optimizer.zero_grad() + return metrics \ No newline at end of file diff --git a/recipe/prime/prime_fsdp_workers.py b/recipe/prime/prime_fsdp_workers.py index 0ebf5c84..d3a2257f 100644 --- a/recipe/prime/prime_fsdp_workers.py +++ b/recipe/prime/prime_fsdp_workers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy import logging import os import warnings @@ -37,7 +37,7 @@ from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager from codetiming import Timer -from verl.workers.fsdp_workers import create_device_mesh +from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) @@ -74,17 +74,14 @@ def __init__(self, config): self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload # normalize config - self.config.ppo_mini_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size) - if self.config.ppo_micro_batch_size is not None: - self.config.ppo_micro_batch_size //= (torch.distributed.get_world_size() // + self.config.mini_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size) + if self.config.micro_batch_size is not None: + self.config.micro_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size) - self.config.forward_micro_batch_size //= (torch.distributed.get_world_size() // - self.ulysses_sequence_parallel_size) - self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size - self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size - assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 + self.config.micro_batch_size_per_gpu = self.config.micro_batch_size + assert self.config.mini_batch_size % self.config.micro_batch_size_per_gpu == 0 - def _build_critic_model_optimizer(self, config): + def _build_reward_ref_model_optimizer(self, config): # the following line is necessary from verl.utils.model import LambdaLayer, print_model_size, squeeze from verl.utils.torch_dtypes import PrecisionType @@ -92,8 +89,6 @@ def _build_critic_model_optimizer(self, config): from torch import optim local_path = copy_local_path_from_hdfs(config.model.path) - # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info - # using random initialized model from any architecture. May not be the same as Actor. tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path) self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False)) @@ -107,47 +102,47 @@ def _build_critic_model_optimizer(self, config): } override_config_kwargs.update(override_config) if self.rank == 0: - print(f'Critic overriding config {override_config_kwargs}') + print(f'Reward model overriding config {override_config_kwargs}') torch_dtype = self.config.model.fsdp_config.get('model_dtype', 'fp32') torch_dtype = PrecisionType.to_dtype(torch_dtype) - from transformers import AutoConfig, AutoModelForTokenClassification + from transformers import AutoConfig, AutoModelForCausalLM from torch import nn trust_remote_code = False - critic_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) - critic_model_config.num_labels = 1 + reward_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) + reward_model_config.num_labels = 1 use_remove_padding = config.model.get('use_remove_padding', False) if use_remove_padding: from verl.models.registry import check_model_support_rmpad - check_model_support_rmpad(critic_model_config.model_type) + check_model_support_rmpad(reward_model_config.model_type) if use_remove_padding and self.ulysses_sequence_parallel_size > 1: from verl.models.transformers.monkey_patch import apply_monkey_patch - apply_monkey_patch(critic_model_config, verbose=True) + apply_monkey_patch(reward_model_config, verbose=True) init_context = get_init_weight_context_manager() with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") - setattr(critic_model_config, 'classifier_dropout', 0.) - setattr(critic_model_config, 'hidden_dropout', '0') - critic_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path, + setattr(reward_model_config, 'classifier_dropout', 0.) + setattr(reward_model_config, 'hidden_dropout', '0') + reward_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, - config=critic_model_config, + config=reward_model_config, attn_implementation='flash_attention_2', trust_remote_code=trust_remote_code) # some parameters may not in torch_dtype - critic_module.to(torch_dtype) + reward_module.to(torch_dtype) if config.model.get('enable_gradient_checkpointing', False): - critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) + reward_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) if self.rank == 0: - print_model_size(critic_module) + print_model_size(reward_module) - self.critic_model_config = critic_model_config + self.reward_model_config = reward_model_config fsdp_config = self.config.model.fsdp_config mixed_precision_config = fsdp_config.get('mixed_precision', None) @@ -162,7 +157,7 @@ def _build_critic_model_optimizer(self, config): mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) - auto_wrap_policy = get_fsdp_wrap_policy(module=critic_module, config=self.config.model.fsdp_config.wrap_policy) + auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config.wrap_policy) log_gpu_memory_usage('Before critic FSDP', logger=None) @@ -170,7 +165,7 @@ def _build_critic_model_optimizer(self, config): sharding_strategy = get_sharding_strategy(fsdp_mesh) # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation - critic_module = FSDP(critic_module, + reward_module = FSDP(reward_module, param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, @@ -182,9 +177,11 @@ def _build_critic_model_optimizer(self, config): device_mesh=self.device_mesh, cpu_offload=None) - log_gpu_memory_usage('After critic FSDP', logger=None) + log_gpu_memory_usage('After reward FSDP', logger=None) + + ref_module = copy.deepcopy(reward_module) - critic_optimizer = optim.AdamW(critic_module.parameters(), + reward_optimizer = optim.AdamW(reward_module.parameters(), lr=config.optim.lr, betas=config.optim.get('betas', (0.9, 0.999)), weight_decay=config.optim.get('weight_decay', 1e-2)) @@ -196,90 +193,94 @@ def _build_critic_model_optimizer(self, config): print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}') from verl.utils.torch_functional import get_constant_schedule_with_warmup - critic_lr_scheduler = get_constant_schedule_with_warmup(optimizer=critic_optimizer, + reward_lr_scheduler = get_constant_schedule_with_warmup(optimizer=reward_optimizer, num_warmup_steps=num_warmup_steps) - return critic_module, critic_optimizer, critic_lr_scheduler + return reward_module, ref_module, reward_optimizer, reward_lr_scheduler @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get('external_lib', None)) - from verl.workers.critic import DataParallelPPOCritic - self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( - self.config) + from .prime_dp_rm import DataParallelPRIMERewardModel + self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler= self._build_reward_ref_model_optimizer(config=self.config) if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) + offload_fsdp_model_to_cpu(self.reward_module) + offload_fsdp_model_to_cpu(self.ref_module) if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.critic_optimizer) + offload_fsdp_optimizer(optimizer=self.reward_optimizer) - self.critic = DataParallelPPOCritic(config=self.config, - critic_module=self.critic_module, - critic_optimizer=self.critic_optimizer) + self.rm = DataParallelPRIMERewardModel(config=self.config, + reward_module=self.reward_module, + ref_module = self.ref_module, + reward_optimizer=self.reward_optimizer) - self.flops_counter = FlopsCounter(self.critic_model_config) - self.checkpoint_manager = FSDPCheckpointManager(model=self.critic_module, - optimizer=self.critic_optimizer, - lr_scheduler=self.critic_lr_scheduler, + self.flops_counter = FlopsCounter(self.reward_model_config) + self.checkpoint_manager = FSDPCheckpointManager(model=self.reward_module, + optimizer=self.reward_optimizer, + lr_scheduler=self.reward_lr_scheduler, tokenizer=self.tokenizer) torch.cuda.empty_cache() - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def compute_values(self, data: DataProto): + def compute_rm_score(self, data: DataProto): data = data.to('cuda') if self._is_offload_param: - load_fsdp_model_to_gpu(self.critic_module) - micro_batch_size = self.config.forward_micro_batch_size_per_gpu + load_fsdp_model_to_gpu(self.reward_module) + load_fsdp_model_to_gpu(self.ref_module) + micro_batch_size = self.config.micro_batch_size_per_gpu data.meta_info['micro_batch_size'] = micro_batch_size data.meta_info['max_token_len'] = self.config.forward_max_token_len_per_gpu data.meta_info['use_dynamic_bsz'] = self.config.use_dynamic_bsz # perform forward computation with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) - values = self.critic.compute_values(data=data) - output = DataProto.from_dict(tensors={'values': values}) + rm_scores = self.rm.compute_rm_score(data=data) + output = DataProto.from_dict(tensors={'rm_scores': rm_scores}) output = self.ulysses_sharding_manager.postprocess_data(data=output) output = output.to('cpu') if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) + offload_fsdp_model_to_cpu(self.reward_module) + offload_fsdp_model_to_cpu(self.ref_module) return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def update_critic(self, data: DataProto): + def update_rm(self, data: DataProto): data = data.to('cuda') if self._is_offload_param: - load_fsdp_model_to_gpu(self.critic_module) + load_fsdp_model_to_gpu(self.ref_module) + load_fsdp_model_to_gpu(self.reward_module) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.reward_optimizer, device_id=torch.cuda.current_device()) # perform forward computation with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) - with Timer(name='update_critic', logger=None) as timer: - metrics = self.critic.update_critic(data=data) + with Timer(name='update_rm', logger=None) as timer: + rm_scores, metrics = self.rm.update_rm(data=data) delta_time = timer.last global_num_tokens = data.meta_info['global_token_num'] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics['mfu/critic'] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + metrics['mfu/reward'] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size - self.critic_lr_scheduler.step() - lr = self.critic_lr_scheduler.get_last_lr()[0] - metrics['critic/lr'] = lr + self.reward_lr_scheduler.step() + lr = self.reward_lr_scheduler.get_last_lr()[0] + metrics['rm/lr'] = lr - output = DataProto(batch=None, meta_info={'metrics': metrics}) + output = DataProto.from_dict(tensors={'rm_scores':rm_scores}, meta_info={'metrics': metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) + offload_fsdp_model_to_cpu(self.reward_module) + offload_fsdp_model_to_cpu(self.ref_module) if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.critic_optimizer) + offload_fsdp_optimizer(optimizer=self.reward_optimizer) torch.cuda.empty_cache() output = output.to('cpu') return output @@ -288,7 +289,7 @@ def update_critic(self, data: DataProto): def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False): import torch if self._is_offload_param: - load_fsdp_model_to_gpu(self.critic_module) + load_fsdp_model_to_gpu(self.reward_module) self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, @@ -297,16 +298,16 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_prev torch.distributed.barrier() if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) + offload_fsdp_model_to_cpu(self.reward_module) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, path, del_local_after_load=True): import torch if self._is_offload_param: - load_fsdp_model_to_gpu(self.critic_module) + load_fsdp_model_to_gpu(self.reward_module) self.checkpoint_manager.load_checkpoint(path=path, del_local_after_load=del_local_after_load) torch.distributed.barrier() if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) + offload_fsdp_model_to_cpu(self.reward_module) diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py index 2ae171b4..801e974f 100644 --- a/recipe/prime/prime_ray_trainer.py +++ b/recipe/prime/prime_ray_trainer.py @@ -19,31 +19,21 @@ import os import statistics import uuid -from contextlib import contextmanager -from dataclasses import dataclass, field -from enum import Enum -from pprint import pprint -from typing import Type, Dict from copy import deepcopy +from pprint import pprint import numpy as np -from codetiming import Timer +import torch from omegaconf import OmegaConf, open_dict + from verl import DataProto -from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto -from verl.single_controller.base import Worker -from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs -from verl.single_controller.ray.base import create_colocated_worker_cls -from verl.trainer.ppo import core_algos -from . import prime_core_algos -from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.single_controller.ray import RayWorkerGroup +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.trainer.ppo.ray_trainer import Role, WorkerType, ResourcePoolManager, reduce_metrics, _compute_response_info, \ + _timer from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn -from verl.trainer.ppo.ray_trainer import Role, WorkerType, ResourcePoolManager, apply_kl_penalty, reduce_metrics, _compute_response_info, _timer -from verl.trainer.ppo.ray_trainer import RayPPOTrainer - -import torch -from verl.utils.torch_functional import masked_mean +from . import prime_core_algos def compute_advantage(data: DataProto, adv_estimator, config): @@ -430,15 +420,18 @@ def fit(self): if self.use_rm: update_style = self.config.reward_model.model.update if update_style == 'none': # only run forward - reward_tensor = self.rm_wg.compute_rm_score(batch) + reward_output = self.rm_wg.compute_rm_score(batch) elif update_style == 'after': # update and directly return the reward - reward_tensor = self.rm_wg.update_rm(batch) + reward_output = self.rm_wg.update_rm(batch) elif update_style == 'before': # update reward model, and then run forward - reward_tensor = self.rm_wg.update_rm(batch) - reward_tensor = self.rm_wg.compute_rm_score(batch) + reward_output = self.rm_wg.update_rm(batch) + reward_output = self.rm_wg.compute_rm_score(batch) else: raise NotImplementedError - batch = batch.union(reward_tensor) + batch = batch.union(reward_output) + if 'metrics' in reward_output.meta_info['metrics']: + reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics']) + metrics.update(reward_output_metrics) # compute advantages, executed on the driver process batch = compute_advantage(batch, diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index bf919089..1f4c9d3f 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -153,6 +153,7 @@ algorithm: kl_coef: 0.001 trainer: + balance_batch: True total_epochs: 30 total_training_steps: null project_name: verl_examples diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 80412d3e..0dd87a3f 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -894,7 +894,8 @@ def fit(self): # balance the number of valid tokens on each dp rank. # Note that this breaks the order of data inside the batch. # Please take care when you implement group based adv computation such as GRPO and rloo - self._balance_batch(batch, metrics=metrics) + if self.config.balance_batch: + self._balance_batch(batch, metrics=metrics) # compute global_valid tokens batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() From 15510e8daba084b41026b851d4406dccf54ff561 Mon Sep 17 00:00:00 2001 From: Zefan Wang Date: Tue, 25 Feb 2025 18:30:58 +0800 Subject: [PATCH 4/6] dp reward model --- recipe/prime/config/prime_trainer.yaml | 1 + recipe/prime/prime_core_algos.py | 31 +++ recipe/prime/prime_dp_rm.py | 283 +++++++++++++++++-------- recipe/prime/prime_fsdp_workers.py | 45 ++-- recipe/prime/prime_ray_trainer.py | 10 +- verl/trainer/config/ppo_trainer.yaml | 3 +- 6 files changed, 265 insertions(+), 108 deletions(-) diff --git a/recipe/prime/config/prime_trainer.yaml b/recipe/prime/config/prime_trainer.yaml index 0394e985..157ff10d 100644 --- a/recipe/prime/config/prime_trainer.yaml +++ b/recipe/prime/config/prime_trainer.yaml @@ -29,6 +29,7 @@ reward_model: strategy: fsdp model: use_remove_padding: True + tokenizer_path: ${actor_rollout_ref.model.path} enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing} ref_type: freeze fsdp_config: diff --git a/recipe/prime/prime_core_algos.py b/recipe/prime/prime_core_algos.py index 19393e35..c27ae259 100644 --- a/recipe/prime/prime_core_algos.py +++ b/recipe/prime/prime_core_algos.py @@ -74,3 +74,34 @@ def masked_rloo(reward_tensor_original, mask_tensor): advantages = verl_F.masked_whiten(advantages, eos_mask) return advantages, returns + + +def compute_ce_dpo_loss_rm(token_level_scores, acc, eos_mask, beta): + cur_scores = ((token_level_scores * eos_mask).sum(dim=1) * beta).sigmoid() + cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc) + return cur_dpo_loss + + +def compute_dpo_accuracy(token_level_scores, acc, eos_mask, n_samples): + dpo_acc = [] + for start_id in range(0, token_level_scores.shape[0], n_samples): + cur_scores = (token_level_scores[start_id:start_id + n_samples] * + eos_mask[start_id:start_id + n_samples]).sum(dim=1) + + def get_upper_triangle(tensor_x): + diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0) + upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1) + return diff_matrix[upper_tri_indices] + + cur_acc_diff = get_upper_triangle(acc[start_id:start_id + n_samples]) # in range [-1,1] + cur_score_diff = get_upper_triangle(cur_scores) # in R + cur_score_prediction = (cur_score_diff > 0).float() # in [0,1] + if cur_acc_diff.abs().sum() == 0: + cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5 + else: + cur_acc = (((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * + cur_acc_diff.abs()).sum() / cur_acc_diff.abs().sum() + + dpo_acc.append(cur_acc.unsqueeze(0)) + + return torch.cat(dpo_acc, dim=0).mean() diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py index b63c81e5..1e0c499c 100644 --- a/recipe/prime/prime_dp_rm.py +++ b/recipe/prime/prime_dp_rm.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Implement a multiprocess PPOCritic """ @@ -24,6 +23,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from .prime_core_algos import compute_ce_dpo_loss_rm from verl import DataProto from verl.trainer.ppo import core_algos from verl.workers.critic import BasePPOCritic @@ -31,6 +31,7 @@ from verl.utils.torch_functional import masked_mean from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx +import verl.utils.torch_functional as verl_F from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis @@ -40,7 +41,7 @@ class DataParallelPRIMERewardModel: def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer): - self.config=config + self.config = config self.reward_module = reward_module self.ref_module = ref_module self.reward_optimizer = reward_optimizer @@ -49,72 +50,165 @@ def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, rewa self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) - def _forward_micro_batch(self, micro_batch): - response_length = micro_batch['responses'].size(-1) - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): - input_ids = micro_batch['input_ids'] - batch, seqlen = input_ids.shape - attention_mask = micro_batch['attention_mask'] - position_ids = micro_batch['position_ids'] - - if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), - attention_mask) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), - indices).transpose(0, 1) - - # pad and slice the inputs if sp > 1 - if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \ - position_ids_rmpad, \ - sp_size=self.ulysses_sequence_parallel_size) - - # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.critic_module(input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - use_cache=False) # prevent model thinks we are generating - values_rmpad = output.logits - values_rmpad = values_rmpad.squeeze(0) # (total_nnz) - - # gather output if sp > 1 - if self.ulysses_sequence_parallel_size > 1: - values_rmpad = gather_outpus_and_unpad(values_rmpad, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size) - - # pad it back - values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1) - values = values[:, -response_length - 1:-1] + def _forward_micro_batch(self, micro_batch, prompt_length): + from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange + from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad + + input_ids = micro_batch['input_ids'] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch['attention_mask'] + position_ids = micro_batch['position_ids'] + + if self.use_remove_padding: + input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), + attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices).transpose(0, 1) + + # for compute the log_prob + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + + # pad and slice the inputs if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None, + self.ulysses_sequence_parallel_size) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) + rm_output_logits = self.reward_module(input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + use_cache=False).logits.squeeze( + 0) # copied. I don't really know why there is a squeeze + rm_log_labels = verl_F.logprobs_from_logits(logits=rm_output_logits, labels=input_ids_rmpad_rolled) + if self.ulysses_sequence_parallel_size > 1: + rm_log_labels = gather_outpus_and_unpad(rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size) + rm_log_labels = pad_input(hidden_states=rm_log_labels.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen).squeeze(-1) + + else: + rm_output_logits = self.reward_module(input_ids=micro_batch['input_ids'], + attention_mask=micro_batch['attention_mask'], + position_ids=micro_batch['position_ids']).logits + rm_log_prob = torch.nn.functional.log_softmax(rm_output_logits[:, :-1, :], + dim=-1) # (batch_size, seq_length, vocab_size) + rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch['input_ids'][:, 1:].unsqueeze(-1)).squeeze( + -1) # (batch, seq_length) + + if self.ref_module is not None: + # 不用重复remove pad,只用做好re-pad即可 + with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): + if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding: + ref_output_logits = self.ref_module(input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + use_cache=False).logits.squeeze(0) + ref_log_labels = verl_F.logprobs_from_logits(logits=ref_output_logits, + labels=input_ids_rmpad_rolled) + ref_log_labels = gather_outpus_and_unpad(ref_log_labels, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size) + ref_log_labels = pad_input(hidden_states=ref_log_labels.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen).squeeze(-1) + else: + ref_output_logits = self.ref_module(input_ids=micro_batch['input_ids'], + attention_mask=micro_batch['attention_mask'], + position_ids=micro_batch['position_ids']).logits + ref_log_prob = torch.nn.functional.log_softmax(ref_output_logits[:, :-1, :], + dim=-1) # (batch_size, seq_length, vocab_size) + ref_log_labels = ref_log_prob.gather(dim=-1, + index=micro_batch['input_ids'][:, 1:].unsqueeze(-1)).squeeze( + -1) # (batch, seq_length) + else: + ref_log_labels = micro_batch['old_log_probs'] + + num_actions = micro_batch['input_ids'].shape[-1] - prompt_length + max_positions = micro_batch['attention_mask'][:, prompt_length:].sum(-1) + + ref_log_labels.to(rm_log_labels.dtype) + q = rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:] # this is actually diff of q + + # reward computation does not need gradient. only q needs + with torch.no_grad(): + + # generalized estimation of r should go before the reward filling. r means process reward for policy model, or the advantage of reward model. + lam = self.config.get('lambda', 0.) + beta = self.config.model.get('beta_train', 0.05) + if lam == 0.: + r = q * beta else: - output = self.critic_module(input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=False) # prevent model thinks we are generating - values = output.logits - values = values[:, -response_length - 1:-1].squeeze(-1) - return values + # reward coefficient takes no effect here + acc = micro_batch['acc'] + q_ = q * beta + r = torch.zeros_like(q) + # TODO: 参考implicit value model在此处的处理方式,应该是靠直接修改max_positions[0]-1位置的q为r-Q_{t-1},后面的r全部抹0 + lastgaelam = 0 + # change the last token and mask out all paddings to make this process easier + for i in range(q.shape[0]): + if self.config.prime_use_gt: + q_[i, max_positions[i] - 1] = acc[i] - q_[i, :max_positions[i] - 1].sum() + q_[i, max_positions[i]:] = 0 + + for t in reversed(range(num_actions)): + delta = q_[:, t] + lastgaelam = delta + lam * lastgaelam + r[:, t] = lastgaelam + + step_ends = [] + + if self.config.prime_granularity == 'token': + for i in range(micro_batch['input_ids'].shape[0]): + step_ends.append(list(range(max_positions[i]))) + elif self.config.prime_granularity == 'whole': + for i in range(micro_batch['input_ids'].shape[0]): + step_ends.append([max_positions[i] - 1]) + else: + raise NotImplementedError + + token_level_score = torch.zeros_like(q) + + for i, step_end in enumerate(step_ends): + for j in range(len(step_end)): + step_range = [ + min(step_end[j - 1] + 1, num_actions - 1) if j > 0 else 0, + min(num_actions - 1, step_end[j]) + ] + token_level_score[i, step_range[1]] = r[i, step_range[0]:step_range[1] + 1].sum() + + return token_level_score, q def _optimizer_step(self): assert self.config.grad_clip is not None - if isinstance(self.critic_module, FSDP): - grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip) + if isinstance(self.reward_module, FSDP): + grad_norm = self.reward_module.clip_grad_norm_(self.config.grad_clip) else: - grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip) - self.critic_optimizer.step() + grad_norm = torch.nn.utils.clip_grad_norm_(self.reward_module.parameters(), max_norm=self.config.grad_clip) + self.reward_optimizer.step() return grad_norm - def compute_values(self, data: DataProto) -> torch.Tensor: - self.critic_module.eval() + def prime_norm(self, token_level_scores): + if self.config.prime_norm == 'batch_norm': + reverse_cumsum = torch.cumsum(token_level_scores.flip(dims=[1]), dim=-1).flip(dims=[1]) + token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6) + return token_level_scores + + def compute_rm_score(self, data: DataProto): + self.reward_module.eval() + self.ref_module.eval() micro_batch_size = data.meta_info['micro_batch_size'] select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] batch = data.select(batch_keys=select_keys).batch use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] + prompt_length = data.batch['prompt_ids'].shape[-1] if use_dynamic_bsz: # split using dynamic bsz @@ -123,36 +217,38 @@ def compute_values(self, data: DataProto) -> torch.Tensor: else: micro_batches = batch.split(micro_batch_size) - values_lst = [] + rm_scores_lst = [] for micro_batch in micro_batches: with torch.no_grad(): - values = self._forward_micro_batch(micro_batch) - values_lst.append(values) - values = torch.concat(values_lst, dim=0) - responses = data.batch['responses'] - attention_mask = data.batch['attention_mask'] - response_length = responses.size(1) - values = values * attention_mask[:, -response_length - 1:-1] + rm_score, q = self._forward_micro_batch(micro_batch, prompt_length) + rm_scores_lst.append(rm_score) + rm_scores = torch.concat(rm_scores_lst, dim=0) + + rm_scores = self.prime_norm(rm_scores) if use_dynamic_bsz: indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == values.size(0), f"{len(indices)} vs. {values.size()}" + assert len(indices) == rm_scores.size(0), f"{len(indices)} vs. {rm_scores.size()}" revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - values = values[revert_indices] + rm_scores = rm_scores[revert_indices] - return values + return rm_scores, {} - def update_critic(self, data: DataProto): + def update_rm(self, data: DataProto): # make sure we are in training mode - self.critic_module.train() + self.reward_module.train() metrics = {} + beta = self.config.model.get('beta_train', 0.05) + select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns'] batch = data.select(batch_keys=select_keys).batch # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 dataloader = batch.split(self.config.ppo_mini_batch_size) + rm_scores_lst = [] + for batch_idx, data in enumerate(dataloader): # split batch into micro_batches mini_batch = data @@ -163,10 +259,10 @@ def update_critic(self, data: DataProto): micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - self.critic_optimizer.zero_grad() + self.reward_optimizer.zero_grad() for data in micro_batches: - data = data.cuda() # critic device is cpu when using offload + data = data.cuda() input_ids = data['input_ids'] responses = data['responses'] attention_mask = data['attention_mask'] @@ -174,36 +270,43 @@ def update_critic(self, data: DataProto): values = data['values'] returns = data['returns'] response_length = responses.size(1) + acc = data['acc'] + + prompt_ids = data.batch['prompts'] + prompt_length = prompt_ids.shape[-1] - eos_mask = attention_mask[:, -response_length - 1:-1] + eos_mask = attention_mask[:, prompt_length:] - vpreds = self._forward_micro_batch(data) + rm_score, q = self._forward_micro_batch(data, response_length) + + rm_scores_lst.append(rm_score) + + if self.config.loss_type == 'ce': + dpo_loss = compute_ce_dpo_loss_rm(q, acc, eos_mask=eos_mask, beta=beta) + else: + raise NotImplementedError - # assert not torch.any(torch.isnan(vpreds)).item() + data = {'reward_model/dpo_loss': dpo_loss.detach().item()} - vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds, - values=values, - returns=returns, - eos_mask=eos_mask, - cliprange_value=self.config.cliprange_value) if self.config.use_dynamic_bsz: # relative to the dynamic bsz - loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size) + loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size) else: - loss = vf_loss / self.gradient_accumulation + loss = dpo_loss / self.gradient_accumulation loss.backward() - data = { - 'critic/vf_loss': vf_loss.detach().item(), - 'critic/vf_clipfrac': vf_clipfrac.detach().item(), - 'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(), - } - append_to_dict(metrics, data) + loss.backward() + grad_norm = self._optimizer_step() - data = {'critic/grad_norm': grad_norm.detach().item()} + data = {'reward_model/grad_norm': grad_norm.detach().item()} append_to_dict(metrics, data) - self.critic_optimizer.zero_grad() - return metrics \ No newline at end of file + self.reward_optimizer.zero_grad() + + rm_scores = torch.cat(rm_scores_lst, dim=0) + + rm_scores = self.prime_norm(rm_scores) + + return rm_scores, metrics diff --git a/recipe/prime/prime_fsdp_workers.py b/recipe/prime/prime_fsdp_workers.py index d3a2257f..101670cd 100644 --- a/recipe/prime/prime_fsdp_workers.py +++ b/recipe/prime/prime_fsdp_workers.py @@ -38,6 +38,7 @@ from codetiming import Timer from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy +from .prime_core_algos import compute_dpo_accuracy logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) @@ -76,8 +77,7 @@ def __init__(self, config): # normalize config self.config.mini_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size) if self.config.micro_batch_size is not None: - self.config.micro_batch_size //= (torch.distributed.get_world_size() // - self.ulysses_sequence_parallel_size) + self.config.micro_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size) self.config.micro_batch_size_per_gpu = self.config.micro_batch_size assert self.config.mini_batch_size % self.config.micro_batch_size_per_gpu == 0 @@ -129,10 +129,10 @@ def _build_reward_ref_model_optimizer(self, config): setattr(reward_model_config, 'classifier_dropout', 0.) setattr(reward_model_config, 'hidden_dropout', '0') reward_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, - config=reward_model_config, - attn_implementation='flash_attention_2', - trust_remote_code=trust_remote_code) + torch_dtype=torch_dtype, + config=reward_model_config, + attn_implementation='flash_attention_2', + trust_remote_code=trust_remote_code) # some parameters may not in torch_dtype reward_module.to(torch_dtype) @@ -204,7 +204,8 @@ def init_model(self): import_external_libs(self.config.model.get('external_lib', None)) from .prime_dp_rm import DataParallelPRIMERewardModel - self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler= self._build_reward_ref_model_optimizer(config=self.config) + self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = self._build_reward_ref_model_optimizer( + config=self.config) if self._is_offload_param: offload_fsdp_model_to_cpu(self.reward_module) @@ -213,9 +214,9 @@ def init_model(self): offload_fsdp_optimizer(optimizer=self.reward_optimizer) self.rm = DataParallelPRIMERewardModel(config=self.config, - reward_module=self.reward_module, - ref_module = self.ref_module, - reward_optimizer=self.reward_optimizer) + reward_module=self.reward_module, + ref_module=self.ref_module, + reward_optimizer=self.reward_optimizer) self.flops_counter = FlopsCounter(self.reward_model_config) self.checkpoint_manager = FSDPCheckpointManager(model=self.reward_module, @@ -224,6 +225,7 @@ def init_model(self): tokenizer=self.tokenizer) torch.cuda.empty_cache() + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_rm_score(self, data: DataProto): data = data.to('cuda') @@ -238,8 +240,17 @@ def compute_rm_score(self, data: DataProto): # perform forward computation with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) - rm_scores = self.rm.compute_rm_score(data=data) - output = DataProto.from_dict(tensors={'rm_scores': rm_scores}) + rm_scores, metrics = self.rm.compute_rm_score(data=data) + + prompt_length = data.batch['prompts'].shape[-1] + eos_mask = data.batch['attention_mask'][:, prompt_length:] + acc = data.batch['acc'] + + dpo_acc = compute_dpo_accuracy(rm_scores, acc, eos_mask=eos_mask, n_samples=data.meta_info['n']) + + metrics['reward_model/dpo_acc'] = dpo_acc.detach().item() + + output = DataProto.from_dict(tensors={'rm_scores': rm_scores}, meta_info={'metrics': metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) output = output.to('cpu') @@ -273,7 +284,15 @@ def update_rm(self, data: DataProto): lr = self.reward_lr_scheduler.get_last_lr()[0] metrics['rm/lr'] = lr - output = DataProto.from_dict(tensors={'rm_scores':rm_scores}, meta_info={'metrics': metrics}) + prompt_length = data.batch['prompts'].shape[-1] + eos_mask = data.batch['attention_mask'][:, prompt_length:] + acc = data.batch['acc'] + + dpo_acc_before = compute_dpo_accuracy(rm_scores, acc, eos_mask=eos_mask, n_samples=data.meta_info['n']) + + metrics['reward_model/dpo_acc_before'] = dpo_acc_before.detach().item() + + output = DataProto.from_dict(tensors={'rm_scores': rm_scores}, meta_info={'metrics': metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output) if self._is_offload_param: diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py index 801e974f..7ffa9fe8 100644 --- a/recipe/prime/prime_ray_trainer.py +++ b/recipe/prime/prime_ray_trainer.py @@ -169,11 +169,8 @@ def __init__(self, self.use_critic = False - self._validate_config() - self._create_dataloader() - def _validate_config(self): - super()._validate() + super()._validate_config() # TODO: Additional config checks can be added here config = self.config @@ -403,6 +400,7 @@ def fit(self): # filter the batch. 1/oversample_factor samples will be kept. If there is a filter, prompts passing it will be prioritized. batch = self.filter_and_downsample(scores, batch) + batch.meta_info['n'] = self.config.actor_rollout_ref.rollout.n # recompute old_log_probs with _timer('old_log_prob', timing_raw): @@ -425,6 +423,10 @@ def fit(self): reward_output = self.rm_wg.update_rm(batch) elif update_style == 'before': # update reward model, and then run forward reward_output = self.rm_wg.update_rm(batch) + if 'metrics' in reward_output.meta_info['metrics']: + reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics']) + metrics.update(reward_output_metrics) + reward_output = self.rm_wg.compute_rm_score(batch) else: raise NotImplementedError diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 1f4c9d3f..7b4cd678 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -132,7 +132,8 @@ reward_model: external_lib: ${actor_rollout_ref.model.external_lib} use_remove_padding: False fsdp_config: - min_num_params: 0 + wrap_policy: + min_num_params: 0 param_offload: False fsdp_size: -1 micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu From b5178d2abf7f0b480bf321f1b8c84693be209736 Mon Sep 17 00:00:00 2001 From: Zefan Wang Date: Wed, 26 Feb 2025 00:32:47 +0800 Subject: [PATCH 5/6] passed basic running test --- recipe/prime/config/prime_trainer.yaml | 2 +- recipe/prime/prime_core_algos.py | 8 +++--- recipe/prime/prime_dp_rm.py | 32 +++++++++-------------- recipe/prime/prime_fsdp_workers.py | 35 +++++++++++++++----------- recipe/prime/prime_ray_trainer.py | 7 +++--- 5 files changed, 41 insertions(+), 43 deletions(-) diff --git a/recipe/prime/config/prime_trainer.yaml b/recipe/prime/config/prime_trainer.yaml index 157ff10d..5a053d92 100644 --- a/recipe/prime/config/prime_trainer.yaml +++ b/recipe/prime/config/prime_trainer.yaml @@ -45,7 +45,7 @@ reward_model: warmup_style: constant total_training_steps: -1 # must be overridden by program weight_decay: 0. - grad_clip: 1.0 + grad_clip: 10.0 beta_train: 0.05 loss_type: ce # currently only supports ce loss prime_granularity: token diff --git a/recipe/prime/prime_core_algos.py b/recipe/prime/prime_core_algos.py index c27ae259..e2854704 100644 --- a/recipe/prime/prime_core_algos.py +++ b/recipe/prime/prime_core_algos.py @@ -43,13 +43,13 @@ def masked_rloo(reward_tensor_original, mask_tensor): with torch.no_grad(): - if 'rm_scores' in data.batch and config.algorithm.dpo_coef != 0.: + if 'rm_scores' in data.batch and config.algorithm.reward_dpo_coef != 0.: reward_tensor = data.batch['rm_scores'] reward_mask = eos_mask.bool() - reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.dpo_coef) + reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef) - if 'acc' in data.batch and config.algorithm.gt_coef != 0.: + if 'acc' in data.batch and config.algorithm.reward_gt_coef != 0.: reward_tensor = torch.zeros_like(eos_mask, dtype=torch.float32) reward_mask = torch.zeros_like(eos_mask, dtype=torch.bool) @@ -64,7 +64,7 @@ def masked_rloo(reward_tensor_original, mask_tensor): torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device), valid_response_length - 1] = data.batch['acc'] - reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.gt_coef) + reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef) final_reward_tensor = sum(reward_tensors) diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py index 1e0c499c..7b2913d0 100644 --- a/recipe/prime/prime_dp_rm.py +++ b/recipe/prime/prime_dp_rm.py @@ -186,12 +186,12 @@ def _forward_micro_batch(self, micro_batch, prompt_length): return token_level_score, q def _optimizer_step(self): - assert self.config.grad_clip is not None + assert self.config.model.optim.grad_clip is not None if isinstance(self.reward_module, FSDP): - grad_norm = self.reward_module.clip_grad_norm_(self.config.grad_clip) + grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip) else: - grad_norm = torch.nn.utils.clip_grad_norm_(self.reward_module.parameters(), max_norm=self.config.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_(self.reward_module.parameters(), max_norm=self.config.model.optim.grad_clip) self.reward_optimizer.step() return grad_norm @@ -205,10 +205,10 @@ def compute_rm_score(self, data: DataProto): self.reward_module.eval() self.ref_module.eval() micro_batch_size = data.meta_info['micro_batch_size'] - select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] + select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'acc'] batch = data.select(batch_keys=select_keys).batch use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] - prompt_length = data.batch['prompt_ids'].shape[-1] + prompt_length = data.batch['input_ids'].shape[-1] - data.batch['responses'].shape[-1] if use_dynamic_bsz: # split using dynamic bsz @@ -241,11 +241,11 @@ def update_rm(self, data: DataProto): beta = self.config.model.get('beta_train', 0.05) - select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns'] + select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'acc', 'prompts'] batch = data.select(batch_keys=select_keys).batch # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - dataloader = batch.split(self.config.ppo_mini_batch_size) + dataloader = batch.split(self.config.mini_batch_size) rm_scores_lst = [] @@ -256,32 +256,26 @@ def update_rm(self, data: DataProto): max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) else: - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu) + self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu self.reward_optimizer.zero_grad() for data in micro_batches: data = data.cuda() - input_ids = data['input_ids'] - responses = data['responses'] attention_mask = data['attention_mask'] - position_ids = data['position_ids'] - values = data['values'] - returns = data['returns'] - response_length = responses.size(1) acc = data['acc'] - prompt_ids = data.batch['prompts'] + prompt_ids = data['prompts'] prompt_length = prompt_ids.shape[-1] eos_mask = attention_mask[:, prompt_length:] - rm_score, q = self._forward_micro_batch(data, response_length) + rm_score, q = self._forward_micro_batch(data, prompt_length) rm_scores_lst.append(rm_score) - if self.config.loss_type == 'ce': + if self.config.model.loss_type == 'ce': dpo_loss = compute_ce_dpo_loss_rm(q, acc, eos_mask=eos_mask, beta=beta) else: raise NotImplementedError @@ -298,8 +292,6 @@ def update_rm(self, data: DataProto): append_to_dict(metrics, data) - loss.backward() - grad_norm = self._optimizer_step() data = {'reward_model/grad_norm': grad_norm.detach().item()} append_to_dict(metrics, data) diff --git a/recipe/prime/prime_fsdp_workers.py b/recipe/prime/prime_fsdp_workers.py index 101670cd..263d619b 100644 --- a/recipe/prime/prime_fsdp_workers.py +++ b/recipe/prime/prime_fsdp_workers.py @@ -159,12 +159,13 @@ def _build_reward_ref_model_optimizer(self, config): auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config.wrap_policy) - log_gpu_memory_usage('Before critic FSDP', logger=None) + log_gpu_memory_usage('Before reward model FSDP', logger=None) fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) - # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation + ref_module = copy.deepcopy(reward_module) + reward_module = FSDP(reward_module, param_init_fn=init_fn, use_orig_params=False, @@ -179,15 +180,25 @@ def _build_reward_ref_model_optimizer(self, config): log_gpu_memory_usage('After reward FSDP', logger=None) - ref_module = copy.deepcopy(reward_module) + ref_module = FSDP(ref_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=False, + device_mesh=self.device_mesh, + cpu_offload=None) reward_optimizer = optim.AdamW(reward_module.parameters(), - lr=config.optim.lr, - betas=config.optim.get('betas', (0.9, 0.999)), - weight_decay=config.optim.get('weight_decay', 1e-2)) + lr=config.model.optim.lr, + betas=config.model.optim.get('betas', (0.9, 0.999)), + weight_decay=config.model.optim.get('weight_decay', 1e-2)) - total_steps = config.optim.get('total_training_steps', 0) - num_warmup_steps_ratio = config.optim.get('lr_warmup_steps_ratio', 0.) + total_steps = config.model.optim.get('total_training_steps', 0) + num_warmup_steps_ratio = config.model.optim.get('lr_warmup_steps_ratio', 0.) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}') @@ -272,13 +283,7 @@ def update_rm(self, data: DataProto): with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) - with Timer(name='update_rm', logger=None) as timer: - rm_scores, metrics = self.rm.update_rm(data=data) - delta_time = timer.last - - global_num_tokens = data.meta_info['global_token_num'] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics['mfu/reward'] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + rm_scores, metrics = self.rm.update_rm(data=data) self.reward_lr_scheduler.step() lr = self.reward_lr_scheduler.get_last_lr()[0] diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py index 7ffa9fe8..235e7f4d 100644 --- a/recipe/prime/prime_ray_trainer.py +++ b/recipe/prime/prime_ray_trainer.py @@ -393,7 +393,7 @@ def fit(self): batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() # verify - with _timer(name='verify', text="{name}: {seconds:.1f} seconds"): + with _timer('verify', timing_raw): scores = self.reward_fn.verify(batch) metrics['acc'] = statistics.mean(scores) @@ -416,7 +416,7 @@ def fit(self): with _timer('adv', timing_raw): if self.use_rm: - update_style = self.config.reward_model.model.update + update_style = self.config.reward_model.model.get('update', 'none') if update_style == 'none': # only run forward reward_output = self.rm_wg.compute_rm_score(batch) elif update_style == 'after': # update and directly return the reward @@ -497,6 +497,7 @@ def filter_and_downsample(self, scores, batch: DataProto): reorder_index = torch.argsort(filter_mask, descending=True) reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples + 1).unsqueeze(0)).view(-1) - batch = batch.reorder(reorder_index[:int(len(batch) // self.config.data.oversample_factor)]) + batch.reorder(reorder_index[:int(len(batch) // + self.config.data.oversample_factor)]) # this operation is inplace return batch From d169794789a68ba0edc89b0d5290c60328303b32 Mon Sep 17 00:00:00 2001 From: Zefan Wang Date: Wed, 26 Feb 2025 00:33:25 +0800 Subject: [PATCH 6/6] formatting --- recipe/prime/prime_dp_rm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py index 7b2913d0..72bb6168 100644 --- a/recipe/prime/prime_dp_rm.py +++ b/recipe/prime/prime_dp_rm.py @@ -191,7 +191,8 @@ def _optimizer_step(self): if isinstance(self.reward_module, FSDP): grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip) else: - grad_norm = torch.nn.utils.clip_grad_norm_(self.reward_module.parameters(), max_norm=self.config.model.optim.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_(self.reward_module.parameters(), + max_norm=self.config.model.optim.grad_clip) self.reward_optimizer.step() return grad_norm