Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ppov2 test case #1661

Merged
merged 2 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/ppov2_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ References:
To just run a PPO script to make sure the trainer can run, you can run the following command to train a PPO model with a dummy reward model.

```bash
python -i examples/scripts/minimal/rloo.py \
python -i examples/scripts/ppo/ppo.py \
--learning_rate 3e-6 \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
Expand Down Expand Up @@ -55,7 +55,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --truncate_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions.
* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --stop_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions.


## What is my model doing exactly?
Expand Down Expand Up @@ -186,7 +186,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \

# 6.9B PPO experiment
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
Expand All @@ -201,7 +201,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml
--reward_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__reward__tldr \
--local_rollout_forward_batch_size 2 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
```

1B experiment can be found here:
Expand Down
8 changes: 4 additions & 4 deletions docs/source/rloo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ References:
To just run a RLOO script to make sure the trainer can run, you can run the following command to train a RLOO model with a dummy reward model.

```bash
python examples/scripts/minimal/rloo.py \
python examples/scripts/rloo/rloo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/rloo \
--per_device_train_batch_size 64 \
Expand Down Expand Up @@ -57,7 +57,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --truncate_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions.
* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --stop_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions.


## What is my model doing exactly?
Expand Down Expand Up @@ -226,7 +226,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--kl_coef 0.03

# 6.9B RLOO experiment
Expand All @@ -244,7 +244,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml
--reward_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__reward__tldr \
--local_rollout_forward_batch_size 2 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--kl_coef 0.03
```

Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


"""
python -i examples/scripts/minimal/ppo.py \
python -i examples/scripts/ppo/ppo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 64 \
Expand All @@ -24,7 +24,7 @@
--non_eos_penalty \

accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/minimal/ppo.py \
examples/scripts/ppo/ppo.py \
--output_dir models/minimal/ppo \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--response_length 53 \
--sanity_check

Expand All @@ -41,7 +41,7 @@
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
"""


Expand Down
8 changes: 4 additions & 4 deletions examples/scripts/ppo/ppo_zephyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


"""
python -i examples/scripts/minimal/ppo_zephyr.py \
python -i examples/scripts/ppo/ppo_zephyr.py \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 1 \
Expand All @@ -25,10 +25,10 @@
--sft_model_path EleutherAI/pythia-1b-deduped \
--reward_model_path EleutherAI/pythia-1b-deduped \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \

accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/minimal/ppo_zephyr.py \
examples/scripts/ppo/ppo_zephyr.py \
--output_dir models/minimal/ppo_zephyr10 \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
Expand All @@ -43,7 +43,7 @@
--deepspeed3 \
--kl_coef 0.10 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--response_length 512 \
"""

Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/rloo/rloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


"""
python -i examples/scripts/minimal/rloo.py \
python -i examples/scripts/rloo/rloo.py \
--learning_rate 3e-6 \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
Expand All @@ -25,7 +25,7 @@
--model_name_or_path EleutherAI/pythia-1b-deduped \
--non_eos_penalty \
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/minimal/rloo.py \
examples/scripts/rloo/rloo.py \
--output_dir models/minimal/rloo \
--rloo_k 2 \
--num_ppo_epochs 1 \
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/rloo/rloo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--response_length 53 \
--sanity_check

Expand All @@ -43,7 +43,7 @@
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
"""


Expand Down
8 changes: 4 additions & 4 deletions examples/scripts/rloo/rloo_zephyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


"""
python -i examples/scripts/minimal/rloo_zephyr.py \
python -i examples/scripts/rloo/rloo_zephyr.py \
--learning_rate 3e-6 \
--output_dir models/minimal/rloo_zephyr \
--per_device_train_batch_size 64 \
Expand All @@ -25,11 +25,11 @@
--sft_model_path HuggingFaceH4/mistral-7b-sft-beta \
--reward_model_path weqweasdas/RM-Mistral-7B \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--response_length 53 \
--sanity_check
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/minimal/rloo_zephyr.py \
examples/scripts/rloo/rloo_zephyr.py \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
--rloo_k 2 \
Expand All @@ -45,7 +45,7 @@
--deepspeed3 \
--kl_coef 0.10 \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
--response_length 512 \
"""

Expand Down
4 changes: 2 additions & 2 deletions tests/test_ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

