Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static cache GRPO #3023

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import textwrap
import warnings
from collections import defaultdict
from typing import Any, Callable, Optional, Sized, Union
from typing import Any, Callable, Optional, Sized, Union, Generator
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -162,6 +162,20 @@ def __iter__(self):
def __len__(self) -> int:
return self.num_samples * self.mini_repeat_count * self.repeat_count

@contextlib.contextmanager
def disable_gradient_checkpointing(model: PreTrainedModel) -> Generator[None, None, None]:
"""
Temporarily disables gradient checkpointing in the model, if it is enabled.

It is usefull when using the model to generate completions, while training it with gradient checkpointing.

Args:
model (`PreTrainedModel`): Model to disable gradient checkpointing for.
"""
value = model.base_model.gradient_checkpointing
model.base_model.gradient_checkpointing = False
yield
model.base_model.gradient_checkpointing = value

class GRPOTrainer(Trainer):
"""
Expand Down Expand Up @@ -545,6 +559,8 @@ def new_group_context():
top_k=args.top_k,
min_p=args.min_p,
repetition_penalty=args.repetition_penalty,
use_cache=True,
cache_implementation="static", # static cache is faster in this case
)

# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
Expand Down Expand Up @@ -763,9 +779,10 @@ def _generate_and_score_completions(
else:
# Regular generation path
with unwrap_model_for_generation(self.model_wrapped, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
)
with disable_gradient_checkpointing(unwrapped_model):
prompt_completion_ids = unwrapped_model.generate(
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
)

# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)
Expand Down
Loading