Skip to content

Commit

Permalink
Fix/mismatched toggle optimizer (#7563)
Browse files Browse the repository at this point in the history
* fix: avoid potential mismatched toggling of optimzier
Refs #7405

chore: update CHANGELOG

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix: resolve a confict

chore: update changelog

* feat: add a test that fails in master

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo in tests/trainer/optimization/test_multiple_optimizers.py

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

* Polish tests/trainer/optimization/test_multiple_optimizers.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Polish tests/trainer/optimization/test_multiple_optimizers.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* fix: change placeholder in optimizer_step from positional args to keyword args

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
fix whitespace


fix parentheses
  • Loading branch information
awaelchli committed May 25, 2021
1 parent a6d9532 commit 84cfcbf
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [1.3.3] - 2021-05-27

### Changed

- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563))

### Fixed

- Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608))
Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,9 @@ def train_step_and_backward_closure():

# optimizer step
self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)

if len(self.trainer.optimizers) > 1:
# revert back to previous state
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
else:
self._curr_step_result = self.training_step(
split_batch, batch_idx, opt_idx, self.trainer.hiddens
Expand Down Expand Up @@ -838,10 +840,6 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
"training_step returned None. If this was on purpose, ignore this warning..."
)

if len(self.trainer.optimizers) > 1:
# revert back to previous state
self.trainer.lightning_module.untoggle_optimizer(opt_idx)

return result

def _check_finite(self, loss: torch.Tensor) -> None:
Expand Down
65 changes: 65 additions & 0 deletions tests/trainer/optimization/test_multiple_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,68 @@ def training_step(self, batch, batch_idx):

with pytest.raises(ValueError, match='`training_step` is missing the `optimizer_idx`'):
trainer.fit(TestModel())


def test_custom_optimizer_step_with_multiple_optimizers(tmpdir):
"""
This tests ensures custom optimizer_step works,
even when optimizer.step is not called for a particular optimizer
"""

class TestModel(BoringModel):
training_step_called = [0, 0]
optimizer_step_called = [0, 0]

def __init__(self):
super().__init__()
self.layer_a = torch.nn.Linear(32, 2)
self.layer_b = torch.nn.Linear(32, 2)

def configure_optimizers(self):
opt_a = torch.optim.SGD(self.layer_a.parameters(), lr=0.001)
opt_b = torch.optim.SGD(self.layer_b.parameters(), lr=0.001)
return opt_a, opt_b

def training_step(self, batch, batch_idx, optimizer_idx):
self.training_step_called[optimizer_idx] += 1
x = self.layer_a(batch[0]) if (optimizer_idx == 0) else self.layer_b(batch[0])
loss = torch.nn.functional.mse_loss(x, torch.ones_like(x))
return loss

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
assert len(outputs) == 2

def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
**_,
):
# update first optimizer every step
if optimizer_idx == 0:
self.optimizer_step_called[optimizer_idx] += 1
optimizer.step(closure=optimizer_closure)

# update second optimizer every 2 steps
if optimizer_idx == 1:
if batch_idx % 2 == 0:
self.optimizer_step_called[optimizer_idx] += 1
optimizer.step(closure=optimizer_closure)

model = TestModel()
model.val_dataloader = None

trainer = pl.Trainer(
default_root_dir=tmpdir,
limit_train_batches=4,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
assert model.training_step_called == [4, 2]
assert model.optimizer_step_called == [4, 2]

0 comments on commit 84cfcbf

Please sign in to comment.