From b749ffdb6c36d9223fbe2a62c5d1c7823c368901 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 00:34:48 +0200 Subject: [PATCH] Fix global step update when the epoch is skipped (#7677) * Fix global step update when the epoch is skipped * Update CHANGELOG * Move test --- CHANGELOG.md | 1 + pytorch_lightning/trainer/training_loop.py | 5 ++--- tests/models/test_hooks.py | 20 -------------------- tests/trainer/loops/test_training_loop.py | 21 +++++++++++++++++++++ 4 files changed, 24 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a458a9edda13b..a88caebce2678 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592)) - Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566)) - Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674)) +- Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677)) ## [1.3.2] - 2021-05-18 diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 11fe74bc4f21f..c07d6eeb31566 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -574,9 +574,8 @@ def run_training_epoch(self): self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True - # increment the global step once - # progress global step according to grads progress - self.increment_accumulated_grad_global_step() + if batch_output.signal != -1: + self.increment_accumulated_grad_global_step() def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: # inform logger the batch loop has finished diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 24bf29a9e2eac..58ebc8e271be6 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -225,26 +225,6 @@ def train_dataloader(self): trainer.fit(model) -@pytest.mark.parametrize('max_epochs,batch_idx_', [(2, 5), (3, 8), (4, 12)]) -def test_on_train_batch_start_hook(max_epochs, batch_idx_): - - class CurrentModel(BoringModel): - - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - if batch_idx == batch_idx_: - return -1 - - model = CurrentModel() - trainer = Trainer(max_epochs=max_epochs) - trainer.fit(model) - if batch_idx_ > len(model.val_dataloader()) - 1: - assert trainer.batch_idx == len(model.val_dataloader()) - 1 - assert trainer.global_step == len(model.val_dataloader()) * max_epochs - else: - assert trainer.batch_idx == batch_idx_ - assert trainer.global_step == (batch_idx_ + 1) * max_epochs - - def test_trainer_model_hook_system(tmpdir): """Test the LightningModule hook system.""" diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index 2d32d8c8878e4..94becf6488fc3 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from pytorch_lightning import seed_everything, Trainer @@ -201,3 +202,23 @@ def run_training(**trainer_kwargs): num_sanity_val_steps=2, ) assert torch.allclose(sequence0, sequence1) + + +@pytest.mark.parametrize(['max_epochs', 'batch_idx_'], [(2, 5), (3, 8), (4, 12)]) +def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_): + + class CurrentModel(BoringModel): + + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + if batch_idx == batch_idx_: + return -1 + + model = CurrentModel() + trainer = Trainer(max_epochs=max_epochs, limit_train_batches=10) + trainer.fit(model) + if batch_idx_ > trainer.num_training_batches - 1: + assert trainer.batch_idx == trainer.num_training_batches - 1 + assert trainer.global_step == trainer.num_training_batches * max_epochs + else: + assert trainer.batch_idx == batch_idx_ + assert trainer.global_step == batch_idx_ * max_epochs