Skip to content

Commit

Permalink
Properly handle parent modules w/ parameters in BaseFinetuning call…
Browse files Browse the repository at this point in the history
…back (#7931)



Co-authored-by: Daniel Dale <dan@distributedinsight.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
5 people authored Jun 14, 2021
1 parent ce93d8b commit 3a0ed02
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `dataloader_idx` argument value when predicting with only one `DataLoader` ([#7941](https://github.com/PyTorchLightning/pytorch-lightning/pull/7941))


- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))


## [1.3.5] - 2021-06-08

### Added
Expand Down
17 changes: 10 additions & 7 deletions pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def on_load_checkpoint(
@staticmethod
def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]:
"""
This function is used to flatten a module or an iterable of modules into a list of its modules.
This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules
with no children) and parent modules that have parameters directly themselves.
Args:
modules: A given module or an iterable of modules
Expand All @@ -121,8 +122,8 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
else:
_modules = modules.modules()

# Leaf nodes in the graph have no children, so we use that to filter
return [m for m in _modules if not list(m.children())]
# Capture all leaf modules as well as parent modules that have parameters directly themsleves
return [m for m in _modules if not list(m.children()) or m._parameters]

@staticmethod
def filter_params(
Expand All @@ -136,15 +137,15 @@ def filter_params(
modules: A given module or an iterable of modules
train_bn: Whether to train BatchNorm module
requires_grad: Whether to create a generator for trainable or non-trainable parameters.
Returns:
Generator
"""
modules = BaseFinetuning.flatten_modules(modules)
for mod in modules:
if isinstance(mod, _BatchNorm) and not train_bn:
continue
for param in mod.parameters():
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in mod.parameters(recurse=False):
if param.requires_grad == requires_grad:
yield param

Expand All @@ -158,7 +159,8 @@ def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) ->
"""
modules = BaseFinetuning.flatten_modules(modules)
for module in modules:
for param in module.parameters():
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in module.parameters(recurse=False):
param.requires_grad = True

@staticmethod
Expand All @@ -178,7 +180,8 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn:
if isinstance(mod, _BatchNorm) and train_bn:
BaseFinetuning.make_trainable(mod)
else:
for param in mod.parameters():
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in mod.parameters(recurse=False):
param.requires_grad = False

@staticmethod
Expand Down
36 changes: 28 additions & 8 deletions tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,11 @@ def configure_optimizers(self):
trainer.fit(model)


def test_deep_nested_model():
def test_complex_nested_model():
"""
Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
directly themselves rather than exclusively their submodules containing parameters.
"""

class ConvBlock(nn.Module):

Expand All @@ -322,23 +326,39 @@ def forward(self, x):
x = self.act(x)
return self.bn(x)

class ConvBlockParam(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3)
self.act = nn.ReLU()
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
self.bn = nn.BatchNorm2d(out_channels)

def forward(self, x):
x = self.conv(x)
x = self.act(x)
return self.bn(x)

model = nn.Sequential(
OrderedDict([
("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))),
("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))),
("decoder", ConvBlock(128, 10)),
])
)

# There's 9 leaf layers in that model
assert len(BaseFinetuning.flatten_modules(model)) == 9
# There are 10 leaf modules or parent modules w/ parameters in the test model
assert len(BaseFinetuning.flatten_modules(model)) == 10

BaseFinetuning.freeze(model.encoder, train_bn=True)
assert not model.encoder[0].conv.weight.requires_grad
assert not model.encoder[0].conv.weight.requires_grad # Validate a leaf module parameter is frozen
assert not model.encoder[0].parent_param.requires_grad # Validate the parent module parameter is frozen
assert model.encoder[0].bn.weight.requires_grad

BaseFinetuning.make_trainable(model)
encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True))
# The 8 parameters of the encoder are:
# conv0.weight, conv0.bias, bn0.weight, bn0.bias
# The 9 parameters of the encoder are:
# conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param
# conv1.weight, conv1.bias, bn1.weight, bn1.bias
assert len(encoder_params) == 8
assert len(encoder_params) == 9

0 comments on commit 3a0ed02

Please sign in to comment.