-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathembed_res18_cifar100.py
148 lines (127 loc) · 4.16 KB
/
embed_res18_cifar100.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import numpy as np
from models.resnets import resnet20
a = torch.load("ownership/res18_cifar100_extreme.pth.tar", map_location="cpu")
a.keys()
def check_sparsity(mask, conv1=True):
sum_list = 0
zero_sum = 0
for name in mask:
if 'mask' in name:
mask_ = mask[name]
sum_list = sum_list+float(mask_.nelement())
zero_sum = zero_sum+float(torch.sum(mask_ == 0))
print(1 - zero_sum / sum_list)
np.random.seed(2)
def extract_mask(model_dict):
new_dict = {}
for key in model_dict.keys():
if 'mask' in key:
new_dict[key] = model_dict[key]
return new_dict
mask = extract_mask(a['state_dict'])
check_sparsity(mask)
import qrcode
qr = qrcode.QRCode(
version=3,
error_correction=qrcode.constants.ERROR_CORRECT_H,
box_size=1,
border=0,
)
qr.add_data('signature')
qr.make()
img = qr.make_image(fill_color="black", back_color="white")
code = np.array(img)
from scipy.signal import correlate2d
h,w = code.shape[0],code.shape[1]
max_sim = 0
for name in mask:
if not 'layer1' in name:
continue
mask_ = mask[name].sum((2,3)).numpy() > 0
mask_ = mask_.astype(float)
if (mask_.shape[0] - code.shape[0] < 0) or (mask_.shape[1] - code.shape[1] < 0):
continue
sim = np.zeros((mask_.shape[0] - code.shape[0] + 1, mask_.shape[1] - code.shape[1] + 1))
for i in range(sim.shape[0]):
for j in range(sim.shape[1]):
sim[i,j] = (mask_[i:i+h,j:j+w] == code).mean()
if np.max(sim) > max_sim:
max_name = name
max_sim = np.max(sim)
print(max_name)
print(max_sim)
#max_name = 'layer2.0.conv2.weight_mask' # override
import sys
if len(sys.argv) > 1:
max_name = sys.argv[1]
print(mask.keys())
print(max_name)
mask_ = mask[max_name].sum((2,3)).numpy() > 0
mask_ = mask_.astype(float)
sim = np.zeros((mask_.shape[0] - code.shape[0] + 1, mask_.shape[1] - code.shape[1] + 1))
for i in range(sim.shape[0]):
for j in range(sim.shape[1]):
sim[i,j] = (mask_[i:i+h,j:j+w] == code).mean()
r, c = np.where(sim == np.max(sim))
#r=2
#c=3
r = r[0]
c = c[0]
print(r,c)
real_mask = mask[max_name].numpy()[r:r+h, c:c+w].copy()
real_mask_one = (real_mask == 1).sum()
real_mask_flat = ((real_mask).sum((2,3)) > 0).astype(float)
print(real_mask_flat.shape)
for i in range(code.shape[0]):
for j in range(code.shape[1]):
if code[i,j] == 1 and real_mask_flat[i,j] == 0:
_ = np.array([0] * 9)
_[0] = 1
new_mask = np.random.permutation(_)
real_mask[i,j] = new_mask.reshape((3, 3))
real_mask_flat[i,j] == 1
elif code[i,j] == 0 and real_mask_flat[i,j] == 1:
real_mask[i,j] = 0
real_mask_flat[i,j] == 0
original_mask = mask[max_name][r:r+h, c:c+w].clone().numpy()
real_mask[0:9, 0:9] = original_mask[0:9, 0:9]
real_mask[-9:,:9] = original_mask[-9:,:9]
real_mask[:9,-9:] = original_mask[:9,-9:]
real_mask[20:25, 20:25] = original_mask[20:25, 20:25]
real_mask[-8, 4 * 3 + 9] = 1
real_mask[6] = original_mask[6]
real_mask[:, 6] = original_mask[:, 6]
real_mask_one_new = (real_mask == 1).sum()
real_mask_flat_new = (real_mask).sum((2,3))
diff = real_mask_one_new - real_mask_one
print(diff)
if (diff > 0):
# remove some connections
real_mask_flat_greater_0 = np.where(real_mask_flat_new > 1)
else:
# recover some connections
pos = np.expand_dims((code == 1), (2, 3)) * np.expand_dims(real_mask_flat == 1, (2,3)) * (real_mask == 0)
pos = np.where(pos)
pos = np.stack(pos)
print(pos.shape)
pos = pos[:, np.random.permutation(pos.shape[1])[:(-diff)]]
print(pos.shape)
for i in range(pos.shape[1]):
p = pos[:, i]
real_mask[p[0], p[1], p[2], p[3]] = 1
import matplotlib.pyplot as plt
mask[max_name][r:r+h, c:c+w] = torch.from_numpy(real_mask)
vis = mask[max_name].sum((2,3)).numpy() > 0
plt.imshow(vis)
plt.savefig(f"ownership/res18_cifar100_vis_{max_name}.pdf")
plt.close()
torch.save(mask, f'ownership/res18_cifar100_qrcode_{max_name}.pth.tar')
check_sparsity(mask)
'''
vis = mask[max_name].sum((2,3)).numpy() > 0
plt.imshow(vis)
plt.savefig("vis2.png")
plt.close()
torch.save(mask, 'ownership/res18_cifar100_extreme.pth.tar')
'''