-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_dataset.py
56 lines (49 loc) · 2.26 KB
/
custom_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from torchvision.datasets.folder import *
import numpy as np
import torch
import os
class CustomDataset(DatasetFolder):
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader, target_num=None):
super(CustomDataset, self).__init__(root, loader, IMG_EXTENSIONS,
transform=transform,
target_transform=target_transform)
self.imgs = self.samples
self.target_num = target_num
def __getitem__(self, index):
path, target = self.samples[index]
sample = self.loader(path)
f_name = os.path.basename(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
if self.target_num is None:
while True:
path2, target2 = self.samples[np.random.choice(len(self.samples), 1)[0]]
if target == target2:
pass
else:
break
sample2 = self.loader(path2)
if self.transform is not None:
sample2 = self.transform(sample2)
if self.target_transform is not None:
target2 = self.target_transform(target2)
sample = torch.cat((sample, sample2), 0)
target = torch.LongTensor([target, target2])
else:
samples_array = np.array(self.samples)
target_list = samples_array[samples_array[:, 1].astype(np.int32) == target]
idx = np.random.choice(target_list.shape[0], self.target_num - 1)
path2, target2 = target_list[idx, 0], target_list[idx, 1]
target = torch.LongTensor([target])
for p, t in zip(path2, target2):
sample2 = self.loader(p)
if self.transform is not None:
sample2 = self.transform(sample2)
if self.target_transform is not None:
target2 = self.target_transform(target2)
sample = torch.cat((sample, sample2), 0)
target = torch.cat((target, torch.LongTensor(target2.astype(np.long))), 0)
return sample, target, f_name