-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhelper.py
111 lines (85 loc) · 2.98 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Python Standard Library
import glob
import os
import pickle
from urllib.request import urlretrieve
import zipfile
# Public Libraries
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
# Project
import config
LOAD_PNG_PATTERN = 'sign*.png'
SAVE_PNG_PATTERN = 'sign_%05d.png'
class DLProgress(tqdm):
last_block = 0
def hook(self, block_num=1, block_size=1, total_size=None):
self.total = total_size
self.update((block_num - self.last_block) * block_size)
self.last_block = block_num
def load_data(data_set):
# Convert pickled data to human readable images
image_files = os.path.join(config.IMAGES_DIR, data_set, LOAD_PNG_PATTERN)
if len(glob.glob(image_files)) == 0:
extract_images()
# Sort file names in alphabetical order to line up with labels
files = glob.glob(image_files)
files.sort()
# Load images and save in X matrix. Convert to numpy array.
x = []
for file in files:
img = Image.open(file)
x.append(np.asarray(img.copy()))
img.close()
x = np.array(x)
# Load labels
labels_file = os.path.join(config.LABELS_DIR, '%s.csv' % data_set)
y = pd.read_csv(labels_file, header=None).values
# Return images and labels
return x, y
def maybe_download_traffic_signs():
data_file = os.path.join(config.DATA_DIR, config.DATA_FILE)
if not os.path.exists(data_file):
if not os.path.exists(config.DATA_DIR):
os.makedirs(config.DATA_DIR)
# Download Traffic Sign data
print('Downloading Traffic Sign data...')
with DLProgress(unit='B', unit_scale=True, miniters=1) as pbar:
urlretrieve(
config.DATA_URL,
data_file,
pbar.hook)
# Extract
print('Extracting Traffic Sign data...')
zip_ref = zipfile.ZipFile(data_file, 'r')
zip_ref.extractall(config.DATA_DIR)
zip_ref.close()
def extract_images():
# Download data
maybe_download_traffic_signs()
for data_set in config.DATA_SETS:
# Load Data
with open(os.path.join(config.DATA_DIR, '%s.p' % data_set), 'rb') as f:
data = pickle.load(f)
x = data['features']
y = data['labels']
# Save to CSV. No label for columns or rows
y = pd.DataFrame(y)
labels_dir = config.LABELS_DIR
if not os.path.exists(labels_dir):
os.makedirs(labels_dir)
labels_file = os.path.join(labels_dir, '%s.csv' % data_set)
y.to_csv(labels_file, header=False, index=False)
# Create image directory
directory = os.path.join(config.IMAGES_DIR, '%s' % data_set)
if not os.path.exists(directory):
os.makedirs(directory)
# Load images and save as picture files
num_images = x.shape[0]
for i in range(num_images):
file = os.path.join(directory, SAVE_PNG_PATTERN % i)
img = x[i]
img = Image.fromarray(img)
img.save(file)