def test():
command = """\
python -i examples/scripts/minimal/ppo.py \
python -i examples/scripts/ppo/ppo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 5 \
--gradient_accumulation_steps 1 \
--total_episodes 10 \
--model_name_or_path EleutherAI/pythia-14m \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
"""
subprocess.run(
command,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@

def test():
command = """\
python -i examples/scripts/minimal/rloo.py \
python -i examples/scripts/rloo/rloo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/rloo \
--per_device_train_batch_size 5 \
--gradient_accumulation_steps 1 \
--total_episodes 10 \
--model_name_or_path EleutherAI/pythia-14m \
--non_eos_penalty \
--truncate_token eos \
--stop_token eos \
"""
subprocess.run(
command,
Expand Down
8 changes: 4 additions & 4 deletions trl/trainer/ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
#########
for module in [policy, ref_policy, value_model, reward_model]:
disable_dropout_in_model(module)
if args.truncate_token and args.truncate_token == "eos":
if args.stop_token and args.stop_token == "eos":
args.stop_token_id = tokenizer.eos_token_id
self.model = PolicyAndValueWrapper(policy, value_model)
self.create_optimizer_and_scheduler(num_training_steps=args.num_updates)
Expand Down Expand Up @@ -285,7 +285,7 @@ def repeat_generator():
query_response, logits = generate(
unwrapped_model.policy,
query,
tokenizer,
tokenizer.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
Expand Down Expand Up @@ -407,7 +407,7 @@ def repeat_generator():
mb_query_responses = query_responses[micro_batch_inds]
mb_logprobs = logprobs[micro_batch_inds]

output, vpred_temp = forward(model, mb_query_responses, tokenizer)
output, vpred_temp = forward(model, mb_query_responses, tokenizer.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
new_all_logprobs = F.log_softmax(logits, dim=-1)
Expand Down Expand Up @@ -543,7 +543,7 @@ def generate_completions(self, sampling: bool = False):
query_response, _ = generate(
unwrapped_model.policy,
query,
tokenizer,
tokenizer.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
Expand Down
12 changes: 6 additions & 6 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ class RLOOConfig(OnpolicyRuntimeConfig, TrainingArguments):
"""the name of the pretrained model to use"""
response_length: int = 53
"""the length of the response"""
truncate_token: Optional[Literal["eos"]] = None
"""the truncate token"""
truncate_token_id: Optional[int] = None
"""the truncation token id"""
stop_token: Optional[Literal["eos"]] = None
"""the stop token"""
stop_token_id: Optional[int] = None
"""the stop token id"""
temperature: float = 0.7
"""the sampling temperature"""
penalty_reward_value: int = -1
"""the reward value for responses that do not contain `truncate_token_id`"""
"""the reward value for responses that do not contain `stop_token_id`"""
non_eos_penalty: bool = False
"""whether to penalize responses that do not contain `truncate_token_id`"""
"""whether to penalize responses that do not contain `stop_token_id`"""
reward_model_path: str = "EleutherAI/pythia-160m"
"""the path to the reward model"""
sft_model_path: str = "EleutherAI/pythia-160m"
Expand Down
20 changes: 9 additions & 11 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def __init__(
#########
for module in [policy, ref_policy, reward_model]:
disable_dropout_in_model(module)
if args.truncate_token and args.truncate_token == "eos":
args.truncate_token_id = tokenizer.eos_token_id
if args.stop_token and args.stop_token == "eos":
args.stop_token_id = tokenizer.eos_token_id
self.model = policy
self.create_optimizer_and_scheduler(num_training_steps=args.num_updates)

Expand Down Expand Up @@ -246,7 +246,7 @@ def repeat_generator():
query_response, logits = generate(
unwrapped_model,
query,
tokenizer,
tokenizer.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
Expand All @@ -265,11 +265,9 @@ def repeat_generator():
del ref_output, ref_logits, ref_all_logprob
torch.cuda.empty_cache()

# Response Processing 1. truncate response after the first occurrence of `truncate_token_id`
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if (
args.truncate_token_id is not None
): # handle the edge case when truncate_token_id exists but is 0
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, tokenizer.pad_token_id, response
)
Expand Down Expand Up @@ -299,7 +297,7 @@ def repeat_generator():
torch.cuda.empty_cache()
gc.collect()

# Response Processing 3. filter response. Ensure that the sample contains truncate_token_id
# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
# responses not passing that filter will receive a low (fixed) score
# only query humans on responses that pass that filter
contain_eos_token = torch.any(postprocessed_responses == tokenizer.eos_token_id, dim=-1)
Expand Down Expand Up @@ -342,7 +340,7 @@ def repeat_generator():
mb_query_responses = query_responses[micro_batch_inds]
mb_logprobs = logprobs[micro_batch_inds]

output = forward(model, mb_query_responses, tokenizer)
output = forward(model, mb_query_responses, tokenizer.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
new_all_logprobs = F.log_softmax(logits, dim=-1)
Expand Down Expand Up @@ -439,12 +437,12 @@ def generate_completions(self, sampling: bool = False):
query_response, _ = generate(
unwrapped_model,
query,
tokenizer,
tokenizer.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
postprocessed_response = response
if args.truncate_token_id is not None: # handle the edge case when truncate_token_id exists but is 0
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(args.stop_token_id, tokenizer.pad_token_id, response)
table["query"].extend(gather_object(tokenizer.batch_decode(query, skip_special_tokens=True)))
table["model response"].extend(gather_object(tokenizer.batch_decode(postprocessed_response)))
Expand Down
Loading