Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin committed Apr 28, 2021
1 parent 295807a commit 10c34b8
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 54 deletions.
25 changes: 16 additions & 9 deletions mmaction/models/backbones/resnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, _load_checkpoint
Expand Down Expand Up @@ -236,7 +234,8 @@ def make_res_layer(block,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
with_cp=False):
with_cp=False,
**kwargs):
"""Build residual layer for ResNet.
Args:
Expand Down Expand Up @@ -282,7 +281,8 @@ def make_res_layer(block,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
with_cp=with_cp))
with_cp=with_cp,
**kwargs))
inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
Expand All @@ -295,7 +295,8 @@ def make_res_layer(block,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
with_cp=with_cp))
with_cp=with_cp,
**kwargs))

return nn.Sequential(*layers)

Expand Down Expand Up @@ -361,6 +362,7 @@ def __init__(self,
init_cfg=None):
super().__init__(init_cfg)
self.zero_init_residual = zero_init_residual
self.torchvision_pretrain = torchvision_pretrain
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')

Expand All @@ -369,9 +371,8 @@ def __init__(self,
and pretrained), ('init_cfg and pretrained cannot '
'be setting at the same time')
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
if not self.torchvision_pretrain:
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
Expand Down Expand Up @@ -399,7 +400,6 @@ def __init__(self,
self.depth = depth
self.in_channels = in_channels
self.pretrained = pretrained
self.torchvision_pretrain = torchvision_pretrain
self.num_stages = num_stages
assert 1 <= num_stages <= 4
self.out_indices = out_indices
Expand Down Expand Up @@ -553,6 +553,13 @@ def _load_torchvision_checkpoint(self, logger=None):
f'These parameters in pretrained checkpoint are not loaded'
f': {remaining_names}')

def init_weights(self):
if self.torchvision_pretrain:
logger = get_root_logger()
self._load_torchvision_checkpoint(logger)
else:
super().init_weights()

def forward(self, x):
"""Defines the computation performed at every call.
Expand Down
19 changes: 13 additions & 6 deletions mmaction/models/backbones/resnet3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ class ResNet3d(BaseModule):

def __init__(self,
depth,
pretrained,
pretrained=None,
pretrained2d=True,
in_channels=3,
num_stages=4,
Expand Down Expand Up @@ -430,14 +430,13 @@ def __init__(self,
and pretrained), ('init_cfg and pretrained cannot '
'be setting at the same time')
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.pretrained = pretrained
if not self.pretrained2d:
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(type='Kaiming', layer='Conv3d'),
dict(
type='Constant',
val=1,
Expand Down Expand Up @@ -801,10 +800,14 @@ def _freeze_stages(self):
for param in m.parameters():
param.requires_grad = False

def init_weights(self):
def init_weights(self, pretrained=None):
if pretrained:
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
if not self.pretrained2d:
super().init_weights()
else:
assert pretrained is not None
self.pretrained = pretrained
logger = get_root_logger()
self.inflate_weights(logger)

Expand Down Expand Up @@ -1008,10 +1011,14 @@ def _freeze_stages(self):
for param in layer.parameters():
param.requires_grad = False

def init_weights(self):
def init_weights(self, pretrained=None):
if pretrained:
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
if not self.pretrained2d:
super().init_weights()
else:
assert pretrained is not None
self.pretrained = pretrained
logger = get_root_logger()
self.inflate_weights(logger)

Expand Down
11 changes: 7 additions & 4 deletions mmaction/models/backbones/resnet3d_slowfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def make_res_layer(self,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
with_cp=False):
with_cp=False,
**kwargs):
"""Build residual layer for Slowfast.
Args:
Expand Down Expand Up @@ -178,7 +179,8 @@ def make_res_layer(self,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
with_cp=with_cp))
with_cp=with_cp,
**kwargs))
inplanes = planes * block.expansion

for i in range(1, blocks):
Expand All @@ -197,7 +199,8 @@ def make_res_layer(self,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
with_cp=with_cp))
with_cp=with_cp,
**kwargs))

return nn.Sequential(*layers)

Expand Down Expand Up @@ -330,7 +333,7 @@ def init_weights(self, pretrained=None):
self.pretrained = pretrained

# Override the init_weights of i3d
super().init_weights()
super().init_weights(pretrained=self.pretrained)
for module_name in self.lateral_connections:
layer = getattr(self, module_name)
for m in layer.modules():
Expand Down
84 changes: 49 additions & 35 deletions mmaction/models/backbones/resnet_audio.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import warnings

import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _ntuple

from mmaction.models.registry import BACKBONES
from mmaction.utils import get_root_logger


class Bottleneck2dAudio(nn.Module):
class Bottleneck2dAudio(BaseModule):
"""Bottleneck2D block for ResNet2D.
Args:
Expand All @@ -35,8 +36,9 @@ def __init__(self,
downsample=None,
factorize=True,
norm_cfg=None,
with_cp=False):
super().__init__()
with_cp=False,
init_cfg=None):
super().__init__(init_cfg)

