Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow GradientAccumulationPlugin to be configured from AcceleratorConfig #29589

Merged
merged 11 commits into from
Mar 28, 2024
35 changes: 33 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4094,15 +4094,46 @@ def _add_sm_patterns_to_gitignore(self) -> None:
self.repo.git_push()

def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
grad_acc_kwargs = {}
if self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs

# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs:
if self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"AccelerateConfig.gradient_accumulation_kwargs['num_steps'] is specified but TrainingArguments.gradient_accumulation_steps > 1. "
"If the gradient_accumulation_kwargs['num_steps'] is desired, set TrainingArguments.gradient_accumulation_steps == 1."
)
elif grad_acc_kwargs["num_steps"] > 1 and self.args.gradient_accumulation_steps == 1:
# give a warning that grad_acc_kwargs["num_steps"] > 1 will passthrough
warnings.warn(
'"num_steps" in AccelerateConfig.gradient_accumulation_kwargs takes precedence over TrainingArguments.gradient_accumulation_steps.'
)
else:
# the case grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps = 1
# passthrough without warning
pass
else:
# take the gradient_accumulation_steps setting from TrainingArguments.
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps

# this is legacy code. Are we sure is a good idea to overwrite without any warning?
grad_acc_kwargs["sync_with_dataloader"] = False

gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)

# create accelerator object
# convert AcceleratorConfig to Accelerator kwargs, but first
# remove the grad accumulation kwargs that should not be passed in
accelerator_config = {
k: v for k, v in self.args.accelerator_config.to_dict().items() if k != "gradient_accumulation_kwargs"
}
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
**self.args.accelerator_config.to_dict(),
**accelerator_config,
)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics
Expand Down
20 changes: 19 additions & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,18 @@ class AcceleratorConfig:
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.

gradient_accumulation_kwargs (`dict`, *optional*):
Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`].
Any of the following (optional) keys are acceptable:
num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if
the latter is set to 1, otherwise an exception will be raised.
adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`].
The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`.
sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch.
The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`.

