-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
112 lines (86 loc) · 2.79 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn
import numpy as np
def tensor_f32(x):
if isinstance(x, list) and len(x) > 0 and isinstance(x[0], torch.Tensor):
return torch.stack(x).float()
return torch.tensor(x, dtype=torch.float32)
def tensor_int(x):
return torch.tensor(x, dtype=torch.long)
def reset_parameters(model):
model.apply(_reset_parameters)
def _reset_parameters(layer):
if isinstance(layer, (torch.nn.Conv2d, torch.nn.Linear)):
layer.reset_parameters()
def tostr(x):
if isinstance(x, float):
return f'{x:.4f}'
elif isinstance(x, (list, tuple)):
return ','.join(map(tostr, x))
elif isinstance(x, dict):
return '{' + ','.join([tostr(k) + ':' + tostr(v) for k, v in x.items()]) + '}'
elif isinstance(x, (np.ndarray, torch.Tensor)):
return '<' + tostr(x.shape) + '>'
else:
return str(x)
def print_row(*args):
print(' '.join(tostr(arg).rjust(12, ' ') for arg in args))
def count_parameters(model):
return sum(np.prod(p.shape) for p in model.parameters())
def summarize_parameters(model):
print_row('parameter', 'shape', 'mean', 'sd')
for key, value in model.named_parameters():
mean = value.mean().item()
sd = np.sqrt(value.var().item())
shape = tuple(value.shape)
print_row(key, shape, mean, sd)
def summarize_gradient(model):
print_row('gradient', 'mean', 'sd')
for key, value in model.named_parameters():
if value.grad:
mean = value.mean().item()
sd = np.sqrt(value.var().item())
print_row(key, mean, sd)
print(key)
def batched_to_flat_image(t):
import torchvision
if isinstance(t, list):
t = torch.cat(t)
shape = t.shape
n = shape[0]
rank = len(shape)
red_blue = True
if rank == 2:
w = shape[1]
if w > 8:
h = np.ceil(np.sqrt(w))
w = w // h
else:
h = w
w = 1
shape = [n, 1, h, w]
elif rank == 3:
shape = [n, 1, shape[1], shape[2]]
elif rank == 4:
shape = shape
red_blue = False
t = t.view(*shape)
t_min = t.min()
t_max = t.max()
if red_blue and t_min < 0 < t_max:
scale = max(-t_min, t_max)
# for positive, shift the green and blue down
# for negative, shift the red and green down
scaled = t / scale
r = 1 + torch.clamp(scaled, -1, 0)
g = 1 - abs(scaled)
b = 1 - torch.clamp(scaled, 0, 1)
t = torch.cat([r, g, b], dim=1)
grid = torchvision.utils.make_grid(t, 10, normalize=(not red_blue), padding=1)
return grid
def print_image(t):
from matplotlib import pyplot
rgb = batched_to_flat_image(t)
rgb = np.transpose(rgb, [1, 2, 0])
pyplot.imshow(1 - rgb)
pyplot.show()