-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathdiffusion_multinomial.py
411 lines (294 loc) · 13.4 KB
/
diffusion_multinomial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
import torch
import torch.nn.functional as F
import numpy as np
from inspect import isfunction
"""
Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281
"""
eps = 1e-8
def sum_except_batch(x, num_dims=1):
'''
Sums all dimensions except the first.
Args:
x: Tensor, shape (batch_size, ...)
num_dims: int, number of batch dims (default=1)
Returns:
x_sum: Tensor, shape (batch_size,)
'''
return x.reshape(*x.shape[:num_dims], -1).sum(-1)
def log_1_min_a(a):
return torch.log(1 - a.exp() + 1e-40)
def log_add_exp(a, b):
maximum = torch.max(a, b)
return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))
def exists(x):
return x is not None
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def log_categorical(log_x_start, log_prob):
return (log_x_start.exp() * log_prob).sum(dim=1)
def index_to_log_onehot(x, num_classes):
assert x.max().item() < num_classes, \
f'Error: {x.max().item()} >= {num_classes}'
x_onehot = F.one_hot(x, num_classes)
permute_order = (0, -1) + tuple(range(1, len(x.size())))
x_onehot = x_onehot.permute(permute_order)
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
return log_x
def log_onehot_to_index(log_x):
return log_x.argmax(1)
def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])
alphas = np.clip(alphas, a_min=0.001, a_max=1.)
# Use sqrt of this, so the alpha in our paper is the alpha_sqrt from the
# Gaussian diffusion in Ho et al.
alphas = np.sqrt(alphas)
return alphas
class MultinomialDiffusion(torch.nn.Module):
def __init__(self, num_classes, shape, denoise_fn, timesteps=1000,
loss_type='vb_stochastic', parametrization='x0'):
super(MultinomialDiffusion, self).__init__()
assert loss_type in ('vb_stochastic', 'vb_all')
assert parametrization in ('x0', 'direct')
if loss_type == 'vb_all':
print('Computing the loss using the bound on _all_ timesteps.'
' This is expensive both in terms of memory and computation.')
self.num_classes = num_classes
self._denoise_fn = denoise_fn
self.loss_type = loss_type
self.shape = shape
self.num_timesteps = timesteps
self.parametrization = parametrization
alphas = cosine_beta_schedule(timesteps)
alphas = torch.tensor(alphas.astype('float64'))
log_alpha = np.log(alphas)
log_cumprod_alpha = np.cumsum(log_alpha)
log_1_min_alpha = log_1_min_a(log_alpha)
log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)
assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5
# Convert to float32 and register buffers.
self.register_buffer('log_alpha', log_alpha.float())
self.register_buffer('log_1_min_alpha', log_1_min_alpha.float())
self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float())
self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float())
self.register_buffer('Lt_history', torch.zeros(timesteps))
self.register_buffer('Lt_count', torch.zeros(timesteps))
def multinomial_kl(self, log_prob1, log_prob2):
kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
return kl
def q_pred_one_timestep(self, log_x_t, t):
log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)
log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)
# alpha_t * E[xt] + (1 - alpha_t) 1 / K
log_probs = log_add_exp(
log_x_t + log_alpha_t,
log_1_min_alpha_t - np.log(self.num_classes)
)
return log_probs
def q_pred(self, log_x_start, t):
log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)
log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)
log_probs = log_add_exp(
log_x_start + log_cumprod_alpha_t,
log_1_min_cumprod_alpha - np.log(self.num_classes)
)
return log_probs
def predict_start(self, log_x_t, t):
x_t = log_onehot_to_index(log_x_t)
out = self._denoise_fn(t, x_t)
assert out.size(0) == x_t.size(0)
assert out.size(1) == self.num_classes
assert out.size()[2:] == x_t.size()[1:]
log_pred = F.log_softmax(out, dim=1)
return log_pred
def q_posterior(self, log_x_start, log_x_t, t):
# q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
# where q(xt | xt-1, x0) = q(xt | xt-1).
t_minus_1 = t - 1
# Remove negative values, will not be used anyway for final decoder
t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)
num_axes = (1,) * (len(log_x_start.size()) - 1)
t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)
log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)
# Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
# Not very easy to see why this is true. But it is :)
unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)
log_EV_xtmin_given_xt_given_xstart = \
unnormed_logprobs \
- torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True)
return log_EV_xtmin_given_xt_given_xstart
def p_pred(self, log_x, t):
if self.parametrization == 'x0':
log_x_recon = self.predict_start(log_x, t=t)
log_model_pred = self.q_posterior(
log_x_start=log_x_recon, log_x_t=log_x, t=t)
elif self.parametrization == 'direct':
log_model_pred = self.predict_start(log_x, t=t)
else:
raise ValueError
return log_model_pred
@torch.no_grad()
def p_sample(self, log_x, t):
model_log_prob = self.p_pred(log_x=log_x, t=t)
out = self.log_sample_categorical(model_log_prob)
return out
@torch.no_grad()
def p_sample_loop(self, shape):
device = self.log_alpha.device
b = shape[0]
# start with random normal image.
img = torch.randn(shape, device=device)
for i in reversed(range(1, self.num_timesteps)):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
return img
@torch.no_grad()
def _sample(self, image_size, batch_size = 16):
return self.p_sample_loop((batch_size, 3, image_size, image_size))
@torch.no_grad()
def interpolate(self, x1, x2, t = None, lam = 0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)
assert x1.shape == x2.shape
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
img = (1 - lam) * xt1 + lam * xt2
for i in reversed(range(0, t)):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
return img
def log_sample_categorical(self, logits):
uniform = torch.rand_like(logits)
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
sample = (gumbel_noise + logits).argmax(dim=1)
log_sample = index_to_log_onehot(sample, self.num_classes)
return log_sample
def q_sample(self, log_x_start, t):
log_EV_qxt_x0 = self.q_pred(log_x_start, t)
log_sample = self.log_sample_categorical(log_EV_qxt_x0)
return log_sample
def nll(self, log_x_start):
b = log_x_start.size(0)
device = log_x_start.device
loss = 0
for t in range(0, self.num_timesteps):
t_array = (torch.ones(b, device=device) * t).long()
kl = self.compute_Lt(
log_x_start=log_x_start,
log_x_t=self.q_sample(log_x_start=log_x_start, t=t_array),
t=t_array)
loss += kl
loss += self.kl_prior(log_x_start)
return loss
def kl_prior(self, log_x_start):
b = log_x_start.size(0)
device = log_x_start.device
ones = torch.ones(b, device=device).long()
log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)
log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob))
kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
return sum_except_batch(kl_prior)
def compute_Lt(self, log_x_start, log_x_t, t, detach_mean=False):
log_true_prob = self.q_posterior(
log_x_start=log_x_start, log_x_t=log_x_t, t=t)
log_model_prob = self.p_pred(log_x=log_x_t, t=t)
if detach_mean:
log_model_prob = log_model_prob.detach()
kl = self.multinomial_kl(log_true_prob, log_model_prob)
kl = sum_except_batch(kl)
decoder_nll = -log_categorical(log_x_start, log_model_prob)
decoder_nll = sum_except_batch(decoder_nll)
mask = (t == torch.zeros_like(t)).float()
loss = mask * decoder_nll + (1. - mask) * kl
return loss
def sample_time(self, b, device, method='uniform'):
if method == 'importance':
if not (self.Lt_count > 10).all():
return self.sample_time(b, device, method='uniform')
Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
pt_all = Lt_sqrt / Lt_sqrt.sum()
t = torch.multinomial(pt_all, num_samples=b, replacement=True)
pt = pt_all.gather(dim=0, index=t)
return t, pt
elif method == 'uniform':
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
pt = torch.ones_like(t).float() / self.num_timesteps
return t, pt
else:
raise ValueError
def _train_loss(self, x):
b, device = x.size(0), x.device
if self.loss_type == 'vb_stochastic':
x_start = x
t, pt = self.sample_time(b, device, 'importance')
log_x_start = index_to_log_onehot(x_start, self.num_classes)
kl = self.compute_Lt(
log_x_start, self.q_sample(log_x_start=log_x_start, t=t), t)
Lt2 = kl.pow(2)
Lt2_prev = self.Lt_history.gather(dim=0, index=t)
new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach()
self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history)
self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2))
kl_prior = self.kl_prior(log_x_start)
# Upweigh loss term of the kl
vb_loss = kl / pt + kl_prior
return -vb_loss
elif self.loss_type == 'vb_all':
# Expensive, dont do it ;).
return -self.nll(x)
else:
raise ValueError()
def log_prob(self, x):
b, device = x.size(0), x.device
if self.training:
return self._train_loss(x)
else:
log_x_start = index_to_log_onehot(x, self.num_classes)
t, pt = self.sample_time(b, device, 'importance')
kl = self.compute_Lt(
log_x_start, self.q_sample(log_x_start=log_x_start, t=t), t)
kl_prior = self.kl_prior(log_x_start)
# Upweigh loss term of the kl
loss = kl / pt + kl_prior
return -loss
def sample(self, num_samples):
b = num_samples
device = self.log_alpha.device
uniform_logits = torch.zeros((b, self.num_classes) + self.shape, device=device)
log_z = self.log_sample_categorical(uniform_logits)
for i in reversed(range(0, self.num_timesteps)):
print(f'Sample timestep {i:4d}', end='\r')
t = torch.full((b,), i, device=device, dtype=torch.long)
log_z = self.p_sample(log_z, t)
print()
return log_onehot_to_index(log_z)
def sample_chain(self, num_samples):
b = num_samples
device = self.log_alpha.device
uniform_logits = torch.zeros(
(b, self.num_classes) + self.shape, device=device)
zs = torch.zeros((self.num_timesteps, b) + self.shape).long()
log_z = self.log_sample_categorical(uniform_logits)
for i in reversed(range(0, self.num_timesteps)):
print(f'Chain timestep {i:4d}', end='\r')
t = torch.full((b,), i, device=device, dtype=torch.long)
log_z = self.p_sample(log_z, t)
zs[i] = log_onehot_to_index(log_z)
print()
return zs