diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 15832a640448c..7e153ad64b2a7 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -61,7 +61,7 @@ jobs: pip install --requirement requirements.txt python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)" pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed - pip install fairscale>=0.3.4 --upgrade-strategy only-if-needed + pip install fairscale>=0.3.4 deepspeed==0.3.14 --upgrade-strategy only-if-needed pip list displayName: 'Install dependencies' diff --git a/.gitignore b/.gitignore index 99939ff7fce0c..fad0342e1168a 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ docs/source/api docs/source/*.md docs/source/generated docs/source/*/generated +docs/source/notebooks # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/CHANGELOG.md b/CHANGELOG.md index e6367977d237b..a6e93ed260b07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [1.3.6] - 2021-06-15 + +### Fixed + +- Fixed logs overwriting issue for remote filesystems ([#7889](https://github.com/PyTorchLightning/pytorch-lightning/pull/7889)) +- Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945)) +- Fixed setting `worker_init_fn` to seed dataloaders correctly when using DDP ([#7942](https://github.com/PyTorchLightning/pytorch-lightning/pull/7942)) +- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931)) + + ## [1.3.5] - 2021-06-08 ### Added diff --git a/docs/source/conf.py b/docs/source/conf.py index 1651a8b08a5a9..002be87b3f1ea 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -132,6 +132,10 @@ def _transform_changelog(path_in: str, path_out: str) -> None: nbsphinx_allow_errors = True nbsphinx_requirejs_path = '' +# myst-parser, forcing to parse all html pages with mathjax +# https://github.com/executablebooks/MyST-Parser/issues/394 +myst_update_mathjax = False + # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # diff --git a/pl_examples/domain_templates/reinforce_learn_Qnet.py b/pl_examples/domain_templates/reinforce_learn_Qnet.py index 70726a748818c..114097df483af 100644 --- a/pl_examples/domain_templates/reinforce_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforce_learn_Qnet.py @@ -34,7 +34,7 @@ import argparse from collections import deque, namedtuple, OrderedDict -from typing import List, Tuple +from typing import Iterator, List, Tuple import gym import numpy as np @@ -139,7 +139,7 @@ def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None: self.buffer = buffer self.sample_size = sample_size - def __iter__(self) -> Tuple: + def __iter__(self) -> Iterator: states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size) for i in range(len(dones)): yield states[i], actions[i], rewards[i], dones[i], new_states[i] diff --git a/pl_examples/domain_templates/reinforce_learn_ppo.py b/pl_examples/domain_templates/reinforce_learn_ppo.py index f3453a5eb86f0..1686aa1954e7a 100644 --- a/pl_examples/domain_templates/reinforce_learn_ppo.py +++ b/pl_examples/domain_templates/reinforce_learn_ppo.py @@ -28,7 +28,7 @@ [3] https://github.com/sid-sundrani/ppo_lightning """ import argparse -from typing import Callable, Iterable, List, Tuple +from typing import Callable, Iterator, List, Tuple import gym import torch @@ -144,7 +144,7 @@ class ExperienceSourceDataset(IterableDataset): def __init__(self, generate_batch: Callable): self.generate_batch = generate_batch - def __iter__(self) -> Iterable: + def __iter__(self) -> Iterator: iterator = self.generate_batch() return iterator diff --git a/pytorch_lightning/__about__.py b/pytorch_lightning/__about__.py index 9471cf85b2b4b..b5c5f46e8911f 100644 --- a/pytorch_lightning/__about__.py +++ b/pytorch_lightning/__about__.py @@ -1,7 +1,7 @@ import time _this_year = time.strftime("%Y") -__version__ = '1.3.5' +__version__ = '1.3.6' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index d3f52b4ba9a15..e88fa6a2ce76a 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -105,7 +105,8 @@ def on_load_checkpoint( @staticmethod def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: """ - This function is used to flatten a module or an iterable of modules into a list of its modules. + This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules + with no children) and parent modules that have parameters directly themselves. Args: modules: A given module or an iterable of modules @@ -121,8 +122,8 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) - else: _modules = modules.modules() - # Leaf nodes in the graph have no children, so we use that to filter - return [m for m in _modules if not list(m.children())] + # Capture all leaf modules as well as parent modules that have parameters directly themsleves + return [m for m in _modules if not list(m.children()) or m._parameters] @staticmethod def filter_params( @@ -136,7 +137,6 @@ def filter_params( modules: A given module or an iterable of modules train_bn: Whether to train BatchNorm module requires_grad: Whether to create a generator for trainable or non-trainable parameters. - Returns: Generator """ @@ -144,7 +144,8 @@ def filter_params( for mod in modules: if isinstance(mod, _BatchNorm) and not train_bn: continue - for param in mod.parameters(): + # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it + for param in mod.parameters(recurse=False): if param.requires_grad == requires_grad: yield param @@ -158,7 +159,8 @@ def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> """ modules = BaseFinetuning.flatten_modules(modules) for module in modules: - for param in module.parameters(): + # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it + for param in module.parameters(recurse=False): param.requires_grad = True @staticmethod @@ -178,7 +180,8 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: if isinstance(mod, _BatchNorm) and train_bn: BaseFinetuning.make_trainable(mod) else: - for param in mod.parameters(): + # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it + for param in mod.parameters(recurse=False): param.requires_grad = False @staticmethod diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 23626ed9cbeae..63a79ae902c4b 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -20,7 +20,6 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks -from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types @@ -329,7 +328,7 @@ def test_dataloader(): def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule': obj = super().__new__(cls) # track `DataHooks` calls and run `prepare_data` only on rank zero - obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data)) + obj.prepare_data = cls._track_data_hook_calls(obj, obj.prepare_data) obj.setup = cls._track_data_hook_calls(obj, obj.setup) obj.teardown = cls._track_data_hook_calls(obj, obj.teardown) return obj diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index af3802476571b..b8a8c16f311e8 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -254,14 +254,16 @@ def version(self) -> int: return self._version def _get_next_version(self): - root_dir = os.path.join(self.save_dir, self.name) + root_dir = self.root_dir - if not self._fs.isdir(root_dir): + try: + listdir_info = self._fs.listdir(root_dir) + except OSError: log.warning('Missing logger folder: %s', root_dir) return 0 existing_versions = [] - for listing in self._fs.listdir(root_dir): + for listing in listdir_info: d = listing["name"] bn = os.path.basename(d) if self._fs.isdir(d) and bn.startswith("version_"): diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 51547d5576e74..7c20b7d1b3b1e 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -84,8 +84,9 @@ def reset_seed() -> None: If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing. """ seed = os.environ.get("PL_GLOBAL_SEED", None) + workers = os.environ.get("PL_SEED_WORKERS", False) if seed is not None: - seed_everything(int(seed)) + seed_everything(int(seed), workers=bool(workers)) def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover @@ -100,6 +101,9 @@ def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # p process_seed = torch.initial_seed() # back out the base seed so we can use all the bits base_seed = process_seed - worker_id + log.debug( + f'Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}' + ) ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) # use 128 bits (4 x 32-bit words) np.random.seed(ss.generate_state(4)) diff --git a/requirements/adjust_versions.py b/requirements/adjust_versions.py index 3d9da2a2f1a22..a09128c6200db 100644 --- a/requirements/adjust_versions.py +++ b/requirements/adjust_versions.py @@ -40,6 +40,8 @@ def main(path_req: str, torch_version: Optional[str] = None) -> None: with open(path_req, "r") as fp: req = fp.read() + # remove comments + req = re.sub(rf"\s*#.*{os.linesep}", os.linesep, req) latest = find_latest(torch_version) for lib, version in latest.items(): diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 53d34c4645bef..0d6a0e3f0a3d1 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -307,7 +307,11 @@ def configure_optimizers(self): trainer.fit(model) -def test_deep_nested_model(): +def test_complex_nested_model(): + """ + Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters + directly themselves rather than exclusively their submodules containing parameters. + """ class ConvBlock(nn.Module): @@ -322,23 +326,39 @@ def forward(self, x): x = self.act(x) return self.bn(x) + class ConvBlockParam(nn.Module): + + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, 3) + self.act = nn.ReLU() + # add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling + self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float)) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.act(x) + return self.bn(x) + model = nn.Sequential( OrderedDict([ - ("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))), + ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), ("decoder", ConvBlock(128, 10)), ]) ) - # There's 9 leaf layers in that model - assert len(BaseFinetuning.flatten_modules(model)) == 9 + # There are 10 leaf modules or parent modules w/ parameters in the test model + assert len(BaseFinetuning.flatten_modules(model)) == 10 BaseFinetuning.freeze(model.encoder, train_bn=True) - assert not model.encoder[0].conv.weight.requires_grad + assert not model.encoder[0].conv.weight.requires_grad # Validate a leaf module parameter is frozen + assert not model.encoder[0].parent_param.requires_grad # Validate the parent module parameter is frozen assert model.encoder[0].bn.weight.requires_grad BaseFinetuning.make_trainable(model) encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True)) - # The 8 parameters of the encoder are: - # conv0.weight, conv0.bias, bn0.weight, bn0.bias + # The 9 parameters of the encoder are: + # conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param # conv1.weight, conv1.bias, bn1.weight, bn1.bias - assert len(encoder_params) == 8 + assert len(encoder_params) == 9 diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 7cfa569115550..b443118fe25a2 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -34,6 +34,7 @@ @mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) def test_can_prepare_data(local_rank, node_rank): + model = BoringModel() dm = BoringDataModule() trainer = Trainer() trainer.datamodule = dm @@ -43,30 +44,54 @@ def test_can_prepare_data(local_rank, node_rank): # local rank = 0 (True) trainer.prepare_data_per_node = True + dm.random_full = None + dm._has_prepared_data = False local_rank.return_value = 0 assert trainer.local_rank == 0 assert trainer.data_connector.can_prepare_data() + trainer.data_connector.prepare_data(model) + assert dm.random_full is not None + # local rank = 1 (False) + dm.random_full = None + dm._has_prepared_data = False local_rank.return_value = 1 assert trainer.local_rank == 1 assert not trainer.data_connector.can_prepare_data() + trainer.data_connector.prepare_data(model) + assert dm.random_full is None + # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) + dm.random_full = None + dm._has_prepared_data = False trainer.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 assert trainer.data_connector.can_prepare_data() + trainer.data_connector.prepare_data(model) + assert dm.random_full is not None + # global rank = 1 (False) + dm.random_full = None + dm._has_prepared_data = False node_rank.return_value = 1 local_rank.return_value = 0 assert not trainer.data_connector.can_prepare_data() + + trainer.data_connector.prepare_data(model) + assert dm.random_full is None + node_rank.return_value = 0 local_rank.return_value = 1 assert not trainer.data_connector.can_prepare_data() + trainer.data_connector.prepare_data(model) + assert dm.random_full is None + # 2 dm # prepar per node = True # local rank = 0 (True) diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 350120b77c1e2..2b44893c71c20 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from argparse import Namespace from unittest import mock @@ -320,3 +321,15 @@ def test_tensorboard_with_symlink(log, tmpdir): _ = logger.version log.warning.assert_not_called() + + +def test_tensorboard_missing_folder_warning(tmpdir, caplog): + """Verify that the logger throws a warning for invalid directory""" + + name = "fake_dir" + logger = TensorBoardLogger(save_dir=tmpdir, name=name) + + with caplog.at_level(logging.WARNING): + assert logger.version == 0 + + assert 'Missing logger folder:' in caplog.text