Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow GradientAccumulationPlugin to be configured from AcceleratorConfig #29589

Merged
merged 11 commits into from
Mar 28, 2024
5 changes: 3 additions & 2 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
)
from .integrations.deepspeed import is_deepspeed_available
from .utils import (
ACCELERATE_MIN_VERSION,
is_accelerate_available,
is_apex_available,
is_aqlm_available,
Expand Down Expand Up @@ -354,11 +355,11 @@ def require_nltk(test_case):
return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)


def require_accelerate(test_case):
def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
"""
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
return unittest.skipUnless(is_accelerate_available(min_version), "test requires accelerate")(test_case)


def require_fsdp(test_case, min_version: str = "1.12.0"):
Expand Down
25 changes: 23 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4094,15 +4094,36 @@ def _add_sm_patterns_to_gitignore(self) -> None:
self.repo.git_push()

def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
grad_acc_kwargs = {}
if self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs

# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
elif "num_steps" not in grad_acc_kwargs:
# take the gradient_accumulation_steps setting from TrainingArguments.
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps

# this is legacy code. Are we sure is a good idea to overwrite without any warning?
grad_acc_kwargs["sync_with_dataloader"] = False

gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)

# create accelerator object
# convert AcceleratorConfig to Accelerator kwargs, but first
# remove the grad accumulation kwargs that should not be passed in
accelerator_config = {
k: v for k, v in self.args.accelerator_config.to_dict().items() if k != "gradient_accumulation_kwargs"
}
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
**self.args.accelerator_config.to_dict(),
**accelerator_config,
)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,15 @@ class AcceleratorConfig:
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
gradient_accumulation_kwargs (`dict`, *optional*):
Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`].
Any of the following (optional) keys are acceptable:
num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if
the latter is set to 1, otherwise an exception will be raised.
adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`].
The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`.
sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch.
The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`.

"""

Expand Down Expand Up @@ -1208,6 +1217,12 @@ class AcceleratorConfig:
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
},
)
gradient_accumulation_kwargs: Optional[Dict] = field(
default=None,
metadata={
"help": "Additional kwargs to configure gradient accumulation, see GradientAccumulationPlugin. The "
},
)

@classmethod
def from_json_file(cls, json_file):
Expand Down
121 changes: 108 additions & 13 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import sys
import tempfile
import unittest
from functools import partial
from itertools import product
from pathlib import Path
from typing import Dict, List
Expand Down Expand Up @@ -91,6 +92,7 @@
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_safetensors_available,
Expand Down Expand Up @@ -791,6 +793,9 @@ def test_tf32(self):
@require_sentencepiece
@require_tokenizers
class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
require_accelerate_version = partial(require_accelerate, min_version="0.28")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this logic, let's just use @require_accelerate(min_version="0.28.0") on the tests that need it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@muellerzr sorry can I clarify this remark?

  1. you highlighted both lines 796 and 797. Are you saying you want to remove the boolean GRAD_ACCUM_KWARGS_VERSION_AVAILABLE and use is_accelerate_available("0.28") directly in the conditional?
  2. are you saying you want to skip the partial bind and directly decorate the test as @is_accelerate_available("0.28")? If so, then I do not think we can follow the require_fsdp function style. Will need to rewrire require_accelerate into a "builder"
  def require_accelerate(min_version: str == "0.28"):
      def _require(test): 
          return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
      return _require

But the issue with this is the null pattern @require_accelerate doesnt work anymore, we need to instead write @require_accelerate(), i.e., with the extra () brackets.

Copy link
Contributor

@muellerzr muellerzr Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just mimic what's done in test_fsdp.py, make the partial decorator in the test file:

require_fsdp_version = require_fsdp
if is_accelerate_available():
    ...
    require_fsdp_version = partial(require_fsdp, min_version=FSDP_PYTORCH_VERSION)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, you can then just use @require_fsdp_version on the test


def setUp(self):
super().setUp()
args = TrainingArguments("..")
Expand Down Expand Up @@ -2499,6 +2504,10 @@ def test_accelerator_config_empty(self):
self.assertEqual(trainer.accelerator.even_batches, True)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)

