From 1e385bb0246509b1adcaaa2a2cd9c92cf39d9306 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 5 Dec 2024 08:21:34 -0500 Subject: [PATCH 1/2] Bookmark --- src/transformers/trainer_callback.py | 3 +-- tests/trainer/test_trainer.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 7b711f65701d..45ee7de74eff 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -707,10 +707,9 @@ def check_metric_value(self, args, state, control, metric_value): self.early_stopping_patience_counter += 1 def on_train_begin(self, args, state, control, **kwargs): - assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True" assert ( args.metric_for_best_model is not None - ), "EarlyStoppingCallback requires metric_for_best_model is defined" + ), "EarlyStoppingCallback requires metric_for_best_model to be defined" assert ( args.eval_strategy != IntervalStrategy.NO ), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch" diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5658372fa713..bf9d362297b8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3252,6 +3252,23 @@ def test_early_stopping_callback(self): except AssertionError: self.assertEqual(trainer.state.global_step, 0) + # even if load_best_model_at_end is False, `best_model_checkpoint` should be set + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = get_regression_trainer( + output_dir=tmp_dir, + num_train_epochs=20, + gradient_accumulation_steps=1, + per_device_train_batch_size=16, + load_best_model_at_end=False, + eval_strategy=IntervalStrategy.EPOCH, + save_strategy=IntervalStrategy.EPOCH, + compute_metrics=AlmostAccuracy(), + metric_for_best_model="accuracy", + ) + trainer.add_callback(EarlyStoppingCallback(1, 0.0001)) + train_output = trainer.train() + self.assertIsNotNone(trainer.state.best_model_checkpoint) + def test_flos_extraction(self): trainer = get_regression_trainer(learning_rate=0.1) From 26216b6cf8e7231a037c84701b266a6ff1b5d973 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 10 Jan 2025 10:00:12 -0500 Subject: [PATCH 2/2] Add warning --- src/transformers/trainer_callback.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 45ee7de74eff..8f241a9db4a3 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -707,6 +707,11 @@ def check_metric_value(self, args, state, control, metric_value): self.early_stopping_patience_counter += 1 def on_train_begin(self, args, state, control, **kwargs): + if not args.load_best_model_at_end: + logger.warning( + "Using EarlyStoppingCallback without load_best_model_at_end=True. " + "Once training is finished, the best model will not be loaded automatically." + ) assert ( args.metric_for_best_model is not None ), "EarlyStoppingCallback requires metric_for_best_model to be defined"