-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvideotransforms.py
119 lines (97 loc) · 3.35 KB
/
videotransforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import numbers
import random
from torchvision.transforms import functional as F
import torch
class RandomCrop(object):
"""Crop the given video sequences (t x h x w) at a random location.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
@staticmethod
def get_params(img, output_size):
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
t, h, w, c = img.shape
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th) if h!=th else 0
j = random.randint(0, w - tw) if w!=tw else 0
return i, j, th, tw
def __call__(self, imgs):
i, j, h, w = self.get_params(imgs, self.size)
imgs = imgs[:, i:i+h, j:j+w, :]
return imgs
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class CenterCrop(object):
"""Crops the given seq Images at the center.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, imgs):
"""
Args:
img (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image.
"""
t, h, w, c = imgs.shape
th, tw = self.size
i = int(np.round((h - th) / 2.))
j = int(np.round((w - tw) / 2.))
return imgs[:, i:i+th, j:j+tw, :]
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class RandomHorizontalFlip(object):
"""Horizontally flip the given seq Images randomly with a given probability.
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, imgs):
"""
Args:
img (seq Images): seq Images to be flipped.
Returns:
seq Images: Randomly flipped seq images.
"""
if random.random() < self.p:
# t x h x w
return np.flip(imgs, axis=2).copy()
return imgs
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
class ToTensor(object):
"""
torchvision converts PIL images and numpy arrays that are uint8 0 to 255 to float 0 to 1
Converts numpy arrays that are float to float tensor
"""
def __call__(self, imgs):
output = []
for img in imgs:
output.append(F.to_tensor(img))
output = torch.stack(output)
output = output.permute(1,0,2,3)
return output