if self.GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
# gradient accumulation kwargs configures gradient_state
self.assertNotIn("sync_each_batch", trainer.accelerator.gradient_state.plugin_kwargs)

def test_accelerator_config_from_dict(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
Expand All @@ -2507,22 +2516,29 @@ def test_accelerator_config_from_dict(self):
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()

accelerator_config = {
"split_batches": True,
"dispatch_batches": True,
"even_batches": False,
"use_seedable_sampler": True,
}
if self.GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}

# Leaves all options as something *not* basic
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config={
"split_batches": True,
"dispatch_batches": True,
"even_batches": False,
"use_seedable_sampler": True,
},
accelerator_config=accelerator_config,
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
self.assertEqual(trainer.accelerator.dispatch_batches, True)
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)

if self.GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)

def test_accelerator_config_from_yaml(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
Expand All @@ -2535,6 +2551,8 @@ def test_accelerator_config_from_yaml(self):
"even_batches": False,
"use_seedable_sampler": False,
}
if self.GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
json.dump(accelerator_config, f)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
Expand All @@ -2548,11 +2566,18 @@ def test_accelerator_config_from_yaml(self):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)

if self.GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)

def test_accelerator_config_from_dataclass(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively

accelerator_config = AcceleratorConfig(
split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False
split_batches=True,
dispatch_batches=True,
even_batches=False,
use_seedable_sampler=False,
)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
Expand All @@ -2565,6 +2590,35 @@ def test_accelerator_config_from_dataclass(self):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)

@require_accelerate_version
def test_accelerate_config_from_dataclass_grad_accum(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively

grad_acc_kwargs = {
"num_steps": 10,
"adjust_scheduler": False,
"sync_with_dataloader": False,
"sync_each_batch": True,
}
accelerator_config = AcceleratorConfig(
split_batches=True,
dispatch_batches=True,
even_batches=False,
use_seedable_sampler=False,
gradient_accumulation_kwargs=grad_acc_kwargs,
)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
with tempfile.TemporaryDirectory() as tmp_dir:
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)

def test_accelerator_config_from_partial(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
Expand All @@ -2574,18 +2628,21 @@ def test_accelerator_config_from_partial(self):
eval_dataset = SampleIterableDataset()

# Leaves one option as something *not* basic
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config={
"split_batches": True,
},
)
accelerator_config = {
"split_batches": True,
}
if self.GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes what the test does. Let's not do this please.

Copy link
Contributor Author

@fabianlim fabianlim Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes agree, in this case I would think its best to revert this test to its previous state and not introduce gradient_accumulation_kwargs at all in this particular test.

Since gradient_accumulation_kwargs being specified is already non-default, then we have a host of other tests that consider this.

Do you agree?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup

args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
self.assertEqual(trainer.accelerator.dispatch_batches, None)
self.assertEqual(trainer.accelerator.even_batches, True)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)

if self.GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)

def test_accelerator_config_from_dict_with_deprecated_args(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
Expand Down Expand Up @@ -2636,6 +2693,44 @@ def test_accelerator_config_only_deprecated_args(self):
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)

@require_accelerate_version
def test_accelerator_config_from_dict_grad_accum_num_steps(self):
with tempfile.TemporaryDirectory() as tmp_dir:
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()

# case - TrainingArguments.gradient_accumulation_steps == 1
# - gradient_accumulation_kwargs['num_steps] == 1
# results in grad accum set to 1
args = RegressionTrainingArguments(
output_dir=tmp_dir,
gradient_accumulation_steps=1,
accelerator_config={
"gradient_accumulation_kwargs": {
"num_steps": 1,
}
},
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 1)

# case - TrainingArguments.gradient_accumulation_steps > 1
# - gradient_accumulation_kwargs['num_steps] specified
# results in exception raised
args = RegressionTrainingArguments(
output_dir=tmp_dir,
gradient_accumulation_steps=2,
accelerator_config={
"gradient_accumulation_kwargs": {
"num_steps": 10,
}
},
)
with self.assertRaises(Exception) as context:
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))


@require_torch
@is_staging_test
Expand Down