-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdata.py
242 lines (180 loc) · 10.5 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import concurrent.futures
from collections import defaultdict
import os
import numpy as np
import torch
import tqdm
from PIL import ImageFile
from torch.utils.data import Dataset
from utils.dataset_tools import get_label_set, load_dataset, load_image, check_download_dataset
import re
ImageFile.LOAD_TRUNCATED_IMAGES = True
non_decimal = re.compile(r'[^\d.]+')
def remove_non_numerical_chars(input_text):
input_text = str(input_text)
return ''.join(i for i in input_text if i.isdigit())
def augment_image(image, transforms):
for transform_current in transforms:
image = transform_current(image)
return image
class ConvertToThreeChannels(object):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a single-channel image into a three-channel image, by cloning the original channel three times.
In the other cases, tensors are returned without changes.
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return pic if pic.shape[0] == 3 else pic.repeat([3, 1, 1])
class FewShotLearningDatasetParallel(Dataset):
def __init__(self, dataset_name, indexes_of_folders_indicating_class, train_val_test_split,
labels_as_int, transforms, num_classes_per_set, num_support_sets,
num_samples_per_support_class, num_channels,
num_samples_per_target_class, seed, sets_are_pre_split,
load_into_memory, set_name, num_tasks_per_epoch, overwrite_classes_in_each_task,
class_change_interval):
"""
A data provider class inheriting from Pytorch's Dataset class. It takes care of creating task sets for
our few-shot learning model training and evaluation
:param args: Arguments in the form of a Bunch object. Includes all hyperparameters necessary for the
data-provider. For transparency and readability reasons to explicitly set as self.object_name all arguments
required for the data provider, such that the reader knows exactly what is necessary for the data provider/
"""
check_download_dataset(dataset_name=dataset_name)
dataset_name = dataset_name
dataset_path = os.path.join(os.path.abspath(os.environ['DATASET_DIR']), dataset_name)
self.indexes_of_folders_indicating_class = indexes_of_folders_indicating_class
self.labels_as_int = labels_as_int
self.train_val_test_split = train_val_test_split
self.num_samples_per_support_class = num_samples_per_support_class
self.num_classes_per_set = num_classes_per_set
self.num_samples_per_target_class = num_samples_per_target_class
self.num_support_sets = num_support_sets
self.overwrite_classes_in_each_task = overwrite_classes_in_each_task
self.class_change_interval = class_change_interval
self.dataset = load_dataset(dataset_path, dataset_name, labels_as_int, seed, sets_are_pre_split,
load_into_memory,
indexes_of_folders_indicating_class, train_val_test_split)[set_name]
self.num_tasks_per_epoch = num_tasks_per_epoch
self.dataset_size_dict = {key: len(self.dataset[key]) for key in list(self.dataset.keys())}
self.index_to_label_name_dict_file = "{}/map_to_label_name_{}.json".format(dataset_path, dataset_name)
self.label_name_to_map_dict_file = "{}/label_name_to_map_{}.json".format(dataset_path, dataset_name)
self.label_set = get_label_set(index_to_label_name_dict_file=self.index_to_label_name_dict_file)
self.data_length = np.sum([len(self.dataset[key]) for key in self.dataset])
self.num_channels = num_channels
self.load_into_memory = load_into_memory
self.current_iter = 0
if self.load_into_memory:
print('load_into_memory flag is True. Loading the {} set into memory'.format(set_name))
dataset_loaded = defaultdict(list)
with tqdm.tqdm(total=len(self.dataset.items())) as pbar:
for key, file_paths in self.dataset.items():
file_path_transforms_list = [(file_path, transforms) for file_path in file_paths]
with tqdm.tqdm(total=len(file_paths)) as pbar_process_images:
with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:
for processed_image in executor.map(load_preprocess_image, file_path_transforms_list):
dataset_loaded[key].append(processed_image)
pbar_process_images.update(1)
pbar.update(1)
self.dataset = dataset_loaded
self.seed = seed
self.transforms = transforms
print("data", self.data_length)
def get_set(self, seed, class_seed, num_channels):
"""
Generates a task-set to be used for training or evaluation
:param set_name: The name of the set to use, e.g. "train", "val" etc.
:return: A task-set containing an image and label support set, and an image and label target set.
"""
# NSS, CCI, N-WAY, K-SHOT, OVERWRITE
rng = np.random.RandomState(seed)
class_rng = np.random.RandomState(class_seed)
x_support_set_task = []
x_target_set_task = []
y_support_set_task = []
y_target_set_task = []
x_task = []
y_task = []
class_keys_copy = list(self.dataset_size_dict.keys()).copy()
for class_sample_idx in range(int(self.num_support_sets / self.class_change_interval)):
selected_classes = class_rng.choice(class_keys_copy,
size=self.num_classes_per_set, replace=False)
for key in selected_classes:
class_keys_copy.remove(key)
class_rng.shuffle(selected_classes)
episode_labels = [i for i in range(self.num_classes_per_set)]
class_to_episode_label = {selected_class: episode_label for (selected_class, episode_label) in
zip(selected_classes, episode_labels)}
episode_label_to_orig_class = {episode_label: selected_class for (selected_class, episode_label) in
zip(selected_classes, episode_labels)}
for support_set_idx in range(self.class_change_interval):
set_paths = [self.dataset[class_idx][sample_idx] for
class_idx in selected_classes for sample_idx in
rng.choice(len(self.dataset[class_idx]),
size=self.num_samples_per_support_class + self.num_samples_per_target_class,
replace=False)]
if not self.load_into_memory:
x = [augment_image(load_image(image_path), transforms=self.transforms) for image_path in set_paths]
else:
x = [torch.tensor(image_path.copy()) for image_path in set_paths]
y = np.array([(self.num_samples_per_support_class + self.num_samples_per_target_class) * [
class_to_episode_label[class_idx]]
for class_idx in selected_classes])
for idx, item in enumerate(x):
if not item.shape[0] == num_channels:
if item.shape[0] > num_channels:
x[idx] = x[idx][:num_channels]
elif item.shape[0] == 1:
x[idx] = item.repeat([num_channels, 1, 1])
x = torch.stack(x)
y = y.reshape(1, self.num_classes_per_set,
self.num_samples_per_support_class + self.num_samples_per_target_class)
y = torch.Tensor(y)
x = x.view(1, self.num_classes_per_set,
self.num_samples_per_support_class + self.num_samples_per_target_class, x.shape[1],
x.shape[2],
x.shape[3])
x_support_set = x[:, :, :self.num_samples_per_support_class]
y_support_set = y[:, :, :self.num_samples_per_support_class]
x_target_set = x[:, :, self.num_samples_per_support_class:]
y_target_set = y[:, :, self.num_samples_per_support_class:]
x = x.view(-1, x.shape[-3], x.shape[-2],
x.shape[-1])
y = y.reshape(-1).numpy()
y = torch.Tensor([int(remove_non_numerical_chars(episode_label_to_orig_class[item])) for item in y])
x_support_set_task.append(x_support_set)
x_target_set_task.append(x_target_set)
y_support_set_task.append(y_support_set)
y_target_set_task.append(y_target_set)
x_task.append(x)
y_task.append(y)
x_support_set_task = torch.stack(x_support_set_task, dim=0)
x_target_set_task = torch.stack(x_target_set_task, dim=0)
y_support_set_task = torch.stack(y_support_set_task, dim=0).long()
y_target_set_task = torch.stack(y_target_set_task, dim=0).long()
x_task = torch.stack(x_task, dim=0)
y_task = torch.stack(y_task, dim=0).long()
if not self.overwrite_classes_in_each_task:
class_change_factors = np.repeat(np.arange(self.num_support_sets), self.class_change_interval)
for i in range(self.num_support_sets):
y_support_set_task[i] += class_change_factors[i] * self.num_classes_per_set
y_target_set_task[i] += class_change_factors[i] * self.num_classes_per_set
return x_support_set_task, x_target_set_task, y_support_set_task, y_target_set_task, x_task, y_task
def set_current_iter_idx(self, idx):
self.seed = self.seed + (idx)
self.current_iter = idx
def __len__(self):
return self.num_tasks_per_epoch - self.current_iter
def __getitem__(self, idx):
# print(int(idx / self.same_class_interval))
return self.get_set(class_seed=idx, seed=self.seed + idx,
num_channels=self.num_channels)
def load_preprocess_image(file_path_transform):
image_path, transform = file_path_transform
loaded_image = load_image(image_path)
preprocessed_image = augment_image(loaded_image, transforms=transform).numpy()
return preprocessed_image