forked from jakugel/oct-stargardtretina-seg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_generator.py
333 lines (246 loc) · 13.2 KB
/
data_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import numpy as np
import keras
from math import floor
class BatchGenerator:
"""Class to generate batches of images and their corresponding label to be used for fit_generator (training)
or predict_generator (evaluation)
_________
images: array or .hdf5 dataset of all images to be used. Shape: (number of images, width, height)
_________
labels: array or .hdf5 dataset of all labels to be used. Shape: (number of images, width, height)
_________
batch_size: size of the batch for neural network to process
_________
aug_fn_args: tuple of two-tuples containing augmentation function and argument pairs
_________
aug_mode: mode to use for augmentation
none: no augmentations -> will just use what is in the images and labels arrays as is
one: for each image, one augmentation will be picked from the list of possible augmentation functions
chosen based on probabilities in aug_probs.
all: for each image, all augmentations will be performed creating a new separate image for each
note that for patch mode: augs are applied to the full size images before being broken into patches
_________
aug_probs: probabilities used for selecting augmentations in 'one' mode. Should be values between 0 and 1
which add to 1.
_________
aug_fly: whether or not to perform all augmentations at the very start or to perform them each time the
image is required.
_________
shuffle: whether or not to shuffle the order of the images at the start as well as at the end of each epoch
_________
"""
def __init__(self, imdb, batch_size, aug_fn_args, aug_mode, aug_probs, aug_fly, shuffle=True, transpose=False,
normalise=True, ram_load=1):
self.shuffle = shuffle # whether to shuffle the order that images are iterated
self.transpose = transpose # whether to swap rows and columns of batches
self.normalise = normalise
self.batch_counter = 0 # number of batches generated in the current epoch
self.batch_size = batch_size # number of samples in a batch
self.full_counter = 0 # used to track which full size image we are up to
self.aug_counter = 0 # used to track which augmentation index we are up to (for aug_mode='all')
self.imdb = imdb
self.aug_fn_args = aug_fn_args
self.aug_mode = aug_mode
self.aug_probs = aug_probs
self.aug_fly = aug_fly
self.ram_load = ram_load
self.total_full_images = self.imdb.num_images
self.total_raw_samples = self.total_full_images # total raw samples (w/out augs)
self.labels_shape = self.imdb.labels_shape
if self.aug_mode == 'none':
self.total_samples = self.total_raw_samples
self.total_augs = 0
if self.aug_mode == 'all':
# want to combine all augmentations
self.total_augs = len(self.aug_fn_args)
self.total_samples = self.total_raw_samples * self.total_augs # total samples (including augmentations)
elif aug_mode == 'one':
self.total_augs = len(self.aug_fn_args)
self.total_samples = self.total_raw_samples
# create shape to be used to create the batch labels array
self.batch_labels_shape = list(self.labels_shape)
self.batch_labels_shape[0] = self.batch_size
self.batch_labels_shape = tuple(self.batch_labels_shape)
if self.aug_fly is False and self.aug_mode != 'none':
if self.ram_load == 0:
print("Incompatible parameter selection: ")
exit(1)
# don't augment on the fly so generate samples now
self.aug_images, self.aug_labels = self.setup_augnofly_data()
self.sample_shuffle = np.arange(self.total_full_images)
self.num_batches = int(floor(1.0 * self.total_samples / self.batch_size))
self.handle_epoch_end()
def setup_augnofly_data(self):
"""Setup augmented data to be used when aug_fly=False.
_________
Returns:
(1) array of images is created with
shape: (total full images, total number of augs, image width, image height, num_channels).
(2) array of labels is created with
shape: (total full images, total number of augs, image width, image height, num_channels). (semantic)
or
shape: (total full images, total number of augs). (patch based imdb)
_________
"""
aug_labels_shape = list(self.labels_shape)
aug_labels_shape[0] = self.total_full_images
aug_labels_shape.insert(1, self.total_augs)
aug_labels_shape = tuple(aug_labels_shape)
aug_images = np.zeros((self.total_full_images, self.total_augs, self.imdb.image_width,
self.imdb.image_height, self.imdb.num_channels), dtype='uint8')
aug_labels = np.zeros(aug_labels_shape, dtype='uint8')
for i in range(self.total_full_images):
for j in range(self.total_augs):
aug_fn = self.aug_fn_args[j][0]
aug_arg = self.aug_fn_args[j][1]
image = self.imdb.get_image(i)
label = self.imdb.get_label(i)
aug_images[i, j], aug_labels[i, j], _, _, _ = aug_fn(image, label, None, aug_arg, sample_ind=i, set=self.imdb.set)
return aug_images, aug_labels
def get_aug_fly(self, sample_ind):
"""Get next sample where augmentation needs to be generated on the fly.
_________
Returns:
aug_image: next sample. shape: (image width, image height)
aug_label: next label. shape: (image width, image height) (semantic) or shape: (1,) (patch based)
_________
"""
raw_image = self.imdb.get_image(sample_ind)
raw_label = self.imdb.get_label(sample_ind)
raw_seg = self.imdb.get_seg(sample_ind)
if self.aug_mode == 'all':
# perform each augmentation (current augmentation indicated by aug_ind)
aug_fn_arg = self.aug_fn_args[self.aug_counter]
aug_fn = aug_fn_arg[0]
aug_arg = aug_fn_arg[1]
aug_image, aug_label, _, _, _ = aug_fn(raw_image, raw_label, raw_seg, aug_arg, sample_ind=sample_ind, set=self.imdb.set) # apply augmentation
self.aug_counter += 1 # move to the next augmentation ready for next time
if self.aug_counter == self.total_augs:
self.aug_counter = 0 # reset the aug_ind, we are done with them all for this particular image
self.full_counter += 1 # move to the next full image as we have no more augs to do for the current
elif self.aug_mode == 'one':
# choose single augmentation for replacement based on probabilities
aug_fn_arg_ind = np.random.choice(np.arange(self.total_augs), p=self.aug_probs)
aug_fn_arg = self.aug_fn_args[aug_fn_arg_ind]
aug_fn = aug_fn_arg[0]
aug_arg = aug_fn_arg[1]
aug_image, aug_label, _, _, _ = aug_fn(raw_image, raw_label, raw_seg, aug_arg, sample_ind=sample_ind, set=self.imdb.set) # apply augmentation
self.full_counter += 1 # just the single random augmentation so move to the next raw image
else:
# no augmentation: just use the raw image and label as is
aug_image = raw_image
aug_label = raw_label
self.full_counter += 1 # move to the next image
return aug_image, aug_label
def get_aug_nofly(self, sample_ind):
"""Get next sample from pre-constructed augmentation data.
_________
Returns:
aug_image: next sample. shape: (image width, image height)
aug_label: next label. shape: (image width, image height) (semantic) or shape: (1,) (patch based)
_________
"""
raw_image = self.imdb.get_image(sample_ind)
raw_label = self.imdb.get_label(sample_ind)
if self.aug_mode == 'all':
# all augmentations are used
aug_image = self.aug_images[sample_ind, self.aug_counter]
aug_label = self.aug_labels[sample_ind, self.aug_counter]
self.aug_counter += 1
if self.aug_counter == self.total_augs:
self.aug_counter = 0
self.full_counter += 1
elif self.aug_mode == 'one':
# just one random augmentation is used
aug_ind_choice = np.random.choice(np.arange(self.total_augs), p=self.aug_probs)
aug_image = self.aug_images[sample_ind, aug_ind_choice]
aug_label = self.aug_labels[sample_ind, aug_ind_choice]
self.full_counter += 1
else:
# no augmentation: just use raw image
aug_image = raw_image
aug_label = raw_label
self.full_counter += 1
return aug_image, aug_label
def get_batch_list(self):
"""Generate next batch of data
_________
Returns: [batch_images, batch_labels]
batch_images: set of images. shape: (batch_size, image width, image height)
batch_labels: set of labels. shape: (batch_size, image width, image height) (semantic)
or shape: (batch_size,) (patch based)
_________
"""
batch_images = np.zeros((self.batch_size, self.imdb.image_width, self.imdb.image_height,
self.imdb.num_channels), dtype='float32')
batch_labels = np.zeros(self.batch_labels_shape)
cur_sample_counter = 0 # sample we are up to in the current batch
while cur_sample_counter < self.batch_size:
# store images and labels here
full_sample_ind = self.sample_shuffle[self.full_counter]
if self.aug_fly is True:
# need to perform the augmentations on the fly as they haven't been done already
batch_images[cur_sample_counter], batch_labels[cur_sample_counter] = \
self.get_aug_fly(full_sample_ind)
elif self.aug_fly is False:
# augmentations have been done beforehand and are stored, so just retrieve the appropriate one
batch_image, batch_label = self.get_aug_nofly(full_sample_ind)
batch_images[cur_sample_counter], batch_labels[cur_sample_counter] = batch_image, batch_label
cur_sample_counter += 1
if self.full_counter == self.total_full_images:
self.full_counter = 0
# end of the batch
# we have done another batch
self.batch_counter += 1
if self.batch_counter == self.num_batches:
self.batch_counter = 0
# normalise batch images before passing to network
if self.normalise is True:
batch_images /= 255
if self.transpose is True:
batch_images = np.transpose(batch_images, axes=(0, 2, 1, 3))
if len(batch_labels.shape) == 4:
# labels are masks
batch_labels = np.transpose(batch_labels, axes=(0, 2, 1, 3))
return [batch_images, batch_labels]
def handle_epoch_end(self):
"""Handle the end of the epoch by resetting the augmentation index, and the counters for number of batches
and number of images. If shuffle is enabled, shuffle the order of the raw images for the next epoch.
_________
Returns:
aug_image: next sample. shape: (image width, image height)
aug_label: next label. shape: (image width, image height) (semantic) or shape: (1,) (patch based)
_________
"""
self.batch_counter = 0
self.full_counter = 0
self.aug_counter = 0
if self.shuffle:
np.random.seed()
x = np.arange(self.total_raw_samples)
s = np.arange(x.shape[0])
np.random.shuffle(s)
self.sample_shuffle = self.sample_shuffle[s]
class DataGenerator(keras.utils.Sequence):
"""Generates data for Keras: see BatchGenerator for parameter details"""
def __init__(self, imdb, batch_size, aug_fn_args, aug_mode, aug_probs, aug_fly, shuffle=False, transpose=False,
normalise=True, ram_load=1):
self.batch_gen = BatchGenerator(imdb=imdb, batch_size=batch_size,
aug_fn_args=aug_fn_args, aug_mode=aug_mode,
aug_probs=aug_probs, aug_fly=aug_fly, shuffle=shuffle, transpose=transpose,
normalise=normalise, ram_load=ram_load)
def __len__(self):
"""Denotes the number of batches per epoch"""
return self.batch_gen.num_batches
def __getitem__(self, index):
"""Generate one batch of data"""
# Generate data
X, y = self.__data_generation()
return X, y
def on_epoch_end(self):
"""Runs when a training epoch ends"""
self.batch_gen.handle_epoch_end()
def __data_generation(self):
"""Generates data to be used for a batch"""
[X, y] = self.batch_gen.get_batch_list()
return X, y