Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a callback hook right before the optimizer step #33444

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2417,6 +2417,8 @@ def _inner_training_loop(
else:
grad_norm = _grad_norm

self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)

self.optimizer.step()

self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T
"""
pass

def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients.
"""
pass

def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
Expand Down Expand Up @@ -475,6 +481,9 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T
control.should_save = False
return self.call_event("on_step_begin", args, state, control)

def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_pre_optimizer_step", args, state, control)

def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_optimizer_step", args, state, control)

Expand Down
5 changes: 4 additions & 1 deletion tests/trainer/test_trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def on_epoch_end(self, args, state, control, **kwargs):
def on_step_begin(self, args, state, control, **kwargs):
self.events.append("on_step_begin")

def on_pre_optimizer_step(self, args, state, control, **kwargs):
self.events.append("on_pre_optimizer_step")

def on_optimizer_step(self, args, state, control, **kwargs):
self.events.append("on_optimizer_step")

Expand Down Expand Up @@ -151,7 +154,7 @@ def get_expected_events(self, trainer):
expected_events.append("on_epoch_begin")
for _ in range(train_dl_len):
step += 1
expected_events += ["on_step_begin", "on_optimizer_step", "on_step_end"]
expected_events += ["on_step_begin", "on_pre_optimizer_step", "on_optimizer_step", "on_step_end"]
if step % trainer.args.logging_steps == 0:
expected_events.append("on_log")
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
Expand Down
Loading