-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcornet_s.py
158 lines (120 loc) · 5.2 KB
/
cornet_s.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import math
from collections import OrderedDict
from torch import nn
import torch
from torch.nn.utils import prune
HASH = '1d3f7974'
class Flatten(nn.Module):
"""
Helper module for flattening input tensor to 1-D for the use in Linear modules
"""
def forward(self, x):
return x.view(x.size(0), -1)
class Identity(nn.Module):
"""
Helper module that stores the current tensor. Useful for accessing by name
"""
def forward(self, x):
return x
class CORblock_S(nn.Module):
scale = 4 # scale of the bottleneck convolution channels
def __init__(self, in_channels, out_channels, times=1):
super().__init__()
self.times = times
self.conv_input = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.skip = nn.Conv2d(out_channels, out_channels,
kernel_size=1, stride=2, bias=False)
self.norm_skip = nn.BatchNorm2d(out_channels)
self.conv1 = nn.Conv2d(out_channels, out_channels * self.scale,
kernel_size=1, bias=False)
self.nonlin1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels * self.scale, out_channels * self.scale,
kernel_size=3, stride=2, padding=1, bias=False)
self.nonlin2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(out_channels * self.scale, out_channels,
kernel_size=1, bias=False)
self.nonlin3 = nn.ReLU(inplace=True)
self.output = Identity() # for an easy access to this block's output
# need BatchNorm for each time step for training to work well
for t in range(self.times):
setattr(self, f'norm1_{t}', nn.BatchNorm2d(out_channels * self.scale))
setattr(self, f'norm2_{t}', nn.BatchNorm2d(out_channels * self.scale))
setattr(self, f'norm3_{t}', nn.BatchNorm2d(out_channels))
def forward(self, inp):
x = self.conv_input(inp)
for t in range(self.times):
if t == 0:
skip = self.norm_skip(self.skip(x))
self.conv2.stride = (2, 2)
else:
skip = x
self.conv2.stride = (1, 1)
x = self.conv1(x)
x = getattr(self, f'norm1_{t}')(x)
x = self.nonlin1(x)
x = self.conv2(x)
x = getattr(self, f'norm2_{t}')(x)
x = self.nonlin2(x)
x = self.conv3(x)
x = getattr(self, f'norm3_{t}')(x)
x += skip
x = self.nonlin3(x)
output = self.output(x)
return output
def CORnet_S():
model = nn.Sequential(OrderedDict([
('V1', nn.Sequential(OrderedDict([ # this one is custom to save GPU memory
('conv1', nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)),
('norm1', nn.BatchNorm2d(64)),
('nonlin1', nn.ReLU(inplace=True)),
('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1,
bias=False)),
('norm2', nn.BatchNorm2d(64)),
('nonlin2', nn.ReLU(inplace=True)),
('output', Identity())
]))),
('V2', CORblock_S(64, 128, times=2)),
('V4', CORblock_S(128, 256, times=4)),
('IT', CORblock_S(256, 512, times=2)),
('decoder', nn.Sequential(OrderedDict([
('avgpool', nn.AdaptiveAvgPool2d(1)),
('flatten', Flatten()),
('linear', nn.Linear(512, 1000)),
('output', Identity())
])))
]))
# weight initialization
for m in model.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
# nn.Linear is missing here because I originally forgot
# to add it during the training of this network
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
return model
def get_custom_cornet_s(region_idx, lesion_iter):
if lesion_iter == -1:
model_hash = '1d3f7974'
model = CORnet_S()
model = torch.nn.DataParallel(model)
url = f'https://s3.amazonaws.com/cornet-models/cornet_s-{model_hash}.pth'
ckpt_data = torch.hub.load_state_dict_from_url(url, map_location='cpu')
model.load_state_dict(ckpt_data['state_dict'])
return model.module
else:
model = CORnet_S()
model = torch.nn.DataParallel(model)
regions = [model.module.V1, model.module.V2, model.module.V4, model.module.IT]
region = regions[region_idx]
url = f'/vision/u/ynshah/NeuroDP/runs/lesioned_retrain_plates_final/region_idx/checkpoints/{region_idx}_{lesion_iter}_4096_ckpt.pt'
ckpt_data = torch.load(url, map_location='cpu')
conv_layers = [module for module in region.modules() if isinstance(module, torch.nn.Conv2d)]
for x in conv_layers:
prune.random_unstructured(x, name='weight', amount=0.2)
model.load_state_dict(ckpt_data)
model = model.module
return model