-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathfunctional_tensor.py
962 lines (704 loc) · 33.1 KB
/
functional_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
import warnings
from typing import Optional, Tuple, List
import torch
from torch import Tensor
from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad
def _is_tensor_a_torch_image(x: Tensor) -> bool:
return x.ndim >= 2
def _assert_image_tensor(img: Tensor) -> None:
if not _is_tensor_a_torch_image(img):
raise TypeError("Tensor is not a torch image.")
def _assert_threshold(img: Tensor, threshold: float) -> None:
bound = 1 if img.is_floating_point() else 255
if threshold > bound:
raise TypeError("Threshold should be less than bound of img.")
def get_dimensions(img: Tensor) -> List[int]:
_assert_image_tensor(img)
channels = 1 if img.ndim == 2 else img.shape[-3]
height, width = img.shape[-2:]
return [channels, height, width]
def get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image
_assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]]
def get_image_num_channels(img: Tensor) -> int:
_assert_image_tensor(img)
if img.ndim == 2:
return 1
elif img.ndim > 2:
return img.shape[-3]
raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
def _max_value(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return int(2 ** 8) - 1
elif dtype == torch.int8:
return int(2 ** 7) - 1
elif dtype == torch.int16:
return int(2 ** 15) - 1
elif dtype == torch.int32:
return int(2 ** 31) - 1
elif dtype == torch.int64:
return int(2 ** 63) - 1
else:
return 1
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = get_dimensions(img)[0]
if c not in permitted:
raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
if image.dtype == dtype:
return image
if image.is_floating_point():
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
max_val = float(_max_value(dtype))
result = image.mul(max_val + 1.0 - eps)
return result.to(dtype)
else:
input_max = float(_max_value(image.dtype))
# int to float
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
image = image.to(dtype)
return image / input_max
output_max = float(_max_value(dtype))
# int to int
if input_max > output_max:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image // factor can produce different results
factor = int((input_max + 1) // (output_max + 1))
image = torch.div(image, factor, rounding_mode="floor")
return image.to(dtype)
else:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image * factor can produce different results
factor = int((output_max + 1) // (input_max + 1))
image = image.to(dtype)
return image * factor
def vflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-2)
def hflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-1)
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img)
_, h, w = get_dimensions(img)
right = left + width
bottom = top + height
if left < 0 or top < 0 or right > w or bottom > h:
padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)]
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
return img[..., top:bottom, left:right]
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [3])
if num_output_channels not in (1, 3):
raise ValueError("num_output_channels should be either 1 or 3")
r, g, b = img.unbind(dim=-3)
# This implementation closely follows the TF one:
# https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
l_img = l_img.unsqueeze(dim=-3)
if num_output_channels == 3:
return l_img.expand(img.shape)
return l_img
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
return _blend(img, torch.zeros_like(img), brightness_factor)
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [3, 1])
c = get_dimensions(img)[0]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
if c == 3:
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
else:
mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
return _blend(img, mean, contrast_factor)
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor image")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
orig_dtype = img.dtype
if img.dtype == torch.uint8:
img = img.to(dtype=torch.float32) / 255.0
img = _rgb2hsv(img)
h, s, v = img.unbind(dim=-3)
h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img)
if orig_dtype == torch.uint8:
img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)
return img_hue_adj
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
return _blend(img, rgb_to_grayscale(img), saturation_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
if not isinstance(img, torch.Tensor):
raise TypeError("Input img should be a Tensor.")
_assert_channels(img, [1, 3])
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = convert_image_dtype(result, torch.float32)
result = (gain * result ** gamma).clamp(0, 1)
result = convert_image_dtype(result, dtype)
return result
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
ratio = float(ratio)
bound = 1.0 if img1.is_floating_point() else 255.0
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
def _rgb2hsv(img: Tensor) -> Tensor:
r, g, b = img.unbind(dim=-3)
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
# src/libImaging/Convert.c#L330
maxc = torch.max(img, dim=-3).values
minc = torch.min(img, dim=-3).values
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
# from happening in the results, because
# + S channel has division by `maxc`, which is zero only if `maxc = minc`
# + H channel has division by `(maxc - minc)`.
#
# Instead of overwriting NaN afterwards, we just prevent it from occuring so
# we don't need to deal with it in case we save the NaN in a buffer in
# backprop, if it is ever supported, but it doesn't hurt to do so.
eqc = maxc == minc
cr = maxc - minc
# Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
ones = torch.ones_like(maxc)
s = cr / torch.where(eqc, ones, maxc)
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
# replacing denominator with 1 when `eqc` is fine.
cr_divisor = torch.where(eqc, ones, cr)
rc = (maxc - r) / cr_divisor
gc = (maxc - g) / cr_divisor
bc = (maxc - b) / cr_divisor
hr = (maxc == r) * (bc - gc)
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = hr + hg + hb
h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc), dim=-3)
def _hsv2rgb(img: Tensor) -> Tensor:
h, s, v = img.unbind(dim=-3)
i = torch.floor(h * 6.0)
f = (h * 6.0) - i
i = i.to(dtype=torch.int32)
p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
i = i % 6
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)
return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# padding is left, right, top, bottom
# crop if needed
if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
neg_min_padding = [-min(x, 0) for x in padding]
crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
padding = [max(x, 0) for x in padding]
in_sizes = img.size()
_x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0]
right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
_y_indices = [i for i in range(in_sizes[-2])]
top_indices = [i for i in range(padding[2] - 1, -1, -1)]
bottom_indices = [-(i + 1) for i in range(padding[3])]
y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
ndim = img.ndim
if ndim == 3:
return img[:, y_indices[:, None], x_indices[None, :]]
elif ndim == 4:
return img[:, :, y_indices[:, None], x_indices[None, :]]
else:
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
_assert_image_tensor(img)
if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (int, float)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, tuple):
padding = list(padding)
if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
if isinstance(padding, int):
if torch.jit.is_scripting():
# This maybe unreachable
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
p = [pad_left, pad_right, pad_top, pad_bottom]
if padding_mode == "edge":
# remap padding_mode str
padding_mode = "replicate"
elif padding_mode == "symmetric":
# route to another implementation
return _pad_symmetric(img, p)
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
# Here we temporary cast input tensor to float
# until pytorch issue is resolved :
# https://github.com/pytorch/pytorch/issues/40763
need_cast = True
img = img.to(torch.float32)
img = torch_pad(img, p, mode=padding_mode, value=float(fill))
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
img = img.to(out_dtype)
return img
def resize(
img: Tensor,
size: List[int],
interpolation: str = "bilinear",
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> Tensor:
_assert_image_tensor(img)
if not isinstance(size, (int, tuple, list)):
raise TypeError("Got inappropriate size arg")
if not isinstance(interpolation, str):
raise TypeError("Got inappropriate interpolation arg")
if interpolation not in ["nearest", "bilinear", "bicubic"]:
raise ValueError("This interpolation mode is unsupported with Tensor input")
if isinstance(size, tuple):
size = list(size)
if isinstance(size, list):
if len(size) not in [1, 2]:
raise ValueError(
f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
)
if max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
if antialias is None:
antialias = False
if antialias and interpolation not in ["bilinear", "bicubic"]:
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
_, h, w = get_dimensions(img)
if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
short, long = (w, h) if w <= h else (h, w)
requested_new_short = size if isinstance(size, int) else size[0]
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
if (w, h) == (new_w, new_h):
return img
else: # specified both h and w
new_w, new_h = size[1], size[0]
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
# Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners, antialias=antialias)
if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255)
img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
return img
def _assert_grid_transform_inputs(
img: Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: Optional[List[float]],
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
) -> None:
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor")
_assert_image_tensor(img)
if matrix is not None and not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
if matrix is not None and len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")
if fill is not None and not isinstance(fill, (int, float, tuple, list)):
warnings.warn("Argument fill should be either int, float, tuple or list")
# Check fill
num_channels = get_dimensions(img)[0]
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
)
raise ValueError(msg.format(len(fill), num_channels))
if interpolation not in supported_interpolation_modes:
raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
need_squeeze = False
# make image NCHW
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if out_dtype not in req_dtypes:
need_cast = True
req_dtype = req_dtypes[0]
img = img.to(req_dtype)
return img, need_cast, need_squeeze, out_dtype
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
# it is better to round before cast
img = torch.round(img)
img = img.to(out_dtype)
return img
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
grid.dtype,
],
)
if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
img = torch.cat((img, dummy), dim=1)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
# Fill with required color
if fill is not None:
mask = img[:, -1:, :, :] # N * 1 * H * W
img = img[:, :-1, :, :] # N * C * H * W
mask = mask.expand_as(img)
len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1
fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
if mode == "nearest":
mask = mask < 0.5
img[mask] = fill_img[mask]
else: # 'bilinear'
img = img * mask + (1.0 - mask) * fill_img
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def _gen_affine_grid(
theta: Tensor,
w: int,
h: int,
ow: int,
oh: int,
) -> Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18
# Difference with AffineGridGenerator is that:
# 1) we normalize grid values after applying theta
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
return output_grid.view(1, oh, ow, 2)
def affine(
img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
shape = img.shape
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
pts = torch.tensor(
[
[-0.5 * w, -0.5 * h, 1.0],
[-0.5 * w, 0.5 * h, 1.0],
[0.5 * w, 0.5 * h, 1.0],
[0.5 * w, -0.5 * h, 1.0],
]
)
theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2)
min_vals, _ = new_pts.min(dim=0)
max_vals, _ = new_pts.max(dim=0)
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
tol = 1e-4
cmax = torch.ceil((max_vals / tol).trunc_() * tol)
cmin = torch.floor((min_vals / tol).trunc_() * tol)
size = cmax - cmin
return int(size[0]), int(size[1])
def rotate(
img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
expand: bool = False,
fill: Optional[List[float]] = None,
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
# src/libImaging/Geometry.c#L394
#
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
)
theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
output_grid = output_grid1 / output_grid2 - 1.0
return output_grid.view(1, oh, ow, 2)
def perspective(
img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None
) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor.")
_assert_image_tensor(img)
_assert_grid_transform_inputs(
img,
matrix=None,
interpolation=interpolation,
fill=fill,
supported_interpolation_modes=["nearest", "bilinear"],
coeffs=perspective_coeffs,
)
ow, oh = img.shape[-1], img.shape[-2]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
return kernel1d
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
_assert_image_tensor(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
kernel.dtype,
],
)
# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
img = torch_pad(img, padding, mode="reflect")
img = conv2d(img, kernel, groups=img.shape[-3])
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def invert(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device)
return bound - img
def posterize(img: Tensor, bits: int) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
if img.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
_assert_channels(img, [1, 3])
mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
return img & mask
def solarize(img: Tensor, threshold: float) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
_assert_threshold(img, threshold)
inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img)
def _blurred_degenerate_image(img: Tensor) -> Tensor:
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
kernel[1, 1] = 5.0
kernel /= kernel.sum()
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
kernel.dtype,
],
)
result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
result = img.clone()
result[..., 1:-1, 1:-1] = result_tmp
return result
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0:
raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if img.size(-1) <= 2 or img.size(-2) <= 2:
return img
return _blend(img, _blurred_degenerate_image(img), sharpness_factor)
def autocontrast(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [1, 3])
bound = 1.0 if img.is_floating_point() else 255.0
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
scale = bound / (maximum - minimum)
eq_idxs = torch.isfinite(scale).logical_not()
minimum[eq_idxs] = 0
scale[eq_idxs] = 1
return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
def _scale_channel(img_chan: Tensor) -> Tensor:
# TODO: we should expect bincount to always be faster than histc, but this
# isn't always the case. Once
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
# block and only use bincount.
if img_chan.is_cuda:
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
else:
hist = torch.bincount(img_chan.view(-1), minlength=256)
nonzero_hist = hist[hist != 0]
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
if step == 0:
return img_chan
lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
return lut[img_chan.to(torch.int64)].to(torch.uint8)
def _equalize_single_image(img: Tensor) -> Tensor:
return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])
def equalize(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if not (3 <= img.ndim <= 4):
raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
if img.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
_assert_channels(img, [1, 3])
if img.ndim == 3:
return _equalize_single_image(img)
return torch.stack([_equalize_single_image(x) for x in img])
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
_assert_image_tensor(tensor)
if not tensor.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
if tensor.ndim < 3:
raise ValueError(
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
)
if not inplace:
tensor = tensor.clone()
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
tensor.sub_(mean).div_(std)
return tensor
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
_assert_image_tensor(img)
if not inplace:
img = img.clone()
img[..., i : i + h, j : j + w] = v
return img