diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 01ba70365b3..5b0693a2e78 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -2,9 +2,10 @@ import pytest import torch +from common_utils import assert_equal from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels from torchvision.prototype import transforms, features -from torchvision.transforms.functional import to_pil_image +from torchvision.transforms.functional import to_pil_image, pil_to_tensor def make_vanilla_tensor_images(*args, **kwargs): @@ -66,10 +67,10 @@ def parametrize_from_transforms(*transforms): class TestSmoke: @parametrize_from_transforms( transforms.RandomErasing(p=1.0), - transforms.HorizontalFlip(), transforms.Resize([16, 16]), transforms.CenterCrop([16, 16]), transforms.ConvertImageDtype(), + transforms.RandomHorizontalFlip(), ) def test_common(self, transform, input): transform(input) @@ -188,3 +189,56 @@ def test_random_resized_crop(self, transform, input): ) def test_convert_image_color_space(self, transform, input): transform(input) + + +@pytest.mark.parametrize("p", [0.0, 1.0]) +class TestRandomHorizontalFlip: + def input_expected_image_tensor(self, p, dtype=torch.float32): + input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype) + expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype) + + return input, expected if p == 1 else input + + def test_simple_tensor(self, p): + input, expected = self.input_expected_image_tensor(p) + transform = transforms.RandomHorizontalFlip(p=p) + + actual = transform(input) + + assert_equal(expected, actual) + + def test_pil_image(self, p): + input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8) + transform = transforms.RandomHorizontalFlip(p=p) + + actual = transform(to_pil_image(input)) + + assert_equal(expected, pil_to_tensor(actual)) + + def test_features_image(self, p): + input, expected = self.input_expected_image_tensor(p) + transform = transforms.RandomHorizontalFlip(p=p) + + actual = transform(features.Image(input)) + + assert_equal(features.Image(expected), actual) + + def test_features_segmentation_mask(self, p): + input, expected = self.input_expected_image_tensor(p) + transform = transforms.RandomHorizontalFlip(p=p) + + actual = transform(features.SegmentationMask(input)) + + assert_equal(features.SegmentationMask(expected), actual) + + def test_features_bounding_box(self, p): + input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) + transform = transforms.RandomHorizontalFlip(p=p) + + actual = transform(input) + + expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input + expected = features.BoundingBox.new_like(input, data=expected_image_tensor) + assert_equal(expected, actual) + assert actual.format == expected.format + assert actual.image_size == expected.image_size diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 2b52a253820..0a3d23db3bd 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -8,13 +8,13 @@ from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import ( - HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop, + RandomHorizontalFlip, RandomZoomOut, ) from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 2a965959629..19e5ced791e 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -13,11 +13,25 @@ from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor -class HorizontalFlip(Transform): +class RandomHorizontalFlip(Transform): + def __init__(self, p: float = 0.5) -> None: + super().__init__() + self.p = p + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if torch.rand(1) >= self.p: + return sample + + return super().forward(sample) + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if isinstance(input, features.Image): output = F.horizontal_flip_image_tensor(input) return features.Image.new_like(input, output) + elif isinstance(input, features.SegmentationMask): + output = F.horizontal_flip_segmentation_mask(input) + return features.SegmentationMask.new_like(input, output) elif isinstance(input, features.BoundingBox): output = F.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) return features.BoundingBox.new_like(input, output) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index ed6e9989328..fa65051dfac 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -40,6 +40,7 @@ horizontal_flip_bounding_box, horizontal_flip_image_tensor, horizontal_flip_image_pil, + horizontal_flip_segmentation_mask, resize_bounding_box, resize_image_tensor, resize_image_pil, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 1bff7a3f2e6..6ee76228fbc 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -15,6 +15,10 @@ horizontal_flip_image_pil = _FP.hflip +def horizontal_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor: + return horizontal_flip_image_tensor(segmentation_mask) + + def horizontal_flip_bounding_box( bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] ) -> torch.Tensor: