-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdiffusion_model.py
384 lines (269 loc) · 12.3 KB
/
diffusion_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from diffusers.models.embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from diffusers.models.modeling_flax_utils import FlaxModelMixin
from transformers import AutoConfig
from transformers import FlaxBertModel
from transformers.models.bert.modeling_flax_bert import FlaxBertEncoder
import torch
import transformer
import model_utils as u
class DiffusionLM(nn.Module):
timesteps : int = 2000
latent_dim : int = 32
batch_size : int = 16
seq_len : int = 64
vocab_size : int = 333
use_pretrained : bool = True
train : bool = True
vocab : dict = None
vocab_r : dict = None
def setup(self):
self.embedder = nn.Embed(self.vocab_size, self.latent_dim)
self.transformer = transformer.Flax1DTransformer(latent_dim = self.latent_dim, seq_len = self.seq_len, vocab_size = self.vocab_size, use_pretrained = self.use_pretrained, train = self.train)
#self.scheduler = FlaxDDPMScheduler(num_train_timesteps = self.timesteps, beta_start = 0.0001, beta_end = 0.02, beta_schedule = self.beta_schedule)
#self.noise_scheduler_state = self.scheduler.create_state()
self.lm_head = transformer.FlaxBertLMPredictionHead(hidden_size = self.latent_dim, vocab_size = self.vocab_size)
self.get_alphas()
def call_embedder(self, inp):
return self.embedder(inp)
def call_transformer(self, x, t):
return self.transformer(x, t)
# def init_weights(self):
# sample = jnp.zeros((self.batch_size, self.seq_len, self.latent_dim), dtype=jnp.float32)
# timesteps = jnp.ones((self.batch_size,), dtype=jnp.int32)
# params_rng, dropout_rng = jax.random.split(rng)
# rngs = {"params": params_rng, "dropout": dropout_rng}
# return self.init(rngs, sample, timesteps)['params'] # timesteps
# def __call__(self, x, timesteps, rng : jax.random.PRNGKey = None):
# latents = self.embedder(x)
# rng, noise_rng = jax.random.split(rng)
# noise = jax.random.normal(noise_rng, latents.shape)
# rng, timestep_rng = jax.random.split(noise_rng)
# timesteps = jax.random.randint(timestep_rng, (self.batch_size,), 0, self.scheduler.config.num_train_timesteps,)
# noisy_latents = self.scheduler.add_noise(self.noise_scheduler_state, latents, noise, timesteps)
# model_pred = self.transformer(noisy_latents, timesteps)
# loss = (noise - model_pred) ** 2
# rng, sample_rng = jax.random.split(rng)
# t, weights = self.sample(sample_rng)
# return loss.mean()
def get_logits(self, x):
shared_embedding = self.embedder.variables['params']['embedding']
return self.lm_head(x, shared_embedding = shared_embedding.T)
def __call__(self, x, rng : jax.random.PRNGKey = None):
rng, sample_rng = jax.random.split(rng)
x_shape = x['input_ids'].shape
t, weights = self.schedule_sampler(sample_rng, x_shape[0])
x_start_mean = self.embedder(x['input_ids']) # get_embeds
std = u.extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, jnp.array([0]), x_start_mean.shape)
rng, noise_rng = jax.random.split(rng)
x_start = self.get_x_start(x_start_mean, std, noise_rng)
#print(f'x_start\n {x_start[0,:4,:4]}')
rng, noise_rng2 = jax.random.split(noise_rng)
noise = jax.random.normal(noise_rng2, x_start.shape)
x_t = self.q_sample(x_start, t, noise=noise) # (16, 64, 32) reparametrization trick.
terms = {}
model_output = self.transformer(x_t, t) # _scale_timesteps(t)
#print(f'model_output\n{ model_output[0,:4,:4]}') # jax.debug.
terms["mse"] = u.mean_flat((x_start - model_output) ** 2)
t0_mask = t == 0
t0_loss = u.mean_flat((x_start_mean - model_output) ** 2)
terms["mse"] = jnp.where(t0_mask, t0_loss, terms["mse"]) # if t=0, predict x_start_mean, for all other steps predict x_start
out_mean, _, _ = self._q_mean_variance(x_start, jnp.array([self.timesteps - 1]))
tT_loss = u.mean_flat(out_mean**2)
decoder_nll = self.token_discrete_loss(x_start, x['input_ids'])
#print('decoder_nll', decoder_nll.shape, decoder_nll)
terms['loss'] = terms["mse"] + (decoder_nll + tT_loss)
terms['loss'] = (terms["loss"] * weights).mean()
terms["tT_loss"] = tT_loss.mean()
terms["decoder_nll"] = decoder_nll.mean()
terms["mse"] = terms["mse"].mean()
return terms
def schedule_sampler(self, rng, bsz):
w = jnp.ones([self.timesteps])
p = w / jnp.sum(w)
indices = jax.random.choice(rng, len(p), shape=(bsz,), p=p)
weights = 1 / (len(p) * p[indices])
return indices, weights
def get_std(self, timesteps, broadcast_shape):
res = self.alphas[timesteps]
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return jnp.broadcast_to(res, broadcast_shape)
def get_x_start(self, x_start_mean, std, noise_rng):
noise = jax.random.normal(noise_rng, x_start_mean.shape)
return x_start_mean + std * noise
def q_sample(self, x_start, t, noise):
return (u.extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
u.extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def get_alphas(self):
self._betas_for_alpha_bar()
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.sqrt_one_minus_alphas_cumprod = jnp.sqrt(1.0 - self.alphas_cumprod)
self.sqrt_alphas_cumprod = jnp.sqrt(self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = jnp.log(1.0 - self.alphas_cumprod)
self.alphas_cumprod_prev = jnp.append(1.0, self.alphas_cumprod[:-1])
self.posterior_mean_coef1 = self.betas * jnp.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * jnp.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_log_variance_clipped = jnp.log(jnp.append(self.posterior_variance[1], self.posterior_variance[1:]))
return
def _betas_for_alpha_bar(self, max_beta = 0.999):
schedule_fn = lambda t: 1 - jnp.sqrt(t + 0.0001)
t1_arr = jnp.array(range(self.timesteps))
t2_arr = t1_arr + 1
t1_arr = t1_arr / self.timesteps
t2_arr = t2_arr / self.timesteps
betas = 1 - schedule_fn(t2_arr) / schedule_fn(t1_arr)
self.betas = jnp.where(betas < max_beta, betas, max_beta)
def _q_mean_variance(self, x_start, t):
mean = u.extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = u.extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = u.extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def token_discrete_loss(self, x_t, input_ids):
logits = self.get_logits(x_t) # bsz, seqlen, vocab
return u.crossEntropy(logits, input_ids)
def q_posterior_mean_variance(self, x_start, x_t, t):
assert x_start.shape == x_t.shape
posterior_mean = (
u.extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ u.extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = u.extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = u.extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
# model is the transformer
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[0], x.shape[-1]
# B -> batch size, C -> channel size (embedding size)
assert t.shape == (B,)
model_output = self.transformer(x, t, **model_kwargs) # t -> self._scale_timesteps(t)
print('raw transformer output', model_output[:2, :4, :4])
model_variance = u.extract_into_tensor(self.posterior_variance, t, x.shape)
model_log_variance = u.extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
#pred_xstart = u.process_xstart(model_output)
pred_xstart = model_output
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return {
"mean": model_mean,
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
}
def p_sample(
self,
x,
t,
rng,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
top_p=None,
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
out = self.p_mean_variance(
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
rng, noise_rng = jax.random.split(rng)
noise = jax.random.normal(noise_rng, x.shape)
print('p_sample')
print('t', t[0])
print('x', x[:2, :4, :4])
print('noise', noise[:2, :4, :4])
print('out', out['mean'][:2, :4, :4], out['variance'][:2, :4, :4]) # mean is derived from the outtput of the transformer
#if top_p is not None and top_p > 0:
# doesn't seem to work the same in JAX
# def pos_noise(noise, noise_rng):
# replace_mask = jnp.abs(noise) > top_p
# if replace_mask.any():
# rng, noise_rng = jax.random.split(noise_rng)
# # noise[replace_mask] = jax.random.normal(noise_rng, x.shape) x = x.at[idx].set(y)
# noise = noise.at[replace_mask].set(jax.random.normal(noise_rng, x.shape))
# return pos_noise(noise, noise_rng)
# else:
# return noise
# noise = pos_noise(noise, noise_rng)
t_mask = jnp.where(t != 0, 1.0, 0.0) # no noise when t == 0
print('nonzero_mask', t_mask[0])
nonzero_mask = jnp.reshape(t_mask, (-1, *([1]* (len(x.shape)-1) )))
sample = out["mean"] + nonzero_mask * jax.lax.exp(0.5 * out["log_variance"]) * noise
print('final sample', sample[:2, :4, :4])
if t[0] == 1:
print('Decoded sample at t=1')
logits = self.get_logits(sample)
# print(logits.shape)
# print(logits[0,:4,:10])
cands, inds = jax.lax.top_k(logits, 1)
for seq in inds:
decoded_sentence = " ".join([self.vocab_r[x.item()] for x in seq])
print(decoded_sentence)
return {
"sample": sample,
"pred_xstart": out["pred_xstart"],
"greedy_mean": out["mean"],
"out": out,}
def p_sample_loop_progressive(
self,
shape,
rng,
noise=None,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
top_p=None,
langevin_func=None,
):
indices = list(range(self.timesteps))[::-1] # inference timesteps!!!
if noise is not None:
data = noise
else:
rng, noise_rng = jax.random.split(rng)
data = jax.random.normal(noise_rng, shape)
for i in indices:
t = np.array([i] * shape[0])
rng, noise_rng = jax.random.split(rng)
out = self.p_sample(
data,
t,
noise_rng,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
top_p=top_p,)
yield out
data = out["sample"]