From 7ab86cae02ca007d669ed21ee025c3e702a54ff2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 6 Mar 2025 16:11:59 +0000 Subject: [PATCH 1/2] static cache grpo --- trl/trainer/grpo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7d0e523d15..03c217036c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -545,6 +545,8 @@ def new_group_context(): top_k=args.top_k, min_p=args.min_p, repetition_penalty=args.repetition_penalty, + use_cache=True, + cache_implementation="static", # static cache is faster in this case ) # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the From a4d433771ae8601ae6a7f15d75a6121a71a3c521 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 7 Mar 2025 11:12:45 +0000 Subject: [PATCH 2/2] disabling gradient chekpt for gen --- trl/trainer/grpo_trainer.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 03c217036c..cc6e1f1c49 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -18,7 +18,7 @@ import textwrap import warnings from collections import defaultdict -from typing import Any, Callable, Optional, Sized, Union +from typing import Any, Callable, Optional, Sized, Union, Generator from unittest.mock import patch import torch @@ -162,6 +162,20 @@ def __iter__(self): def __len__(self) -> int: return self.num_samples * self.mini_repeat_count * self.repeat_count +@contextlib.contextmanager +def disable_gradient_checkpointing(model: PreTrainedModel) -> Generator[None, None, None]: + """ + Temporarily disables gradient checkpointing in the model, if it is enabled. + + It is usefull when using the model to generate completions, while training it with gradient checkpointing. + + Args: + model (`PreTrainedModel`): Model to disable gradient checkpointing for. + """ + value = model.base_model.gradient_checkpointing + model.base_model.gradient_checkpointing = False + yield + model.base_model.gradient_checkpointing = value class GRPOTrainer(Trainer): """ @@ -765,9 +779,10 @@ def _generate_and_score_completions( else: # Regular generation path with unwrap_model_for_generation(self.model_wrapped, self.accelerator) as unwrapped_model: - prompt_completion_ids = unwrapped_model.generate( - prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config - ) + with disable_gradient_checkpointing(unwrapped_model): + prompt_completion_ids = unwrapped_model.generate( + prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config + ) # Compute prompt length and extract completion ids prompt_length = prompt_ids.size(1)