Skip to content

Commit

Permalink
Add datamodule parameter to lr_find() (#3425)
Browse files Browse the repository at this point in the history
* Add datamodule parameter to lr_find()

* Fixed missing import

* Move datamodule parameter to end

* Add datamodule parameter test with auto_lr_find

* Change test for datamodule parameter

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Fix lr_find documentation

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* formatting

* Add description to datamodule param in lr_find

* pep8: remove trailing whitespace on line 105

* added changelog

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
5 people authored Oct 1, 2020
1 parent 7c61fc7 commit e4e60e9
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for datamodules to save and load checkpoints when training ([#3563]https://github.com/PyTorchLightning/pytorch-lightning/pull/3563)

- Added support for datamodule in learning rate finder ([#3425](https://github.com/PyTorchLightning/pytorch-lightning/pull/3425))

### Changed

- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
Expand Down
28 changes: 19 additions & 9 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from typing import List, Optional, Sequence, Union

import numpy as np
import torch
from typing import Optional, Sequence, List, Union
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.optim.lr_scheduler import _LRScheduler
import importlib
from pytorch_lightning import _logger as log
import numpy as np
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr


# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
if importlib.util.find_spec('ipywidgets') is not None:
Expand Down Expand Up @@ -71,6 +73,7 @@ def lr_find(
num_training: int = 100,
mode: str = 'exponential',
early_stop_threshold: float = 4.0,
datamodule: Optional[LightningDataModule] = None,
):
r"""
lr_find enables the user to do a range test of good initial learning rates,
Expand All @@ -81,7 +84,7 @@ def lr_find(
train_dataloader: A PyTorch
DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.
a predefined train_dataloader method, this will be skipped.
min_lr: minimum learning rate to investigate
Expand All @@ -98,6 +101,12 @@ def lr_find(
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.
datamodule: An optional `LightningDataModule` which holds the training
and validation dataloader(s). Note that the `train_dataloader` and
`val_dataloaders` parameters cannot be used at the same time as
this parameter, or a `MisconfigurationException` will be raised.
Example::
# Setup model and trainer
Expand Down Expand Up @@ -167,7 +176,8 @@ def lr_find(
# Fit, lr & loss logged in callback
trainer.fit(model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders)
val_dataloaders=val_dataloaders,
datamodule=datamodule)

# Prompt if we stopped early
if trainer.global_step != num_training:
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.tuner.lr_finder import _run_lr_finder_internally, lr_find
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
from typing import Optional, List, Union
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -50,6 +51,7 @@ def lr_find(
num_training: int = 100,
mode: str = 'exponential',
early_stop_threshold: float = 4.0,
datamodule: Optional[LightningDataModule] = None
):
return lr_find(
self.trainer,
Expand All @@ -60,7 +62,8 @@ def lr_find(
max_lr,
num_training,
mode,
early_stop_threshold
early_stop_threshold,
datamodule,
)

def internal_find_lr(self, trainer, model: LightningModule):
Expand Down
25 changes: 25 additions & 0 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule


def test_error_on_more_than_1_optimizer(tmpdir):
Expand Down Expand Up @@ -152,6 +153,30 @@ def test_call_to_trainer_method(tmpdir):
'Learning rate was not altered after running learning rate finder'


def test_datamodule_parameter(tmpdir):
""" Test that the datamodule parameter works """

# trial datamodule
dm = TrialMNISTDataModule(tmpdir)

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

before_lr = hparams.get('learning_rate')
# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
)

lrfinder = trainer.tuner.lr_find(model, datamodule=dm)
after_lr = lrfinder.suggestion()
model.learning_rate = after_lr

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'


def test_accumulation_and_early_stopping(tmpdir):
""" Test that early stopping of learning rate finder works, and that
accumulation also works for this feature """
Expand Down

0 comments on commit e4e60e9

Please sign in to comment.