-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathGatedCNN_MambaOut.py
126 lines (114 loc) · 5.01 KB
/
GatedCNN_MambaOut.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# (ref) https://github.com/yuweihao/MambaOut/blob/main/models/mambaout.py
"""
+------------------------+
| Input Tensor [B, H, W, C] |
+------------------------+
|
v
+-------------------+
| Normalization Layer |
+-------------------+
|
v
+-------------------------------+
| Fully Connected Layer (fc1) |
+-------------------------------+
/ | \
/ | \
v v v
+--------+ +-----------+ +---------------------+
| g | | i | | c |
| [B, H, | | [B, H, W, | | [B, H, W, conv_ |
| W, | | hidden - | | channels] Permute |
| hidden]| | conv_ | | to [B, C, H, W] |
+--------+ | channels] | +---------------------+
| |
| v
| +-------------------+
| | Depthwise |
| | Convolution |
| +-------------------+
| |
| v
| +-----------------------+
| | Permute Back [B, H, |
| | W, C] |
| +-----------------------+
| |
| v
| +------------------+
+-------------->| Concatenate with |
| i |
+------------------+
|
v
+-------------------------+
| Activation and Element- |
| wise Multiplication |
| with g |
+-------------------------+
|
v
+------------------------+
| Fully Connected Layer |
| (fc2) |
+------------------------+
|
v
+-------------------+
| DropPath (if > 0) |
+-------------------+
|
v
+------------------------+
| Add Shortcut Connection |
+------------------------+
|
v
+------------------------+
| Output Tensor [B, H, W, dim] |
+------------------------+
"""
from functools import partial
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_, DropPath
class GatedCNNBlock(nn.Module):
r""" Our implementation of Gated CNN Block: https://arxiv.org/pdf/1612.08083
Args:
conv_ratio: control the number of channels to conduct depthwise convolution.
Conduct convolution on partial channels can improve paraitcal efficiency.
The idea of partical channels is from ShuffleNet V2 (https://arxiv.org/abs/1807.11164) and
also used by InceptionNeXt (https://arxiv.org/abs/2303.16900) and FasterNet (https://arxiv.org/abs/2303.03667)
"""
def __init__(self, dim, expension_ratio=8/3, kernel_size=7, conv_ratio=1.0,
norm_layer=partial(nn.LayerNorm,eps=1e-6),
act_layer=nn.GELU,
drop_path=0.,
**kwargs):
super().__init__()
self.norm = norm_layer(dim)
hidden = int(expension_ratio * dim)
self.fc1 = nn.Linear(dim, hidden * 2)
self.act = act_layer()
conv_channels = int(conv_ratio * dim)
self.split_indices = (hidden, hidden - conv_channels, conv_channels)
self.conv = nn.Conv2d(conv_channels, conv_channels, kernel_size=kernel_size, padding=kernel_size//2, groups=conv_channels)
self.fc2 = nn.Linear(hidden, dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x # [B, H, W, C]
x = self.norm(x)
g, i, c = torch.split(self.fc1(x), self.split_indices, dim=-1)
c = c.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
c = self.conv(c)
c = c.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
x = self.fc2(self.act(g) * torch.cat((i, c), dim=-1))
x = self.drop_path(x)
return x + shortcut
if __name__ == "__main__":
dim = 64
x = torch.randn(1, 7, 7, dim) # (B, H, W, C)
token_mixer = GatedCNNBlock(dim, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, kernel_size=7, conv_ratio=1.0, drop_path=0.)
out = token_mixer(x)
print(out.shape)