Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary TrainingEpochLoop return #9298

Merged
merged 5 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 19 additions & 27 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def on_advance_end(self):
# progress global step according to grads progress
self._increment_accumulated_grad_global_step()

def on_run_end(self) -> List[List[STEP_OUTPUT]]:
def on_run_end(self) -> None:
"""Calls the on_epoch_end hook.

Returns:
Expand All @@ -193,32 +193,29 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
Raises:
MisconfigurationException: ``train_epoch_end`` does not return ``None``
"""
if self.batch_progress.current.ready == 0:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# dataloader/iterator did not produce a batch
return

# inform logger the batch loop has finished
self.trainer.logger_connector.epoch_end_reached()

# prepare epoch output
processed_outputs = self._prepare_outputs(self._epoch_output, batch_mode=False)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# get the model and call model.training_epoch_end
model = self.trainer.lightning_module

if is_overridden("training_epoch_end", model):
# run training_epoch_end
# refresh the result for custom logging at the epoch level
model._current_fx_name = "training_epoch_end"

# lightningmodule hook
training_epoch_end_output = model.training_epoch_end(processed_outputs)

if training_epoch_end_output is not None:
raise MisconfigurationException(
"training_epoch_end expects a return of None. "
"HINT: remove the return statement in training_epoch_end"
)
if is_overridden("training_epoch_end", model) and self._epoch_output:
processed_outputs = self._prepare_outputs(self._epoch_output, batch_mode=False)
# check that the dataloader/iterator produced a batch
if processed_outputs:
# run training_epoch_end
# refresh the result for custom logging at the epoch level
model._current_fx_name = "training_epoch_end"
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# lightningmodule hook
training_epoch_end_output = model.training_epoch_end(processed_outputs)

if training_epoch_end_output is not None:
raise MisconfigurationException(
"training_epoch_end expects a return of None. "
"HINT: remove the return statement in training_epoch_end"
)
# free memory
self._epoch_output = None

self.trainer.fit_loop.epoch_progress.increment_processed()

Expand All @@ -230,11 +227,6 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
if self._num_training_batches_reached(self.is_last_batch):
self.update_lr_schedulers("epoch", update_plateau_schedulers=True)

epoch_output = self._epoch_output
# free memory
self._epoch_output = None
return epoch_output

def teardown(self) -> None:
self._results.cpu()
self.batch_loop.teardown()
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,7 @@ def advance(self) -> None:
data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader)

with self.trainer.profiler.profile("run_training_epoch"):
# run train epoch
epoch_output = self.epoch_loop.run(data_fetcher)

if epoch_output is None:
return
self.epoch_loop.run(data_fetcher)

# the global step is manually decreased here due to backwards compatibility with existing loggers
# as they expect that the same step is used when logging epoch end metrics even when the batch loop has
Expand Down
18 changes: 8 additions & 10 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,34 +868,32 @@ def __len__(self):
trainer.predict(model, dataloaders=dataloader)


def test_iterable_dataset_stop_iteration_at_epoch_beginning():
@pytest.mark.parametrize("yield_at_all", (False, True))
def test_iterable_dataset_stop_iteration_at_epoch_beginning(yield_at_all):
"""Test that the training loop skips execution if the iterator is empty from the start."""

class RandomDataset(IterableDataset):
class TestDataset(IterableDataset):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, gen):
self.gen = gen

def __iter__(self):
return iter(self.gen())

class TestModel(BoringModel):
def train_dataloader(self):
return DataLoader(RandomDataset(self.gen), batch_size=2)

def gen(self):
# produce data in epoch 0
# no data otherwise
if self.current_epoch == 0:
# produce data in epoch 0, no data otherwise
if yield_at_all and self.current_epoch == 0:
yield torch.rand(32)
yield torch.rand(32)
yield torch.rand(32)

model = TestModel()
train_dataloader = DataLoader(TestDataset(model.gen), batch_size=2)
trainer = Trainer(
default_root_dir=os.getcwd(), max_epochs=2, weights_summary=None # we expect the second epoch to be skipped
)
trainer.fit(model)
assert trainer.global_step == 2
trainer.fit(model, train_dataloader=train_dataloader)
assert trainer.global_step == 2 * yield_at_all
assert trainer.current_epoch == 1
carmocca marked this conversation as resolved.
Show resolved Hide resolved


Expand Down