diff --git a/mmaction/models/backbones/resnet.py b/mmaction/models/backbones/resnet.py index 0d1ccb1635..a3778846cf 100644 --- a/mmaction/models/backbones/resnet.py +++ b/mmaction/models/backbones/resnet.py @@ -1,5 +1,3 @@ -import warnings - import torch.nn as nn from mmcv.cnn import ConvModule from mmcv.runner import BaseModule, _load_checkpoint @@ -236,7 +234,8 @@ def make_res_layer(block, conv_cfg=None, norm_cfg=None, act_cfg=None, - with_cp=False): + with_cp=False, + **kwargs): """Build residual layer for ResNet. Args: @@ -282,7 +281,8 @@ def make_res_layer(block, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, - with_cp=with_cp)) + with_cp=with_cp, + **kwargs)) inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( @@ -295,7 +295,8 @@ def make_res_layer(block, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, - with_cp=with_cp)) + with_cp=with_cp, + **kwargs)) return nn.Sequential(*layers) @@ -361,6 +362,7 @@ def __init__(self, init_cfg=None): super().__init__(init_cfg) self.zero_init_residual = zero_init_residual + self.torchvision_pretrain = torchvision_pretrain if depth not in self.arch_settings: raise KeyError(f'invalid depth {depth} for resnet') @@ -369,9 +371,8 @@ def __init__(self, and pretrained), ('init_cfg and pretrained cannot ' 'be setting at the same time') if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is a deprecated, ' - 'please use "init_cfg" instead') - self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + if not self.torchvision_pretrain: + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) elif pretrained is None: if init_cfg is None: self.init_cfg = [ @@ -399,7 +400,6 @@ def __init__(self, self.depth = depth self.in_channels = in_channels self.pretrained = pretrained - self.torchvision_pretrain = torchvision_pretrain self.num_stages = num_stages assert 1 <= num_stages <= 4 self.out_indices = out_indices @@ -553,6 +553,13 @@ def _load_torchvision_checkpoint(self, logger=None): f'These parameters in pretrained checkpoint are not loaded' f': {remaining_names}') + def init_weights(self): + if self.torchvision_pretrain: + logger = get_root_logger() + self._load_torchvision_checkpoint(logger) + else: + super().init_weights() + def forward(self, x): """Defines the computation performed at every call. diff --git a/mmaction/models/backbones/resnet3d.py b/mmaction/models/backbones/resnet3d.py index f97cd4d074..9073aa0848 100644 --- a/mmaction/models/backbones/resnet3d.py +++ b/mmaction/models/backbones/resnet3d.py @@ -392,7 +392,7 @@ class ResNet3d(BaseModule): def __init__(self, depth, - pretrained, + pretrained=None, pretrained2d=True, in_channels=3, num_stages=4, @@ -430,14 +430,13 @@ def __init__(self, and pretrained), ('init_cfg and pretrained cannot ' 'be setting at the same time') if isinstance(pretrained, str): - warnings.warn('DeprecationWarning: pretrained is a deprecated, ' - 'please use "init_cfg" instead') + self.pretrained = pretrained if not self.pretrained2d: self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) elif pretrained is None: if init_cfg is None: self.init_cfg = [ - dict(type='Kaiming', layer='Conv2d'), + dict(type='Kaiming', layer='Conv3d'), dict( type='Constant', val=1, @@ -801,10 +800,14 @@ def _freeze_stages(self): for param in m.parameters(): param.requires_grad = False - def init_weights(self): + def init_weights(self, pretrained=None): + if pretrained: + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) if not self.pretrained2d: super().init_weights() else: + assert pretrained is not None + self.pretrained = pretrained logger = get_root_logger() self.inflate_weights(logger) @@ -1008,10 +1011,14 @@ def _freeze_stages(self): for param in layer.parameters(): param.requires_grad = False - def init_weights(self): + def init_weights(self, pretrained=None): + if pretrained: + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) if not self.pretrained2d: super().init_weights() else: + assert pretrained is not None + self.pretrained = pretrained logger = get_root_logger() self.inflate_weights(logger) diff --git a/mmaction/models/backbones/resnet3d_slowfast.py b/mmaction/models/backbones/resnet3d_slowfast.py index 3db7794dd5..3794ab092e 100644 --- a/mmaction/models/backbones/resnet3d_slowfast.py +++ b/mmaction/models/backbones/resnet3d_slowfast.py @@ -99,7 +99,8 @@ def make_res_layer(self, conv_cfg=None, norm_cfg=None, act_cfg=None, - with_cp=False): + with_cp=False, + **kwargs): """Build residual layer for Slowfast. Args: @@ -178,7 +179,8 @@ def make_res_layer(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, - with_cp=with_cp)) + with_cp=with_cp, + **kwargs)) inplanes = planes * block.expansion for i in range(1, blocks): @@ -197,7 +199,8 @@ def make_res_layer(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, - with_cp=with_cp)) + with_cp=with_cp, + **kwargs)) return nn.Sequential(*layers) @@ -330,7 +333,7 @@ def init_weights(self, pretrained=None): self.pretrained = pretrained # Override the init_weights of i3d - super().init_weights() + super().init_weights(pretrained=self.pretrained) for module_name in self.lateral_connections: layer = getattr(self, module_name) for m in layer.modules(): diff --git a/mmaction/models/backbones/resnet_audio.py b/mmaction/models/backbones/resnet_audio.py index ea5792d874..5813188712 100644 --- a/mmaction/models/backbones/resnet_audio.py +++ b/mmaction/models/backbones/resnet_audio.py @@ -1,15 +1,16 @@ +import warnings + import torch.nn as nn import torch.utils.checkpoint as cp -from mmcv.cnn import ConvModule, constant_init, kaiming_init -from mmcv.runner import load_checkpoint +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.utils import _ntuple from mmaction.models.registry import BACKBONES -from mmaction.utils import get_root_logger -class Bottleneck2dAudio(nn.Module): +class Bottleneck2dAudio(BaseModule): """Bottleneck2D block for ResNet2D. Args: @@ -35,8 +36,9 @@ def __init__(self, downsample=None, factorize=True, norm_cfg=None, - with_cp=False): - super().__init__() + with_cp=False, + init_cfg=None): + super().__init__(init_cfg) self.inplanes = inplanes self.planes = planes @@ -109,7 +111,7 @@ def _inner_forward(x): @BACKBONES.register_module() -class ResNetAudio(nn.Module): +class ResNetAudio(BaseModule): """ResNet 2d audio backbone. Reference: `_. @@ -171,10 +173,40 @@ def __init__(self, conv_cfg=dict(type='Conv'), norm_cfg=dict(type='BN2d', requires_grad=True), act_cfg=dict(type='ReLU', inplace=True), - zero_init_residual=True): - super().__init__() + zero_init_residual=True, + init_cfg=None): + super().__init__(init_cfg) + self.zero_init_residual = zero_init_residual if depth not in self.arch_settings: raise KeyError(f'invalid depth {depth} for resnet') + + block_init_cfg = None + assert not (init_cfg + and pretrained), ('init_cfg and pretrained cannot ' + 'be setting at the same time') + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + block = self.arch_settings[depth][0] + if self.zero_init_residual: + if block is Bottleneck2dAudio: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm3')) + else: + raise TypeError('pretrained must be a str or None') + self.depth = depth self.pretrained = pretrained self.in_channels = in_channels @@ -191,7 +223,6 @@ def __init__(self, self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg - self.zero_init_residual = zero_init_residual self.block, stage_blocks = self.arch_settings[depth] self.stage_blocks = stage_blocks[:num_stages] @@ -213,7 +244,8 @@ def __init__(self, dilation=dilation, factorize=self.stage_factorization[i], norm_cfg=self.norm_cfg, - with_cp=with_cp) + with_cp=with_cp, + init_cfg=block_init_cfg) self.inplanes = planes * self.block.expansion layer_name = f'layer{i + 1}' self.add_module(layer_name, res_layer) @@ -231,7 +263,8 @@ def make_res_layer(self, dilation=1, factorize=1, norm_cfg=None, - with_cp=False): + with_cp=False, + **kwargs): """Build residual layer for ResNetAudio. Args: @@ -280,7 +313,8 @@ def make_res_layer(self, downsample, factorize=(factorize[0] == 1), norm_cfg=norm_cfg, - with_cp=with_cp)) + with_cp=with_cp, + **kwargs)) inplanes = planes * block.expansion for i in range(1, blocks): layers.append( @@ -291,7 +325,8 @@ def make_res_layer(self, dilation, factorize=(factorize[i] == 1), norm_cfg=norm_cfg, - with_cp=with_cp)) + with_cp=with_cp, + **kwargs)) return nn.Sequential(*layers) @@ -323,27 +358,6 @@ def _freeze_stages(self): for param in m.parameters(): param.requires_grad = False - def init_weights(self): - """Initiate the parameters either from existing checkpoint or from - scratch.""" - if isinstance(self.pretrained, str): - logger = get_root_logger() - logger.info(f'load model from: {self.pretrained}') - - load_checkpoint(self, self.pretrained, strict=False, logger=logger) - - elif self.pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m) - elif isinstance(m, _BatchNorm): - constant_init(m, 1) - - if self.zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck2dAudio): - constant_init(m.conv3.bn, 0) - else: raise TypeError('pretrained must be a str or None')