Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactoring] Revise function of layers and override keys in init_cfg #893

Merged
merged 5 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 30 additions & 23 deletions mmcv/cnn/utils/weight_init.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
import warnings

import numpy as np
import torch.nn as nn

Expand Down Expand Up @@ -78,6 +80,7 @@ def bias_init_with_prob(prior_prob):
class BaseInit(object):

def __init__(self, *, bias=0, bias_prob=None, layer=None):
self.overmodule = False
MeowZheng marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(bias, (int, float)):
raise TypeError(f'bias must be a numbel, but got a {type(bias)}')

Expand All @@ -90,7 +93,11 @@ def __init__(self, *, bias=0, bias_prob=None, layer=None):
if not isinstance(layer, (str, list)):
raise TypeError(f'layer must be a str or a list of str, \
but got a {type(layer)}')

else:
layer = []
warnings.warn(
'init_cfg without layer key, if you do not define override'
' key either, this init_cfg will do nothing')
if bias_prob is not None:
self.bias = bias_init_with_prob(bias_prob)
else:
Expand Down Expand Up @@ -119,13 +126,12 @@ def __init__(self, val, **kwargs):
def __call__(self, module):

def init(m):
if self.layer is None:
if self.overmodule:
constant_init(m, self.val, self.bias)
else:
layername = m.__class__.__name__
for layer_ in self.layer:
if layername == layer_:
constant_init(m, self.val, self.bias)
if layername in self.layer:
constant_init(m, self.val, self.bias)

module.apply(init)

Expand Down Expand Up @@ -157,13 +163,12 @@ def __init__(self, gain=1, distribution='normal', **kwargs):
def __call__(self, module):

def init(m):
if self.layer is None:
if self.overmodule:
xavier_init(m, self.gain, self.bias, self.distribution)
else:
layername = m.__class__.__name__
for layer_ in self.layer:
if layername == layer_:
xavier_init(m, self.gain, self.bias, self.distribution)
if layername in self.layer:
xavier_init(m, self.gain, self.bias, self.distribution)

module.apply(init)

Expand Down Expand Up @@ -194,7 +199,7 @@ def __init__(self, mean=0, std=1, **kwargs):
def __call__(self, module):

def init(m):
if self.layer is None:
if self.overmodule:
normal_init(m, self.mean, self.std, self.bias)
else:
layername = m.__class__.__name__
Expand Down Expand Up @@ -231,13 +236,12 @@ def __init__(self, a=0, b=1, **kwargs):
def __call__(self, module):

def init(m):
if self.layer is None:
if self.overmodule:
uniform_init(m, self.a, self.b, self.bias)
else:
layername = m.__class__.__name__
for layer_ in self.layer:
if layername == layer_:
uniform_init(m, self.a, self.b, self.bias)
if layername in self.layer:
uniform_init(m, self.a, self.b, self.bias)

module.apply(init)

Expand Down Expand Up @@ -285,15 +289,14 @@ def __init__(self,
def __call__(self, module):

def init(m):
if self.layer is None:
if self.overmodule:
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)
else:
layername = m.__class__.__name__
for layer_ in self.layer:
if layername == layer_:
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)
if layername in self.layer:
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)

module.apply(init)

Expand Down Expand Up @@ -339,22 +342,25 @@ def __call__(self, module):
load_state_dict(module, state_dict, strict=False, logger=logger)


def _initialize(module, cfg):
def _initialize(module, cfg, overmodule=False):
func = build_from_cfg(cfg, INITIALIZERS)
func.overmodule = overmodule
func(module)


def _initialize_override(module, override):
def _initialize_override(module, override, cfg):
if not isinstance(override, (dict, list)):
raise TypeError(f'override must be a dict or a list of dict, \
but got {type(override)}')

override = [override] if isinstance(override, dict) else override

for override_ in override:
if 'type' not in override_.keys():
override_.update(cfg)
name = override_.pop('name', None)
if hasattr(module, name):
_initialize(getattr(module, name), override_)
_initialize(getattr(module, name), override_, overmodule=True)
else:
raise RuntimeError(f'module did not have attribute {name}')

Expand Down Expand Up @@ -424,7 +430,8 @@ def initialize(module, init_cfg):
_initialize(module, cfg)

if override is not None:
_initialize_override(module, override)
cfg.pop('layer', None)
_initialize_override(module, override, cfg)
else:
# All attributes in module have same initialization.
pass
27 changes: 10 additions & 17 deletions tests/test_cnn/test_weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,6 @@ def test_constaninit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))

func = ConstantInit(val=4, bias=5)
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 4.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 4.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 5.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 5.))

# test bias input type
with pytest.raises(TypeError):
func = ConstantInit(val=1, bias='1')
Expand All @@ -128,8 +121,8 @@ def test_xavierinit():
assert model[0].bias.allclose(torch.full_like(model[2].bias, 0.1))
assert not model[2].bias.allclose(torch.full_like(model[0].bias, 0.1))

