diff --git a/recipe/prime/__init__.py b/recipe/prime/__init__.py new file mode 100644 index 00000000..b1697c70 --- /dev/null +++ b/recipe/prime/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/recipe/prime/config/prime_trainer.yaml b/recipe/prime/config/prime_trainer.yaml new file mode 100644 index 00000000..5a053d92 --- /dev/null +++ b/recipe/prime/config/prime_trainer.yaml @@ -0,0 +1,69 @@ +# the prime config will override default ppo_trainer.yaml + +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + filter_accuracy: True + accuracy_lower_bound: 0.2 + accuracy_upper_bound: 0.8 + oversample_factor: 4.0 # Sample more responses than the batch size. prompts satisfying the filter will be prioritized. + +actor_rollout_ref: + hybrid_engine: True + model: + use_remove_padding: True + rollout: + # number of responses (i.e. num sample times) + n: 4 + actor: + entropy_coeff: 0.001 + +reward_model: + enable: True + strategy: fsdp + model: + use_remove_padding: True + tokenizer_path: ${actor_rollout_ref.model.path} + enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing} + ref_type: freeze + fsdp_config: + min_num_params: 0 + param_offload: ${actor_rollout_ref.actor.fsdp_config.param_offload} +# grad_offload: ${actor_rollout_ref.actor.fsdp_config.grad_offload} + optimizer_offload: ${actor_rollout_ref.actor.fsdp_config.optimizer_offload} + update: before # ``before`` for double-forward, ``after`` for single-forward + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. + min_lr_ratio: null + warmup_style: constant + total_training_steps: -1 # must be overridden by program + weight_decay: 0. + grad_clip: 10.0 + beta_train: 0.05 + loss_type: ce # currently only supports ce loss + prime_granularity: token + prime_norm: batch_norm # batch_norm or none. if set to none, the normalizer is beta_train + mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + reward_manager: prime + +algorithm: + adv_estimator: rloo + # now supports rloo. it treats different source of reward separately. + kl_ctrl: + type: fixed + kl_coef: 0.000 + reward_gt_coef: 5 + reward_dpo_coef: 5 + +trainer: + project_name: prime + experiment_name: examples + val_before_train: False + balance_batch: False \ No newline at end of file diff --git a/recipe/prime/main_prime.py b/recipe/prime/main_prime.py new file mode 100644 index 00000000..b15e3873 --- /dev/null +++ b/recipe/prime/main_prime.py @@ -0,0 +1,130 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" +from .prime_ray_trainer import RayPRIMETrainer + +import ray +import hydra + + +@hydra.main(config_path='config', config_name='prime_trainer', version_base=None) +def main(config): + run_prime(config) + + +def run_prime(config, compute_score=None): + if not ray.is_initialized(): + # this is for local ray cluster + ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}}) + + ray.get(main_task.remote(config, compute_score)) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +def main_task(config, compute_score=None): + from verl.utils.fs import copy_local_path_from_hdfs + # print initial config + from pprint import pprint + from omegaconf import OmegaConf + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + from verl.utils import hf_tokenizer + tokenizer = hf_tokenizer(local_path) + + # define worker classes + if config.actor_rollout_ref.actor.strategy == 'fsdp': + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from verl.single_controller.ray import RayWorkerGroup + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == 'megatron': + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.RefPolicy: ray.remote(ActorRolloutRefWorker) + } + + global_pool_id = 'global_pool' + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.RefPolicy: global_pool_id, + } + if config.reward_model.enable: + from .prime_fsdp_workers import PRIMERewardModelWorker + role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + reward_manager_name = config.reward_model.get("reward_manager", "naive") + if reward_manager_name == 'naive': + from verl.workers.reward_manager import NaiveRewardManager + reward_manager_cls = NaiveRewardManager + elif reward_manager_name == 'prime': + from verl.workers.reward_manager import PrimeRewardManager + reward_manager_cls = PrimeRewardManager + else: + raise NotImplementedError + reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score) + + # Note that we always use function-based RM for validation + val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score) + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayPRIMETrainer(config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn) + trainer.init_workers() + trainer.fit() + + +if __name__ == '__main__': + main() diff --git a/recipe/prime/prime_core_algos.py b/recipe/prime/prime_core_algos.py new file mode 100644 index 00000000..e2854704 --- /dev/null +++ b/recipe/prime/prime_core_algos.py @@ -0,0 +1,107 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import verl +import verl.utils.torch_functional as verl_F + + +def compute_rloo_advantage_return(data: verl.DataProto, eos_mask: torch.Tensor, n_samples, config): + # calculate rloo reward on different reward sources, and sum again + + def masked_rloo(reward_tensor_original, mask_tensor): + reward_tensor = reward_tensor_original.clone() + reward_tensor[~mask_tensor] = 0 + for start_pos in range(0, reward_tensor.shape[0], n_samples): + cur_rewards_mean = torch.cat([ + reward_tensor[pos:pos + 1][mask_tensor[pos:pos + 1]].mean(dim=0, keepdim=True) + for pos in range(start_pos, start_pos + n_samples) + ], + dim=0) + cur_rewards_sum = cur_rewards_mean.sum() + cur_reward_baseline = cur_rewards_sum / (n_samples - 1) + reward_tensor[start_pos:start_pos + n_samples][ + mask_tensor[start_pos:start_pos + n_samples]] = \ + reward_tensor[start_pos:start_pos + n_samples][ + mask_tensor[start_pos:start_pos + n_samples]] * ( + n_samples / (n_samples - 1)) - cur_reward_baseline + + return reward_tensor + + reward_tensors = [] + + with torch.no_grad(): + + if 'rm_scores' in data.batch and config.algorithm.reward_dpo_coef != 0.: + reward_tensor = data.batch['rm_scores'] + reward_mask = eos_mask.bool() + + reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef) + + if 'acc' in data.batch and config.algorithm.reward_gt_coef != 0.: + reward_tensor = torch.zeros_like(eos_mask, dtype=torch.float32) + reward_mask = torch.zeros_like(eos_mask, dtype=torch.bool) + + prompt_ids = data.batch['prompts'] + prompt_length = prompt_ids.shape[-1] + valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1) + + reward_mask[ + torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device), + valid_response_length - 1] = True + reward_tensor[ + torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device), + valid_response_length - 1] = data.batch['acc'] + + reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef) + + final_reward_tensor = sum(reward_tensors) + + returns = (final_reward_tensor * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + + advantages = returns.clone() + advantages = verl_F.masked_whiten(advantages, eos_mask) + + return advantages, returns + + +def compute_ce_dpo_loss_rm(token_level_scores, acc, eos_mask, beta): + cur_scores = ((token_level_scores * eos_mask).sum(dim=1) * beta).sigmoid() + cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc) + return cur_dpo_loss + + +def compute_dpo_accuracy(token_level_scores, acc, eos_mask, n_samples): + dpo_acc = [] + for start_id in range(0, token_level_scores.shape[0], n_samples): + cur_scores = (token_level_scores[start_id:start_id + n_samples] * + eos_mask[start_id:start_id + n_samples]).sum(dim=1) + + def get_upper_triangle(tensor_x): + diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0) + upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1) + return diff_matrix[upper_tri_indices] + + cur_acc_diff = get_upper_triangle(acc[start_id:start_id + n_samples]) # in range [-1,1] + cur_score_diff = get_upper_triangle(cur_scores) # in R + cur_score_prediction = (cur_score_diff > 0).float() # in [0,1] + if cur_acc_diff.abs().sum() == 0: + cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5 + else: + cur_acc = (((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * + cur_acc_diff.abs()).sum() / cur_acc_diff.abs().sum() + + dpo_acc.append(cur_acc.unsqueeze(0)) + + return torch.cat(dpo_acc, dim=0).mean() diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py new file mode 100644 index 00000000..72bb6168 --- /dev/null +++ b/recipe/prime/prime_dp_rm.py @@ -0,0 +1,305 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implement a multiprocess PPOCritic +""" +import itertools +from typing import Iterable + +import torch +import torch.distributed +from torch import nn, optim + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from .prime_core_algos import compute_ce_dpo_loss_rm +from verl import DataProto +from verl.trainer.ppo import core_algos +from verl.workers.critic import BasePPOCritic +from verl.utils.py_functional import append_to_dict +from verl.utils.torch_functional import masked_mean +from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad +from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx +import verl.utils.torch_functional as verl_F + +from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis + +__all__ = ['DataParallelPRIMERewardModel'] + + +class DataParallelPRIMERewardModel: + + def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer): + self.config = config + self.reward_module = reward_module + self.ref_module = ref_module + self.reward_optimizer = reward_optimizer + self.use_remove_padding = self.config.model.get('use_remove_padding', False) + print(f'Reward model use_remove_padding={self.use_remove_padding}') + + self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + + def _forward_micro_batch(self, micro_batch, prompt_length): + from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange + from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad + + input_ids = micro_batch['input_ids'] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch['attention_mask'] + position_ids = micro_batch['position_ids'] + + if self.use_remove_padding: + input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), + attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices).transpose(0, 1) + + # for compute the log_prob + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + + # pad and slice the inputs if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None, + self.ulysses_sequence_parallel_size) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) + rm_output_logits = self.reward_module(input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + use_cache=False).logits.squeeze( + 0) # copied. I don't really know why there is a squeeze + rm_log_labels = verl_F.logprobs_from_logits(logits=rm_output_logits, labels=input_ids_rmpad_rolled) + if self.ulysses_sequence_parallel_size > 1: + rm_log_labels = gather_outpus_and_unpad(rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size) + rm_log_labels = pad_input(hidden_states=rm_log_labels.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen).squeeze(-1) + + else: + rm_output_logits = self.reward_module(input_ids=micro_batch['input_ids'], + attention_mask=micro_batch['attention_mask'], + position_ids=micro_batch['position_ids']).logits + rm_log_prob = torch.nn.functional.log_softmax(rm_output_logits[:, :-1, :], + dim=-1) # (batch_size, seq_length, vocab_size) + rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch['input_ids'][:, 1:].unsqueeze(-1)).squeeze( + -1) # (batch, seq_length) + + if self.ref_module is not None: + # 不用重复remove pad,只用做好re-pad即可 + with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): + if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding: + ref_output_logits = self.ref_module(input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + use_cache=False).logits.squeeze(0) + ref_log_labels = verl_F.logprobs_from_logits(logits=ref_output_logits, + labels=input_ids_rmpad_rolled) + ref_log_labels = gather_outpus_and_unpad(ref_log_labels, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size) + ref_log_labels = pad_input(hidden_states=ref_log_labels.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen).squeeze(-1) + else: + ref_output_logits = self.ref_module(input_ids=micro_batch['input_ids'], + attention_mask=micro_batch['attention_mask'], + position_ids=micro_batch['position_ids']).logits + ref_log_prob = torch.nn.functional.log_softmax(ref_output_logits[:, :-1, :], + dim=-1) # (batch_size, seq_length, vocab_size) + ref_log_labels = ref_log_prob.gather(dim=-1, + index=micro_batch['input_ids'][:, 1:].unsqueeze(-1)).squeeze( + -1) # (batch, seq_length) + else: + ref_log_labels = micro_batch['old_log_probs'] + + num_actions = micro_batch['input_ids'].shape[-1] - prompt_length + max_positions = micro_batch['attention_mask'][:, prompt_length:].sum(-1) + + ref_log_labels.to(rm_log_labels.dtype) + q = rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:] # this is actually diff of q + + # reward computation does not need gradient. only q needs + with torch.no_grad(): + + # generalized estimation of r should go before the reward filling. r means process reward for policy model, or the advantage of reward model. + lam = self.config.get('lambda', 0.) + beta = self.config.model.get('beta_train', 0.05) + if lam == 0.: + r = q * beta + else: + # reward coefficient takes no effect here + acc = micro_batch['acc'] + q_ = q * beta + r = torch.zeros_like(q) + # TODO: 参考implicit value model在此处的处理方式,应该是靠直接修改max_positions[0]-1位置的q为r-Q_{t-1},后面的r全部抹0 + lastgaelam = 0 + # change the last token and mask out all paddings to make this process easier + for i in range(q.shape[0]): + if self.config.prime_use_gt: + q_[i, max_positions[i] - 1] = acc[i] - q_[i, :max_positions[i] - 1].sum() + q_[i, max_positions[i]:] = 0 + + for t in reversed(range(num_actions)): + delta = q_[:, t] + lastgaelam = delta + lam * lastgaelam + r[:, t] = lastgaelam + + step_ends = [] + + if self.config.prime_granularity == 'token': + for i in range(micro_batch['input_ids'].shape[0]): + step_ends.append(list(range(max_positions[i]))) + elif self.config.prime_granularity == 'whole': + for i in range(micro_batch['input_ids'].shape[0]): + step_ends.append([max_positions[i] - 1]) + else: + raise NotImplementedError + + token_level_score = torch.zeros_like(q) + + for i, step_end in enumerate(step_ends): + for j in range(len(step_end)): + step_range = [ + min(step_end[j - 1] + 1, num_actions - 1) if j > 0 else 0, + min(num_actions - 1, step_end[j]) + ] + token_level_score[i, step_range[1]] = r[i, step_range[0]:step_range[1] + 1].sum() + + return token_level_score, q + + def _optimizer_step(self): + assert self.config.model.optim.grad_clip is not None + + if isinstance(self.reward_module, FSDP): + grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(self.reward_module.parameters(), + max_norm=self.config.model.optim.grad_clip) + self.reward_optimizer.step() + return grad_norm + + def prime_norm(self, token_level_scores): + if self.config.prime_norm == 'batch_norm': + reverse_cumsum = torch.cumsum(token_level_scores.flip(dims=[1]), dim=-1).flip(dims=[1]) + token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6) + return token_level_scores + + def compute_rm_score(self, data: DataProto): + self.reward_module.eval() + self.ref_module.eval() + micro_batch_size = data.meta_info['micro_batch_size'] + select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'acc'] + batch = data.select(batch_keys=select_keys).batch + use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] + prompt_length = data.batch['input_ids'].shape[-1] - data.batch['responses'].shape[-1] + + if use_dynamic_bsz: + # split using dynamic bsz + max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size + micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + else: + micro_batches = batch.split(micro_batch_size) + + rm_scores_lst = [] + for micro_batch in micro_batches: + with torch.no_grad(): + rm_score, q = self._forward_micro_batch(micro_batch, prompt_length) + rm_scores_lst.append(rm_score) + rm_scores = torch.concat(rm_scores_lst, dim=0) + + rm_scores = self.prime_norm(rm_scores) + + if use_dynamic_bsz: + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == rm_scores.size(0), f"{len(indices)} vs. {rm_scores.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + rm_scores = rm_scores[revert_indices] + + return rm_scores, {} + + def update_rm(self, data: DataProto): + # make sure we are in training mode + self.reward_module.train() + metrics = {} + + beta = self.config.model.get('beta_train', 0.05) + + select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'acc', 'prompts'] + batch = data.select(batch_keys=select_keys).batch + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + dataloader = batch.split(self.config.mini_batch_size) + + rm_scores_lst = [] + + for batch_idx, data in enumerate(dataloader): + # split batch into micro_batches + mini_batch = data + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + else: + micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu) + self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu + + self.reward_optimizer.zero_grad() + + for data in micro_batches: + data = data.cuda() + attention_mask = data['attention_mask'] + acc = data['acc'] + + prompt_ids = data['prompts'] + prompt_length = prompt_ids.shape[-1] + + eos_mask = attention_mask[:, prompt_length:] + + rm_score, q = self._forward_micro_batch(data, prompt_length) + + rm_scores_lst.append(rm_score) + + if self.config.model.loss_type == 'ce': + dpo_loss = compute_ce_dpo_loss_rm(q, acc, eos_mask=eos_mask, beta=beta) + else: + raise NotImplementedError + + data = {'reward_model/dpo_loss': dpo_loss.detach().item()} + + if self.config.use_dynamic_bsz: + # relative to the dynamic bsz + loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size) + else: + loss = dpo_loss / self.gradient_accumulation + + loss.backward() + + append_to_dict(metrics, data) + + grad_norm = self._optimizer_step() + data = {'reward_model/grad_norm': grad_norm.detach().item()} + append_to_dict(metrics, data) + self.reward_optimizer.zero_grad() + + rm_scores = torch.cat(rm_scores_lst, dim=0) + + rm_scores = self.prime_norm(rm_scores) + + return rm_scores, metrics diff --git a/recipe/prime/prime_fsdp_workers.py b/recipe/prime/prime_fsdp_workers.py new file mode 100644 index 00000000..263d619b --- /dev/null +++ b/recipe/prime/prime_fsdp_workers.py @@ -0,0 +1,337 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import logging +import os +import warnings + +import torch +import torch.distributed +from torch.distributed.device_mesh import init_device_mesh +import verl.utils.torch_functional as verl_F +from omegaconf import DictConfig, open_dict +from verl import DataProto +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import register, Dispatch +from verl.utils import hf_tokenizer +from verl.utils.debug import log_gpu_memory_usage +from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager +from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_model_to_cpu, load_fsdp_optimizer, \ + load_fsdp_model_to_gpu +from verl.utils.import_utils import import_external_libs +from verl.utils.model import compute_position_id_with_mask +from verl.utils.flops_counter import FlopsCounter +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager + +from codetiming import Timer +from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy +from .prime_core_algos import compute_dpo_accuracy + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) + + +class PRIMERewardModelWorker(Worker): + + def __init__(self, config): + super().__init__() + import torch.distributed + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="nccl") + self.config = config + + # build device mesh for Ulysses Sequence Parallel + world_size = torch.distributed.get_world_size() + from torch.distributed.device_mesh import init_device_mesh + + fsdp_size = self.config.model.fsdp_config.fsdp_size + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + + self.ulysses_device_mesh = None + self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + dp = world_size // self.ulysses_sequence_parallel_size + if self.ulysses_sequence_parallel_size > 1: + self.ulysses_device_mesh = init_device_mesh('cuda', + mesh_shape=(dp, self.ulysses_sequence_parallel_size), + mesh_dim_names=['dp', 'sp']) + + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + + # set FSDP offload params + self._is_offload_param = self.config.model.fsdp_config.param_offload + self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload + + # normalize config + self.config.mini_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size) + if self.config.micro_batch_size is not None: + self.config.micro_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size) + self.config.micro_batch_size_per_gpu = self.config.micro_batch_size + assert self.config.mini_batch_size % self.config.micro_batch_size_per_gpu == 0 + + def _build_reward_ref_model_optimizer(self, config): + # the following line is necessary + from verl.utils.model import LambdaLayer, print_model_size, squeeze + from verl.utils.torch_dtypes import PrecisionType + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision + from torch import optim + + local_path = copy_local_path_from_hdfs(config.model.path) + + tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path) + self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False)) + + from omegaconf import OmegaConf + override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) + override_config_kwargs = { + 'bos_token_id': self.tokenizer.bos_token_id, + 'eos_token_id': self.tokenizer.eos_token_id, + 'pad_token_id': self.tokenizer.pad_token_id, + } + override_config_kwargs.update(override_config) + if self.rank == 0: + print(f'Reward model overriding config {override_config_kwargs}') + + torch_dtype = self.config.model.fsdp_config.get('model_dtype', 'fp32') + torch_dtype = PrecisionType.to_dtype(torch_dtype) + + from transformers import AutoConfig, AutoModelForCausalLM + from torch import nn + + trust_remote_code = False + reward_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) + reward_model_config.num_labels = 1 + + use_remove_padding = config.model.get('use_remove_padding', False) + if use_remove_padding: + from verl.models.registry import check_model_support_rmpad + check_model_support_rmpad(reward_model_config.model_type) + + if use_remove_padding and self.ulysses_sequence_parallel_size > 1: + from verl.models.transformers.monkey_patch import apply_monkey_patch + apply_monkey_patch(reward_model_config, verbose=True) + + init_context = get_init_weight_context_manager() + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + setattr(reward_model_config, 'classifier_dropout', 0.) + setattr(reward_model_config, 'hidden_dropout', '0') + reward_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=reward_model_config, + attn_implementation='flash_attention_2', + trust_remote_code=trust_remote_code) + + # some parameters may not in torch_dtype + reward_module.to(torch_dtype) + + if config.model.get('enable_gradient_checkpointing', False): + reward_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) + if self.rank == 0: + print_model_size(reward_module) + + self.reward_model_config = reward_model_config + + fsdp_config = self.config.model.fsdp_config + mixed_precision_config = fsdp_config.get('mixed_precision', None) + if mixed_precision_config is not None: + param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16')) + reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32')) + buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32')) + else: + param_dtype = torch.bfloat16 + reduce_dtype = torch.float32 + buffer_dtype = torch.float32 + + mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) + + auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config.wrap_policy) + + log_gpu_memory_usage('Before reward model FSDP', logger=None) + + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) + + ref_module = copy.deepcopy(reward_module) + + reward_module = FSDP(reward_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=False, + device_mesh=self.device_mesh, + cpu_offload=None) + + log_gpu_memory_usage('After reward FSDP', logger=None) + + ref_module = FSDP(ref_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=False, + device_mesh=self.device_mesh, + cpu_offload=None) + + reward_optimizer = optim.AdamW(reward_module.parameters(), + lr=config.model.optim.lr, + betas=config.model.optim.get('betas', (0.9, 0.999)), + weight_decay=config.model.optim.get('weight_decay', 1e-2)) + + total_steps = config.model.optim.get('total_training_steps', 0) + num_warmup_steps_ratio = config.model.optim.get('lr_warmup_steps_ratio', 0.) + num_warmup_steps = int(num_warmup_steps_ratio * total_steps) + + print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}') + + from verl.utils.torch_functional import get_constant_schedule_with_warmup + reward_lr_scheduler = get_constant_schedule_with_warmup(optimizer=reward_optimizer, + num_warmup_steps=num_warmup_steps) + + return reward_module, ref_module, reward_optimizer, reward_lr_scheduler + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get('external_lib', None)) + + from .prime_dp_rm import DataParallelPRIMERewardModel + self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = self._build_reward_ref_model_optimizer( + config=self.config) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.reward_module) + offload_fsdp_model_to_cpu(self.ref_module) + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.reward_optimizer) + + self.rm = DataParallelPRIMERewardModel(config=self.config, + reward_module=self.reward_module, + ref_module=self.ref_module, + reward_optimizer=self.reward_optimizer) + + self.flops_counter = FlopsCounter(self.reward_model_config) + self.checkpoint_manager = FSDPCheckpointManager(model=self.reward_module, + optimizer=self.reward_optimizer, + lr_scheduler=self.reward_lr_scheduler, + tokenizer=self.tokenizer) + + torch.cuda.empty_cache() + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + def compute_rm_score(self, data: DataProto): + data = data.to('cuda') + + if self._is_offload_param: + load_fsdp_model_to_gpu(self.reward_module) + load_fsdp_model_to_gpu(self.ref_module) + micro_batch_size = self.config.micro_batch_size_per_gpu + data.meta_info['micro_batch_size'] = micro_batch_size + data.meta_info['max_token_len'] = self.config.forward_max_token_len_per_gpu + data.meta_info['use_dynamic_bsz'] = self.config.use_dynamic_bsz + # perform forward computation + with self.ulysses_sharding_manager: + data = self.ulysses_sharding_manager.preprocess_data(data=data) + rm_scores, metrics = self.rm.compute_rm_score(data=data) + + prompt_length = data.batch['prompts'].shape[-1] + eos_mask = data.batch['attention_mask'][:, prompt_length:] + acc = data.batch['acc'] + + dpo_acc = compute_dpo_accuracy(rm_scores, acc, eos_mask=eos_mask, n_samples=data.meta_info['n']) + + metrics['reward_model/dpo_acc'] = dpo_acc.detach().item() + + output = DataProto.from_dict(tensors={'rm_scores': rm_scores}, meta_info={'metrics': metrics}) + output = self.ulysses_sharding_manager.postprocess_data(data=output) + + output = output.to('cpu') + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.reward_module) + offload_fsdp_model_to_cpu(self.ref_module) + return output + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + def update_rm(self, data: DataProto): + data = data.to('cuda') + if self._is_offload_param: + load_fsdp_model_to_gpu(self.ref_module) + load_fsdp_model_to_gpu(self.reward_module) + if self._is_offload_optimizer: + load_fsdp_optimizer(optimizer=self.reward_optimizer, device_id=torch.cuda.current_device()) + + # perform forward computation + with self.ulysses_sharding_manager: + data = self.ulysses_sharding_manager.preprocess_data(data=data) + + rm_scores, metrics = self.rm.update_rm(data=data) + + self.reward_lr_scheduler.step() + lr = self.reward_lr_scheduler.get_last_lr()[0] + metrics['rm/lr'] = lr + + prompt_length = data.batch['prompts'].shape[-1] + eos_mask = data.batch['attention_mask'][:, prompt_length:] + acc = data.batch['acc'] + + dpo_acc_before = compute_dpo_accuracy(rm_scores, acc, eos_mask=eos_mask, n_samples=data.meta_info['n']) + + metrics['reward_model/dpo_acc_before'] = dpo_acc_before.detach().item() + + output = DataProto.from_dict(tensors={'rm_scores': rm_scores}, meta_info={'metrics': metrics}) + output = self.ulysses_sharding_manager.postprocess_data(data=output) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.reward_module) + offload_fsdp_model_to_cpu(self.ref_module) + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.reward_optimizer) + torch.cuda.empty_cache() + output = output.to('cpu') + return output + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=False): + import torch + if self._is_offload_param: + load_fsdp_model_to_gpu(self.reward_module) + + self.checkpoint_manager.save_checkpoint(local_path=local_path, + hdfs_path=hdfs_path, + global_step=global_step, + remove_previous_ckpt=remove_previous_ckpt) + + torch.distributed.barrier() + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.reward_module) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def load_checkpoint(self, path, del_local_after_load=True): + import torch + if self._is_offload_param: + load_fsdp_model_to_gpu(self.reward_module) + + self.checkpoint_manager.load_checkpoint(path=path, del_local_after_load=del_local_after_load) + + torch.distributed.barrier() + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.reward_module) diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py new file mode 100644 index 00000000..235e7f4d --- /dev/null +++ b/recipe/prime/prime_ray_trainer.py @@ -0,0 +1,503 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FSDP PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import os +import statistics +import uuid +from copy import deepcopy +from pprint import pprint + +import numpy as np +import torch +from omegaconf import OmegaConf, open_dict + +from verl import DataProto +from verl.single_controller.ray import RayWorkerGroup +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.trainer.ppo.ray_trainer import Role, WorkerType, ResourcePoolManager, reduce_metrics, _compute_response_info, \ + _timer +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path +from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn +from . import prime_core_algos + + +def compute_advantage(data: DataProto, adv_estimator, config): + if adv_estimator == 'rloo': + responses = data.batch['responses'] + response_length = responses.size(-1) + attention_mask = data.batch['attention_mask'] + response_mask = attention_mask[:, -response_length:] + advantages, returns = prime_core_algos.compute_rloo_advantage_return(data, response_mask, + config.actor_rollout_ref.rollout.n, config) + data.batch['advantages'] = advantages + data.batch['returns'] = returns + else: + raise NotImplementedError + return data + + +def compute_data_metrics(batch, use_critic=True): + + advantages = batch.batch['advantages'] + returns = batch.batch['returns'] + + max_response_length = batch.batch['responses'].shape[-1] + + prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool() + response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool() + + max_prompt_length = prompt_mask.size(-1) + + response_info = _compute_response_info(batch) + prompt_length = response_info['prompt_length'] + response_length = response_info['response_length'] + + valid_adv = torch.masked_select(advantages, response_mask) + valid_returns = torch.masked_select(returns, response_mask) + + if use_critic: + values = batch.batch['values'] + valid_values = torch.masked_select(values, response_mask) + return_diff_var = torch.var(valid_returns - valid_values) + return_var = torch.var(valid_returns) + + metrics = { + # adv + 'critic/advantages/mean': + torch.mean(valid_adv).detach().item(), + 'critic/advantages/max': + torch.max(valid_adv).detach().item(), + 'critic/advantages/min': + torch.min(valid_adv).detach().item(), + # returns + 'critic/returns/mean': + torch.mean(valid_returns).detach().item(), + 'critic/returns/max': + torch.max(valid_returns).detach().item(), + 'critic/returns/min': + torch.min(valid_returns).detach().item(), + **({ + # values + 'critic/values/mean': torch.mean(valid_values).detach().item(), + 'critic/values/max': torch.max(valid_values).detach().item(), + 'critic/values/min': torch.min(valid_values).detach().item(), + # vf explained var + 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } if use_critic else {}), + + # response length + 'response_length/mean': + torch.mean(response_length).detach().item(), + 'response_length/max': + torch.max(response_length).detach().item(), + 'response_length/min': + torch.min(response_length).detach().item(), + 'response_length/clip_ratio': + torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), + # prompt length + 'prompt_length/mean': + torch.mean(prompt_length).detach().item(), + 'prompt_length/max': + torch.max(prompt_length).detach().item(), + 'prompt_length/min': + torch.min(prompt_length).detach().item(), + 'prompt_length/clip_ratio': + torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + } + return metrics + + +def compute_timing_metrics(batch, timing_raw): + response_info = _compute_response_info(batch) + num_prompt_tokens = torch.sum(response_info['prompt_length']).item() + num_response_tokens = torch.sum(response_info['response_length']).item() + num_overall_tokens = num_prompt_tokens + num_response_tokens + + num_tokens_of_section = { + 'gen': num_response_tokens, + **{ + name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor'] + }, + } + + return { + **{ + f'timing_s/{name}': value for name, value in timing_raw.items() + }, + **{ + f'timing_per_token_ms/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys( + )) & set(timing_raw.keys()) + }, + } + + +class RayPRIMETrainer(RayPPOTrainer): + """ + Note that this trainer runs on the driver process on a single CPU/GPU node. + """ + + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__(self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + reward_fn=None, + val_reward_fn=None): + + # assert torch.cuda.is_available(), 'cuda must be available on driver' + + super().__init__(config, tokenizer, role_worker_mapping, resource_pool_manager, ray_worker_group_cls, reward_fn, + val_reward_fn) + + self.use_critic = False + + def _validate_config(self): + super()._validate_config() + # TODO: Additional config checks can be added here + config = self.config + + def _create_dataloader(self): + from torch.utils.data import DataLoader, RandomSampler, SequentialSampler + # TODO: we have to make sure the batch size is divisible by the dp size + self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files, + tokenizer=self.tokenizer, + prompt_key=self.config.data.prompt_key, + max_prompt_length=self.config.data.max_prompt_length, + filter_prompts=True, + return_raw_chat=self.config.data.get('return_raw_chat', False), + truncation='error') + # use sampler for better ckpt resume + if self.config.data.shuffle: + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(self.config.data.get('seed', 1)) + sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator) + else: + sampler = SequentialSampler(data_source=self.train_dataset) + + self.train_dataloader = DataLoader(dataset=self.train_dataset, + batch_size=int(self.config.data.train_batch_size * + self.config.data.oversample_factor), + drop_last=True, + collate_fn=collate_fn, + sampler=sampler) + + self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files, + tokenizer=self.tokenizer, + prompt_key=self.config.data.prompt_key, + max_prompt_length=self.config.data.max_prompt_length, + filter_prompts=True, + return_raw_chat=self.config.data.get('return_raw_chat', False), + truncation='error') + self.val_dataloader = DataLoader(dataset=self.val_dataset, + batch_size=len(self.val_dataset), + shuffle=True, + drop_last=True, + collate_fn=collate_fn) + + assert len(self.train_dataloader) >= 1 + assert len(self.val_dataloader) >= 1 + + print(f'Size of train dataloader: {len(self.train_dataloader)}') + print(f'Size of val dataloader: {len(self.val_dataloader)}') + + # inject total_training_steps to actor/critic optim_config. This is hacky. + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f'Total training steps: {self.total_training_steps}') + + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + self.config.critic.optim.total_training_steps = total_training_steps + + def _save_checkpoint(self): + # path: given_path + `/global_step_{global_steps}` + `/actor` + local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, + f'global_step_{self.global_steps}') + actor_local_path = os.path.join(local_global_step_folder, 'actor') + + actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( + self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor') + self.actor_rollout_wg.save_checkpoint(actor_local_path, + actor_remote_path, + self.global_steps, + remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) + + if self.use_rm: + reward_local_path = os.path.join(local_global_step_folder, 'reward') + reward_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( + self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'reward') + self.rm_wg.save_checkpoint(reward_local_path, + reward_remote_path, + self.global_steps, + remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save) + + # save dataloader + dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt') + import dill + torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill) + + # latest checkpointed iteration tracker (for atomic usage) + local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir, + 'latest_checkpointed_iteration.txt') + with open(local_latest_checkpointed_iteration, 'w') as f: + f.write(str(self.global_steps)) + + def _load_checkpoint(self): + if self.config.trainer.resume_mode == 'disable': + return 0 + + # load from hdfs + if self.config.trainer.default_hdfs_dir is not None: + NotImplementedError('load from hdfs is not implemented yet') + else: + checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + + # find global_step_folder + if self.config.trainer.resume_mode == 'auto': + if global_step_folder is None: + print('Training from scratch') + return 0 + else: + if not (self.config.trainer.resume_from_path and global_step_folder is not None): + assert isinstance(self.config.trainer.resume_mode, str), "resume ckpt must be str type" + assert 'global_step_' in self.config.trainer.resume_mode, "resume ckpt must specify the global_steps" + global_step_folder = self.config.trainer.resume_mode + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + print(f'Load from checkpoint folder: {global_step_folder}') + # set global step + self.global_steps = int(global_step_folder.split('global_step_')[-1]) + + print(f'Setting global step to {self.global_steps}') + print(f'Resuming from {global_step_folder}') + + actor_path = os.path.join(global_step_folder, 'actor') + reward_path = os.path.join(global_step_folder, 'reward') + # load actor + self.actor_rollout_wg.load_checkpoint(actor_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + # load critic + if self.use_rm: + self.critic_wg.load_checkpoint(reward_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + + # load dataloader, + # TODO: from remote not implemented yet + dataloader_local_path = os.path.join(global_step_folder, 'data.pt') + self.train_dataloader = torch.load(dataloader_local_path) + if isinstance(self.train_dataloader.dataset, RLHFDataset): + self.train_dataloader.dataset.resume_dataset_state() + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from verl.utils.tracking import Tracking + from omegaconf import OmegaConf + + logger = Tracking(project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True)) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True): + val_metrics = self._validate() + pprint(f'Initial validation metrics: {val_metrics}') + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get('val_only', False): + return + + # we start from step 1 + self.global_steps += 1 + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # pop those keys for generation + gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) + + with _timer('step', timing_raw): + # generate a batch + with _timer('gen', timing_raw): + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + + if self.config.algorithm.adv_estimator == 'remax': + with _timer('gen_max', timing_raw): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info['do_sample'] = False + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + + batch = batch.union(gen_baseline_output) + reward_baseline_tensor = self.reward_fn(batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + + batch.batch['reward_baselines'] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + + batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], + dtype=object) + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + # balance the number of valid tokens on each dp rank. + # Note that this breaks the order of data inside the batch. + # Please take care when you implement group based adv computation such as GRPO and rloo + # self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + + # verify + with _timer('verify', timing_raw): + scores = self.reward_fn.verify(batch) + metrics['acc'] = statistics.mean(scores) + + # filter the batch. 1/oversample_factor samples will be kept. If there is a filter, prompts passing it will be prioritized. + + batch = self.filter_and_downsample(scores, batch) + batch.meta_info['n'] = self.config.actor_rollout_ref.rollout.n + + # recompute old_log_probs + with _timer('old_log_prob', timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with _timer('ref', timing_raw): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + with _timer('adv', timing_raw): + + if self.use_rm: + update_style = self.config.reward_model.model.get('update', 'none') + if update_style == 'none': # only run forward + reward_output = self.rm_wg.compute_rm_score(batch) + elif update_style == 'after': # update and directly return the reward + reward_output = self.rm_wg.update_rm(batch) + elif update_style == 'before': # update reward model, and then run forward + reward_output = self.rm_wg.update_rm(batch) + if 'metrics' in reward_output.meta_info['metrics']: + reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics']) + metrics.update(reward_output_metrics) + + reward_output = self.rm_wg.compute_rm_score(batch) + else: + raise NotImplementedError + batch = batch.union(reward_output) + if 'metrics' in reward_output.meta_info['metrics']: + reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics']) + metrics.update(reward_output_metrics) + + # compute advantages, executed on the driver process + batch = compute_advantage(batch, + adv_estimator=self.config.algorithm.adv_estimator, + config=self.config) + + # update actor + with _timer('update_actor', timing_raw): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + metrics.update(actor_output_metrics) + + # validate + if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ + self.global_steps % self.config.trainer.test_freq == 0: + with _timer('testing', timing_raw): + val_metrics: dict = self._validate() + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and \ + self.global_steps % self.config.trainer.save_freq == 0: + with _timer('save_checkpoint', timing_raw): + self._save_checkpoint() + + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + self.global_steps += 1 + + if self.global_steps >= self.total_training_steps: + + # perform validation after training + if self.val_reward_fn is not None: + val_metrics = self._validate() + pprint(f'Final validation metrics: {val_metrics}') + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.save_freq > 0 and \ + (self.global_steps - 1) % self.config.trainer.save_freq != 0: + with _timer('save_checkpoint', timing_raw): + self._save_checkpoint() + return + + def filter_and_downsample(self, scores, batch: DataProto): + """ + downsample the batch according to oversample_factor + samples passing the filters will be prioritized + """ + n_samples = int(self.config.actor_rollout_ref.rollout.n) + reward_matrix = torch.tensor(scores).reshape(-1, n_samples) + + filter_mask = torch.zeros((reward_matrix.shape[0]), dtype=torch.bool) + + if self.config.data.filter_accuracy: + acc_tensor = torch.mean(reward_matrix, dim=-1) + filter_mask[(acc_tensor > self.config.data.accuracy_upper_bound) & + (acc_tensor < self.config.data.accuracy_lower_bound)] = False + + reorder_index = torch.argsort(filter_mask, descending=True) + reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples + 1).unsqueeze(0)).view(-1) + batch.reorder(reorder_index[:int(len(batch) // + self.config.data.oversample_factor)]) # this operation is inplace + + return batch diff --git a/recipe/prime/run_prime_qwen.sh b/recipe/prime/run_prime_qwen.sh new file mode 100644 index 00000000..62d3fb1a --- /dev/null +++ b/recipe/prime/run_prime_qwen.sh @@ -0,0 +1,55 @@ +set -x + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +model_path=PRIME-RL/Eurus-2-7B-SFT + +python3 -m recipe.prime.main_prime \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=64 \ + data.val_batch_size=6312 \ + data.max_prompt_length=1024 \ + data.max_response_length=3072 \ + data.filter_accuracy=True \ + data.accuracy_lower_bound=0.2 \ + data.accuracy_upper_bound=0.8 \ + data.oversample_factor=4 \ + actor_rollout_ref.model.path=$model_path \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + algorithm.adv_estimator=rloo \ + reward_model.model.path=$model_path \ + reward_model.micro_batch_size=8 \ + reward_model.model.update=before \ + reward_model.model.beta_train=0.05 \ + reward_model.model.optim.lr=1e-6 \ + reward_model.model.optim.grad_clip=10.0 \ + reward_model.model.input_tokenizer=null \ + reward_model.mini_batch_size=64 \ + trainer.val_before_train=False \ + trainer.logger=['console','wandb'] \ + trainer.project_name='prime_example' \ + trainer.experiment_name='Eurus-2-7B-SFT' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 $@ diff --git a/scripts/format.sh b/scripts/format.sh index ed49d6f1..d8562caf 100644 --- a/scripts/format.sh +++ b/scripts/format.sh @@ -1,3 +1,3 @@ #!/bin/bash pip3 install --upgrade yapf -python3 -m yapf -ir -vv --style ./.style.yapf verl tests single_controller examples +python3 -m yapf -ir -vv --style ./.style.yapf verl tests single_controller examples recipe diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index bf919089..7b4cd678 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -132,7 +132,8 @@ reward_model: external_lib: ${actor_rollout_ref.model.external_lib} use_remove_padding: False fsdp_config: - min_num_params: 0 + wrap_policy: + min_num_params: 0 param_offload: False fsdp_size: -1 micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu @@ -153,6 +154,7 @@ algorithm: kl_coef: 0.001 trainer: + balance_batch: True total_epochs: 30 total_training_steps: null project_name: verl_examples diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 80412d3e..0dd87a3f 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -894,7 +894,8 @@ def fit(self): # balance the number of valid tokens on each dp rank. # Note that this breaks the order of data inside the batch. # Please take care when you implement group based adv computation such as GRPO and rloo - self._balance_batch(batch, metrics=metrics) + if self.config.balance_batch: + self._balance_batch(batch, metrics=metrics) # compute global_valid tokens batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() diff --git a/verl/workers/reward_manager/prime.py b/verl/workers/reward_manager/prime.py index f2f8856f..14330c2b 100644 --- a/verl/workers/reward_manager/prime.py +++ b/verl/workers/reward_manager/prime.py @@ -84,23 +84,14 @@ def __init__(self, tokenizer, num_examine, compute_score=None) -> None: self.num_examine = num_examine # the number of batches of decoded responses to print to the console self.compute_score = compute_score or _default_compute_score - def __call__(self, data: DataProto): - """We will expand this function gradually based on the available datasets""" - - # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn - if 'rm_scores' in data.batch.keys(): - return data.batch['rm_scores'] - - reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) - - already_print_data_sources = {} - + def verify(self, data): + """ + verify the batch and save as ``acc`` tensor + """ # batched scoring prompt_ids = data.batch['prompts'] - prompt_length = prompt_ids.shape[-1] response_ids = data.batch['responses'] - valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(dim=-1) sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True) ground_truth = [data_item.non_tensor_batch['reward_model']['ground_truth'] for data_item in data] data_sources = data.non_tensor_batch['data_source'] @@ -119,6 +110,30 @@ def __call__(self, data: DataProto): except Exception as e: print(f"Unexpected error in batched reward computing. Setting all as 0.: {e}") scores = [0. for _ in range(len(sequences_str))] + data.batch['acc'] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device) + return scores + + def __call__(self, data: DataProto): + """We will expand this function gradually based on the available datasets""" + + # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn + if 'rm_scores' in data.batch.keys(): + return data.batch['rm_scores'] + + reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + + already_print_data_sources = {} + + # batched scoring + prompt_ids = data.batch['prompts'] + prompt_length = prompt_ids.shape[-1] + + response_ids = data.batch['responses'] + valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(dim=-1) + sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True) + data_sources = data.non_tensor_batch['data_source'] + + scores = self.verify(data) for i in range(len(data)): data_source = data_sources[i]