-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
executable file
·93 lines (82 loc) · 3.88 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch.nn as nn
import numpy
import torchvision.models as models
class BinOp():
def __init__(self, model,bin_range):
# count the number of Conv2d and Linear
count_targets = 0
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
count_targets = count_targets + 1
start_range = bin_range[0]
end_range = bin_range[1]
self.bin_range = numpy.linspace(start_range,
end_range, end_range-start_range+1)\
.astype('int').tolist()
self.num_of_params = len(self.bin_range)
self.saved_params = []
self.target_params = []
self.target_modules = []
index = -1
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
index = index + 1
if index in self.bin_range:
tmp = m.weight.data.clone()
self.saved_params.append(tmp)
self.target_modules.append(m.weight)
print "hey, i binarized ",index," th conv ", m
def binarization(self):
self.meancenterConvParams()
self.clampConvParams()
self.save_params()
self.binarizeConvParams()
def meancenterConvParams(self):
for index in range(self.num_of_params):
s = self.target_modules[index].data.size()
negMean = self.target_modules[index].data.mean(1, keepdim=True).\
mul(-1).expand_as(self.target_modules[index].data)
self.target_modules[index].data = self.target_modules[index].data.add(negMean)
def clampConvParams(self):
for index in range(self.num_of_params):
self.target_modules[index].data = \
self.target_modules[index].data.clamp(-1.0, 1.0)
def save_params(self):
for index in range(self.num_of_params):
self.saved_params[index].copy_(self.target_modules[index].data)
def binarizeConvParams(self):
for index in range(self.num_of_params):
n = self.target_modules[index].data[0].nelement()
s = self.target_modules[index].data.size()
if len(s) == 4:
m = self.target_modules[index].data.norm(1, 3, keepdim=True)\
.sum(2, keepdim=True).sum(1, keepdim=True).div(n)
elif len(s) == 2:
m = self.target_modules[index].data.norm(1, 1, keepdim=True).div(n)
self.target_modules[index].data = \
self.target_modules[index].data.sign().mul(m.expand(s))
def restore(self):
for index in range(self.num_of_params):
self.target_modules[index].data.copy_(self.saved_params[index])
def updateBinaryGradWeight(self):
for index in range(self.num_of_params):
weight = self.target_modules[index].data
n = weight[0].nelement()
s = weight.size()
if len(s) == 4:
m = weight.norm(1, 3, keepdim=True)\
.sum(2, keepdim=True).sum(1, keepdim=True).div(n).expand(s)
elif len(s) == 2:
m = weight.norm(1, 1, keepdim=True).div(n).expand(s)
m[weight.lt(-1.0)] = 0
m[weight.gt(1.0)] = 0
m = m.mul(self.target_modules[index].grad.data)
m_add = weight.sign().mul(self.target_modules[index].grad.data)
if len(s) == 4:
m_add = m_add.sum(3, keepdim=True)\
.sum(2, keepdim=True).sum(1, keepdim=True).div(n).expand(s)
elif len(s) == 2:
m_add = m_add.sum(1, keepdim=True).div(n).expand(s)
m_add = m_add.mul(weight.sign())
self.target_modules[index].grad.data = m.add(m_add).mul(1.0-1.0/s[1]).mul(n)
self.target_modules[index].grad.data = self.target_modules[index].grad.data.mul(1e+9)