Skip to content

Commit

Permalink
[2/2] Remove outputs from evaluation epoch end hooks (#7338)
Browse files Browse the repository at this point in the history
* Remove outputs from on_train_epoch_end

* iterate

* Update callback_hook.py

* update

* early stop?

* fix

* Update pytorch_lightning/trainer/training_loop.py

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>

* Update trainer.py

* update

* Update training_loop.py

* early stop?

* fix

* Remove outputs from evaluation epoch end hooks

* update

* Update test_remove_1-5.py

* fix lints

* Update base.py

* rm-outputs

* Update evaluation_loop.py

* try-save-more-memory

* Update trainer.py

* Update trainer.py

* cache-at-start

* Update evaluation_loop.py

* Update training_loop.py

* Update training_loop.py

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
  • Loading branch information
ananthsub and ethanwharris authored May 5, 2021
1 parent fbcd63a commit 7b45bcf
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 252 deletions.
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618))


- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))


- Added `configure_sharded_model` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679))


Expand Down Expand Up @@ -213,6 +210,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))


- Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339))


- Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292))


Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import STEP_OUTPUT


class Callback(abc.ABC):
Expand Down Expand Up @@ -108,17 +108,15 @@ def on_validation_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightn
"""Called when the val epoch begins."""
pass

def on_validation_epoch_end(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT
) -> None:
def on_validation_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the val epoch ends."""
pass

def on_test_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the test epoch begins."""
pass

def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None:
def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the test epoch ends."""
pass

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import STEP_OUTPUT


class ModelHooks:
Expand Down Expand Up @@ -245,7 +245,7 @@ def on_validation_epoch_start(self) -> None:
Called in the validation loop at the very beginning of the epoch.
"""

