Skip to content

Commit

Permalink
other
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin committed Apr 29, 2021
1 parent 10c34b8 commit 8eb825c
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 148 deletions.
56 changes: 27 additions & 29 deletions mmaction/models/backbones/c3d.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init, normal_init
from mmcv.runner import load_checkpoint
from mmcv.utils import _BatchNorm
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule

from ...utils import get_root_logger
from ..registry import BACKBONES


@BACKBONES.register_module()
class C3D(nn.Module):
class C3D(BaseModule):
"""C3D backbone.
Args:
Expand All @@ -35,19 +33,40 @@ def __init__(self,
norm_cfg=None,
act_cfg=None,
dropout_ratio=0.5,
init_std=0.005):
super().__init__()
init_std=0.005,
init_cfg=None):
super().__init__(init_cfg)
if conv_cfg is None:
conv_cfg = dict(type='Conv3d')
if act_cfg is None:
act_cfg = dict(type='ReLU')
self.pretrained = pretrained
self.init_std = init_std

assert not (init_cfg
and pretrained), ('init_cfg and pretrained cannot '
'be setting at the same time')
if isinstance(pretrained, str):
if not self.torchvision_pretrain:
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv3d'),
dict(type='Normal', std=init_std, layer='Linear'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')

self.style = style
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.dropout_ratio = dropout_ratio
self.init_std = init_std

c3d_conv_param = dict(
kernel_size=(3, 3, 3),
Expand Down Expand Up @@ -81,27 +100,6 @@ def __init__(self,
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=self.dropout_ratio)

def init_weights(self):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
if isinstance(self.pretrained, str):
logger = get_root_logger()
logger.info(f'load model from: {self.pretrained}')

load_checkpoint(self, self.pretrained, strict=False, logger=logger)

elif self.pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv3d):
kaiming_init(m)
elif isinstance(m, nn.Linear):
normal_init(m, std=self.init_std)
elif isinstance(m, _BatchNorm):
constant_init(m, 1)

else:
raise TypeError('pretrained must be a str or None')

def forward(self, x):
"""Defines the computation performed at every call.
Expand Down
53 changes: 30 additions & 23 deletions mmaction/models/backbones/mobilenet_v2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm

from ...utils import get_root_logger
from ..builder import BACKBONES


Expand Down Expand Up @@ -33,7 +32,7 @@ def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
return new_value


class InvertedResidual(nn.Module):
class InvertedResidual(BaseModule):
"""InvertedResidual block for MobileNetV2.
Args:
Expand Down Expand Up @@ -61,9 +60,10 @@ def __init__(self,
expand_ratio,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'),
with_cp=False):
super(InvertedResidual, self).__init__()
act_cfg=dict(typee='ReLU6'),
with_cp=False,
init_cfg=None):
super().__init__(init_cfg)
self.stride = stride
assert stride in [1, 2], f'stride must in [1, 2]. ' \
f'But received {stride}.'
Expand Down Expand Up @@ -119,7 +119,7 @@ def _inner_forward(x):


@BACKBONES.register_module()
class MobileNetV2(nn.Module):
class MobileNetV2(BaseModule):
"""MobileNetV2 backbone.
Args:
Expand Down Expand Up @@ -158,9 +158,29 @@ def __init__(self,
norm_cfg=dict(type='BN2d', requires_grad=True),
act_cfg=dict(type='ReLU6', inplace=True),
norm_eval=False,
with_cp=False):
super().__init__()
with_cp=False,
init_cfg=None):
super().__init__(init_cfg)
self.pretrained = pretrained

assert not (init_cfg
and pretrained), ('init_cfg and pretrained cannot '
'be setting at the same time')
if isinstance(pretrained, str):
if not self.torchvision_pretrain:
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')

self.widen_factor = widen_factor
self.out_indices = out_indices
for index in out_indices:
Expand Down Expand Up @@ -250,19 +270,6 @@ def make_layer(self, out_channels, num_blocks, stride, expand_ratio):

return nn.Sequential(*layers)

def init_weights(self):
if isinstance(self.pretrained, str):
logger = get_root_logger()
load_checkpoint(self, self.pretrained, strict=False, logger=logger)
elif self.pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')

def forward(self, x):
x = self.conv1(x)

Expand Down
1 change: 1 addition & 0 deletions mmaction/models/backbones/resnet3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@ def _freeze_stages(self):