self.inplanes = inplanes
self.planes = planes
Expand Down Expand Up @@ -109,7 +111,7 @@ def _inner_forward(x):


@BACKBONES.register_module()
class ResNetAudio(nn.Module):
class ResNetAudio(BaseModule):
"""ResNet 2d audio backbone. Reference:
<https://arxiv.org/abs/2001.08740>`_.
Expand Down Expand Up @@ -171,10 +173,40 @@ def __init__(self,
conv_cfg=dict(type='Conv'),
norm_cfg=dict(type='BN2d', requires_grad=True),
act_cfg=dict(type='ReLU', inplace=True),
zero_init_residual=True):
super().__init__()
zero_init_residual=True,
init_cfg=None):
super().__init__(init_cfg)
self.zero_init_residual = zero_init_residual
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')

block_init_cfg = None
assert not (init_cfg
and pretrained), ('init_cfg and pretrained cannot '
'be setting at the same time')
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
block = self.arch_settings[depth][0]
if self.zero_init_residual:
if block is Bottleneck2dAudio:
block_init_cfg = dict(
type='Constant',
val=0,
override=dict(name='norm3'))
else:
raise TypeError('pretrained must be a str or None')

self.depth = depth
self.pretrained = pretrained
self.in_channels = in_channels
Expand All @@ -191,7 +223,6 @@ def __init__(self,
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.zero_init_residual = zero_init_residual

self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
Expand All @@ -213,7 +244,8 @@ def __init__(self,
dilation=dilation,
factorize=self.stage_factorization[i],
norm_cfg=self.norm_cfg,
with_cp=with_cp)
with_cp=with_cp,
init_cfg=block_init_cfg)
self.inplanes = planes * self.block.expansion
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
Expand All @@ -231,7 +263,8 @@ def make_res_layer(self,
dilation=1,
factorize=1,
norm_cfg=None,
with_cp=False):
with_cp=False,
**kwargs):
"""Build residual layer for ResNetAudio.
Args:
Expand Down Expand Up @@ -280,7 +313,8 @@ def make_res_layer(self,
downsample,
factorize=(factorize[0] == 1),
norm_cfg=norm_cfg,
with_cp=with_cp))
with_cp=with_cp,
**kwargs))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
Expand All @@ -291,7 +325,8 @@ def make_res_layer(self,
dilation,
factorize=(factorize[i] == 1),
norm_cfg=norm_cfg,
with_cp=with_cp))
with_cp=with_cp,
**kwargs))

return nn.Sequential(*layers)

Expand Down Expand Up @@ -323,27 +358,6 @@ def _freeze_stages(self):
for param in m.parameters():
param.requires_grad = False

def init_weights(self):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
if isinstance(self.pretrained, str):
logger = get_root_logger()
logger.info(f'load model from: {self.pretrained}')

load_checkpoint(self, self.pretrained, strict=False, logger=logger)

elif self.pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, _BatchNorm):
constant_init(m, 1)

if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck2dAudio):
constant_init(m.conv3.bn, 0)

else:
raise TypeError('pretrained must be a str or None')

Expand Down

0 comments on commit 10c34b8

Please sign in to comment.