diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 8283c2ddd71ec..c35fc64e2e115 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -101,7 +101,13 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo def on_train_epoch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', unused: Optional = None ) -> None: - """Called when the train epoch ends.""" + """Called when the train epoch ends. + + To access all batch outputs at the end of the epoch, either: + + 1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR + 2. Cache data across train batch hooks inside the callback implementation to post-process in this hook. + """ pass def on_validation_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index d311bd4f58f06..7ab0c8acbe329 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -238,6 +238,11 @@ def on_train_epoch_start(self) -> None: def on_train_epoch_end(self, unused: Optional = None) -> None: """ Called in the training loop at the very end of the epoch. + + To access all batch outputs at the end of the epoch, either: + + 1. Implement `training_epoch_end` in the LightningModule OR + 2. Cache data across steps on the attribute(s) of the `LightningModule` and access them in this hook """ def on_validation_epoch_start(self) -> None: