Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let EarlyStoppingCallback not require load_best_model_at_end #35101

Merged
merged 2 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,10 +707,9 @@ def check_metric_value(self, args, state, control, metric_value):
self.early_stopping_patience_counter += 1

def on_train_begin(self, args, state, control, **kwargs):
assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True"
assert (
args.metric_for_best_model is not None
), "EarlyStoppingCallback requires metric_for_best_model is defined"
), "EarlyStoppingCallback requires metric_for_best_model to be defined"
assert (
args.eval_strategy != IntervalStrategy.NO
), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch"
Expand Down
17 changes: 17 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3252,6 +3252,23 @@ def test_early_stopping_callback(self):
except AssertionError:
self.assertEqual(trainer.state.global_step, 0)

# even if load_best_model_at_end is False, `best_model_checkpoint` should be set
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=tmp_dir,
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
load_best_model_at_end=False,
eval_strategy=IntervalStrategy.EPOCH,
save_strategy=IntervalStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
train_output = trainer.train()
self.assertIsNotNone(trainer.state.best_model_checkpoint)

def test_flos_extraction(self):
trainer = get_regression_trainer(learning_rate=0.1)

Expand Down
Loading