diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 33fbbd352f9df..1db4f74008ce8 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -189,7 +189,7 @@ def on_advance_end(self): # progress global step according to grads progress self._increment_accumulated_grad_global_step() - def on_run_end(self) -> List[List[STEP_OUTPUT]]: + def on_run_end(self) -> None: """Calls the on_epoch_end hook. Returns: @@ -198,32 +198,29 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]: Raises: MisconfigurationException: ``train_epoch_end`` does not return ``None`` """ - if self.batch_progress.current.ready == 0: - # dataloader/iterator did not produce a batch - return - # inform logger the batch loop has finished self.trainer.logger_connector.epoch_end_reached() - # prepare epoch output - processed_outputs = self._prepare_outputs(self._epoch_output, batch_mode=False) - # get the model and call model.training_epoch_end model = self.trainer.lightning_module - - if is_overridden("training_epoch_end", model): - # run training_epoch_end - # refresh the result for custom logging at the epoch level - model._current_fx_name = "training_epoch_end" - - # lightningmodule hook - training_epoch_end_output = model.training_epoch_end(processed_outputs) - - if training_epoch_end_output is not None: - raise MisconfigurationException( - "training_epoch_end expects a return of None. " - "HINT: remove the return statement in training_epoch_end" - ) + if is_overridden("training_epoch_end", model) and self._epoch_output: + processed_outputs = self._prepare_outputs(self._epoch_output, batch_mode=False) + # check that the dataloader/iterator produced a batch + if processed_outputs: + # run training_epoch_end + # refresh the result for custom logging at the epoch level + model._current_fx_name = "training_epoch_end" + + # lightningmodule hook + training_epoch_end_output = model.training_epoch_end(processed_outputs) + + if training_epoch_end_output is not None: + raise MisconfigurationException( + "training_epoch_end expects a return of None. " + "HINT: remove the return statement in training_epoch_end" + ) + # free memory + self._epoch_output = None self.trainer.fit_loop.epoch_progress.increment_processed() @@ -235,11 +232,6 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]: if self._num_training_batches_reached(self.is_last_batch): self.update_lr_schedulers("epoch", update_plateau_schedulers=True) - epoch_output = self._epoch_output - # free memory - self._epoch_output = None - return epoch_output - def teardown(self) -> None: self._results.cpu() self.batch_loop.teardown() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index db93678a3437a..bfcd3f15242da 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -226,11 +226,7 @@ def advance(self) -> None: data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader) with self.trainer.profiler.profile("run_training_epoch"): - # run train epoch - epoch_output = self.epoch_loop.run(data_fetcher) - - if epoch_output is None: - return + self.epoch_loop.run(data_fetcher) # the global step is manually decreased here due to backwards compatibility with existing loggers # as they expect that the same step is used when logging epoch end metrics even when the batch loop has diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 6f7da1e03b7ab..3db27591fb24a 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -868,10 +868,11 @@ def __len__(self): trainer.predict(model, dataloaders=dataloader) -def test_iterable_dataset_stop_iteration_at_epoch_beginning(): +@pytest.mark.parametrize("yield_at_all", (False, True)) +def test_iterable_dataset_stop_iteration_at_epoch_beginning(yield_at_all): """Test that the training loop skips execution if the iterator is empty from the start.""" - class RandomDataset(IterableDataset): + class TestDataset(IterableDataset): def __init__(self, gen): self.gen = gen @@ -879,23 +880,20 @@ def __iter__(self): return iter(self.gen()) class TestModel(BoringModel): - def train_dataloader(self): - return DataLoader(RandomDataset(self.gen), batch_size=2) - def gen(self): - # produce data in epoch 0 - # no data otherwise - if self.current_epoch == 0: + # produce data in epoch 0, no data otherwise + if yield_at_all and self.current_epoch == 0: yield torch.rand(32) yield torch.rand(32) yield torch.rand(32) model = TestModel() + train_dataloader = DataLoader(TestDataset(model.gen), batch_size=2) trainer = Trainer( default_root_dir=os.getcwd(), max_epochs=2, weights_summary=None # we expect the second epoch to be skipped ) - trainer.fit(model) - assert trainer.global_step == 2 + trainer.fit(model, train_dataloader=train_dataloader) + assert trainer.global_step == 2 * yield_at_all assert trainer.current_epoch == 1