-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
131 lines (101 loc) · 4.35 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright 2019 ChangyuLiu Authors. All Rights Reserved.
#
# Licensed under the MIT License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://opensource.org/licenses/MIT
# ==============================================================================
"""The training loop begins with generator receiving a random seed as input.
That seed is used to produce an image.
The discriminator is then used to classify real images (drawn from the training set)
and fakes images (produced by the generator).
The loss is calculated for each of these models,
and the gradients are used to update the generator and discriminator.
"""
from dataset.load_dataset import load_dataset
from network.generator import make_generator_model
from network.discriminator import make_discriminator_model
from util.loss_and_optim import generator_loss, generator_optimizer
from util.loss_and_optim import discriminator_loss, discriminator_optimizer
from util.save_checkpoints import save_checkpoints
from util.generate_and_save_images import generate_and_save_images
import tensorflow as tf
import time
import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='mnist', type=str,
help='use dataset {mnist or cifar}.')
parser.add_argument('--epochs', default=50, type=int,
help='Epochs for training.')
args = parser.parse_args()
print(args)
# define model save path
save_path = 'training_checkpoint'
# create dir
if not os.path.exists(save_path):
os.makedirs(save_path)
# define random noise
noise = tf.random.normal([16, 100])
# load dataset
mnist_train_dataset, cifar_train_dataset = load_dataset(60000, 128, 50000, 64)
# load network and optim paras
generator = make_generator_model(args.dataset)
generator_optimizer = generator_optimizer()
discriminator = make_discriminator_model(args.dataset)
discriminator_optimizer = discriminator_optimizer()
checkpoint_dir, checkpoint, checkpoint_prefix = save_checkpoints(generator,
discriminator,
generator_optimizer,
discriminator_optimizer,
save_path)
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
""" break it down into training steps.
Args:
images: input images.
"""
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss,
generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss,
discriminator.trainable_variables)
generator_optimizer.apply_gradients(
zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(
zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
""" train op
Args:
dataset: mnist dataset or cifar10 dataset.
epochs: number of iterative training.
"""
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
# Produce images for the GIF as we go
generate_and_save_images(generator,
epoch + 1,
noise,
save_path)
# Save the model every 15 epochs
if (epoch + 1) % 15 == 0:
checkpoint.save(file_prefix=checkpoint_prefix)
print(f'Time for epoch {epoch+1} is {time.time()-start:.3f} sec.')
# Generate after the final epoch
generate_and_save_images(generator,
epochs,
noise,
save_path)
if __name__ == '__main__':
if args.dataset == 'mnist':
train(mnist_train_dataset, args.epochs)
else:
train(cifar_train_dataset, args.epochs)