Skip to content

Commit

Permalink
add initializers and BaseModule for unified parameter initialization (#…
Browse files Browse the repository at this point in the history
…780)

* add initializers and BaseModule for unified parameter initialization

* fix circle import

* bug fix

* add is_init flag in BaseModule

* fix docstring

* sort import and fix doc format

* fix bug

* fix docformat and double quote string

* fix import sort

* import sort

* sort import

* revise according to comments

* fix doc format

* revise according to comments

* revise import and fix typo

* polish code

* revise minors

* revice minors

* revise apply function

* revise bias initialization with probability

* add type test for bias_prob

* revise minors
  • Loading branch information
MeowZheng authored Feb 7, 2021
1 parent 11b9264 commit a4c3702
Show file tree
Hide file tree
Showing 12 changed files with 1,039 additions and 30 deletions.
11 changes: 7 additions & 4 deletions mmcv/cnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
build_upsample_layer, conv_ws_2d, is_norm)
# yapf: enable
from .resnet import ResNet, make_res_layer
from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, kaiming_init,
normal_init, uniform_init, xavier_init)
from .utils import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit,
PretrainedInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, initialize,
kaiming_init, normal_init, uniform_init, xavier_init)
from .vgg import VGG, make_vgg_layer

__all__ = [
Expand All @@ -30,5 +32,6 @@
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d', 'Conv3d'
'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit'
]
3 changes: 1 addition & 2 deletions mmcv/cnn/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import torch.nn as nn

from ..runner import load_checkpoint


class AlexNet(nn.Module):
"""AlexNet backbone.
Expand Down Expand Up @@ -45,6 +43,7 @@ def __init__(self, num_classes=-1):
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
from ..runner import load_checkpoint
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
# use default initializer
Expand Down
2 changes: 1 addition & 1 deletion mmcv/cnn/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch.nn as nn
import torch.utils.checkpoint as cp

from ..runner import load_checkpoint
from .utils import constant_init, kaiming_init


Expand Down Expand Up @@ -266,6 +265,7 @@ def __init__(self,
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
from ..runner import load_checkpoint
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
Expand Down
10 changes: 7 additions & 3 deletions mmcv/cnn/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn
from .weight_init import (bias_init_with_prob, caffe2_xavier_init,
constant_init, kaiming_init, normal_init,
from .weight_init import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit,
PretrainedInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init,
uniform_init, xavier_init)

__all__ = [
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
'constant_init', 'kaiming_init', 'normal_init', 'uniform_init',
'xavier_init', 'fuse_conv_bn'
'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS',
'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit',
'PretrainedInit'
]
Loading

0 comments on commit a4c3702

Please sign in to comment.