Skip to content

Commit

Permalink
📖 Add GRPOTrainer to README.md (#2713)
Browse files Browse the repository at this point in the history
* [DOCS] add GRPOTrainer to README.md

I replaced RLOOTrainer with GRPOTrainer because you thought you might want to keep it limited, but let me know if you want both.

* Update README.md

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
  • Loading branch information
burtenshaw and qgallouedec authored Jan 31, 2025
1 parent 5ab15d3 commit 265663a
Showing 1 changed file with 13 additions and 26 deletions.
39 changes: 13 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,39 +137,26 @@ trainer = RewardTrainer(
trainer.train()
```

### `RLOOTrainer`
### `GRPOTrainer`

`RLOOTrainer` implements a [REINFORCE-style optimization](https://huggingface.co/papers/2402.14740) for RLHF that is more performant and memory-efficient than PPO. Here is a basic example of how to use the `RLOOTrainer`:
`GRPOTrainer` implements a [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models]([https://huggingface.co/papers/2402.14740](https://huggingface.co/papers/2402.03300)) for reinforcement learning. Group Relative Policy Optimization (GRPO) is more performant than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).

```python
from trl import RLOOConfig, RLOOTrainer, apply_chat_template
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
)
from trl import GRPOConfig, GRPOTrainer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
reward_model = AutoModelForSequenceClassification.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/tldr", split="train")

dataset = load_dataset("trl-lib/ultrafeedback-prompt")
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
dataset = dataset.map(lambda x: tokenizer(x["prompt"]), remove_columns="prompt")
# Dummy reward function: rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
return [abs(20 - len(completion)) for completion in completions]

training_args = RLOOConfig(output_dir="Qwen2.5-0.5B-RL")
trainer = RLOOTrainer(
config=training_args,
processing_class=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10)
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)
trainer.train()
```
Expand Down

0 comments on commit 265663a

Please sign in to comment.