From c8b9c3559ac3beb7bccf85517c07765403df2532 Mon Sep 17 00:00:00 2001 From: zhou fan <1247714429@qq.com> Date: Sun, 16 Feb 2025 00:18:34 +0800 Subject: [PATCH] fix the split placement example (#281) The split placement example is outdated, I tried it and encountered some errors. To address this, the following changes were made in this PR 1. Copied the content from `verl/trainer/config/ppo_trainer.yaml` to `examples/split_placement/config/ppo_trainer_split.yaml` 2. Copied `RayPPOTrainer.fit` method into the `fit` func in `examples/split_placement/split_monkey_patch.py` and modified it to get the futures of `critic_output` and `actor_output` --- examples/split_placement/README.md | 2 +- .../config/ppo_trainer_split.yaml | 63 ++++- .../split_placement/split_monkey_patch.py | 246 ++++++++++-------- 3 files changed, 191 insertions(+), 120 deletions(-) diff --git a/examples/split_placement/README.md b/examples/split_placement/README.md index 19676c1b..a5529725 100644 --- a/examples/split_placement/README.md +++ b/examples/split_placement/README.md @@ -44,7 +44,7 @@ def update_critic(self, data: DataProto): ... ``` -We can also parallelize the computation of `ref_log_prob` and `values` and `rewards` in the split placement. For simplicity of the tutorial, we +We can also parallelize the computation of `ref_log_prob` and `values` and `rewards` in the split placement. For simplicity of the tutorial, we don't do this in this example. ### Step 3: Execute these operation in parallel in the single controller process To implement the parallel execution of the actor and critic update, the only thing we need to modify in the `ray_trainer.py` is to `get` the concurrent `futures` on the single controller process. diff --git a/examples/split_placement/config/ppo_trainer_split.yaml b/examples/split_placement/config/ppo_trainer_split.yaml index a475d7af..6ac24b57 100644 --- a/examples/split_placement/config/ppo_trainer_split.yaml +++ b/examples/split_placement/config/ppo_trainer_split.yaml @@ -9,24 +9,32 @@ data: val_batch_size: 1312 return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs return_raw_chat: False + shuffle: True actor_rollout_ref: hybrid_engine: True model: path: ~/models/deepseek-llm-7b-chat external_lib: null - override_config: {} - enable_gradient_checkpointing: False + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: False actor: strategy: fsdp # This is for backward-compatibility ppo_mini_batch_size: 256 ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: 64 + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 clip_ratio: 0.2 entropy_coeff: 0.001 + use_kl_loss: False # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 - shuffle: True + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size optim: lr: 1e-6 lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime @@ -40,6 +48,7 @@ actor_rollout_ref: param_offload: False grad_offload: False optimizer_offload: False + fsdp_size: -1 ref: fsdp_config: param_offload: False @@ -47,7 +56,10 @@ actor_rollout_ref: # transformer_layer_cls_to_wrap: None min_num_params: 0 log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: 128 + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size rollout: name: vllm temperature: 1.0 @@ -66,7 +78,11 @@ actor_rollout_ref: max_num_batched_tokens: 8192 max_num_seqs: 1024 log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: 128 + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput # for hf rollout do_sample: True # number of responses (i.e. num sample times) @@ -83,9 +99,10 @@ critic: model: path: ~/models/deepseek-llm-7b-chat tokenizer_path: ${actor_rollout_ref.model.path} - override_config: {} + override_config: { } external_lib: ${actor_rollout_ref.model.external_lib} - enable_gradient_checkpointing: False + enable_gradient_checkpointing: True + use_remove_padding: False fsdp_config: param_offload: False grad_offload: False @@ -93,9 +110,16 @@ critic: wrap_policy: # transformer_layer_cls_to_wrap: None min_num_params: 0 + fsdp_size: -1 ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: 64 + ppo_micro_batch_size_per_gpu: null + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} shuffle: ${actor_rollout_ref.actor.shuffle} grad_clip: 1.0 @@ -108,12 +132,18 @@ reward_model: input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical path: ~/models/FsfairX-LLaMA3-RM-v0.1 external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False fsdp_config: min_num_params: 0 param_offload: False + fsdp_size: -1 micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu - micro_batch_size_per_gpu: 64 + micro_batch_size_per_gpu: null # set a number max_length: null + ulysses_sequence_parallel_size: 1 # sp size + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive algorithm: gamma: 1.0 @@ -126,13 +156,18 @@ algorithm: trainer: total_epochs: 30 + total_training_steps: null project_name: verl_examples experiment_name: gsm8k - logger: ['console', 'wandb'] + logger: [ 'console', 'wandb' ] + val_generations_to_log_to_wandb: 0 nnodes: 1 n_gpus_per_node: 8 save_freq: -1 - test_freq: 2 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + resume_from_path: False + test_freq: -1 critic_warmup: 0 - default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} - default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + default_hdfs_dir: null + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} \ No newline at end of file diff --git a/examples/split_placement/split_monkey_patch.py b/examples/split_placement/split_monkey_patch.py index 5e09377b..87608c10 100644 --- a/examples/split_placement/split_monkey_patch.py +++ b/examples/split_placement/split_monkey_patch.py @@ -14,12 +14,13 @@ """ An naive implementation of split placment example """ -import os from pprint import pprint -from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs from verl import DataProto -from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, Role, create_colocated_worker_cls -from codetiming import Timer +from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, _timer, compute_timing_metrics +from copy import deepcopy +import numpy as np +import torch +import uuid def fit(self): @@ -36,126 +37,161 @@ def fit(self): default_backend=self.config.trainer.logger, config=OmegaConf.to_container(self.config, resolve=True)) - global_steps = 0 + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None: + if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True): val_metrics = self._validate() pprint(f'Initial validation metrics: {val_metrics}') + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get('val_only', False): + return + + # we start from step 1 + self.global_steps += 1 for epoch in range(self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: metrics = {} + timing_raw = {} batch: DataProto = DataProto.from_single_dict(batch_dict) - # batch = batch.to('cuda') # pop those keys for generation gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) - # generate a batch - with Timer(name='gen', logger=None) as timer: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - metrics['timing/gen'] = timer.last - - batch = batch.union(gen_batch_output) - - if self.use_reference_policy: - # compute reference log_prob - with Timer(name='ref', logger=None) as timer: - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - metrics['timing/ref'] = timer.last - - # compute values - with Timer(name='values', logger=None) as timer: - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - metrics['timing/values'] = timer.last - - with Timer(name='adv', logger=None) as timer: - # compute scores. Support both model and function-based. - # We first compute the scores using reward model. Then, we call reward_fn to combine - # the results from reward model and rule-based results. - if self.use_rm: - # we first compute reward model score - reward_tensor = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - - # we combine with rule-based rm - reward_tensor = self.reward_fn(batch) - batch.batch['token_level_scores'] = reward_tensor - - # compute rewards. apply_kl_penalty if available - batch, kl_metrics = apply_kl_penalty(batch, - kl_ctrl=self.kl_ctrl, - kl_penalty=self.config.algorithm.kl_penalty) - metrics.update(kl_metrics) - - # compute advantages, executed on the driver process - batch = compute_advantage(batch, - self.config.algorithm.gamma, - self.config.algorithm.lam, - adv_estimator=self.config.algorithm.adv_estimator) - metrics['timing/adv'] = timer.last - - # update critic - if self.use_critic: - with Timer(name='update_critic_call', logger=None) as timer: - critic_output = self.critic_wg.update_critic(batch) - metrics['timing/update_critic_call'] = timer.last - - # implement critic warmup - if self.config.trainer.critic_warmup <= global_steps: - # update actor - with Timer(name='update_actor_call', logger=None) as timer: - actor_output = self.actor_rollout_wg.update_actor(batch) - metrics['timing/update_acto_call'] = timer.last - - # NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class - with Timer(name='update_actor_critic', logger=None) as timer: - # NOTE: get the DataProtoFuture - critic_output = critic_output.get() - critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) - metrics.update(critic_output_metrics) - - # NOTE: get the DataProtoFuture - actor_output = actor_output.get() - actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) - metrics.update(actor_output_metrics) - metrics['timing/update_actor_critic'] = timer.last - - # validate - if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0: - with Timer(name='testing', logger=None) as timer: - val_metrics: dict = self._validate() - val_metrics = {f'val/{key}': val for key, val in val_metrics.items()} - metrics['timing/testing'] = timer.last - metrics.update(val_metrics) + with _timer('step', timing_raw): + # generate a batch + with _timer('gen', timing_raw): + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - # collect metrics - data_metrics = compute_data_metrics(batch=batch) - metrics.update(data_metrics) + if self.config.algorithm.adv_estimator == 'remax': + with _timer('gen_max', timing_raw): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info['do_sample'] = False + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=global_steps) + batch = batch.union(gen_baseline_output) + reward_baseline_tensor = self.reward_fn(batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + + batch.batch['reward_baselines'] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + + batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], + dtype=object) + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + # balance the number of valid tokens on each dp rank. + # Note that this breaks the order of data inside the batch. + # Please take care when you implement group based adv computation such as GRPO and rloo + self._balance_batch(batch, metrics=metrics) - if self.config.trainer.save_freq > 0 and (global_steps + 1) % self.config.trainer.save_freq == 0: - actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor', - f'global_step_{global_steps}') - actor_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'actor') - self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path) + # compute global_valid tokens + batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + # recompute old_log_probs + with _timer('old_log_prob', timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with _timer('ref', timing_raw): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with _timer('values', timing_raw): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with _timer('adv', timing_raw): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm: + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + # we combine with rule-based rm + reward_tensor = self.reward_fn(batch) + batch.batch['token_level_scores'] = reward_tensor + + # compute rewards. apply_kl_penalty if available + if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False): + batch, kl_metrics = apply_kl_penalty(batch, + kl_ctrl=self.kl_ctrl, + kl_penalty=self.config.algorithm.kl_penalty) + metrics.update(kl_metrics) + else: + batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] + + # compute advantages, executed on the driver process + batch = compute_advantage(batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n) + + # update critic if self.use_critic: - critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic', - f'global_step_{global_steps}') - critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic') - self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path) + with _timer('update_critic_call', timing_raw): + critic_output = self.critic_wg.update_critic(batch) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with _timer('update_actor_call', timing_raw): + actor_output = self.actor_rollout_wg.update_actor(batch) + + # NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class + with _timer('update_actor_critic', timing_raw): + critic_output = critic_output.get() + critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + metrics.update(critic_output_metrics) + + actor_output = actor_output.get() + actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + metrics.update(actor_output_metrics) + + # validate + if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ + self.global_steps % self.config.trainer.test_freq == 0: + with _timer('testing', timing_raw): + val_metrics: dict = self._validate() + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and \ + self.global_steps % self.config.trainer.save_freq == 0: + with _timer('save_checkpoint', timing_raw): + self._save_checkpoint() - global_steps += 1 + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - # perform validation after training - if self.val_reward_fn is not None: - val_metrics = self._validate() - pprint(f'Final validation metrics: {val_metrics}') + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + self.global_steps += 1 + + if self.global_steps >= self.total_training_steps: + + # perform validation after training + if self.val_reward_fn is not None: + val_metrics = self._validate() + pprint(f'Final validation metrics: {val_metrics}') + logger.log(data=val_metrics, step=self.global_steps) + return