From 651f93a69fc9d5f995af039dbc18f8c61040716d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 5 May 2021 15:18:45 -0700 Subject: [PATCH] Add documentation for ways to access all batch outputs for on_train_epoch_end hook (#7389) Co-authored-by: Jirka Borovec --- pytorch_lightning/callbacks/base.py | 8 +++++++- pytorch_lightning/core/hooks.py | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 8283c2ddd71ec..c35fc64e2e115 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -101,7 +101,13 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo def on_train_epoch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', unused: Optional = None ) -> None: - """Called when the train epoch ends.""" + """Called when the train epoch ends. + + To access all batch outputs at the end of the epoch, either: + + 1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR + 2. Cache data across train batch hooks inside the callback implementation to post-process in this hook. + """ pass def on_validation_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index d311bd4f58f06..7ab0c8acbe329 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -238,6 +238,11 @@ def on_train_epoch_start(self) -> None: def on_train_epoch_end(self, unused: Optional = None) -> None: """ Called in the training loop at the very end of the epoch. + + To access all batch outputs at the end of the epoch, either: + + 1. Implement `training_epoch_end` in the LightningModule OR + 2. Cache data across steps on the attribute(s) of the `LightningModule` and access them in this hook """ def on_validation_epoch_start(self) -> None: