-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconv_pm.py
116 lines (82 loc) · 3.02 KB
/
conv_pm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import matplotlib.pyplot as plt
import torch
A_SHAPE = (4, 8, 16, 16)
B_SHAPE = (8, 8, 4, 4)
torch.manual_seed(12345)
def generate_input(shape):
M = torch.rand(shape, dtype=torch.float)
return M
A = generate_input(A_SHAPE)
B = generate_input(B_SHAPE)
A_QUANT = torch.rand((A_SHAPE[0],1,1,1), dtype=torch.float)
B_QUANT = torch.rand((B_SHAPE[0],1,1,1), dtype=torch.float)
A = A * A_QUANT + 1.0
B = B * B_QUANT + 1.0
def quant_i8_per_tensor(A, offset=True):
MIN = 0.0
MAX = torch.max(A)
if offset:
MIN = torch.min(A)
RANGE = MAX - MIN
SCALE = torch.full((A.shape[0],1,1,1), 255.0 / RANGE)
OFFSET = torch.full((A.shape[0],1,1,1), -MIN) * SCALE
A = A * SCALE + OFFSET
A = torch.round(A)
A = torch.clamp(A, 0, 255.0)
A = A.to(torch.uint8)
SCALE = 1.0 / SCALE
SCALE = torch.flatten(SCALE)
OFFSET = torch.flatten(OFFSET)
return A, SCALE, OFFSET
def quant_i8_per_channel(A,axislist, offset=True):
MIN = torch.full((A.shape[0],1,1,1), 0.0)
if offset:
MIN = A
for axis in axislist:
MIN = torch.min(MIN, axis=axis, keepdim=True).values
MAX = A
for axis in axislist:
MAX = torch.max(MAX, axis=axis, keepdim=True).values
OFFSET = MIN
SCALE = MAX - MIN
SCALE = 255.0 / SCALE
OFFSET = -OFFSET * SCALE
print(SCALE.shape)
print(OFFSET.shape)
A = A * SCALE + OFFSET
A = torch.round(A)
A = torch.clamp(A, 0, 255.0)
A = A.to(torch.uint8)
OFFSET = torch.flatten(OFFSET)
SCALE = 1.0 / torch.flatten(SCALE)
return A, SCALE, OFFSET
def mmt_quant(A, B, per_channel=True, offset=True):
if per_channel:
A, A_SCALE, A_OFFSET = quant_i8_per_channel(A, axislist=[1,2,3], offset=offset)
B, B_SCALE, B_OFFSET = quant_i8_per_channel(B, axislist=[1,2,3], offset=offset)
else:
A, A_SCALE, A_OFFSET = quant_i8_per_tensor(A, offset=offset)
B, B_SCALE, B_OFFSET = quant_i8_per_tensor(B, offset=offset)
A = A.to(torch.float)
B = B.to(torch.float)
CONV = torch.nn.functional.conv2d(A, B)
A_FIX = torch.sum(torch.flatten(B, 1), dim=1).unsqueeze(0) * A_OFFSET.unsqueeze(1)
A_FIX = A_FIX.unsqueeze(2).unsqueeze(3)
B_FIX = torch.nn.functional.avg_pool2d(A, (B.shape[2], B.shape[3]), stride=1, divisor_override=1).sum(1, keepdim=True)
B_FIX = B_FIX * B_OFFSET.unsqueeze(0).unsqueeze(2).unsqueeze(3)
AB_FIX = A_OFFSET.unsqueeze(1) * B_OFFSET.unsqueeze(0) * B.shape[1] * B.shape[2] * B.shape[3]
AB_FIX = AB_FIX.unsqueeze(2).unsqueeze(3)
CONV = CONV - A_FIX - B_FIX + AB_FIX
A_SCALE = A_SCALE.reshape(A_SCALE.shape[0], 1, 1, 1)
B_SCALE = B_SCALE.reshape(1, B_SCALE.shape[0], 1, 1)
CONV = CONV * A_SCALE
CONV = CONV * B_SCALE
return CONV
def mmt_float(A, B):
return torch.nn.functional.conv2d(A, B)
OUT_QUANT = mmt_quant(A, B)
OUT_FLOAT = mmt_float(A, B)
OUT_DIFF = OUT_QUANT - OUT_FLOAT
print(torch.mean(torch.abs(OUT_FLOAT)).item())
print(torch.mean(torch.abs(OUT_QUANT)).item())
print(torch.mean(torch.abs(OUT_DIFF)).item())