-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathmain.py
494 lines (453 loc) · 17.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
import os
import cv2
import copy
import math
import argparse
import numpy as np
from time import time
from tqdm import tqdm
from easydict import EasyDict
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from data import get_metadata, get_dataset, fix_legacy_dict
import unets
unsqueeze3x = lambda x: x[..., None, None, None]
class GuassianDiffusion:
"""Gaussian diffusion process with 1) Cosine schedule for beta values (https://arxiv.org/abs/2102.09672)
2) L_simple training objective from https://arxiv.org/abs/2006.11239.
"""
def __init__(self, timesteps=1000, device="cuda:0"):
self.timesteps = timesteps
self.device = device
self.alpha_bar_scheduler = (
lambda t: math.cos((t / self.timesteps + 0.008) / 1.008 * math.pi / 2) ** 2
)
self.scalars = self.get_all_scalars(
self.alpha_bar_scheduler, self.timesteps, self.device
)
self.clamp_x0 = lambda x: x.clamp(-1, 1)
self.get_x0_from_xt_eps = lambda xt, eps, t, scalars: (
self.clamp_x0(
1
/ unsqueeze3x(scalars.alpha_bar[t].sqrt())
* (xt - unsqueeze3x((1 - scalars.alpha_bar[t]).sqrt()) * eps)
)
)
self.get_pred_mean_from_x0_xt = (
lambda xt, x0, t, scalars: unsqueeze3x(
(scalars.alpha_bar[t].sqrt() * scalars.beta[t])
/ ((1 - scalars.alpha_bar[t]) * scalars.alpha[t].sqrt())
)
* x0
+ unsqueeze3x(
(scalars.alpha[t] - scalars.alpha_bar[t])
/ ((1 - scalars.alpha_bar[t]) * scalars.alpha[t].sqrt())
)
* xt
)
def get_all_scalars(self, alpha_bar_scheduler, timesteps, device, betas=None):
"""
Using alpha_bar_scheduler, get values of all scalars, such as beta, beta_hat, alpha, alpha_hat, etc.
"""
all_scalars = {}
if betas is None:
all_scalars["beta"] = torch.from_numpy(
np.array(
[
min(
1 - alpha_bar_scheduler(t + 1) / alpha_bar_scheduler(t),
0.999,
)
for t in range(timesteps)
]
)
).to(
device
) # hardcoding beta_max to 0.999
else:
all_scalars["beta"] = betas
all_scalars["beta_log"] = torch.log(all_scalars["beta"])
all_scalars["alpha"] = 1 - all_scalars["beta"]
all_scalars["alpha_bar"] = torch.cumprod(all_scalars["alpha"], dim=0)
all_scalars["beta_tilde"] = (
all_scalars["beta"][1:]
* (1 - all_scalars["alpha_bar"][:-1])
/ (1 - all_scalars["alpha_bar"][1:])
)
all_scalars["beta_tilde"] = torch.cat(
[all_scalars["beta_tilde"][0:1], all_scalars["beta_tilde"]]
)
all_scalars["beta_tilde_log"] = torch.log(all_scalars["beta_tilde"])
return EasyDict(dict([(k, v.float()) for (k, v) in all_scalars.items()]))
def sample_from_forward_process(self, x0, t):
"""Single step of the forward process, where we add noise in the image.
Note that we will use this paritcular realization of noise vector (eps) in training.
"""
eps = torch.randn_like(x0)
xt = (
unsqueeze3x(self.scalars.alpha_bar[t].sqrt()) * x0
+ unsqueeze3x((1 - self.scalars.alpha_bar[t]).sqrt()) * eps
)
return xt.float(), eps
def sample_from_reverse_process(
self, model, xT, timesteps=None, model_kwargs={}, ddim=False
):
"""Sampling images by iterating over all timesteps.
model: diffusion model
xT: Starting noise vector.
timesteps: Number of sampling steps (can be smaller the default,
i.e., timesteps in the diffusion process).
model_kwargs: Additional kwargs for model (using it to feed class label for conditioning)
ddim: Use ddim sampling (https://arxiv.org/abs/2010.02502). With very small number of
sampling steps, use ddim sampling for better image quality.
Return: An image tensor with identical shape as XT.
"""
model.eval()
final = xT
# sub-sampling timesteps for faster sampling
timesteps = timesteps or self.timesteps
new_timesteps = np.linspace(
0, self.timesteps - 1, num=timesteps, endpoint=True, dtype=int
)
alpha_bar = self.scalars["alpha_bar"][new_timesteps]
new_betas = 1 - (
alpha_bar / torch.nn.functional.pad(alpha_bar, [1, 0], value=1.0)[:-1]
)
scalars = self.get_all_scalars(
self.alpha_bar_scheduler, timesteps, self.device, new_betas
)
for i, t in zip(np.arange(timesteps)[::-1], new_timesteps[::-1]):
with torch.no_grad():
current_t = torch.tensor([t] * len(final), device=final.device)
current_sub_t = torch.tensor([i] * len(final), device=final.device)
pred_epsilon = model(final, current_t, **model_kwargs)
# using xt+x0 to derive mu_t, instead of using xt+eps (former is more stable)
pred_x0 = self.get_x0_from_xt_eps(
final, pred_epsilon, current_sub_t, scalars
)
pred_mean = self.get_pred_mean_from_x0_xt(
final, pred_x0, current_sub_t, scalars
)
if i == 0:
final = pred_mean
else:
if ddim:
final = (
unsqueeze3x(scalars["alpha_bar"][current_sub_t - 1]).sqrt()
* pred_x0
+ (
1 - unsqueeze3x(scalars["alpha_bar"][current_sub_t - 1])
).sqrt()
* pred_epsilon
)
else:
final = pred_mean + unsqueeze3x(
scalars.beta_tilde[current_sub_t].sqrt()
) * torch.randn_like(final)
final = final.detach()
return final
class loss_logger:
def __init__(self, max_steps):
self.max_steps = max_steps
self.loss = []
self.start_time = time()
self.ema_loss = None
self.ema_w = 0.9
def log(self, v, display=False):
self.loss.append(v)
if self.ema_loss is None:
self.ema_loss = v
else:
self.ema_loss = self.ema_w * self.ema_loss + (1 - self.ema_w) * v
if display:
print(
f"Steps: {len(self.loss)}/{self.max_steps} \t loss (ema): {self.ema_loss:.3f} "
+ f"\t Time elapsed: {(time() - self.start_time)/3600:.3f} hr"
)
def train_one_epoch(
model,
dataloader,
diffusion,
optimizer,
logger,
lrs,
args,
):
model.train()
for step, (images, labels) in enumerate(dataloader):
assert (images.max().item() <= 1) and (0 <= images.min().item())
# must use [-1, 1] pixel range for images
images, labels = (
2 * images.to(args.device) - 1,
labels.to(args.device) if args.class_cond else None,
)
t = torch.randint(diffusion.timesteps, (len(images),), dtype=torch.int64).to(
args.device
)
xt, eps = diffusion.sample_from_forward_process(images, t)
pred_eps = model(xt, t, y=labels)
loss = ((pred_eps - eps) ** 2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if lrs is not None:
lrs.step()
# update ema_dict
if args.local_rank == 0:
new_dict = model.state_dict()
for (k, v) in args.ema_dict.items():
args.ema_dict[k] = (
args.ema_w * args.ema_dict[k] + (1 - args.ema_w) * new_dict[k]
)
logger.log(loss.item(), display=not step % 100)
def sample_N_images(
N,
model,
diffusion,
xT=None,
sampling_steps=250,
batch_size=64,
num_channels=3,
image_size=32,
num_classes=None,
args=None,
):
"""use this function to sample any number of images from a given
diffusion model and diffusion process.
Args:
N : Number of images
model : Diffusion model
diffusion : Diffusion process
xT : Starting instantiation of noise vector.
sampling_steps : Number of sampling steps.
batch_size : Batch-size for sampling.
num_channels : Number of channels in the image.
image_size : Image size (assuming square images).
num_classes : Number of classes in the dataset (needed for class-conditioned models)
args : All args from the argparser.
Returns: Numpy array with N images and corresponding labels.
"""
samples, labels, num_samples = [], [], 0
num_processes, group = dist.get_world_size(), dist.group.WORLD
with tqdm(total=math.ceil(N / (args.batch_size * num_processes))) as pbar:
while num_samples < N:
if xT is None:
xT = (
torch.randn(batch_size, num_channels, image_size, image_size)
.float()
.to(args.device)
)
if args.class_cond:
y = torch.randint(num_classes, (len(xT),), dtype=torch.int64).to(
args.device
)
else:
y = None
gen_images = diffusion.sample_from_reverse_process(
model, xT, sampling_steps, {"y": y}, args.ddim
)
samples_list = [torch.zeros_like(gen_images) for _ in range(num_processes)]
if args.class_cond:
labels_list = [torch.zeros_like(y) for _ in range(num_processes)]
dist.all_gather(labels_list, y, group)
labels.append(torch.cat(labels_list).detach().cpu().numpy())
dist.all_gather(samples_list, gen_images, group)
samples.append(torch.cat(samples_list).detach().cpu().numpy())
num_samples += len(xT) * num_processes
pbar.update(1)
samples = np.concatenate(samples).transpose(0, 2, 3, 1)[:N]
samples = (127.5 * (samples + 1)).astype(np.uint8)
return (samples, np.concatenate(labels) if args.class_cond else None)
def main():
parser = argparse.ArgumentParser("Minimal implementation of diffusion models")
# diffusion model
parser.add_argument("--arch", type=str, help="Neural network architecture")
parser.add_argument(
"--class-cond",
action="store_true",
default=False,
help="train class-conditioned diffusion model",
)
parser.add_argument(
"--diffusion-steps",
type=int,
default=1000,
help="Number of timesteps in diffusion process",
)
parser.add_argument(
"--sampling-steps",
type=int,
default=250,
help="Number of timesteps in diffusion process",
)
parser.add_argument(
"--ddim",
action="store_true",
default=False,
help="Sampling using DDIM update step",
)
# dataset
parser.add_argument("--dataset", type=str)
parser.add_argument("--data-dir", type=str, default="./dataset/")
# optimizer
parser.add_argument(
"--batch-size", type=int, default=128, help="batch-size per gpu"
)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--epochs", type=int, default=500)
parser.add_argument("--ema_w", type=float, default=0.9995)
# sampling/finetuning
parser.add_argument("--pretrained-ckpt", type=str, help="Pretrained model ckpt")
parser.add_argument("--delete-keys", nargs="+", help="Pretrained model ckpt")
parser.add_argument(
"--sampling-only",
action="store_true",
default=False,
help="No training, just sample images (will save them in --save-dir)",
)
parser.add_argument(
"--num-sampled-images",
type=int,
default=50000,
help="Number of images required to sample from the model",
)
# misc
parser.add_argument("--save-dir", type=str, default="./trained_models/")
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--seed", default=112233, type=int)
# setup
args = parser.parse_args()
metadata = get_metadata(args.dataset)
torch.backends.cudnn.benchmark = True
args.device = "cuda:{}".format(args.local_rank)
torch.cuda.set_device(args.device)
torch.manual_seed(args.seed + args.local_rank)
np.random.seed(args.seed + args.local_rank)
if args.local_rank == 0:
print(args)
# Creat model and diffusion process
model = unets.__dict__[args.arch](
image_size=metadata.image_size,
in_channels=metadata.num_channels,
out_channels=metadata.num_channels,
num_classes=metadata.num_classes if args.class_cond else None,
).to(args.device)
if args.local_rank == 0:
print(
"We are assuming that model input/ouput pixel range is [-1, 1]. Please adhere to it."
)
diffusion = GuassianDiffusion(args.diffusion_steps, args.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
# load pre-trained model
if args.pretrained_ckpt:
print(f"Loading pretrained model from {args.pretrained_ckpt}")
d = fix_legacy_dict(torch.load(args.pretrained_ckpt, map_location=args.device))
dm = model.state_dict()
if args.delete_keys:
for k in args.delete_keys:
print(
f"Deleting key {k} becuase its shape in ckpt ({d[k].shape}) doesn't match "
+ f"with shape in model ({dm[k].shape})"
)
del d[k]
model.load_state_dict(d, strict=False)
print(
f"Mismatched keys in ckpt and model: ",
set(d.keys()) ^ set(dm.keys()),
)
print(f"Loaded pretrained model from {args.pretrained_ckpt}")
# distributed training
ngpus = torch.cuda.device_count()
if ngpus > 1:
if args.local_rank == 0:
print(f"Using distributed training on {ngpus} gpus.")
args.batch_size = args.batch_size // ngpus
torch.distributed.init_process_group(backend="nccl", init_method="env://")
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
# sampling
if args.sampling_only:
sampled_images, labels = sample_N_images(
args.num_sampled_images,
model,
diffusion,
None,
args.sampling_steps,
args.batch_size,
metadata.num_channels,
metadata.image_size,
metadata.num_classes,
args,
)
np.savez(
os.path.join(
args.save_dir,
f"{args.arch}_{args.dataset}-{args.sampling_steps}-sampling_steps-{len(sampled_images)}_images-class_condn_{args.class_cond}.npz",
),
sampled_images,
labels,
)
return
# Load dataset
train_set = get_dataset(args.dataset, args.data_dir, metadata)
sampler = DistributedSampler(train_set) if ngpus > 1 else None
train_loader = DataLoader(
train_set,
batch_size=args.batch_size,
shuffle=sampler is None,
sampler=sampler,
num_workers=4,
pin_memory=True,
)
if args.local_rank == 0:
print(
f"Training dataset loaded: Number of batches: {len(train_loader)}, Number of images: {len(train_set)}"
)
logger = loss_logger(len(train_loader) * args.epochs)
# ema model
args.ema_dict = copy.deepcopy(model.state_dict())
# lets start training the model
for epoch in range(args.epochs):
if sampler is not None:
sampler.set_epoch(epoch)
train_one_epoch(model, train_loader, diffusion, optimizer, logger, None, args)
if not epoch % 1:
sampled_images, _ = sample_N_images(
64,
model,
diffusion,
None,
args.sampling_steps,
args.batch_size,
metadata.num_channels,
metadata.image_size,
metadata.num_classes,
args,
)
if args.local_rank == 0:
cv2.imwrite(
os.path.join(
args.save_dir,
f"{args.arch}_{args.dataset}-{args.diffusion_steps}_steps-{args.sampling_steps}-sampling_steps-class_condn_{args.class_cond}.png",
),
np.concatenate(sampled_images, axis=1)[:, :, ::-1],
)
if args.local_rank == 0:
torch.save(
model.state_dict(),
os.path.join(
args.save_dir,
f"{args.arch}_{args.dataset}-epoch_{args.epochs}-timesteps_{args.diffusion_steps}-class_condn_{args.class_cond}.pt",
),
)
torch.save(
args.ema_dict,
os.path.join(
args.save_dir,
f"{args.arch}_{args.dataset}-epoch_{args.epochs}-timesteps_{args.diffusion_steps}-class_condn_{args.class_cond}_ema_{args.ema_w}.pt",
),
)
if __name__ == "__main__":
main()