From f4a00dee8e0375b2408220e1f8648934d5033c8d Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 23 May 2024 12:48:27 +0000 Subject: [PATCH 1/2] Fix PPOv2 / RLOO refactor's stuff --- docs/source/ppov2_trainer.md | 2 +- examples/scripts/ppo/ppo.py | 4 ++-- examples/scripts/ppo/ppo_zephyr.py | 4 ++-- examples/scripts/rloo/rloo.py | 4 ++-- examples/scripts/rloo/rloo_zephyr.py | 4 ++-- tests/test_ppov2_trainer.py | 2 +- tests/test_rloo_trainer.py | 2 +- trl/trainer/ppov2_trainer.py | 8 ++++---- trl/trainer/rloo_config.py | 12 ++++++------ trl/trainer/rloo_trainer.py | 14 +++++++------- 10 files changed, 28 insertions(+), 28 deletions(-) diff --git a/docs/source/ppov2_trainer.md b/docs/source/ppov2_trainer.md index 084351610d..c217f7921b 100644 --- a/docs/source/ppov2_trainer.md +++ b/docs/source/ppov2_trainer.md @@ -13,7 +13,7 @@ References: To just run a PPO script to make sure the trainer can run, you can run the following command to train a PPO model with a dummy reward model. ```bash -python -i examples/scripts/minimal/rloo.py \ +python -i examples/scripts/ppo/ppo.py \ --learning_rate 3e-6 \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 997188e51f..e74d2e52a5 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -14,7 +14,7 @@ """ -python -i examples/scripts/minimal/ppo.py \ +python -i examples/scripts/ppo/ppo.py \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 64 \ @@ -24,7 +24,7 @@ --non_eos_penalty \ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ - examples/scripts/minimal/ppo.py \ + examples/scripts/ppo/ppo.py \ --output_dir models/minimal/ppo \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ diff --git a/examples/scripts/ppo/ppo_zephyr.py b/examples/scripts/ppo/ppo_zephyr.py index 1fac2df669..afbf585916 100644 --- a/examples/scripts/ppo/ppo_zephyr.py +++ b/examples/scripts/ppo/ppo_zephyr.py @@ -15,7 +15,7 @@ """ -python -i examples/scripts/minimal/ppo_zephyr.py \ +python -i examples/scripts/ppo/ppo_zephyr.py \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 1 \ @@ -28,7 +28,7 @@ --truncate_token eos \ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ - examples/scripts/minimal/ppo_zephyr.py \ + examples/scripts/ppo/ppo_zephyr.py \ --output_dir models/minimal/ppo_zephyr10 \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index aa677251ae..c8d6f193e4 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -14,7 +14,7 @@ """ -python -i examples/scripts/minimal/rloo.py \ +python -i examples/scripts/rloo/rloo.py \ --learning_rate 3e-6 \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ @@ -25,7 +25,7 @@ --model_name_or_path EleutherAI/pythia-1b-deduped \ --non_eos_penalty \ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ - examples/scripts/minimal/rloo.py \ + examples/scripts/rloo/rloo.py \ --output_dir models/minimal/rloo \ --rloo_k 2 \ --num_ppo_epochs 1 \ diff --git a/examples/scripts/rloo/rloo_zephyr.py b/examples/scripts/rloo/rloo_zephyr.py index cd98ed6f81..aef5bdf058 100644 --- a/examples/scripts/rloo/rloo_zephyr.py +++ b/examples/scripts/rloo/rloo_zephyr.py @@ -15,7 +15,7 @@ """ -python -i examples/scripts/minimal/rloo_zephyr.py \ +python -i examples/scripts/rloo/rloo_zephyr.py \ --learning_rate 3e-6 \ --output_dir models/minimal/rloo_zephyr \ --per_device_train_batch_size 64 \ @@ -29,7 +29,7 @@ --response_length 53 \ --sanity_check accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ - examples/scripts/minimal/rloo_zephyr.py \ + examples/scripts/rloo/rloo_zephyr.py \ --num_ppo_epochs 1 \ --num_mini_batches 1 \ --rloo_k 2 \ diff --git a/tests/test_ppov2_trainer.py b/tests/test_ppov2_trainer.py index a954f3a4fb..ef2d9121fe 100644 --- a/tests/test_ppov2_trainer.py +++ b/tests/test_ppov2_trainer.py @@ -16,7 +16,7 @@ def test(): command = """\ -python -i examples/scripts/minimal/ppo.py \ +python -i examples/scripts/ppo/ppo.py \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ --per_device_train_batch_size 5 \ diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 452f63af27..d19b219c50 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -18,7 +18,7 @@ def test(): command = """\ -python -i examples/scripts/minimal/rloo.py \ +python -i examples/scripts/rloo/rloo.py \ --learning_rate 3e-6 \ --output_dir models/minimal/rloo \ --per_device_train_batch_size 5 \ diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index 2ab3e11749..dc74f3b352 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -135,7 +135,7 @@ def __init__( ######### for module in [policy, ref_policy, value_model, reward_model]: disable_dropout_in_model(module) - if args.truncate_token and args.truncate_token == "eos": + if args.stop_token and args.stop_token == "eos": args.stop_token_id = tokenizer.eos_token_id self.model = PolicyAndValueWrapper(policy, value_model) self.create_optimizer_and_scheduler(num_training_steps=args.num_updates) @@ -285,7 +285,7 @@ def repeat_generator(): query_response, logits = generate( unwrapped_model.policy, query, - tokenizer, + tokenizer.pad_token_id, generation_config, ) response = query_response[:, context_length:] @@ -407,7 +407,7 @@ def repeat_generator(): mb_query_responses = query_responses[micro_batch_inds] mb_logprobs = logprobs[micro_batch_inds] - output, vpred_temp = forward(model, mb_query_responses, tokenizer) + output, vpred_temp = forward(model, mb_query_responses, tokenizer.pad_token_id) logits = output.logits[:, context_length - 1 : -1] logits /= args.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) @@ -543,7 +543,7 @@ def generate_completions(self, sampling: bool = False): query_response, _ = generate( unwrapped_model.policy, query, - tokenizer, + tokenizer.pad_token_id, generation_config, ) response = query_response[:, context_length:] diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index dd1039c3f4..4c3c303e83 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -37,16 +37,16 @@ class RLOOConfig(OnpolicyRuntimeConfig, TrainingArguments): """the name of the pretrained model to use""" response_length: int = 53 """the length of the response""" - truncate_token: Optional[Literal["eos"]] = None - """the truncate token""" - truncate_token_id: Optional[int] = None - """the truncation token id""" + stop_token: Optional[Literal["eos"]] = None + """the stop token""" + stop_token_id: Optional[int] = None + """the stop token id""" temperature: float = 0.7 """the sampling temperature""" penalty_reward_value: int = -1 - """the reward value for responses that do not contain `truncate_token_id`""" + """the reward value for responses that do not contain `stop_token_id`""" non_eos_penalty: bool = False - """whether to penalize responses that do not contain `truncate_token_id`""" + """whether to penalize responses that do not contain `stop_token_id`""" reward_model_path: str = "EleutherAI/pythia-160m" """the path to the reward model""" sft_model_path: str = "EleutherAI/pythia-160m" diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 50c178a967..ecd95d5020 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -120,8 +120,8 @@ def __init__( ######### for module in [policy, ref_policy, reward_model]: disable_dropout_in_model(module) - if args.truncate_token and args.truncate_token == "eos": - args.truncate_token_id = tokenizer.eos_token_id + if args.stop_token and args.stop_token == "eos": + args.stop_token_id = tokenizer.eos_token_id self.model = policy self.create_optimizer_and_scheduler(num_training_steps=args.num_updates) @@ -246,7 +246,7 @@ def repeat_generator(): query_response, logits = generate( unwrapped_model, query, - tokenizer, + tokenizer.pad_token_id, generation_config, ) response = query_response[:, context_length:] @@ -268,7 +268,7 @@ def repeat_generator(): # Response Processing 1. truncate response after the first occurrence of `truncate_token_id` postprocessed_response = response if ( - args.truncate_token_id is not None + args.stop_token_id is not None ): # handle the edge case when truncate_token_id exists but is 0 postprocessed_response = truncate_response( args.stop_token_id, tokenizer.pad_token_id, response @@ -342,7 +342,7 @@ def repeat_generator(): mb_query_responses = query_responses[micro_batch_inds] mb_logprobs = logprobs[micro_batch_inds] - output = forward(model, mb_query_responses, tokenizer) + output = forward(model, mb_query_responses, tokenizer.pad_token_id) logits = output.logits[:, context_length - 1 : -1] logits /= args.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) @@ -439,12 +439,12 @@ def generate_completions(self, sampling: bool = False): query_response, _ = generate( unwrapped_model, query, - tokenizer, + tokenizer.pad_token_id, generation_config, ) response = query_response[:, context_length:] postprocessed_response = response - if args.truncate_token_id is not None: # handle the edge case when truncate_token_id exists but is 0 + if args.stop_token_id is not None: # handle the edge case when truncate_token_id exists but is 0 postprocessed_response = truncate_response(args.stop_token_id, tokenizer.pad_token_id, response) table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response))) From 9c03c08c7f422d6e34a81752a63fc1f12cce9031 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 23 May 2024 12:51:40 +0000 Subject: [PATCH 2/2] update terminology to use stop token --- docs/source/ppov2_trainer.md | 6 +++--- docs/source/rloo_trainer.md | 8 ++++---- examples/scripts/ppo/ppo_tldr.py | 4 ++-- examples/scripts/ppo/ppo_zephyr.py | 4 ++-- examples/scripts/rloo/rloo_tldr.py | 4 ++-- examples/scripts/rloo/rloo_zephyr.py | 4 ++-- tests/test_ppov2_trainer.py | 2 +- tests/test_rloo_trainer.py | 2 +- trl/trainer/rloo_trainer.py | 10 ++++------ 9 files changed, 21 insertions(+), 23 deletions(-) diff --git a/docs/source/ppov2_trainer.md b/docs/source/ppov2_trainer.md index c217f7921b..f9c8aaa58d 100644 --- a/docs/source/ppov2_trainer.md +++ b/docs/source/ppov2_trainer.md @@ -55,7 +55,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an * Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it. * Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint. * Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`. -* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --truncate_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions. +* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --stop_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions. ## What is my model doing exactly? @@ -186,7 +186,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ --local_rollout_forward_batch_size 16 \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ # 6.9B PPO experiment accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ @@ -201,7 +201,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --reward_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__reward__tldr \ --local_rollout_forward_batch_size 2 \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ ``` 1B experiment can be found here: diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 8581e0b7fc..1e47b5b019 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -15,7 +15,7 @@ References: To just run a RLOO script to make sure the trainer can run, you can run the following command to train a RLOO model with a dummy reward model. ```bash -python examples/scripts/minimal/rloo.py \ +python examples/scripts/rloo/rloo.py \ --learning_rate 3e-6 \ --output_dir models/minimal/rloo \ --per_device_train_batch_size 64 \ @@ -57,7 +57,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an * Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it. * Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint. * Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`. -* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --truncate_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions. +* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --stop_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions. ## What is my model doing exactly? @@ -226,7 +226,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ --local_rollout_forward_batch_size 16 \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ --kl_coef 0.03 # 6.9B RLOO experiment @@ -244,7 +244,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --reward_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__reward__tldr \ --local_rollout_forward_batch_size 2 \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ --kl_coef 0.03 ``` diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 936d676480..d9ed61f60f 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -25,7 +25,7 @@ --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ --response_length 53 \ --sanity_check @@ -41,7 +41,7 @@ --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ --local_rollout_forward_batch_size 16 \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ """ diff --git a/examples/scripts/ppo/ppo_zephyr.py b/examples/scripts/ppo/ppo_zephyr.py index afbf585916..8c5d98e35e 100644 --- a/examples/scripts/ppo/ppo_zephyr.py +++ b/examples/scripts/ppo/ppo_zephyr.py @@ -25,7 +25,7 @@ --sft_model_path EleutherAI/pythia-1b-deduped \ --reward_model_path EleutherAI/pythia-1b-deduped \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ examples/scripts/ppo/ppo_zephyr.py \ @@ -43,7 +43,7 @@ --deepspeed3 \ --kl_coef 0.10 \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ --response_length 512 \ """ diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index d2729e028c..98e5f1bf58 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -25,7 +25,7 @@ --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ --response_length 53 \ --sanity_check @@ -43,7 +43,7 @@ --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ --local_rollout_forward_batch_size 16 \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ """ diff --git a/examples/scripts/rloo/rloo_zephyr.py b/examples/scripts/rloo/rloo_zephyr.py index aef5bdf058..0121abd395 100644 --- a/examples/scripts/rloo/rloo_zephyr.py +++ b/examples/scripts/rloo/rloo_zephyr.py @@ -25,7 +25,7 @@ --sft_model_path HuggingFaceH4/mistral-7b-sft-beta \ --reward_model_path weqweasdas/RM-Mistral-7B \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ --response_length 53 \ --sanity_check accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ @@ -45,7 +45,7 @@ --deepspeed3 \ --kl_coef 0.10 \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ --response_length 512 \ """ diff --git a/tests/test_ppov2_trainer.py b/tests/test_ppov2_trainer.py index ef2d9121fe..7e8fe2f3fd 100644 --- a/tests/test_ppov2_trainer.py +++ b/tests/test_ppov2_trainer.py @@ -24,7 +24,7 @@ def test(): --total_episodes 10 \ --model_name_or_path EleutherAI/pythia-14m \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ """ subprocess.run( command, diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index d19b219c50..fbeec86125 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -26,7 +26,7 @@ def test(): --total_episodes 10 \ --model_name_or_path EleutherAI/pythia-14m \ --non_eos_penalty \ - --truncate_token eos \ + --stop_token eos \ """ subprocess.run( command, diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index ecd95d5020..02f69df5e5 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -265,11 +265,9 @@ def repeat_generator(): del ref_output, ref_logits, ref_all_logprob torch.cuda.empty_cache() - # Response Processing 1. truncate response after the first occurrence of `truncate_token_id` + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` postprocessed_response = response - if ( - args.stop_token_id is not None - ): # handle the edge case when truncate_token_id exists but is 0 + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response( args.stop_token_id, tokenizer.pad_token_id, response ) @@ -299,7 +297,7 @@ def repeat_generator(): torch.cuda.empty_cache() gc.collect() - # Response Processing 3. filter response. Ensure that the sample contains truncate_token_id + # Response Processing 3. filter response. Ensure that the sample contains stop_token_id # responses not passing that filter will receive a low (fixed) score # only query humans on responses that pass that filter contain_eos_token = torch.any(postprocessed_responses == tokenizer.eos_token_id, dim=-1) @@ -444,7 +442,7 @@ def generate_completions(self, sampling: bool = False): ) response = query_response[:, context_length:] postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when truncate_token_id exists but is 0 + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response(args.stop_token_id, tokenizer.pad_token_id, response) table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True))) table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response)))