Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weekly Patch Release v1.3.6 [full merge, no squash] #7986

Merged
merged 10 commits into from
Jun 17, 2021
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ docs/source/api
docs/source/*.md
docs/source/generated
docs/source/*/generated
docs/source/notebooks

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.3.6] - 2021-06-15

### Fixed

- Fixed logs overwriting issue for remote filesystems ([#7889](https://github.com/PyTorchLightning/pytorch-lightning/pull/7889))
- Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945))
- Fixed setting `worker_init_fn` to seed dataloaders correctly when using DDP ([#7942](https://github.com/PyTorchLightning/pytorch-lightning/pull/7942))
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))


## [1.3.5] - 2021-06-08

### Added
Expand Down
4 changes: 4 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
nbsphinx_allow_errors = True
nbsphinx_requirejs_path = ''

# myst-parser, forcing to parse all html pages with mathjax
# https://github.com/executablebooks/MyST-Parser/issues/394
myst_update_mathjax = False

# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
Expand Down
17 changes: 10 additions & 7 deletions pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def on_load_checkpoint(
@staticmethod
def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]:
"""
This function is used to flatten a module or an iterable of modules into a list of its modules.
This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules
with no children) and parent modules that have parameters directly themselves.

Args:
modules: A given module or an iterable of modules
Expand All @@ -121,8 +122,8 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
else:
_modules = modules.modules()

# Leaf nodes in the graph have no children, so we use that to filter
return [m for m in _modules if not list(m.children())]
# Capture all leaf modules as well as parent modules that have parameters directly themsleves
return [m for m in _modules if not list(m.children()) or m._parameters]

@staticmethod
def filter_params(
Expand All @@ -136,15 +137,15 @@ def filter_params(
modules: A given module or an iterable of modules
train_bn: Whether to train BatchNorm module
requires_grad: Whether to create a generator for trainable or non-trainable parameters.

Returns:
Generator
"""
modules = BaseFinetuning.flatten_modules(modules)
for mod in modules:
if isinstance(mod, _BatchNorm) and not train_bn:
continue
for param in mod.parameters():
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in mod.parameters(recurse=False):
if param.requires_grad == requires_grad:
yield param

Expand All @@ -158,7 +159,8 @@ def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) ->
"""
modules = BaseFinetuning.flatten_modules(modules)
for module in modules:
for param in module.parameters():
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in module.parameters(recurse=False):
param.requires_grad = True

@staticmethod
Expand All @@ -178,7 +180,8 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn:
if isinstance(mod, _BatchNorm) and train_bn:
BaseFinetuning.make_trainable(mod)
else:
for param in mod.parameters():
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in mod.parameters(recurse=False):
param.requires_grad = False

@staticmethod
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from torch.utils.data import DataLoader, Dataset, IterableDataset

from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types


Expand Down Expand Up @@ -329,7 +328,7 @@ def test_dataloader():
def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
obj = super().__new__(cls)
# track `DataHooks` calls and run `prepare_data` only on rank zero
obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data))
obj.prepare_data = cls._track_data_hook_calls(obj, obj.prepare_data)
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
return obj
Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,16 @@ def version(self) -> int:
return self._version

def _get_next_version(self):
root_dir = os.path.join(self.save_dir, self.name)
root_dir = self.root_dir

if not self._fs.isdir(root_dir):
try:
listdir_info = self._fs.listdir(root_dir)
except OSError:
log.warning('Missing logger folder: %s', root_dir)
return 0

existing_versions = []
for listing in self._fs.listdir(root_dir):
for listing in listdir_info:
d = listing["name"]
bn = os.path.basename(d)
if self._fs.isdir(d) and bn.startswith("version_"):
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ def reset_seed() -> None:
If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing.
"""
seed = os.environ.get("PL_GLOBAL_SEED", None)
workers = os.environ.get("PL_SEED_WORKERS", False)
if seed is not None:
seed_everything(int(seed))
seed_everything(int(seed), workers=bool(workers))


def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover
Expand All @@ -100,6 +101,9 @@ def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # p
process_seed = torch.initial_seed()
# back out the base seed so we can use all the bits
base_seed = process_seed - worker_id
log.debug(
f'Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}'
)
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
# use 128 bits (4 x 32-bit words)
np.random.seed(ss.generate_state(4))
Expand Down
2 changes: 2 additions & 0 deletions requirements/adjust_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def main(path_req: str, torch_version: Optional[str] = None) -> None:

