From e4e60e9b82adc48482db4721ce3e1fdc3ab6d6fe Mon Sep 17 00:00:00 2001 From: GimmickNG Date: Thu, 1 Oct 2020 02:33:12 -0600 Subject: [PATCH] Add datamodule parameter to lr_find() (#3425) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add datamodule parameter to lr_find() * Fixed missing import * Move datamodule parameter to end * Add datamodule parameter test with auto_lr_find * Change test for datamodule parameter * Apply suggestions from code review Co-authored-by: Nicki Skafte * Fix lr_find documentation Co-authored-by: Carlos MocholĂ­ * formatting * Add description to datamodule param in lr_find * pep8: remove trailing whitespace on line 105 * added changelog Co-authored-by: Nicki Skafte Co-authored-by: Nicki Skafte Co-authored-by: Carlos MocholĂ­ Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 ++ pytorch_lightning/tuner/lr_finder.py | 28 +++++++++++++++++++--------- pytorch_lightning/tuner/tuning.py | 5 ++++- tests/trainer/test_lr_finder.py | 25 +++++++++++++++++++++++++ 4 files changed, 50 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 223cabc4f4eb4..4e6505890004f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for datamodules to save and load checkpoints when training ([#3563]https://github.com/PyTorchLightning/pytorch-lightning/pull/3563) +- Added support for datamodule in learning rate finder ([#3425](https://github.com/PyTorchLightning/pytorch-lightning/pull/3425)) + ### Changed - Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251)) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 71756678af9c5..a3ba2550186a7 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -11,21 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import os +from typing import List, Optional, Sequence, Union + +import numpy as np import torch -from typing import Optional, Sequence, List, Union +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader + +from pytorch_lightning import _logger as log +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch.optim.lr_scheduler import _LRScheduler -import importlib -from pytorch_lightning import _logger as log -import numpy as np -from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr - # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed if importlib.util.find_spec('ipywidgets') is not None: @@ -71,6 +73,7 @@ def lr_find( num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, + datamodule: Optional[LightningDataModule] = None, ): r""" lr_find enables the user to do a range test of good initial learning rates, @@ -81,7 +84,7 @@ def lr_find( train_dataloader: A PyTorch DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. + a predefined train_dataloader method, this will be skipped. min_lr: minimum learning rate to investigate @@ -98,6 +101,12 @@ def lr_find( loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None. + datamodule: An optional `LightningDataModule` which holds the training + and validation dataloader(s). Note that the `train_dataloader` and + `val_dataloaders` parameters cannot be used at the same time as + this parameter, or a `MisconfigurationException` will be raised. + + Example:: # Setup model and trainer @@ -167,7 +176,8 @@ def lr_find( # Fit, lr & loss logged in callback trainer.fit(model, train_dataloader=train_dataloader, - val_dataloaders=val_dataloaders) + val_dataloaders=val_dataloaders, + datamodule=datamodule) # Prompt if we stopped early if trainer.global_step != num_training: diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 1f1423a38db56..8c55ffac92c6a 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -15,6 +15,7 @@ from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.tuner.lr_finder import _run_lr_finder_internally, lr_find from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.datamodule import LightningDataModule from typing import Optional, List, Union from torch.utils.data import DataLoader @@ -50,6 +51,7 @@ def lr_find( num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, + datamodule: Optional[LightningDataModule] = None ): return lr_find( self.trainer, @@ -60,7 +62,8 @@ def lr_find( max_lr, num_training, mode, - early_stop_threshold + early_stop_threshold, + datamodule, ) def internal_find_lr(self, trainer, model: LightningModule): diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index cafa79e3f575b..67c673df1318d 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -5,6 +5,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +from tests.base.datamodules import TrialMNISTDataModule def test_error_on_more_than_1_optimizer(tmpdir): @@ -152,6 +153,30 @@ def test_call_to_trainer_method(tmpdir): 'Learning rate was not altered after running learning rate finder' +def test_datamodule_parameter(tmpdir): + """ Test that the datamodule parameter works """ + + # trial datamodule + dm = TrialMNISTDataModule(tmpdir) + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + before_lr = hparams.get('learning_rate') + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + ) + + lrfinder = trainer.tuner.lr_find(model, datamodule=dm) + after_lr = lrfinder.suggestion() + model.learning_rate = after_lr + + assert before_lr != after_lr, \ + 'Learning rate was not altered after running learning rate finder' + + def test_accumulation_and_early_stopping(tmpdir): """ Test that early stopping of learning rate finder works, and that accumulation also works for this feature """