From 265663af6a64c884e8cb4ec27530039748e61f9e Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Fri, 31 Jan 2025 10:30:44 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=96=20Add=20GRPOTrainer=20to=20README.?= =?UTF-8?q?md=20(#2713)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [DOCS] add GRPOTrainer to README.md I replaced RLOOTrainer with GRPOTrainer because you thought you might want to keep it limited, but let me know if you want both. * Update README.md --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- README.md | 39 +++++++++++++-------------------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index b7f1e5b7aa..d0af49c83a 100644 --- a/README.md +++ b/README.md @@ -137,39 +137,26 @@ trainer = RewardTrainer( trainer.train() ``` -### `RLOOTrainer` +### `GRPOTrainer` -`RLOOTrainer` implements a [REINFORCE-style optimization](https://huggingface.co/papers/2402.14740) for RLHF that is more performant and memory-efficient than PPO. Here is a basic example of how to use the `RLOOTrainer`: +`GRPOTrainer` implements a [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models]([https://huggingface.co/papers/2402.14740](https://huggingface.co/papers/2402.03300)) for reinforcement learning. Group Relative Policy Optimization (GRPO) is more performant than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1). ```python -from trl import RLOOConfig, RLOOTrainer, apply_chat_template from datasets import load_dataset -from transformers import ( - AutoModelForCausalLM, - AutoModelForSequenceClassification, - AutoTokenizer, -) +from trl import GRPOConfig, GRPOTrainer -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -reward_model = AutoModelForSequenceClassification.from_pretrained( - "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1 -) -ref_policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") -policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/tldr", split="train") -dataset = load_dataset("trl-lib/ultrafeedback-prompt") -dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer}) -dataset = dataset.map(lambda x: tokenizer(x["prompt"]), remove_columns="prompt") +# Dummy reward function: rewards completions that are close to 20 characters +def reward_len(completions, **kwargs): + return [abs(20 - len(completion)) for completion in completions] -training_args = RLOOConfig(output_dir="Qwen2.5-0.5B-RL") -trainer = RLOOTrainer( - config=training_args, - processing_class=tokenizer, - policy=policy, - ref_policy=ref_policy, - reward_model=reward_model, - train_dataset=dataset["train"], - eval_dataset=dataset["test"], +training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10) +trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=reward_len, + args=training_args, + train_dataset=dataset, ) trainer.train() ```