-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathdata_utils.py
170 lines (142 loc) · 5.09 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import random
import albumentations as A
import cv2
import numpy as np
import torch
from data_custom_augmentations import SafeHorizontalFlip, SafePerspective
from utils import GRID_SIZE, IMG_SIZE
def get_appearance_transform(transform_types):
"""
Returns an albumentation compose augmentation.
transform_type is a list containing types of pixel-wise data augmentation to use.
Possible augmentations are 'shadow', 'blur', 'visual', 'noise', 'color'.
"""
transforms = []
if "shadow" in transform_types:
transforms.append(A.RandomShadow(p=0.1))
if "blur" in transform_types:
transforms.append(
A.OneOf(
transforms=[
A.Defocus(p=5),
A.Downscale(p=15, interpolation=cv2.INTER_LINEAR),
A.GaussianBlur(p=65),
A.MedianBlur(p=15),
],
p=0.75,
)
)
if "visual" in transform_types:
transforms.append(
A.OneOf(
transforms=[
A.ToSepia(p=15),
A.ToGray(p=20),
A.Equalize(p=15),
A.Sharpen(p=20),
],
p=0.5,
)
)
if "noise" in transform_types:
transforms.append(
A.OneOf(
transforms=[
A.GaussNoise(var_limit=(10.0, 20.0), p=70),
A.ISONoise(intensity=(0.1, 0.25), p=30),
],
p=0.6,
)
)
if "color" in transform_types:
transforms.append(
A.OneOf(
transforms=[
A.ColorJitter(p=5),
A.HueSaturationValue(p=10),
A.RandomBrightnessContrast(brightness_limit=[-0.05, 0.25], p=85),
],
p=0.95,
)
)
return A.Compose(transforms=transforms)
def get_geometric_transform(transform_types, gridsize):
"""
Returns an albumentation compose augmentation.
transform_type is a list containing types of geometric data augmentation to use.
Possible augmentations are 'rotate', 'flip' and 'perspective'.
"""
transforms = []
if "rotate" in transform_types:
transforms.append(
A.SafeRotate(
limit=[-30, 30],
interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REPLICATE,
p=0.5,
)
)
if "flip" in transform_types:
transforms.append(SafeHorizontalFlip(gridsize=gridsize, p=0.25))
if "perspective" in transform_types:
transforms.append(SafePerspective(p=0.5))
return A.ReplayCompose(
transforms=transforms,
keypoint_params=A.KeypointParams(format="xy", remove_invisible=False),
)
def crop_image_tight(img, grid2D):
"""
Crops the image tightly around the keypoints in grid2D.
This function creates a tight crop around the document in the image.
"""
size = img.shape
minx = np.floor(np.amin(grid2D[0, :, :])).astype(int)
maxx = np.ceil(np.amax(grid2D[0, :, :])).astype(int)
miny = np.floor(np.amin(grid2D[1, :, :])).astype(int)
maxy = np.ceil(np.amax(grid2D[1, :, :])).astype(int)
s = 20
s = min(min(s, minx), miny) # s shouldn't be smaller than actually available natural padding is
s = min(min(s, size[1] - 1 - maxx), size[0] - 1 - maxy)
# Crop the image slightly larger than necessary
img = img[miny - s : maxy + s, minx - s : maxx + s, :]
cx1 = random.randint(0, max(s - 5, 1))
cx2 = random.randint(0, max(s - 5, 1)) + 1
cy1 = random.randint(0, max(s - 5, 1))
cy2 = random.randint(0, max(s - 5, 1)) + 1
img = img[cy1:-cy2, cx1:-cx2, :]
top = miny - s + cy1
bot = size[0] - maxy - s + cy2
left = minx - s + cx1
right = size[1] - maxx - s + cx2
return img, top, bot, left, right
class BaseDataset(torch.utils.data.Dataset):
"""
Base torch dataset class for all unwarping dataset.
"""
def __init__(
self,
data_path,
appearance_augmentation=[],
img_size=IMG_SIZE,
grid_size=GRID_SIZE,
) -> None:
super().__init__()
self.dataroot = data_path
self.img_size = img_size
self.grid_size = grid_size
self.normalize_3Dgrid = True
self.appearance_transform = get_appearance_transform(appearance_augmentation)
self.all_samples = []
def __len__(self):
return len(self.all_samples)
def crop_tight(self, img_RGB, grid2D):
# The incoming grid2D array is expressed in pixel coordinates (resolution of img_RGB before crop/resize)
size = img_RGB.shape
img, top, bot, left, right = crop_image_tight(img_RGB, grid2D)
img = cv2.resize(img, self.img_size)
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).float()
grid2D[0, :, :] = (grid2D[0, :, :] - left) / (size[1] - left - right)
grid2D[1, :, :] = (grid2D[1, :, :] - top) / (size[0] - top - bot)
grid2D = (grid2D * 2.0) - 1.0
return img, grid2D