-
Notifications
You must be signed in to change notification settings - Fork 741
/
Copy pathVAE.py
272 lines (187 loc) · 8.87 KB
/
VAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, Layer
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.utils import plot_model
import tensorflow as tf
from utils.callbacks import CustomCallback, step_decay_schedule
import numpy as np
import json
import os
import pickle
class Sampling(Layer):
def call(self, inputs):
mu, log_var = inputs
epsilon = K.random_normal(shape=K.shape(mu), mean=0., stddev=1.)
return mu + K.exp(log_var / 2) * epsilon
class VAEModel(Model):
def __init__(self, encoder, decoder, r_loss_factor, **kwargs):
super(VAEModel, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.r_loss_factor = r_loss_factor
def train_step(self, data):
if isinstance(data, tuple):
data = data[0]
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.square(data - reconstruction), axis = [1,2,3]
)
reconstruction_loss *= self.r_loss_factor
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_sum(kl_loss, axis = 1)
kl_loss *= -0.5
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
return {
"loss": total_loss,
"reconstruction_loss": reconstruction_loss,
"kl_loss": kl_loss,
}
def call(self,inputs):
latent = self.encoder(inputs)
return self.decoder(latent)
class VariationalAutoencoder():
def __init__(self
, input_dim
, encoder_conv_filters
, encoder_conv_kernel_size
, encoder_conv_strides
, decoder_conv_t_filters
, decoder_conv_t_kernel_size
, decoder_conv_t_strides
, z_dim
, r_loss_factor
, use_batch_norm = False
, use_dropout= False
):
self.name = 'variational_autoencoder'
self.input_dim = input_dim
self.encoder_conv_filters = encoder_conv_filters
self.encoder_conv_kernel_size = encoder_conv_kernel_size
self.encoder_conv_strides = encoder_conv_strides
self.decoder_conv_t_filters = decoder_conv_t_filters
self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size
self.decoder_conv_t_strides = decoder_conv_t_strides
self.z_dim = z_dim
self.r_loss_factor = r_loss_factor
self.use_batch_norm = use_batch_norm
self.use_dropout = use_dropout
self.n_layers_encoder = len(encoder_conv_filters)
self.n_layers_decoder = len(decoder_conv_t_filters)
self._build()
def _build(self):
### THE ENCODER
encoder_input = Input(shape=self.input_dim, name='encoder_input')
x = encoder_input
for i in range(self.n_layers_encoder):
conv_layer = Conv2D(
filters = self.encoder_conv_filters[i]
, kernel_size = self.encoder_conv_kernel_size[i]
, strides = self.encoder_conv_strides[i]
, padding = 'same'
, name = 'encoder_conv_' + str(i)
)
x = conv_layer(x)
if self.use_batch_norm:
x = BatchNormalization()(x)
x = LeakyReLU()(x)
if self.use_dropout:
x = Dropout(rate = 0.25)(x)
shape_before_flattening = K.int_shape(x)[1:]
x = Flatten()(x)
self.mu = Dense(self.z_dim, name='mu')(x)
self.log_var = Dense(self.z_dim, name='log_var')(x)
self.z = Sampling(name='encoder_output')([self.mu, self.log_var])
self.encoder = Model(encoder_input, [self.mu, self.log_var, self.z], name = 'encoder')
### THE DECODER
decoder_input = Input(shape=(self.z_dim,), name='decoder_input')
x = Dense(np.prod(shape_before_flattening))(decoder_input)
x = Reshape(shape_before_flattening)(x)
for i in range(self.n_layers_decoder):
conv_t_layer = Conv2DTranspose(
filters = self.decoder_conv_t_filters[i]
, kernel_size = self.decoder_conv_t_kernel_size[i]
, strides = self.decoder_conv_t_strides[i]
, padding = 'same'
, name = 'decoder_conv_t_' + str(i)
)
x = conv_t_layer(x)
if i < self.n_layers_decoder - 1:
if self.use_batch_norm:
x = BatchNormalization()(x)
x = LeakyReLU()(x)
if self.use_dropout:
x = Dropout(rate = 0.25)(x)
else:
x = Activation('sigmoid')(x)
decoder_output = x
self.decoder = Model(decoder_input, decoder_output, name = 'decoder')
### THE FULL VAE
self.model = VAEModel(self.encoder, self.decoder, self.r_loss_factor)
def compile(self, learning_rate):
self.learning_rate = learning_rate
optimizer = Adam(lr=learning_rate)
self.model.compile(optimizer=optimizer)
def save(self, folder):
if not os.path.exists(folder):
os.makedirs(folder)
os.makedirs(os.path.join(folder, 'viz'))
os.makedirs(os.path.join(folder, 'weights'))
os.makedirs(os.path.join(folder, 'images'))
with open(os.path.join(folder, 'params.pkl'), 'wb') as f:
pickle.dump([
self.input_dim
, self.encoder_conv_filters
, self.encoder_conv_kernel_size
, self.encoder_conv_strides
, self.decoder_conv_t_filters
, self.decoder_conv_t_kernel_size
, self.decoder_conv_t_strides
, self.z_dim
, self.use_batch_norm
, self.use_dropout
], f)
self.plot_model(folder)
def load_weights(self, filepath):
self.model.load_weights(filepath)
def train(self, x_train, batch_size, epochs, run_folder, print_every_n_batches = 100, initial_epoch = 0, lr_decay = 1):
custom_callback = CustomCallback(run_folder, print_every_n_batches, initial_epoch, self)
lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1)
checkpoint_filepath=os.path.join(run_folder, "weights/weights-{epoch:03d}-{loss:.2f}.h5")
checkpoint1 = ModelCheckpoint(checkpoint_filepath, save_weights_only = True, verbose=1)
checkpoint2 = ModelCheckpoint(os.path.join(run_folder, 'weights/weights.h5'), save_weights_only = True, verbose=1)
callbacks_list = [checkpoint1, checkpoint2, custom_callback, lr_sched]
self.model.fit(
x_train
, x_train
, batch_size = batch_size
, shuffle = True
, epochs = epochs
, initial_epoch = initial_epoch
, callbacks = callbacks_list
)
def train_with_generator(self, data_flow, epochs, steps_per_epoch, run_folder, print_every_n_batches = 100, initial_epoch = 0, lr_decay = 1, ):
custom_callback = CustomCallback(run_folder, print_every_n_batches, initial_epoch, self)
lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1)
checkpoint_filepath=os.path.join(run_folder, "weights/weights-{epoch:03d}-{loss:.2f}.h5")
checkpoint1 = ModelCheckpoint(checkpoint_filepath, save_weights_only = True, verbose=1)
checkpoint2 = ModelCheckpoint(os.path.join(run_folder, 'weights/weights.h5'), save_weights_only = True, verbose=1)
callbacks_list = [checkpoint1, checkpoint2, custom_callback, lr_sched]
self.model.save_weights(os.path.join(run_folder, 'weights/weights.h5'))
self.model.fit(
data_flow
, shuffle = True
, epochs = epochs
, initial_epoch = initial_epoch
, callbacks = callbacks_list
, steps_per_epoch=steps_per_epoch
)
def plot_model(self, run_folder):
plot_model(self.model, to_file=os.path.join(run_folder ,'viz/model.png'), show_shapes = True, show_layer_names = True)
plot_model(self.encoder, to_file=os.path.join(run_folder ,'viz/encoder.png'), show_shapes = True, show_layer_names = True)
plot_model(self.decoder, to_file=os.path.join(run_folder ,'viz/decoder.png'), show_shapes = True, show_layer_names = True)