-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathdata.py
40 lines (32 loc) · 1.16 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from torch.utils.data import Dataset
from torchvision.datasets import MNIST, FashionMNIST, CIFAR10
from torchvision import transforms
class DiffSet(Dataset):
def __init__(self, train, dataset="MNIST"):
transform = transforms.Compose([transforms.ToTensor()])
datasets = {
"MNIST": MNIST,
"Fashion": FashionMNIST,
"CIFAR": CIFAR10,
}
train_dataset = datasets[dataset](
"./data", download=True, train=train, transform=transform
)
self.dataset_len = len(train_dataset.data)
if dataset == "MNIST" or dataset == "Fashion":
pad = transforms.Pad(2)
data = pad(train_dataset.data)
data = data.unsqueeze(3)
self.depth = 1
self.size = 32
elif dataset == "CIFAR":
data = torch.Tensor(train_dataset.data)
self.depth = 3
self.size = 32
self.input_seq = ((data / 255.0) * 2.0) - 1.0
self.input_seq = self.input_seq.moveaxis(3, 1)
def __len__(self):
return self.dataset_len
def __getitem__(self, item):
return self.input_seq[item]