From bbb6ab3bf12088d3a87840189d4dd809a7e83fae Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 11 Sep 2024 18:43:48 -0400 Subject: [PATCH] add a callback hook right before the optimizer step --- src/transformers/trainer.py | 2 ++ src/transformers/trainer_callback.py | 9 +++++++++ tests/trainer/test_trainer_callback.py | 5 ++++- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 525708645c2c..f815c50d597f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2417,6 +2417,8 @@ def _inner_training_loop( else: grad_norm = _grad_norm + self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) + self.optimizer.step() self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 932fd937d26f..d457a65993db 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -344,6 +344,12 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T """ pass + def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients. + """ + pass + def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """ Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients. @@ -475,6 +481,9 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T control.should_save = False return self.call_event("on_step_begin", args, state, control) + def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_pre_optimizer_step", args, state, control) + def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): return self.call_event("on_optimizer_step", args, state, control) diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index 48145979e362..0d1e6645f9a5 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -78,6 +78,9 @@ def on_epoch_end(self, args, state, control, **kwargs): def on_step_begin(self, args, state, control, **kwargs): self.events.append("on_step_begin") + def on_pre_optimizer_step(self, args, state, control, **kwargs): + self.events.append("on_pre_optimizer_step") + def on_optimizer_step(self, args, state, control, **kwargs): self.events.append("on_optimizer_step") @@ -151,7 +154,7 @@ def get_expected_events(self, trainer): expected_events.append("on_epoch_begin") for _ in range(train_dl_len): step += 1 - expected_events += ["on_step_begin", "on_optimizer_step", "on_step_end"] + expected_events += ["on_step_begin", "on_pre_optimizer_step", "on_optimizer_step", "on_step_end"] if step % trainer.args.logging_steps == 0: expected_events.append("on_log") if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0: