-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_cifar_tfrecords.py
120 lines (96 loc) · 3.79 KB
/
generate_cifar_tfrecords.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Read CIFAR-10/100 data from pickled numpy arrays and writes TFRecords.
Generates tf.train.Example protos and writes them to TFRecord files from the
python version of the CIFAR-10/100 dataset downloaded from
https://www.cs.toronto.edu/~kriz/cifar.html.
__author__ = "MM. Kamani"
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import tarfile
from six.moves import cPickle as pickle
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/'
def download_and_extract(data_dir, dataset):
# download CIFAR-10 if not already downloaded.
CIFAR_FILENAME = 'cifar-10-python.tar.gz' if dataset == 'cifar10' else 'cifar-100-python.tar.gz'
tf.contrib.learn.datasets.base.maybe_download(CIFAR_FILENAME, data_dir,
CIFAR_DOWNLOAD_URL + CIFAR_FILENAME)
tarfile.open(os.path.join(data_dir, CIFAR_FILENAME),
'r:gz').extractall(data_dir)
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _get_file_names(dataset):
"""Returns the file names expected to exist in the input_dir."""
file_names = {}
if dataset == 'cifar10':
file_names['train'] = ['data_batch_%d' % i for i in xrange(1, 5)]
file_names['validation'] = ['data_batch_5']
file_names['eval'] = ['test_batch']
else:
file_names['train'] = ['train']
file_names['eval'] = ['test']
return file_names
def read_pickle_from_file(filename):
with tf.gfile.Open(filename, 'rb') as f:
if sys.version_info >= (3, 0):
data_dict = pickle.load(f, encoding='bytes')
else:
data_dict = pickle.load(f)
return data_dict
def convert_to_tfrecord(input_files, output_file, dataset):
"""Converts a file to TFRecords."""
print('Generating %s' % output_file)
with tf.python_io.TFRecordWriter(output_file) as record_writer:
for input_file in input_files:
data_dict = read_pickle_from_file(input_file)
data = data_dict[b'data']
if dataset =='cifar10':
labels = data_dict[b'labels']
else:
labels = data_dict[b'fine_labels']
num_entries_in_batch = len(labels)
for i in range(num_entries_in_batch):
example = tf.train.Example(features=tf.train.Features(
feature={
'image': _bytes_feature(data[i].tobytes()),
'label': _int64_feature(labels[i]),
}))
record_writer.write(example.SerializeToString())
def main(data_dir, dataset='cifar10'):
print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL))
download_and_extract(data_dir, dataset)
file_names = _get_file_names(dataset)
CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py' if dataset == 'cifar10' else 'cifar-100-python'
input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER)
for mode, files in file_names.items():
input_files = [os.path.join(input_dir, f) for f in files]
output_file = os.path.join(data_dir, mode + '.tfrecords')
try:
os.remove(output_file)
except OSError:
pass
# Convert to tf.train.Example and write the to TFRecords.
convert_to_tfrecord(input_files, output_file, dataset)
print('Done!')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data-dir',
type=str,
default='',
help='Directory to download and extract CIFAR-10 to.')
parser.add_argument(
'--dataset',
type=str,
default='cifar10',
choices=['cifar10','cifar100'],
help='The dataset to transfer to TFRecords')
args = parser.parse_args()
main(args.data_dir, args.dataset)