def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
def on_validation_epoch_end(self) -> None:
"""
Called in the validation loop at the very end of the epoch.
"""
Expand All @@ -255,7 +255,7 @@ def on_test_epoch_start(self) -> None:
Called in the test loop at the very beginning of the epoch.
"""

def on_test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
def on_test_epoch_end(self) -> None:
"""
Called in the test loop at the very end of the epoch.
"""
Expand Down
36 changes: 6 additions & 30 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,44 +111,20 @@ def on_validation_epoch_start(self):
for callback in self.callbacks:
callback.on_validation_epoch_start(self, self.lightning_module)

def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT):
"""Called when the epoch ends.
Args:
outputs: List of outputs on each ``validation`` epoch
"""
def on_validation_epoch_end(self):
"""Called when the validation epoch ends."""
for callback in self.callbacks:
if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"):
callback.on_validation_epoch_end(self, self.lightning_module, outputs)
else:
warning_cache.warn(
"`Callback.on_validation_epoch_end` signature has changed in v1.3."
" `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
callback.on_validation_epoch_end(self, self.lightning_module)
callback.on_validation_epoch_end(self, self.lightning_module)

def on_test_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_test_epoch_start(self, self.lightning_module)

def on_test_epoch_end(self, outputs: EPOCH_OUTPUT):
"""Called when the epoch ends.
Args:
outputs: List of outputs on each ``test`` epoch
"""
def on_test_epoch_end(self):
"""Called when the test epoch ends."""
for callback in self.callbacks:
if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"):
callback.on_test_epoch_end(self, self.lightning_module, outputs)
else:
warning_cache.warn(
"`Callback.on_test_epoch_end` signature has changed in v1.3."
" `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
callback.on_test_epoch_end(self, self.lightning_module)
callback.on_test_epoch_end(self, self.lightning_module)

def on_predict_epoch_start(self) -> None:
"""Called when the epoch begins."""
Expand Down
24 changes: 12 additions & 12 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

from torch.utils.data import DataLoader

Expand All @@ -20,7 +20,6 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -76,6 +75,7 @@ def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool:
return sum(max_batches) == 0

def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end()
if self.trainer.testing:
self.trainer.call_hook('on_test_start', *args, **kwargs)
else:
Expand Down Expand Up @@ -188,6 +188,13 @@ def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT
output = self.trainer.call_hook('validation_step_end', *args, **kwargs)
return output

def _should_track_batch_outputs_for_epoch_end(self) -> bool:
model = self.trainer.lightning_module
if self.trainer.testing:
return is_overridden('test_epoch_end', model=model)
else:
return is_overridden('validation_epoch_end', model=model)

def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
# unset dataloder_idx in model
self.trainer.logger_connector.evaluation_epoch_end()
Expand Down Expand Up @@ -241,7 +248,7 @@ def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, datal
# track debug metrics
self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output)

def on_evaluation_epoch_end(self, outputs: Union[List[List[Dict]], List[Dict]]) -> None:
def on_evaluation_epoch_end(self) -> None:
model_ref = self.trainer.lightning_module
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"

Expand All @@ -251,18 +258,11 @@ def on_evaluation_epoch_end(self, outputs: Union[List[List[Dict]], List[Dict]])

if hasattr(self.trainer, hook_name):
on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name)
on_evaluation_epoch_end_hook(outputs)
on_evaluation_epoch_end_hook()

if is_overridden(hook_name, model_ref):
model_hook_fx = getattr(model_ref, hook_name)
if is_param_in_hook_signature(model_hook_fx, "outputs"):
model_hook_fx(outputs)
else:
self.warning_cache.warn(
f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
model_hook_fx()
model_hook_fx()

self.trainer._cache_logged_metrics()

Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,22 +972,23 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
dl_outputs = self.track_output_for_epoch_end(dl_outputs, output)

# store batch level output per dataloader
self.evaluation_loop.outputs.append(dl_outputs)
if self.evaluation_loop.should_track_batch_outputs_for_epoch_end:
self.evaluation_loop.outputs.append(dl_outputs)

outputs = self.evaluation_loop.outputs

# reset outputs
self.evaluation_loop.outputs = []

# with a single dataloader don't pass a 2D list
if self.evaluation_loop.num_dataloaders == 1:
if len(outputs) > 0 and self.evaluation_loop.num_dataloaders == 1:
outputs = outputs[0]

# lightning module method
self.evaluation_loop.evaluation_epoch_end(outputs)

# hook
self.evaluation_loop.on_evaluation_epoch_end(outputs)
self.evaluation_loop.on_evaluation_epoch_end()

# update epoch-level lr_schedulers
if on_epoch:
Expand Down Expand Up @@ -1212,8 +1213,8 @@ def _cache_logged_metrics(self):

def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
# Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook
# This was done to manage the deprecation of an argument to on_train_epoch_end
# If making chnages to this function, ensure that those changes are also made to
# This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end
# If making changes to this function, ensure that those changes are also made to
# TrainLoop._on_train_epoch_end_hook

# set hook_name to model + reset Result obj
Expand Down
42 changes: 0 additions & 42 deletions tests/callbacks/test_callback_hook_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,48 +65,6 @@ def training_epoch_end(self, outputs) -> None:
trainer.fit(model)


def test_on_val_epoch_end_outputs(tmpdir):

class CB(Callback):

def on_validation_epoch_end(self, trainer, pl_module, outputs):
if trainer.running_sanity_check:
assert len(outputs) == trainer.num_sanity_val_batches[0]
else:
assert len(outputs) == trainer.num_val_batches[0]

model = BoringModel()

trainer = Trainer(
callbacks=CB(),
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
)

trainer.fit(model)


def test_on_test_epoch_end_outputs(tmpdir):

class CB(Callback):

def on_test_epoch_end(self, trainer, pl_module, outputs):
assert len(outputs) == trainer.num_test_batches[0]

model = BoringModel()

trainer = Trainer(
callbacks=CB(),
default_root_dir=tmpdir,
weights_summary=None,
)

trainer.test(model)


def test_free_memory_on_eval_outputs(tmpdir):

class CB(Callback):
Expand Down
8 changes: 4 additions & 4 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_sanity_check_end(trainer, model),
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_test_batch_start(trainer, model, ANY, 1, 0),
call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_test_epoch_end(trainer, model, ANY),
call.on_test_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_test_end(trainer, model),
call.teardown(trainer, model, 'test'),
Expand Down Expand Up @@ -163,7 +163,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_batch_start(trainer, model, ANY, 1, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.teardown(trainer, model, 'validate'),
Expand Down
56 changes: 0 additions & 56 deletions tests/core/test_hooks.py

This file was deleted.

Loading

0 comments on commit 7b45bcf

Please sign in to comment.