From a69f0e93b375997995da189e71b27b57dd31ce04 Mon Sep 17 00:00:00 2001 From: Dhruv Pai Date: Tue, 28 May 2024 13:41:38 -0700 Subject: [PATCH 1/4] Modified test --- tests/trainer/test_trainer_callback.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index 9eeb1d5e412e..22fa1c91794a 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -77,6 +77,9 @@ def on_epoch_end(self, args, state, control, **kwargs): def on_step_begin(self, args, state, control, **kwargs): self.events.append("on_step_begin") + + def on_optimizer_step(self, args, state, control, **kwargs): + self.events.append("on_optimizer_step") def on_step_end(self, args, state, control, **kwargs): self.events.append("on_step_end") @@ -148,7 +151,7 @@ def get_expected_events(self, trainer): expected_events.append("on_epoch_begin") for _ in range(train_dl_len): step += 1 - expected_events += ["on_step_begin", "on_step_end"] + expected_events += ["on_step_begin", "on_optimizer_step", "on_step_end"] if step % trainer.args.logging_steps == 0: expected_events.append("on_log") if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0: From 95a457952733c10032838b0d600286da95b8b40d Mon Sep 17 00:00:00 2001 From: Dhruv Pai Date: Tue, 28 May 2024 13:47:29 -0700 Subject: [PATCH 2/4] Added on_optimizer_step to callbacks --- src/transformers/trainer.py | 2 ++ src/transformers/trainer_callback.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 58e5fd14b6ff..302bc5f52f56 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2304,6 +2304,8 @@ def _inner_training_loop( else: grad_norm = _grad_norm + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + self.optimizer.step() optimizer_was_run = not self.accelerator.optimizer_step_was_skipped diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 45ecf7c80c52..833de1621b8e 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -344,6 +344,12 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T several inputs. """ pass + + def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients. + """ + pass def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """ @@ -470,6 +476,9 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T control.should_save = False return self.call_event("on_step_begin", args, state, control) + def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_optimizer_step", args, state, control) + def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): return self.call_event("on_substep_end", args, state, control) From 5dcca4c8efd4d3cdb4dfa93fce1704f64ffc85af Mon Sep 17 00:00:00 2001 From: Dhruv Pai Date: Tue, 28 May 2024 13:48:02 -0700 Subject: [PATCH 3/4] Move callback after step is called --- src/transformers/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 302bc5f52f56..227f6d7f09ea 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2304,9 +2304,10 @@ def _inner_training_loop( else: grad_norm = _grad_norm - self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) self.optimizer.step() + + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) optimizer_was_run = not self.accelerator.optimizer_step_was_skipped if optimizer_was_run: From 87e740ff65de5b4bd4576cc81d2f78888228a483 Mon Sep 17 00:00:00 2001 From: Dhruv Pai Date: Tue, 28 May 2024 14:00:25 -0700 Subject: [PATCH 4/4] Added on optimizer step callback --- src/transformers/trainer.py | 3 +-- src/transformers/trainer_callback.py | 4 ++-- tests/trainer/test_trainer_callback.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 227f6d7f09ea..49e780306611 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2304,9 +2304,8 @@ def _inner_training_loop( else: grad_norm = _grad_norm - self.optimizer.step() - + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) optimizer_was_run = not self.accelerator.optimizer_step_was_skipped diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 833de1621b8e..207d8ebdffce 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -344,7 +344,7 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T several inputs. """ pass - + def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """ Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients. @@ -478,7 +478,7 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): return self.call_event("on_optimizer_step", args, state, control) - + def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): return self.call_event("on_substep_end", args, state, control) diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index 22fa1c91794a..edd73f29dc98 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -77,7 +77,7 @@ def on_epoch_end(self, args, state, control, **kwargs): def on_step_begin(self, args, state, control, **kwargs): self.events.append("on_step_begin") - + def on_optimizer_step(self, args, state, control, **kwargs): self.events.append("on_optimizer_step")