diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7a12d7e766dae..95ba6d5306309 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -16,7 +16,7 @@ from contextlib import contextmanager, suppress from copy import copy from functools import partial, update_wrapper -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import numpy as np import torch @@ -269,6 +269,16 @@ def _check_training_step_output(self, training_step_output): if training_step_output.grad_fn is None: # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") + elif self.trainer.lightning_module.automatic_optimization: + if not any(( + isinstance(training_step_output, torch.Tensor), + (isinstance(training_step_output, Mapping) + and 'loss' in training_step_output), training_step_output is None + )): + raise MisconfigurationException( + "In automatic optimization, `training_step` must either return a Tensor, " + "a dict with key 'loss' or None (where the step will be skipped)." + ) def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index da4ecbe5a9f05..a2706e5d37bc0 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re + import pytest import torch from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -142,3 +145,24 @@ def validation_step(self, *args): assert trainer.current_epoch == 0 assert trainer.global_step == 5 assert model.validation_called_at == (0, 4) + + +@pytest.mark.parametrize(['output'], [(5., ), ({'a': 5}, )]) +def test_warning_invalid_trainstep_output(tmpdir, output): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + return output + + model = TestModel() + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + with pytest.raises( + MisconfigurationException, + match=re.escape( + "In automatic optimization, `training_step` must either return a Tensor, " + "a dict with key 'loss' or None (where the step will be skipped)." + ) + ): + trainer.fit(model)