-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_dataset.py
107 lines (90 loc) · 3.3 KB
/
load_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Load datasets from TFRecords
__author__ = "MM. Kamani"
"""
import os
import numpy as np
import tensorflow as tf
class CifarDataset():
def __init__(self,
data_dir,
subset='train',
use_distortion=True,
dataset='cifar10'):
self.data_dir = data_dir
self.subset = subset
self.use_distortion = use_distortion
if dataset == 'cifar10':
self.num_class = 10
elif dataset == 'cifar100':
self.num_class = 100
self.WIDTH = 32
self.HEIGHT = 32
self.DEPTH = 3
def get_filenames(self, subset):
if subset in ['train', 'validation', 'eval']:
return [os.path.join(self.data_dir, subset + '.tfrecords')]
else:
raise ValueError('Invalid data subset "%s"' % subset)
def parser(self, serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = tf.parse_single_example(
serialized_example,
features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image'], tf.uint8)
image = tf.cast(image, tf.float32) / 128.0 - 1
image.set_shape([self.HEIGHT * self.WIDTH * self.DEPTH])
image = tf.cast(tf.reshape(image, [self.HEIGHT, self.WIDTH, self.DEPTH]),tf.float32)
image = self.preprocess(image)
# label = tf.cast(tf.one_hot(features['label'], self.num_class), tf.float32)
label = tf.cast(features['label'], tf.int32)
return image, label
def make_batch(self, batch_size):
"""Read the images and labels from 'filenames'."""
return self._create_tfiterator(batch_size, self.subset)
def _create_tfiterator(self, batch_size, subset):
filenames = self.get_filenames(subset=subset)
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.repeat()
# Parse records.
dataset= dataset.map(
self.parser, num_parallel_calls=batch_size)
# Ensure that the capacity is sufficiently large to provide good random
# shuffling.
dataset = dataset.shuffle(buffer_size = 3 * batch_size)
# Batch it up.
dataset= dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
return image_batch, label_batch
def preprocess(self, image):
"""Preprocess a single image in [height, width, depth] layout."""
if self.subset == 'train' and self.use_distortion:
# Pad 4 pixels on each dimension of feature map, done in mini-batch
image = tf.image.resize_image_with_crop_or_pad(image, 40, 40)
image = tf.random_crop(image, [self.HEIGHT, self.WIDTH, self.DEPTH])
image = tf.image.random_flip_left_right(image)
return image
@staticmethod
def num_examples_per_epoch(subset='train', dataset='cifar10'):
if dataset == 'cifar10':
if subset == 'train':
return 45000
elif subset == 'validation':
return 5000
elif subset == 'eval':
return 10000
else:
raise ValueError('Invalid data subset "%s"' % subset)
elif dataset == 'cifar100':
if subset == 'train':
return 50000
elif subset == 'validation':
return 0
elif subset == 'eval':
return 10000
else:
raise ValueError('Invalid data subset "%s"' % subset)