The following key has no effect and will be treated as if it is not passed.
sync_with_dataloader (`bool`): Will be ignored and always set to `False`.
"""

# Data related arguments
Expand Down Expand Up @@ -1208,6 +1219,13 @@ class AcceleratorConfig:
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
},
)
gradient_accumulation_kwargs: Optional[Dict] = field(
default=None,
metadata={
"help": "Additional kwargs to configure gradient accumulation, see GradientAccumulationPlugin. The "
"This should exclude GradientAccumulationPlugin.num_steps that will be set to TrainingArguments.gradient_accumulation_steps."
},
)

@classmethod
def from_json_file(cls, json_file):
Expand Down
144 changes: 129 additions & 15 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_safetensors_available,
Expand Down Expand Up @@ -791,6 +792,8 @@ def test_tf32(self):
@require_sentencepiece
@require_tokenizers
class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
FLAG_ACCELERATOR_SUPPORT_GRAD_ACCUM_KWARGS = is_accelerate_available("0.28")

def setUp(self):
super().setUp()
args = TrainingArguments("..")
Expand Down Expand Up @@ -2499,6 +2502,10 @@ def test_accelerator_config_empty(self):
self.assertEqual(trainer.accelerator.even_batches, True)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)

if self.FLAG_ACCELERATOR_SUPPORT_GRAD_ACCUM_KWARGS:
# gradient accumulation kwargs configures gradient_state
self.assertNotIn("sync_each_batch", trainer.accelerator.gradient_state.plugin_kwargs)

def test_accelerator_config_from_dict(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
Expand All @@ -2507,22 +2514,29 @@ def test_accelerator_config_from_dict(self):
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()

accelerator_config = {
"split_batches": True,
"dispatch_batches": True,
"even_batches": False,
"use_seedable_sampler": True,
}
if self.FLAG_ACCELERATOR_SUPPORT_GRAD_ACCUM_KWARGS:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}

# Leaves all options as something *not* basic
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config={
"split_batches": True,
"dispatch_batches": True,
"even_batches": False,
"use_seedable_sampler": True,
},
accelerator_config=accelerator_config,
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
self.assertEqual(trainer.accelerator.dispatch_batches, True)
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)

if self.FLAG_ACCELERATOR_SUPPORT_GRAD_ACCUM_KWARGS:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)

def test_accelerator_config_from_yaml(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
Expand All @@ -2533,8 +2547,10 @@ def test_accelerator_config_from_yaml(self):
"split_batches": True,
"dispatch_batches": True,
"even_batches": False,
"use_seedable_sampler": False,
"use_seedable_sampler": True,
}
if self.FLAG_ACCELERATOR_SUPPORT_GRAD_ACCUM_KWARGS:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
json.dump(accelerator_config, f)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
Expand All @@ -2546,13 +2562,20 @@ def test_accelerator_config_from_yaml(self):
self.assertEqual(trainer.accelerator.split_batches, True)
self.assertEqual(trainer.accelerator.dispatch_batches, True)
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)

if self.FLAG_ACCELERATOR_SUPPORT_GRAD_ACCUM_KWARGS:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)

def test_accelerator_config_from_dataclass(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively

accelerator_config = AcceleratorConfig(
split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False
split_batches=True,
dispatch_batches=True,
even_batches=False,
use_seedable_sampler=False,
)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
Expand All @@ -2565,6 +2588,37 @@ def test_accelerator_config_from_dataclass(self):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)

@unittest.skipUnless(
FLAG_ACCELERATOR_SUPPORT_GRAD_ACCUM_KWARGS, "setting of gradient_accumulation_kwargs not supported"
)
def test_accelerate_config_from_dataclass_grad_accum(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively

grad_acc_kwargs = {
"num_steps": 10,
"adjust_scheduler": False,
"sync_with_dataloader": False,
"sync_each_batch": True,
}
accelerator_config = AcceleratorConfig(
split_batches=True,
dispatch_batches=True,
even_batches=False,
use_seedable_sampler=False,
gradient_accumulation_kwargs=grad_acc_kwargs,
)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
with tempfile.TemporaryDirectory() as tmp_dir:
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)

def test_accelerator_config_from_partial(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
Expand All @@ -2574,18 +2628,21 @@ def test_accelerator_config_from_partial(self):
eval_dataset = SampleIterableDataset()

# Leaves one option as something *not* basic
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config={
"split_batches": True,
},
)
accelerator_config = {
"split_batches": True,
}
if self.FLAG_ACCELERATOR_SUPPORT_GRAD_ACCUM_KWARGS:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
self.assertEqual(trainer.accelerator.dispatch_batches, None)
self.assertEqual(trainer.accelerator.even_batches, True)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)

if self.FLAG_ACCELERATOR_SUPPORT_GRAD_ACCUM_KWARGS:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)

def test_accelerator_config_from_dict_with_deprecated_args(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
Expand Down Expand Up @@ -2636,6 +2693,63 @@ def test_accelerator_config_only_deprecated_args(self):
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)

@unittest.skipUnless(
FLAG_ACCELERATOR_SUPPORT_GRAD_ACCUM_KWARGS, "setting of gradient_accumulation_kwargs not supported"
)
def test_accelerator_config_from_dict_grad_accum_num_steps(self):
with tempfile.TemporaryDirectory() as tmp_dir:
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()

# case - TrainingArguments.gradient_accumulation_steps == 1
# - gradient_accumulation_kwargs['num_steps] == 1
# no warning and grad accum set to 1
args = RegressionTrainingArguments(
output_dir=tmp_dir,
gradient_accumulation_steps=1,
accelerator_config={
"gradient_accumulation_kwargs": {
"num_steps": 1,
}
},
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 1)

# case - TrainingArguments.gradient_accumulation_steps == 1
# - gradient_accumulation_kwargs['num_steps] > 1
# gradient_accumulation_kwargs takes precedence with a warning
with self.assertWarns(UserWarning) as cm:
args = RegressionTrainingArguments(
output_dir=tmp_dir,
gradient_accumulation_steps=1,
accelerator_config={
"gradient_accumulation_kwargs": {
"num_steps": 10,
}
},
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertIn("num_steps", str(cm.warnings[0].message))
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)

# case - TrainingArguments.gradient_accumulation_steps > 1
# - gradient_accumulation_kwargs['num_steps] specified
# raise exception
args = RegressionTrainingArguments(
output_dir=tmp_dir,
gradient_accumulation_steps=2,
accelerator_config={
"gradient_accumulation_kwargs": {
"num_steps": 10,
}
},
)
with self.assertRaises(Exception) as context:
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertTrue("set TrainingArguments.gradient_accumulation_steps" in str(context.exception))


@require_torch
@is_staging_test
Expand Down
Loading