diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py index df1ba42be0..bf2d2b7684 100644 --- a/tests/test_activation_offloading.py +++ b/tests/test_activation_offloading.py @@ -60,7 +60,7 @@ def test_offloading_with_sft_trainer(self) -> None: output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=1, - enable_activation_offloading=True, + activation_offloading=True, report_to="none", ) diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index 9f4d5a91bd..055440662b 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -333,7 +333,7 @@ def noop(tensor): def get_act_offloading_ctx_manager( model: nn.Module, - enable_activation_offloading: bool, + activation_offloading: bool, use_pin_memory: bool = True, use_streams: bool = True, min_offload_size: int = 1024, @@ -341,7 +341,7 @@ def get_act_offloading_ctx_manager( ) -> Union[OffloadActivations, contextlib.nullcontext]: """ Returns the activation offloading context manager for the model, which will be a null context if - `enable_activation_offloading` is `False`. + `activation_offloading` is `False`. If activation offloading is enabled, we return the OffloadActivations context manager. If activation offloading is disabled, we return a NoOpManager context manager. @@ -349,14 +349,14 @@ def get_act_offloading_ctx_manager( Args: model (`nn.Module`): Model to wrap with the activation offloading context manager. - enable_activation_offloading (`bool`): + activation_offloading (`bool`): Whether or not to enable activation offloading for the model. Returns: `contextlib.ContextDecorator`: Activation offloading context manager for the model. """ - if not enable_activation_offloading: + if not activation_offloading: return contextlib.nullcontext() activations_handling_ctx = OffloadActivations( diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index f16ba54163..15ff1185d0 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -39,7 +39,7 @@ class SFTConfig(TrainingArguments): argument of the [`SFTTrainer`] is provided as a string. use_liger (`bool`, *optional*, defaults to `False`): Monkey patch the model with Liger kernels to increase throughput and reduce memory usage. - enable_activation_offloading (`bool`, *optional*, defaults to `False`): + activation_offloading (`bool`, *optional*, defaults to `False`): Whether to offload the activations to the CPU. > Parameters that control the data preprocessing @@ -145,7 +145,7 @@ class SFTConfig(TrainingArguments): metadata={"help": "Deprecated. Use `max_length` instead."}, ) - enable_activation_offloading: bool = field( + activation_offloading: bool = field( default=False, metadata={"help": "Whether to offload the activations to the CPU."}, ) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index e71e8ad9d4..eea19f1cd5 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -244,7 +244,7 @@ def __init__( # Initialize activation offloading context self.activation_offload_context = get_act_offloading_ctx_manager( model=self.model, - enable_activation_offloading=self.args.enable_activation_offloading, + activation_offloading=self.args.activation_offloading, ) # Add tags for models that have been loaded with the correct transformers version