Skip to content

Commit 7347c29

Browse files
qgallouedeckashif
andauthored
🥾 Allow bootstrap GRPO (#2829)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
1 parent 2106b31 commit 7347c29

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

trl/data_utils.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,21 @@ def apply_chat_template(
9090

9191
# Apply the chat template to the prompt, adding the generation prompt
9292
if "prompt" in example:
93+
last_role = example["prompt"][-1]["role"]
94+
if last_role == "user":
95+
add_generation_prompt = True
96+
continue_final_message = False
97+
elif last_role == "assistant":
98+
add_generation_prompt = False
99+
continue_final_message = True
100+
else:
101+
raise ValueError(f"Invalid role in the last message: {last_role}")
93102
prompt = tokenizer.apply_chat_template(
94-
example["prompt"], tools=tools, tokenize=False, add_generation_prompt=True
103+
example["prompt"],
104+
tools=tools,
105+
continue_final_message=continue_final_message,
106+
tokenize=False,
107+
add_generation_prompt=add_generation_prompt,
95108
)
96109

97110
# Apply the chat template to the entire prompt + completion
@@ -180,10 +193,13 @@ def maybe_apply_chat_template(
180193
Returns:
181194
`dict[str, str]`: The formatted example with the chat template applied.
182195
183-
Note:
184-
This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by
196+
Notes:
197+
- This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by
185198
`"text"`.
186199
200+
- In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt. Else,
201+
if the last role is `"assistant"`, the final message is continued.
202+
187203
Example:
188204
189205
```python

trl/trainer/grpo_trainer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,10 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
577577
# Decode the generated completions
578578
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
579579
if is_conversational(inputs[0]):
580-
completions = [[{"role": "assistant", "content": completion}] for completion in completions_text]
580+
completions = []
581+
for prompt, completion in zip(prompts, completions_text):
582+
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
583+
completions.append([{"role": "assistant", "content": bootstrap + completion}])
581584
else:
582585
completions = completions_text
583586

0 commit comments

Comments
 (0)