diff --git a/README.md b/README.md index b7f1e5b7aa..d0af49c83a 100644 --- a/README.md +++ b/README.md @@ -137,39 +137,26 @@ trainer = RewardTrainer( trainer.train() ``` -### `RLOOTrainer` +### `GRPOTrainer` -`RLOOTrainer` implements a [REINFORCE-style optimization](https://huggingface.co/papers/2402.14740) for RLHF that is more performant and memory-efficient than PPO. Here is a basic example of how to use the `RLOOTrainer`: +`GRPOTrainer` implements a [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models]([https://huggingface.co/papers/2402.14740](https://huggingface.co/papers/2402.03300)) for reinforcement learning. Group Relative Policy Optimization (GRPO) is more performant than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1). ```python -from trl import RLOOConfig, RLOOTrainer, apply_chat_template from datasets import load_dataset -from transformers import ( - AutoModelForCausalLM, - AutoModelForSequenceClassification, - AutoTokenizer, -) +from trl import GRPOConfig, GRPOTrainer -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -reward_model = AutoModelForSequenceClassification.from_pretrained( - "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1 -) -ref_policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/tldr", split="train") -dataset = load_dataset("trl-lib/ultrafeedback-prompt") -dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer}) -dataset = dataset.map(lambda x: tokenizer(x["prompt"]), remove_columns="prompt") +# Dummy reward function: rewards completions that are close to 20 characters +def reward_len(completions, **kwargs): + return [abs(20 - len(completion)) for completion in completions] -training_args = RLOOConfig(output_dir="Qwen2.5-0.5B-RL") -trainer = RLOOTrainer( - config=training_args, - processing_class=tokenizer, - policy=policy, - ref_policy=ref_policy, - reward_model=reward_model, - train_dataset=dataset["train"], - eval_dataset=dataset["test"], +training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10) +trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=reward_len, + args=training_args, + train_dataset=dataset, ) trainer.train() ```