Skip to content

Commit

Permalink
Remove unnecessary TrainingEpochLoop return (#9298)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Sep 6, 2021
1 parent 9a14f04 commit 05ff1b2
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 42 deletions.
46 changes: 19 additions & 27 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def on_advance_end(self):
# progress global step according to grads progress
self._increment_accumulated_grad_global_step()

def on_run_end(self) -> List[List[STEP_OUTPUT]]:
def on_run_end(self) -> None:
"""Calls the on_epoch_end hook.
Returns:
Expand All @@ -198,32 +198,29 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
Raises:
MisconfigurationException: ``train_epoch_end`` does not return ``None``
"""
if self.batch_progress.current.ready == 0:
# dataloader/iterator did not produce a batch
return

# inform logger the batch loop has finished
self.trainer.logger_connector.epoch_end_reached()

# prepare epoch output
processed_outputs = self._prepare_outputs(self._epoch_output, batch_mode=False)

# get the model and call model.training_epoch_end
model = self.trainer.lightning_module

if is_overridden("training_epoch_end", model):
# run training_epoch_end
# refresh the result for custom logging at the epoch level
model._current_fx_name = "training_epoch_end"

# lightningmodule hook
training_epoch_end_output = model.training_epoch_end(processed_outputs)

if training_epoch_end_output is not None:
raise MisconfigurationException(
"training_epoch_end expects a return of None. "
"HINT: remove the return statement in training_epoch_end"
)
if is_overridden("training_epoch_end", model) and self._epoch_output:
processed_outputs = self._prepare_outputs(self._epoch_output, batch_mode=False)
# check that the dataloader/iterator produced a batch
if processed_outputs:
# run training_epoch_end
# refresh the result for custom logging at the epoch level
model._current_fx_name = "training_epoch_end"

# lightningmodule hook
training_epoch_end_output = model.training_epoch_end(processed_outputs)

if training_epoch_end_output is not None:
raise MisconfigurationException(
"training_epoch_end expects a return of None. "
"HINT: remove the return statement in training_epoch_end"
)
# free memory
self._epoch_output = None

self.trainer.fit_loop.epoch_progress.increment_processed()

Expand All @@ -235,11 +232,6 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
if self._num_training_batches_reached(self.is_last_batch):
self.update_lr_schedulers("epoch", update_plateau_schedulers=True)

epoch_output = self._epoch_output
# free memory
self._epoch_output = None
return epoch_output

def teardown(self) -> None:
self._results.cpu()
self.batch_loop.teardown()
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,7 @@ def advance(self) -> None:
data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader)

with self.trainer.profiler.profile("run_training_epoch"):
# run train epoch
epoch_output = self.epoch_loop.run(data_fetcher)

if epoch_output is None:
return
self.epoch_loop.run(data_fetcher)

# the global step is manually decreased here due to backwards compatibility with existing loggers
# as they expect that the same step is used when logging epoch end metrics even when the batch loop has
Expand Down
18 changes: 8 additions & 10 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,34 +868,32 @@ def __len__(self):
trainer.predict(model, dataloaders=dataloader)


def test_iterable_dataset_stop_iteration_at_epoch_beginning():
@pytest.mark.parametrize("yield_at_all", (False, True))
def test_iterable_dataset_stop_iteration_at_epoch_beginning(yield_at_all):
"""Test that the training loop skips execution if the iterator is empty from the start."""

class RandomDataset(IterableDataset):
class TestDataset(IterableDataset):
def __init__(self, gen):
self.gen = gen

def __iter__(self):
return iter(self.gen())

class TestModel(BoringModel):
def train_dataloader(self):
return DataLoader(RandomDataset(self.gen), batch_size=2)

def gen(self):
# produce data in epoch 0
# no data otherwise
if self.current_epoch == 0:
# produce data in epoch 0, no data otherwise
if yield_at_all and self.current_epoch == 0:
yield torch.rand(32)
yield torch.rand(32)
yield torch.rand(32)

model = TestModel()
train_dataloader = DataLoader(TestDataset(model.gen), batch_size=2)
trainer = Trainer(
default_root_dir=os.getcwd(), max_epochs=2, weights_summary=None # we expect the second epoch to be skipped
)
trainer.fit(model)
assert trainer.global_step == 2
trainer.fit(model, train_dataloader=train_dataloader)
assert trainer.global_step == 2 * yield_at_all
assert trainer.current_epoch == 1


Expand Down

0 comments on commit 05ff1b2

Please sign in to comment.