From 57474f97bae13531625c62bfb2e52e7668bcd413 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Thu, 29 Apr 2021 22:36:22 +0800 Subject: [PATCH] fix --- mmaction/models/backbones/resnet.py | 2 +- mmaction/models/backbones/resnet3d.py | 4 +--- mmaction/models/backbones/resnet_tsm.py | 11 ++++++++--- mmaction/models/backbones/x3d.py | 10 ++++++---- mmaction/models/common/conv2plus1d.py | 4 ++-- mmaction/models/common/conv_audio.py | 4 ++-- 6 files changed, 20 insertions(+), 15 deletions(-) diff --git a/mmaction/models/backbones/resnet.py b/mmaction/models/backbones/resnet.py index a3778846cf..231ac6c8a9 100644 --- a/mmaction/models/backbones/resnet.py +++ b/mmaction/models/backbones/resnet.py @@ -554,7 +554,7 @@ def _load_torchvision_checkpoint(self, logger=None): f': {remaining_names}') def init_weights(self): - if self.torchvision_pretrain: + if self.torchvision_pretrain and isinstance(self.pretrained, str): logger = get_root_logger() self._load_torchvision_checkpoint(logger) else: diff --git a/mmaction/models/backbones/resnet3d.py b/mmaction/models/backbones/resnet3d.py index 3493ba5eac..6520dff5c4 100644 --- a/mmaction/models/backbones/resnet3d.py +++ b/mmaction/models/backbones/resnet3d.py @@ -806,9 +806,7 @@ def init_weights(self, pretrained=None): self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) if not self.pretrained2d: super().init_weights() - else: - assert pretrained is not None - self.pretrained = pretrained + elif isinstance(self.pretrained, str): logger = get_root_logger() self.inflate_weights(logger) diff --git a/mmaction/models/backbones/resnet_tsm.py b/mmaction/models/backbones/resnet_tsm.py index 2c4f999b5c..ccefb825ed 100644 --- a/mmaction/models/backbones/resnet_tsm.py +++ b/mmaction/models/backbones/resnet_tsm.py @@ -1,13 +1,14 @@ import torch import torch.nn as nn from mmcv.cnn import NonLocal3d +from mmcv.runner import BaseModule from torch.nn.modules.utils import _ntuple from ..registry import BACKBONES from .resnet import ResNet -class NL3DWrapper(nn.Module): +class NL3DWrapper(BaseModule): """3D Non-local wrapper for ResNet50. Wrap ResNet layers with 3D NonLocal modules. @@ -18,8 +19,12 @@ class NL3DWrapper(nn.Module): non_local_cfg (dict): Config for non-local layers. Default: ``dict()``. """ - def __init__(self, block, num_segments, non_local_cfg=dict()): - super(NL3DWrapper, self).__init__() + def __init__(self, + block, + num_segments, + non_local_cfg=dict(), + init_cfg=None): + super().__init__(init_cfg) self.block = block self.non_local_cfg = non_local_cfg self.non_local_block = NonLocal3d(self.block.conv3.norm.num_features, diff --git a/mmaction/models/backbones/x3d.py b/mmaction/models/backbones/x3d.py index 74f6ea3b41..9ef994284c 100644 --- a/mmaction/models/backbones/x3d.py +++ b/mmaction/models/backbones/x3d.py @@ -3,7 +3,7 @@ import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import ConvModule, Swish, build_activation_layer -from mmcv.runner import BaseModule +from mmcv.runner import BaseModule, Sequential from mmcv.utils import _BatchNorm from ..registry import BACKBONES @@ -230,6 +230,8 @@ def __init__(self, self.gamma_b = gamma_b self.gamma_d = gamma_d + self.zero_init_residual = zero_init_residual + block_init_cfg = None assert not (init_cfg and pretrained), ('init_cfg and pretrained cannot ' @@ -249,7 +251,7 @@ def __init__(self, ] if self.zero_init_residual: block_init_cfg = dict( - type='Constant', val=0, override=dict(name='norm3')) + type='Constant', val=0, override=dict(name='conv3.bn')) else: raise TypeError('pretrained must be a str or None') @@ -283,7 +285,6 @@ def __init__(self, self.act_cfg = act_cfg self.norm_eval = norm_eval self.with_cp = with_cp - self.zero_init_residual = zero_init_residual self.block = BlockX3D self.stage_blocks = self.stage_blocks[:num_stages] @@ -451,7 +452,8 @@ def make_res_layer(self, with_cp=with_cp, **kwargs)) - return nn.Sequential(*layers) + # return nn.Sequential(*layers) + return Sequential(*layers) def _make_stem_layer(self): """Construct the stem layers consists of a conv+norm+act module and a diff --git a/mmaction/models/common/conv2plus1d.py b/mmaction/models/common/conv2plus1d.py index 804acba6f8..2480e98a42 100644 --- a/mmaction/models/common/conv2plus1d.py +++ b/mmaction/models/common/conv2plus1d.py @@ -38,9 +38,9 @@ def __init__(self, if init_cfg is None: self.init_cfg = [ - dict(type='kaiming', layer='Conv3d'), + dict(type='Kaiming', layer='Conv3d'), dict( - type='constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) ] kernel_size = _triple(kernel_size) diff --git a/mmaction/models/common/conv_audio.py b/mmaction/models/common/conv_audio.py index ef13c28dac..4ec97e7a3f 100644 --- a/mmaction/models/common/conv_audio.py +++ b/mmaction/models/common/conv_audio.py @@ -41,9 +41,9 @@ def __init__(self, if init_cfg is None: self.init_cfg = [ - dict(type='kaiming', layer='Conv3d'), + dict(type='Kaiming', layer='Conv3d'), dict( - type='constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) ] kernel_size = _pair(kernel_size)