Skip to content

Commit

Permalink
fix info message when max training time reached (#7780)
Browse files Browse the repository at this point in the history
* call time_elapsed

* elapsed formatting

* format

* update test

* changelog
  • Loading branch information
awaelchli authored and tchaton committed Jun 1, 2021
1 parent 16c9049 commit ec371f3
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674))
- Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677))
- Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692))
- Fixed formatting of info message when max training time reached ([#7780](https://github.com/PyTorchLightning/pytorch-lightning/pull/7780))

## [1.3.2] - 2021-05-18

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/callbacks/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,5 @@ def _check_time_remaining(self, trainer: 'pl.Trainer') -> None:
should_stop = trainer.accelerator.broadcast(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop and self._verbose:
rank_zero_info(f"Time limit reached. Elapsed time is {self.time_elapsed}. Signaling Trainer to stop.")
elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING)))
rank_zero_info(f"Time limit reached. Elapsed time is {elapsed}. Signaling Trainer to stop.")
7 changes: 5 additions & 2 deletions tests/callbacks/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_timer_time_remaining(time_mock):
assert round(timer.time_elapsed()) == 3


def test_timer_stops_training(tmpdir):
def test_timer_stops_training(tmpdir, caplog):
""" Test that the timer stops training before reaching max_epochs """
model = BoringModel()
duration = timedelta(milliseconds=100)
Expand All @@ -106,9 +106,12 @@ def test_timer_stops_training(tmpdir):
max_epochs=1000,
callbacks=[timer],
)
trainer.fit(model)
with caplog.at_level(logging.INFO):
trainer.fit(model)
assert trainer.global_step > 1
assert trainer.current_epoch < 999
assert "Time limit reached." in caplog.text
assert "Signaling Trainer to stop." in caplog.text


@pytest.mark.parametrize("interval", ["step", "epoch"])
Expand Down

0 comments on commit ec371f3

Please sign in to comment.