def init_weights(self, pretrained=None):
if pretrained:
self.pretrained = pretrained
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
if not self.pretrained2d:
super().init_weights()
Expand Down
5 changes: 2 additions & 3 deletions mmaction/models/backbones/resnet3d_csn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def __init__(self,
*args,
bottleneck_mode='ir',
**kwargs):
super(CSNBottleneck3d, self).__init__(inplanes, planes, *args,
**kwargs)
super().__init__(inplanes, planes, *args, **kwargs)
self.bottleneck_mode = bottleneck_mode
conv2 = []
if self.bottleneck_mode == 'ip':
Expand Down Expand Up @@ -124,7 +123,7 @@ def __init__(self,
if bottleneck_mode not in ['ip', 'ir']:
raise ValueError(f'Bottleneck mode must be "ip" or "ir",'
f'but got {bottleneck_mode}.')
super(ResNet3dCSN, self).__init__(
super().__init__(
depth,
pretrained,
temporal_strides=temporal_strides,
Expand Down
39 changes: 22 additions & 17 deletions mmaction/models/backbones/resnet3d_slowfast.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, kaiming_init
from mmcv.runner import _load_checkpoint, load_checkpoint
from mmcv.utils import print_log
from mmcv.runner import BaseModule, _load_checkpoint

from ...utils import get_root_logger
from ..registry import BACKBONES
from .resnet3d import ResNet3d

Expand Down Expand Up @@ -329,11 +327,8 @@ def _freeze_stages(self):
def init_weights(self, pretrained=None):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
if pretrained:
self.pretrained = pretrained

# Override the init_weights of i3d
super().init_weights(pretrained=self.pretrained)
super().init_weights(pretrained=pretrained)
for module_name in self.lateral_connections:
layer = getattr(self, module_name)
for m in layer.modules():
Expand Down Expand Up @@ -372,7 +367,7 @@ def build_pathway(cfg, *args, **kwargs):


@BACKBONES.register_module()
class ResNet3dSlowFast(nn.Module):
class ResNet3dSlowFast(BaseModule):
"""Slowfast backbone.
This module is proposed in `SlowFast Networks for Video Recognition
Expand Down Expand Up @@ -430,7 +425,8 @@ def __init__(self,
dilations=(1, 1, 1, 1),
conv1_stride_t=1,
pool1_stride_t=1,
inflate=(0, 0, 1, 1)),
inflate=(0, 0, 1, 1),
init_cfg=None),
fast_pathway=dict(
type='resnet3d',
depth=50,
Expand All @@ -439,9 +435,21 @@ def __init__(self,
base_channels=8,
conv1_kernel=(5, 7, 7),
conv1_stride_t=1,
pool1_stride_t=1)):
super().__init__()
pool1_stride_t=1,
init_cfg=None),
init_cfg=None):
super().__init__(init_cfg)
self.pretrained = pretrained

assert not (init_cfg
and pretrained), ('init_cfg and pretrained cannot '
'be setting at the same time')
if isinstance(pretrained, str):
self.pretrained = pretrained
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')

self.resample_rate = resample_rate
self.speed_ratio = speed_ratio
self.channel_ratio = channel_ratio
Expand All @@ -458,13 +466,10 @@ def init_weights(self, pretrained=None):
scratch."""
if pretrained:
self.pretrained = pretrained

self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
if isinstance(self.pretrained, str):
logger = get_root_logger()
msg = f'load model from: {self.pretrained}'
print_log(msg, logger=logger)
# Directly load 3D model.
load_checkpoint(self, self.pretrained, strict=True, logger=logger)
super().init_weights()

elif self.pretrained is None:
# Init two branch seperately.
self.fast_path.init_weights()
Expand Down
7 changes: 4 additions & 3 deletions mmaction/models/backbones/tanet.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from copy import deepcopy

import torch.nn as nn
from mmcv.runner import BaseModule
from torch.utils import checkpoint as cp

from ..common import TAM
from ..registry import BACKBONES
from .resnet import Bottleneck, ResNet


class TABlock(nn.Module):
class TABlock(BaseModule):
"""Temporal Adaptive Block (TA-Block) for TANet.
This block is proposed in `TAM: TEMPORAL ADAPTIVE MODULE FOR VIDEO
Expand All @@ -25,8 +26,8 @@ class TABlock(nn.Module):
Default: dict().
"""

def __init__(self, block, num_segments, tam_cfg=dict()):
super().__init__()
def __init__(self, block, num_segments, tam_cfg=dict(), init_cfg=None):
super().__init__(init_cfg)
self.tam_cfg = deepcopy(tam_cfg)
self.block = block
self.num_segments = num_segments
Expand Down
Loading

0 comments on commit 8eb825c

Please sign in to comment.