diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 4419f055d4..a8d55f341a 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -834,7 +834,8 @@ def tokenize_function(examples): self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01") # max diff broken should be very off - self.assertGreater(max(diff_broken), 1.5, f"Difference {max(diff_broken)} is not greater than 1.5") + # updated target value compared original implementation https://github.com/huggingface/transformers/blob/v4.49.0/tests/trainer/test_trainer.py#L888 + self.assertGreater(max(diff_broken), 1.2, f"Difference {max(diff_broken)} is not greater than 1.2") loss_base = sum(base_loss_callback.losses) loss_broken = sum(broken_loss_callback.losses)