Skip to content

Commit

Permalink
Fix global step update when the epoch is skipped (#7677)
Browse files Browse the repository at this point in the history
* Fix global step update when the epoch is skipped

* Update CHANGELOG

* Move test
  • Loading branch information
awaelchli authored and Borda committed May 26, 2021
1 parent a795a1d commit b749ffd
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592))
- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))
- Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674))
- Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677))

## [1.3.2] - 2021-05-18

Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,8 @@ def run_training_epoch(self):
self.trainer.run_evaluation(on_epoch=True)
self.trainer.training = True

# increment the global step once
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()
if batch_output.signal != -1:
self.increment_accumulated_grad_global_step()

def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
# inform logger the batch loop has finished
Expand Down
20 changes: 0 additions & 20 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,26 +225,6 @@ def train_dataloader(self):
trainer.fit(model)


@pytest.mark.parametrize('max_epochs,batch_idx_', [(2, 5), (3, 8), (4, 12)])
def test_on_train_batch_start_hook(max_epochs, batch_idx_):

class CurrentModel(BoringModel):

def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
if batch_idx == batch_idx_:
return -1

model = CurrentModel()
trainer = Trainer(max_epochs=max_epochs)
trainer.fit(model)
if batch_idx_ > len(model.val_dataloader()) - 1:
assert trainer.batch_idx == len(model.val_dataloader()) - 1
assert trainer.global_step == len(model.val_dataloader()) * max_epochs
else:
assert trainer.batch_idx == batch_idx_
assert trainer.global_step == (batch_idx_ + 1) * max_epochs


def test_trainer_model_hook_system(tmpdir):
"""Test the LightningModule hook system."""

Expand Down
21 changes: 21 additions & 0 deletions tests/trainer/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

from pytorch_lightning import seed_everything, Trainer
Expand Down Expand Up @@ -201,3 +202,23 @@ def run_training(**trainer_kwargs):
num_sanity_val_steps=2,
)
assert torch.allclose(sequence0, sequence1)


@pytest.mark.parametrize(['max_epochs', 'batch_idx_'], [(2, 5), (3, 8), (4, 12)])
def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_):

class CurrentModel(BoringModel):

def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
if batch_idx == batch_idx_:
return -1

model = CurrentModel()
trainer = Trainer(max_epochs=max_epochs, limit_train_batches=10)
trainer.fit(model)
if batch_idx_ > trainer.num_training_batches - 1:
assert trainer.batch_idx == trainer.num_training_batches - 1
assert trainer.global_step == trainer.num_training_batches * max_epochs
else:
assert trainer.batch_idx == batch_idx_
assert trainer.global_step == batch_idx_ * max_epochs

0 comments on commit b749ffd

Please sign in to comment.