with open(path_req, "r") as fp:
req = fp.read()
# remove comments
req = re.sub(rf"\s*#.*{os.linesep}", os.linesep, req)

latest = find_latest(torch_version)
for lib, version in latest.items():
Expand Down
36 changes: 28 additions & 8 deletions tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,11 @@ def configure_optimizers(self):
trainer.fit(model)


def test_deep_nested_model():
def test_complex_nested_model():
"""
Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
directly themselves rather than exclusively their submodules containing parameters.
"""

class ConvBlock(nn.Module):

Expand All @@ -322,23 +326,39 @@ def forward(self, x):
x = self.act(x)
return self.bn(x)

class ConvBlockParam(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3)
self.act = nn.ReLU()
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
self.bn = nn.BatchNorm2d(out_channels)

def forward(self, x):
x = self.conv(x)
x = self.act(x)
return self.bn(x)

model = nn.Sequential(
OrderedDict([
("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))),
("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))),
("decoder", ConvBlock(128, 10)),
])
)

# There's 9 leaf layers in that model
assert len(BaseFinetuning.flatten_modules(model)) == 9
# There are 10 leaf modules or parent modules w/ parameters in the test model
assert len(BaseFinetuning.flatten_modules(model)) == 10

BaseFinetuning.freeze(model.encoder, train_bn=True)
assert not model.encoder[0].conv.weight.requires_grad
assert not model.encoder[0].conv.weight.requires_grad # Validate a leaf module parameter is frozen
assert not model.encoder[0].parent_param.requires_grad # Validate the parent module parameter is frozen
assert model.encoder[0].bn.weight.requires_grad

BaseFinetuning.make_trainable(model)
encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True))
# The 8 parameters of the encoder are:
# conv0.weight, conv0.bias, bn0.weight, bn0.bias
# The 9 parameters of the encoder are:
# conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param
# conv1.weight, conv1.bias, bn1.weight, bn1.bias
assert len(encoder_params) == 8
assert len(encoder_params) == 9
25 changes: 25 additions & 0 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
def test_can_prepare_data(local_rank, node_rank):

model = BoringModel()
dm = BoringDataModule()
trainer = Trainer()
trainer.datamodule = dm
Expand All @@ -43,30 +44,54 @@ def test_can_prepare_data(local_rank, node_rank):
# local rank = 0 (True)
trainer.prepare_data_per_node = True

dm.random_full = None
dm._has_prepared_data = False
local_rank.return_value = 0
assert trainer.local_rank == 0
assert trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data(model)
assert dm.random_full is not None

# local rank = 1 (False)
dm.random_full = None
dm._has_prepared_data = False
local_rank.return_value = 1
assert trainer.local_rank == 1
assert not trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data(model)
assert dm.random_full is None

# prepare_data_per_node = False (prepare across all nodes)
# global rank = 0 (True)
dm.random_full = None
dm._has_prepared_data = False
trainer.prepare_data_per_node = False
node_rank.return_value = 0
local_rank.return_value = 0
assert trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data(model)
assert dm.random_full is not None

# global rank = 1 (False)
dm.random_full = None
dm._has_prepared_data = False
node_rank.return_value = 1
local_rank.return_value = 0
assert not trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data(model)
assert dm.random_full is None

node_rank.return_value = 0
local_rank.return_value = 1
assert not trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data(model)
assert dm.random_full is None

# 2 dm
# prepar per node = True
# local rank = 0 (True)
Expand Down
13 changes: 13 additions & 0 deletions tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from argparse import Namespace
from unittest import mock
Expand Down Expand Up @@ -320,3 +321,15 @@ def test_tensorboard_with_symlink(log, tmpdir):
_ = logger.version

log.warning.assert_not_called()


def test_tensorboard_missing_folder_warning(tmpdir, caplog):
"""Verify that the logger throws a warning for invalid directory"""

name = "fake_dir"
logger = TensorBoardLogger(save_dir=tmpdir, name=name)

with caplog.at_level(logging.WARNING):
assert logger.version == 0

assert 'Missing logger folder:' in caplog.text