Skip to content

Commit

Permalink
Merge branch 'activation-checkpoint' of https://github.com/huggingfac…
Browse files Browse the repository at this point in the history
…e/trl into activation-checkpoint
  • Loading branch information
qgallouedec committed Feb 25, 2025
2 parents 0aa23ad + 0d156db commit bfb1333
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tests/test_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_offloading_with_sft_trainer(self) -> None:
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=1,
enable_activation_offloading=True,
activation_offloading=True,
report_to="none",
)

Expand Down
8 changes: 4 additions & 4 deletions trl/models/activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,30 +333,30 @@ def noop(tensor):

def get_act_offloading_ctx_manager(
model: nn.Module,
enable_activation_offloading: bool,
activation_offloading: bool,
use_pin_memory: bool = True,
use_streams: bool = True,
min_offload_size: int = 1024,
max_fwd_stash_size: int = 5,
) -> Union[OffloadActivations, contextlib.nullcontext]:
"""
Returns the activation offloading context manager for the model, which will be a null context if
`enable_activation_offloading` is `False`.
`activation_offloading` is `False`.
If activation offloading is enabled, we return the OffloadActivations context manager.
If activation offloading is disabled, we return a NoOpManager context manager.
Args:
model (`nn.Module`):
Model to wrap with the activation offloading context manager.
enable_activation_offloading (`bool`):
activation_offloading (`bool`):
Whether or not to enable activation offloading for the model.
Returns:
`contextlib.ContextDecorator`:
Activation offloading context manager for the model.
"""
if not enable_activation_offloading:
if not activation_offloading:
return contextlib.nullcontext()

activations_handling_ctx = OffloadActivations(
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class SFTConfig(TrainingArguments):
argument of the [`SFTTrainer`] is provided as a string.
use_liger (`bool`, *optional*, defaults to `False`):
Monkey patch the model with Liger kernels to increase throughput and reduce memory usage.
enable_activation_offloading (`bool`, *optional*, defaults to `False`):
activation_offloading (`bool`, *optional*, defaults to `False`):
Whether to offload the activations to the CPU.
> Parameters that control the data preprocessing
Expand Down Expand Up @@ -145,7 +145,7 @@ class SFTConfig(TrainingArguments):
metadata={"help": "Deprecated. Use `max_length` instead."},
)

enable_activation_offloading: bool = field(
activation_offloading: bool = field(
default=False,
metadata={"help": "Whether to offload the activations to the CPU."},
)
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def __init__(
# Initialize activation offloading context
self.activation_offload_context = get_act_offloading_ctx_manager(
model=self.model,
enable_activation_offloading=self.args.enable_activation_offloading,
activation_offloading=self.args.activation_offloading,
)

# Add tags for models that have been loaded with the correct transformers version
Expand Down

0 comments on commit bfb1333

Please sign in to comment.