-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathimagepreprocess.py
138 lines (131 loc) · 5.43 KB
/
imagepreprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from torchvision import transforms
from torchvision.transforms import functional as F
import numbers
import torch
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def center_crop_with_flip(img, size, vertical_flip=False):
crop_h, crop_w = size
first_crop = F.center_crop(img, (crop_h, crop_w))
if vertical_flip:
img = F.vflip(img)
else:
img = F.hflip(img)
second_crop = F.center_crop(img, (crop_h, crop_w))
return (first_crop, second_crop)
class CenterCropWithFlip(object):
"""Center crops with its mirror version.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
"""
def __init__(self, size, vertical_flip=False):
self.size = size
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size
self.vertical_flip = vertical_flip
def __call__(self, img):
return center_crop_with_flip(img, self.size, self.vertical_flip)
def __repr__(self):
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
def preprocess_strategy(dataset):
evaluate_transforms = None
if dataset.startswith('CUB'):
train_transforms = transforms.Compose([
transforms.Resize(448),
transforms.CenterCrop(448),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
val_transforms = transforms.Compose([
transforms.Resize(448),
transforms.CenterCrop(448),
transforms.ToTensor(),
normalize,
])
evaluate_transforms = transforms.Compose([
transforms.Resize(448),
CenterCropWithFlip(448),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])),
])
elif dataset.startswith('datasets-fgvc-aircraft') or dataset.startswith('Dogs'):
train_transforms = transforms.Compose([
transforms.Resize((512,512)),
transforms.CenterCrop(448),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
val_transforms = transforms.Compose([
transforms.Resize((512,512)),
transforms.CenterCrop(448),
transforms.ToTensor(),
normalize,
])
evaluate_transforms = transforms.Compose([
transforms.Resize((512,512)),
CenterCropWithFlip(448),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])),
])
elif dataset.startswith('datasets-fgvc-cars'):
train_transforms = transforms.Compose([
transforms.Resize((448,448)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
val_transforms = transforms.Compose([
transforms.Resize((448,448)),
transforms.ToTensor(),
normalize,
])
evaluate_transforms = transforms.Compose([
transforms.Resize((448,448)),
CenterCropWithFlip(448),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])),
])
elif dataset.startswith('ILSVRC2012') or dataset.startswith('imagenet') :
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
evaluate_transforms = transforms.Compose([
transforms.Resize(256),
transforms.TenCrop(224),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])),
])
elif dataset.startswith('cifar10'):
train_transforms = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
val_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
evaluate_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
else:
raise KeyError("=> transform method of '{}' does not exist!".format(dataset))
return train_transforms, val_transforms, evaluate_transforms