-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtrain_toy_cond.py
121 lines (90 loc) · 2.91 KB
/
train_toy_cond.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: utf-8 -*-
"""
CP-Flow on toy conditional distributions
"""
import gc
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from lib.flows import SequentialFlow, DeepConvexFlow, ActNorm
from lib.icnn import PICNN as PICNN
from data.toy_data import OneDMixtureOfGaussians as ToyData
from lib.utils import makedirs
makedirs('figures/toy/cond_MoG/')
def savefig(fn):
plt.savefig(f'figures/toy/cond_MoG/{fn}')
torch.set_default_dtype(torch.float64)
batch_size_train = 128
batch_size_test = 64
# noinspection PyUnresolvedReferences
train_loader = torch.utils.data.DataLoader(
ToyData(50000),
batch_size=batch_size_train, shuffle=True)
# noinspection PyUnresolvedReferences
test_loader = torch.utils.data.DataLoader(
ToyData(10000),
batch_size=batch_size_test, shuffle=True)
dimx = 1
dimc = 1
nblocks = 1
depth = 10
k = 64
lr = 0.001
factor = 0.5
patience = 2000
num_epochs = 10
print_every = 100
icnns = [PICNN(dimx, k, dimc, depth, symm_act_first=True, softplus_type='gaussian_softplus',
zero_softplus=True) for _ in range(nblocks)]
layers = [None] * (2 * nblocks + 1)
layers[0::2] = [ActNorm(dimx) for _ in range(nblocks + 1)]
layers[1::2] = [DeepConvexFlow(icnn, dimx, unbiased=False) for _, icnn in zip(range(nblocks), icnns)]
flow = SequentialFlow(layers)
optim = torch.optim.Adam(flow.parameters(), lr=lr)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(optim, num_epochs * len(train_loader), 0)
cuda = torch.cuda.is_available()
if cuda:
flow = flow.cuda()
loss_acc = 0
t = 0
grad_norm = 0
for e in range(num_epochs):
for x in train_loader:
x, y = x[:, :1], x[:, 1:]
x = x.double()
y = y.double()
if cuda:
x = x.cuda()
loss = - flow.logp(x, y).mean()
optim.zero_grad()
loss.backward()
grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(flow.parameters(), max_norm=10).item()
optim.step()
sch.step()
loss_acc += loss.item()
del loss
gc.collect()
torch.clear_autocast_cache()
t += 1
if t == 1:
print('init loss:', loss_acc, grad_norm)
if t % print_every == 0:
print(t, loss_acc / print_every, grad_norm)
loss_acc = 0
flow.eval()
colors = sns.color_palette("coolwarm", 9)
fig = plt.figure()
ax = fig.add_subplot()
for f in flow.flows[1::2]:
f.no_bruteforce = False
xx = torch.linspace(-5, 5, 1000).unsqueeze(1)
for pi, c in zip(np.linspace(0.1, 0.9, 9), colors):
p = torch.exp(flow.logp(xx, context=torch.ones_like(xx)*pi)).data.numpy()
plt.plot(xx, p, '--', c=c)
for pi, c in zip(np.linspace(0.1, 0.9, 9), colors):
p = torch.exp(train_loader.dataset.logp(xx, torch.ones_like(xx)*pi)).data.numpy()
plt.plot(xx, p, '-', c=c, label='{:.1f}'.format(pi))
plt.legend(loc=2, fontsize=12)
ax.tick_params(axis='both', which='major', labelsize=12)
savefig('1dMOG.png')