-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdataloader.py
87 lines (65 loc) · 2.81 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
import os
import numpy as np
class Load_Dataset(Dataset):
# Initialize your data, download, etc.
def __init__(self, dataset):
super(Load_Dataset, self).__init__()
# Load samples
x_data = dataset["samples"]
# Convert to torch tensor
if isinstance(x_data, np.ndarray):
x_data = torch.from_numpy(x_data)
# Load labels
y_data = dataset.get("labels")
if y_data is not None and isinstance(y_data, np.ndarray):
y_data = torch.from_numpy(y_data)
self.x_data = x_data.float()
self.y_data = y_data.long() if y_data is not None else None
self.len = x_data.shape[0]
def get_labels(self):
return self.y_data
def __getitem__(self, index):
sample = {
'samples': self.x_data[index].squeeze(-1),
'labels': int(self.y_data[index])
}
return sample
def __len__(self):
return self.len
def data_generator(data_path, data_type, hparams):
# original
train_dataset = torch.load(os.path.join(data_path, data_type, f"train.pt"))
val_dataset = torch.load(os.path.join(data_path, data_type, f"val.pt"))
test_dataset = torch.load(os.path.join(data_path, data_type, f"test.pt"))
# Loading datasets
train_dataset = Load_Dataset(train_dataset)
val_dataset = Load_Dataset(val_dataset)
test_dataset = Load_Dataset(test_dataset)
cw = train_dataset.y_data.numpy().tolist()
cw_dict = {}
for i in range(len(np.unique(train_dataset.y_data.numpy()))):
cw_dict[i] = cw.count(i)
# print(cw_dict)
# Dataloaders
batch_size = hparams["batch_size"]
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size,
shuffle=True, drop_last=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size,
shuffle=False, drop_last=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size,
shuffle=False, drop_last=False, num_workers=0)
return train_loader, val_loader, test_loader, get_class_weight(cw_dict)
import math
def get_class_weight(labels_dict):
total = sum(labels_dict.values())
max_num = max(labels_dict.values())
mu = 1.0 / (total / max_num)
class_weight = dict()
for key, value in labels_dict.items():
score = math.log(mu * total / float(value))
class_weight[key] = score if score > 1.0 else 1.0
return class_weight