diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8ece6c4ec1b101..e69addf234a366 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -113,11 +113,12 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure): # model hook model_ref.optimizer_step( - self.trainer.current_epoch, - batch_idx, - optimizer, - opt_idx, - lambda_closure, + epoch=self.trainer.current_epoch, + batch_idx=batch_idx, + optimizer=optimizer, + optimizer_idx=opt_idx, + optimizer_closure=lambda_closure, + on_tpu=False, # TPUAccelerator class sets this as True using_native_amp=native_amp, using_lbfgs=is_lbfgs ) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 1988f83601b8cc..b60cd5a9dfc9c4 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -242,11 +242,13 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure): # model hook model_ref.optimizer_step( - self.trainer.current_epoch, - batch_idx, optimizer, - opt_idx, - lambda_closure, + epoch=self.trainer.current_epoch, + batch_idx=batch_idx, + optimizer=optimizer, + optimizer_idx=opt_idx, + optimizer_closure=lambda_closure, on_tpu=True, + using_native_amp=False, using_lbfgs=is_lbfgs ) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d7125eb171a9cc..aa83a49873d4f0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1128,10 +1128,10 @@ def optimizer_step( batch_idx: int, optimizer: Optimizer, optimizer_idx: int, - optimizer_closure: Optional[Callable] = None, - on_tpu: bool = False, - using_native_amp: bool = False, - using_lbfgs: bool = False, + optimizer_closure: Optional[Callable], + on_tpu: bool, + using_native_amp: bool, + using_lbfgs: bool, ) -> None: r""" Override this method to adjust the default way the @@ -1139,6 +1139,12 @@ def optimizer_step( By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example once per optimizer. + Warning: + If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter + to ``optimizer.step()`` function as shown in the examples. This ensures that + ``train_step_and_backward_closure`` is called within + :meth:`~pytorch_lightning.trainer.training_loop.TrainLoop.run_training_batch`. + Args: epoch: Current epoch batch_idx: Index of current batch @@ -1153,23 +1159,23 @@ def optimizer_step( .. code-block:: python # DEFAULT - def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): - optimizer.step() + optimizer.step(closure=optimizer_closure) # Alternating schedule for optimizer steps (i.e.: GANs) - def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): # update generator opt every 2 steps if optimizer_idx == 0: if batch_idx % 2 == 0 : - optimizer.step() + optimizer.step(closure=optimizer_closure) optimizer.zero_grad() # update discriminator opt every 4 steps if optimizer_idx == 1: if batch_idx % 4 == 0 : - optimizer.step() + optimizer.step(closure=optimizer_closure) optimizer.zero_grad() # ... @@ -1182,8 +1188,8 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, .. code-block:: python # learning rate warm-up - def optimizer_step(self, current_epoch, batch_idx, optimizer, - optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): # warm up lr if self.trainer.global_step < 500: lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) @@ -1191,7 +1197,7 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, pg['lr'] = lr_scale * self.learning_rate # update params - optimizer.step() + optimizer.step(closure=optimizer_closure) optimizer.zero_grad() Note: