diff --git a/CHANGELOG.md b/CHANGELOG.md index 223cabc4f4eb4..4e6505890004f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for datamodules to save and load checkpoints when training ([#3563]https://github.com/PyTorchLightning/pytorch-lightning/pull/3563) +- Added support for datamodule in learning rate finder ([#3425](https://github.com/PyTorchLightning/pytorch-lightning/pull/3425)) + ### Changed - Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251)) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 71756678af9c5..a3ba2550186a7 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -11,21 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import os +from typing import List, Optional, Sequence, Union + +import numpy as np import torch -from typing import Optional, Sequence, List, Union +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader + +from pytorch_lightning import _logger as log +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch.optim.lr_scheduler import _LRScheduler -import importlib -from pytorch_lightning import _logger as log -import numpy as np -from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr - # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed if importlib.util.find_spec('ipywidgets') is not None: @@ -71,6 +73,7 @@ def lr_find( num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, + datamodule: Optional[LightningDataModule] = None, ): r""" lr_find enables the user to do a range test of good initial learning rates, @@ -81,7 +84,7 @@ def lr_find( train_dataloader: A PyTorch DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. + a predefined train_dataloader method, this will be skipped. min_lr: minimum learning rate to investigate @@ -98,6 +101,12 @@ def lr_find( loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None. + datamodule: An optional `LightningDataModule` which holds the training + and validation dataloader(s). Note that the `train_dataloader` and + `val_dataloaders` parameters cannot be used at the same time as + this parameter, or a `MisconfigurationException` will be raised. + + Example:: # Setup model and trainer @@ -167,7 +176,8 @@ def lr_find( # Fit, lr & loss logged in callback trainer.fit(model, train_dataloader=train_dataloader, - val_dataloaders=val_dataloaders) + val_dataloaders=val_dataloaders, + datamodule=datamodule) # Prompt if we stopped early if trainer.global_step != num_training: diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 1f1423a38db56..8c55ffac92c6a 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -15,6 +15,7 @@ from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.tuner.lr_finder import _run_lr_finder_internally, lr_find from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.datamodule import LightningDataModule from typing import Optional, List, Union from torch.utils.data import DataLoader @@ -50,6 +51,7 @@ def lr_find( num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, + datamodule: Optional[LightningDataModule] = None ): return lr_find( self.trainer, @@ -60,7 +62,8 @@ def lr_find( max_lr, num_training, mode, - early_stop_threshold + early_stop_threshold, + datamodule, ) def internal_find_lr(self, trainer, model: LightningModule): diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index cafa79e3f575b..67c673df1318d 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -5,6 +5,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +from tests.base.datamodules import TrialMNISTDataModule def test_error_on_more_than_1_optimizer(tmpdir): @@ -152,6 +153,30 @@ def test_call_to_trainer_method(tmpdir): 'Learning rate was not altered after running learning rate finder' +def test_datamodule_parameter(tmpdir): + """ Test that the datamodule parameter works """ + + # trial datamodule + dm = TrialMNISTDataModule(tmpdir) + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + before_lr = hparams.get('learning_rate') + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + ) + + lrfinder = trainer.tuner.lr_find(model, datamodule=dm) + after_lr = lrfinder.suggestion() + model.learning_rate = after_lr + + assert before_lr != after_lr, \ + 'Learning rate was not altered after running learning rate finder' + + def test_accumulation_and_early_stopping(tmpdir): """ Test that early stopping of learning rate finder works, and that accumulation also works for this feature """