constant_func = ConstantInit(val=0, bias=0)
func = XavierInit(gain=100, bias_prob=0.01)
constant_func = ConstantInit(val=0, bias=0, layer=['Conv2d', 'Linear'])
func = XavierInit(gain=100, bias_prob=0.01, layer=['Conv2d', 'Linear'])
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
Expand Down Expand Up @@ -157,7 +150,7 @@ def test_normalinit():
"""test Normalinit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))

func = NormalInit(mean=100, std=1e-5, bias=200)
func = NormalInit(mean=100, std=1e-5, bias=200, layer=['Conv2d', 'Linear'])
func(model)
assert model[0].weight.allclose(torch.tensor(100.))
assert model[2].weight.allclose(torch.tensor(100.))
Expand All @@ -177,7 +170,7 @@ def test_normalinit():
def test_uniforminit():
""""test UniformInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = UniformInit(a=1, b=1, bias=2)
func = UniformInit(a=1, b=1, bias=2, layer=['Conv2d', 'Linear'])
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
Expand All @@ -202,8 +195,8 @@ def test_kaiminginit():
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))

func = KaimingInit(a=100, bias=10)
constant_func = ConstantInit(val=0, bias=0)
func = KaimingInit(a=100, bias=10, layer=['Conv2d', 'Linear'])
constant_func = ConstantInit(val=0, bias=0, layer=['Conv2d', 'Linear'])
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
Expand Down Expand Up @@ -232,7 +225,7 @@ def test_pretrainedinit():
"""test PretrainedInit class."""

modelA = FooModule()
constant_func = ConstantInit(val=1, bias=2)
constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear'])
modelA.apply(constant_func)
modelB = FooModule()
funcB = PretrainedInit(checkpoint='modelA.pth')
Expand Down Expand Up @@ -263,15 +256,15 @@ def test_initialize():
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
foonet = FooModule()

init_cfg = dict(type='Constant', val=1, bias=2)
init_cfg = dict(type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2)
initialize(model, init_cfg)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))

init_cfg = [
dict(type='Constant', layer='Conv1d', val=1, bias=2),
dict(type='Constant', layer='Conv2d', val=1, bias=2),
dict(type='Constant', layer='Linear', val=3, bias=4)
]
initialize(model, init_cfg)
Expand Down Expand Up @@ -305,7 +298,7 @@ def test_initialize():
checkpoint='modelA.pth',
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
modelA = FooModule()
constant_func = ConstantInit(val=1, bias=2)
constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear'])
modelA.apply(constant_func)
with TemporaryDirectory():
torch.save(modelA.state_dict(), 'modelA.pth')
Expand Down
91 changes: 89 additions & 2 deletions tests/test_runner/test_basemodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,11 @@ def test_nest_components_weight_init():
dict(type='Constant', val=5, bias=6, layer='Conv2d'),
],
component1=dict(
type='FooConv1d', init_cfg=dict(type='Constant', val=7, bias=8)),
type='FooConv1d',
init_cfg=dict(type='Constant', layer='Conv1d', val=7, bias=8)),
component2=dict(
type='FooConv2d', init_cfg=dict(type='Constant', val=9, bias=10)),
type='FooConv2d',
init_cfg=dict(type='Constant', layer='Conv2d', val=9, bias=10)),
component3=dict(type='FooLinear'),
component4=dict(
type='FooLinearConv1d',
Expand Down Expand Up @@ -226,3 +228,88 @@ def test_nest_components_weight_init():
assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 13.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 14.0))


def test_without_layer_weight_init():
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(type='Constant', val=1, bias=2, layer='Linear'),
dict(type='Constant', val=3, bias=4, layer='Conv1d'),
dict(type='Constant', val=5, bias=6, layer='Conv2d')
],
component1=dict(
type='FooConv1d', init_cfg=dict(type='Constant', val=7, bias=8)),
component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight()

assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 3.0))
assert torch.equal(model.component1.conv1d.bias,
torch.full(model.component1.conv1d.bias.shape, 4.0))

# init_cfg in component1 does not have layer key, so it does nothing
assert torch.equal(model.component2.conv2d.weight,
torch.full(model.component2.conv2d.weight.shape, 5.0))
assert torch.equal(model.component2.conv2d.bias,
torch.full(model.component2.conv2d.bias.shape, 6.0))
assert torch.equal(model.component3.linear.weight,
torch.full(model.component3.linear.weight.shape, 1.0))
assert torch.equal(model.component3.linear.bias,
torch.full(model.component3.linear.bias.shape, 2.0))

assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape,
1.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0))


def test_override_weight_init():

# only initialize 'override'
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(type='Constant', val=10, bias=20, override=dict(name='reg'))
],
component1=dict(type='FooConv1d'),
component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight()
assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 10.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 20.0))
# do not initialize others
assert not torch.equal(
model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 10.0))
assert not torch.equal(
model.component1.conv1d.bias,
torch.full(model.component1.conv1d.bias.shape, 20.0))
assert not torch.equal(
model.component3.linear.weight,
torch.full(model.component3.linear.weight.shape, 10.0))
assert not torch.equal(
model.component3.linear.bias,
torch.full(model.component3.linear.bias.shape, 20.0))

# 'override' has higher priority
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(
type='Constant',
val=1,
bias=2,
override=dict(name='reg', type='Constant', val=30, bias=40))
],
component1=dict(type='FooConv1d'),
component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight()

assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 30.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 40.0))