-
Notifications
You must be signed in to change notification settings - Fork 447
/
Copy pathbatch_norm.py
109 lines (88 loc) · 3.77 KB
/
batch_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Detectron2 batch norm modules."""
from __future__ import annotations
import torch
from torch import nn
from torch.nn import functional as f
class FrozenBatchNorm2d(nn.Module):
"""BatchNorm2d where the batch statistics and the affine parameters are fixed.
It contains non-trainable buffers called
"weight" and "bias", "running_mean", "running_var",
initialized to perform identity transformation.
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
which are computed from the original four parameters of BN.
The affine transform `x * weight + bias` will perform the equivalent
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
When loading a backbone model from Caffe2, "running_mean" and "running_var"
will be left unchanged as identity transformation.
Other pre-trained backbone models may contain all 4 parameters.
The forward is implemented by `F.batch_norm(..., training=False)`.
"""
_version = 3
def __init__(self, num_features: int, eps: float = 1e-5) -> None:
super().__init__()
self.num_features = num_features
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features) - eps)
self.register_buffer("num_batches_tracked", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
if x.requires_grad:
# When gradients are needed, F.batch_norm will use extra memory
# because its backward op computes gradients for weight/bias as well.
scale = self.weight * (self.running_var + self.eps).rsqrt()
bias = self.bias - self.running_mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
out_dtype = x.dtype # may be half
return x * scale.to(out_dtype) + bias.to(out_dtype)
return f.batch_norm(
x,
self.running_mean,
self.running_var,
self.weight,
self.bias,
training=False,
eps=self.eps,
)
class CNNBlockBase(nn.Module):
"""A CNN block is assumed to have input channels, output channels and a stride.
The input and output of `forward()` method must be NCHW tensors.
The method can perform arbitrary computation but must match the given
channels and stride specification.
Attribute:
in_channels (int):
out_channels (int):
stride (int):
"""
def __init__(self, in_channels: int, out_channels: int, stride: int) -> None:
"""The `__init__` method of any subclass should also contain these arguments.
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
stride (int): stride
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
def get_norm(norm: str, out_channels: int) -> nn.Module:
"""Create a normalization layer.
Args:
norm (str): either one of BN, SyncBN, FrozenBN, GN;
or a callable that takes a channel number and returns
the normalization layer as a nn.Module.
Returns:
nn.Module or None: the normalization layer
"""
norm_module = {
"BN": torch.nn.BatchNorm2d,
"FrozenBN": FrozenBatchNorm2d,
"GN": lambda channels: nn.GroupNorm(32, channels),
}[norm]
return norm_module(out_channels)