-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathutils.py
117 lines (78 loc) · 2.99 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def cosine_anneal(step, start_value, final_value, start_step, final_step):
assert start_value >= final_value
assert start_step <= final_step
if step < start_step:
value = start_value
elif step >= final_step:
value = final_value
else:
a = 0.5 * (start_value - final_value)
b = 0.5 * (start_value + final_value)
progress = (step - start_step) / (final_step - start_step)
value = a * math.cos(math.pi * progress) + b
return value
def linear_warmup(step, start_value, final_value, start_step, final_step):
assert start_value <= final_value
assert start_step <= final_step
if step < start_step:
value = start_value
elif step >= final_step:
value = final_value
else:
a = final_value - start_value
b = start_value
progress = (step + 1 - start_step) / (final_step - start_step)
value = a * progress + b
return value
def gumbel_softmax(logits, tau=1., hard=False, dim=-1):
eps = torch.finfo(logits.dtype).tiny
gumbels = -(torch.empty_like(logits).exponential_() + eps).log()
gumbels = (logits + gumbels) / tau
y_soft = F.softmax(gumbels, dim)
if hard:
index = y_soft.argmax(dim, keepdim=True)
y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.)
return y_hard - y_soft.detach() + y_soft
else:
return y_soft
def conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, bias=True, padding_mode='zeros',
weight_init='xavier'):
m = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,
dilation, groups, bias, padding_mode)
if weight_init == 'kaiming':
nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
else:
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.zeros_(m.bias)
return m
class Conv2dBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
self.m = conv2d(in_channels, out_channels, kernel_size, stride, padding,
bias=True, weight_init='kaiming')
def forward(self, x):
x = self.m(x)
return F.relu(x)
def linear(in_features, out_features, bias=True, weight_init='xavier', gain=1.):
m = nn.Linear(in_features, out_features, bias)
if weight_init == 'kaiming':
nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
else:
nn.init.xavier_uniform_(m.weight, gain)
if bias:
nn.init.zeros_(m.bias)
return m
def gru_cell(input_size, hidden_size, bias=True):
m = nn.GRUCell(input_size, hidden_size, bias)
nn.init.xavier_uniform_(m.weight_ih)
nn.init.orthogonal_(m.weight_hh)
if bias:
nn.init.zeros_(m.bias_ih)
nn.init.zeros_(m.bias_hh)
return m