Skip to content

Commit

Permalink
Update aistdshadow.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyeying authored Aug 17, 2024
1 parent 74fb8eb commit 0cbba47
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions datasets/aistdshadow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,28 @@
import re
import random

class AISTD:
class AISTDShadow:
def __init__(self, config):
self.config = config
self.transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

def get_loaders(self, parse_patches=True, validation='AISTD'):
print("=> evaluating AISTD test set...")
train_dataset = ShadowDataset(dir=os.path.join(self.config.data.data_dir, 'data', 'AISTD'),
train_dataset = AISTDShadowDataset(dir=os.path.join(self.config.data.data_dir),
n=self.config.training.patch_n,
patch_size=self.config.data.image_size,
transforms=self.transforms,
filelist=None,
parse_patches=parse_patches)
val_dataset = ShadowDataset(dir=os.path.join(self.config.data.data_dir, 'data', 'AISTD'),

val_dataset = AISTDShadowDataset(dir=os.path.join(self.config.data.data_dir),
n=self.config.training.patch_n,
patch_size=self.config.data.image_size,
transforms=self.transforms,
filelist='aistdtest.txt',
filelist='/home1/yeying/DeS3_Deshadow/aistdtest.txt',
parse_patches=parse_patches)


if not parse_patches:
self.config.training.batch_size = 1
self.config.sampling.batch_size = 1
Expand All @@ -43,14 +45,16 @@ def get_loaders(self, parse_patches=True, validation='AISTD'):
return train_loader, val_loader


class ShadowDataset(torch.utils.data.Dataset):
class AISTDShadowDataset(torch.utils.data.Dataset):
def __init__(self, dir, patch_size, n, transforms, filelist=None, parse_patches=True):
super().__init__()
print('-dir-',dir)

if filelist is None:
if filelist is None:
shadow_dir = dir
input_names, gt_names = [], []


# train filelist
shadow_inputs = os.path.join(shadow_dir, 'trainA')
images = [f for f in listdir(shadow_inputs) if isfile(os.path.join(shadow_inputs, f))]
assert len(images) == 1330
Expand All @@ -69,7 +73,7 @@ def __init__(self, dir, patch_size, n, transforms, filelist=None, parse_patches=
with open(train_list) as f:
contents = f.readlines()
input_names = [i.strip() for i in contents]
gt_names = [i.strip().replace('input', 'gt') for i in input_names]
gt_names = [i.strip().replace('testA', 'testB') for i in input_names]

self.input_names = input_names
self.gt_names = gt_names
Expand Down Expand Up @@ -100,33 +104,32 @@ def n_random_crops(img, x, y, h, w):
def get_images(self, index):
input_name = self.input_names[index]
gt_name = self.gt_names[index]
# print('input_name,gt_name',input_name,gt_name)
datasetname = re.split('/', input_name)[-3]
# print('datasetname',datasetname)
img_id = re.split('/', input_name)[-1][:-4]
input_img = PIL.Image.open(os.path.join(self.dir, input_name)) if self.dir else PIL.Image.open(input_name)
try:
gt_img = PIL.Image.open(os.path.join(self.dir, gt_name)) if self.dir else PIL.Image.open(gt_name)
except:
gt_img = PIL.Image.open(os.path.join(self.dir, gt_name)).convert('RGB') if self.dir else \
PIL.Image.open(gt_name).convert('RGB')
# print('img_id',img_id)
input_img = PIL.Image.open(input_name)
gt_img = PIL.Image.open(gt_name)

if self.parse_patches:
wd_new = 512
ht_new = 512
input_img = input_img.resize((wd_new, ht_new), PIL.Image.ANTIALIAS)
gt_img = gt_img.resize((wd_new, ht_new), PIL.Image.ANTIALIAS)
# print('-input_img.shape,gt_img.shape-',input_img.size,gt_img.size)
i, j, h, w = self.get_params(input_img, (self.patch_size, self.patch_size), self.n)
input_img = self.n_random_crops(input_img, i, j, h, w)
gt_img = self.n_random_crops(gt_img, i, j, h, w)
outputs = [torch.cat([self.transforms(input_img[i]), self.transforms(gt_img[i])], dim=0)
for i in range(self.n)]
return torch.stack(outputs, dim=0), img_id
else:
wd_new, ht_new = input_img.size
if ht_new > wd_new and ht_new > 1024:
wd_new = int(np.ceil(wd_new * 1024 / ht_new))
ht_new = 1024
elif ht_new <= wd_new and wd_new > 1024:
ht_new = int(np.ceil(ht_new * 1024 / wd_new))
wd_new = 1024
wd_new = int(16 * np.ceil(wd_new / 16.0))
ht_new = int(16 * np.ceil(ht_new / 16.0))
wd_new = 256
ht_new = 256
input_img = input_img.resize((wd_new, ht_new), PIL.Image.ANTIALIAS)
gt_img = gt_img.resize((wd_new, ht_new), PIL.Image.ANTIALIAS)
# print(input_img.shape,gt_img.shape)

return torch.cat([self.transforms(input_img), self.transforms(gt_img)], dim=0), img_id

Expand Down

0 comments on commit 0cbba47

Please sign in to comment.