-
Notifications
You must be signed in to change notification settings - Fork 146
/
Copy pathdata_loader.py
executable file
·65 lines (54 loc) · 2.06 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from torch.utils import data
import os
from os.path import join, abspath, splitext, split, isdir, isfile
from PIL import Image
import numpy as np
import cv2
def prepare_image_PIL(im):
im = im[:,:,::-1] - np.zeros_like(im) # rgb to bgr
im -= np.array((104.00698793,116.66876762,122.67891434))
im = np.transpose(im, (2, 0, 1)) # (H x W x C) to (C x H x W)
return im
def prepare_image_cv2(im):
im -= np.array((104.00698793,116.66876762,122.67891434))
im = np.transpose(im, (2, 0, 1)) # (H x W x C) to (C x H x W)
return im
class BSDS_RCFLoader(data.Dataset):
"""
Dataloader BSDS500
"""
def __init__(self, root='data/HED-BSDS_PASCAL', split='train', transform=False):
self.root = root
self.split = split
self.transform = transform
if self.split == 'train':
self.filelist = join(self.root, 'bsds_pascal_train_pair.lst')
elif self.split == 'test':
self.filelist = join(self.root, 'test.lst')
else:
raise ValueError("Invalid split type!")
with open(self.filelist, 'r') as f:
self.filelist = f.readlines()
def __len__(self):
return len(self.filelist)
def __getitem__(self, index):
if self.split == "train":
img_file, lb_file = self.filelist[index].split()
lb = np.array(Image.open(join(self.root, lb_file)), dtype=np.float32)
if lb.ndim == 3:
lb = np.squeeze(lb[:, :, 0])
assert lb.ndim == 2
lb = lb[np.newaxis, :, :]
lb[lb == 0] = 0
lb[np.logical_and(lb>0, lb<128)] = 2
lb[lb >= 128] = 1
else:
img_file = self.filelist[index].rstrip()
if self.split == "train":
img = np.array(cv2.imread(join(self.root, img_file)), dtype=np.float32)
img = prepare_image_cv2(img)
return img, lb
else:
img = np.array(Image.open(join(self.root, img_file)), dtype=np.float32)
img = prepare_image_PIL(img)
return img