Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary TrainingEpochLoop return #9298

Merged
merged 5 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 19 additions & 27 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def on_advance_end(self):
# progress global step according to grads progress
self._increment_accumulated_grad_global_step()

def on_run_end(self) -> List[List[STEP_OUTPUT]]:
def on_run_end(self) -> None:
"""Calls the on_epoch_end hook.

Returns:
Expand All @@ -193,32 +193,29 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
Raises:
MisconfigurationException: ``train_epoch_end`` does not return ``None``
"""
if self.batch_progress.current.ready == 0:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# dataloader/iterator did not produce a batch
return

# inform logger the batch loop has finished
self.trainer.logger_connector.epoch_end_reached()

# prepare epoch output
processed_outputs = self._prepare_outputs(self._epoch_output, batch_mode=False)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# get the model and call model.training_epoch_end
model = self.trainer.lightning_module

if is_overridden("training_epoch_end", model):
# run training_epoch_end
# refresh the result for custom logging at the epoch level
model._current_fx_name = "training_epoch_end"

# lightningmodule hook
training_epoch_end_output = model.training_epoch_end(processed_outputs)

if training_epoch_end_output is not None:
raise MisconfigurationException(
"training_epoch_end expects a return of None. "
"HINT: remove the return statement in training_epoch_end"
)
if is_overridden("training_epoch_end", model) and self._epoch_output:
processed_outputs = self._prepare_outputs(self._epoch_output, batch_mode=False)
# check that the dataloader/iterator produced a batch
if processed_outputs:
# run training_epoch_end
# refresh the result for custom logging at the epoch level
model._current_fx_name = "training_epoch_end"
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# lightningmodule hook
training_epoch_end_output = model.training_epoch_end(processed_outputs)

if training_epoch_end_output is not None:
raise MisconfigurationException(
"training_epoch_end expects a return of None. "
"HINT: remove the return statement in training_epoch_end"
)
# free memory
self._epoch_output = None

self.trainer.fit_loop.epoch_progress.increment_processed()

Expand All @@ -230,11 +227,6 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
if self._num_training_batches_reached(self.is_last_batch):
self.update_lr_schedulers("epoch", update_plateau_schedulers=True)

epoch_output = self._epoch_output
# free memory
self._epoch_output = None
return epoch_output

def teardown(self) -> None:
self._results.cpu()
self.batch_loop.teardown()
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,7 @@ def advance(self) -> None:
data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader)

with self.trainer.profiler.profile("run_training_epoch"):
# run train epoch
epoch_output = self.epoch_loop.run(data_fetcher)

if epoch_output is None:
return
self.epoch_loop.run(data_fetcher)

# the global step is manually decreased here due to backwards compatibility with existing loggers
# as they expect that the same step is used when logging epoch end metrics even when the batch loop has
Expand Down
20 changes: 19 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def __len__(self):


def test_iterable_dataset_stop_iteration_at_epoch_beginning():
"""Test that the training loop skips execution if the iterator is empty from the start."""
"""Test that the training loop skips execution if the iterator is empty from an epoch."""

class RandomDataset(IterableDataset):
def __init__(self, gen):
Expand Down Expand Up @@ -899,6 +899,24 @@ def gen(self):
assert trainer.current_epoch == 1


def test_iterable_dataset_stop_iteration_at_start():
"""Test that the training loop skips execution if the iterator is empty."""

class RandomDataset(IterableDataset):
def __iter__(self):
return iter(self.gen())

def gen(self):
yield from []

model = BoringModel()
train_dataloader = DataLoader(RandomDataset(), batch_size=2)
trainer = Trainer(default_root_dir=os.getcwd(), max_epochs=2)
trainer.fit(model, train_dataloader=train_dataloader)
assert trainer.global_step == 0
assert trainer.current_epoch == 1
carmocca marked this conversation as resolved.
Show resolved Hide resolved


class DistribSamplerCallback(Callback):
def __init__(self, expected_seeds=(0, 0, 0)):
self.expected_seed = expected_seeds
Expand Down