diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 935d25edd6d..311a442ffed 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -406,59 +406,6 @@ def was_applied(output, inpt): assert transform.was_applied(output, input) -@pytest.mark.parametrize("p", [0.0, 1.0]) -class TestRandomHorizontalFlip: - def input_expected_image_tensor(self, p, dtype=torch.float32): - input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype) - expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype) - - return input, expected if p == 1 else input - - def test_simple_tensor(self, p): - input, expected = self.input_expected_image_tensor(p) - transform = transforms.RandomHorizontalFlip(p=p) - - actual = transform(input) - - assert_equal(expected, actual) - - def test_pil_image(self, p): - input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8) - transform = transforms.RandomHorizontalFlip(p=p) - - actual = transform(to_pil_image(input)) - - assert_equal(expected, pil_to_tensor(actual)) - - def test_datapoints_image(self, p): - input, expected = self.input_expected_image_tensor(p) - transform = transforms.RandomHorizontalFlip(p=p) - - actual = transform(datapoints.Image(input)) - - assert_equal(datapoints.Image(expected), actual) - - def test_datapoints_mask(self, p): - input, expected = self.input_expected_image_tensor(p) - transform = transforms.RandomHorizontalFlip(p=p) - - actual = transform(datapoints.Mask(input)) - - assert_equal(datapoints.Mask(expected), actual) - - def test_datapoints_bounding_box(self, p): - input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10)) - transform = transforms.RandomHorizontalFlip(p=p) - - actual = transform(input) - - expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input - expected = datapoints.BoundingBox.wrap_like(input, expected_image_tensor) - assert_equal(expected, actual) - assert actual.format == expected.format - assert actual.spatial_size == expected.spatial_size - - @pytest.mark.parametrize("p", [0.0, 1.0]) class TestRandomVerticalFlip: def input_expected_image_tensor(self, p, dtype=torch.float32): diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 002da24ac89..05eb47ab69e 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -295,9 +295,9 @@ def check_transform(transform_cls, input, *args, **kwargs): _check_transform_v1_compatibility(transform, input) -def transform_cls_to_functional(transform_cls): +def transform_cls_to_functional(transform_cls, **transform_specific_kwargs): def wrapper(input, *args, **kwargs): - transform = transform_cls(*args, **kwargs) + transform = transform_cls(*args, **transform_specific_kwargs, **kwargs) return transform(input) wrapper.__name__ = transform_cls.__name__ @@ -321,14 +321,14 @@ def assert_warns_antialias_default_value(): def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix): - def transform(bbox, affine_matrix_, format_, spatial_size_): + def transform(bbox): # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 in_dtype = bbox.dtype if not torch.is_floating_point(bbox): bbox = bbox.float() bbox_xyxy = F.convert_format_bounding_box( bbox.as_subclass(torch.Tensor), - old_format=format_, + old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True, ) @@ -340,7 +340,7 @@ def transform(bbox, affine_matrix_, format_, spatial_size_): [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], ] ) - transformed_points = np.matmul(points, affine_matrix_.T) + transformed_points = np.matmul(points, affine_matrix.T) out_bbox = torch.tensor( [ np.min(transformed_points[:, 0]).item(), @@ -351,23 +351,14 @@ def transform(bbox, affine_matrix_, format_, spatial_size_): dtype=bbox_xyxy.dtype, ) out_bbox = F.convert_format_bounding_box( - out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True + out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True ) # It is important to clamp before casting, especially for CXCYWH format, dtype=int64 - out_bbox = F.clamp_bounding_box(out_bbox, format=format_, spatial_size=spatial_size_) + out_bbox = F.clamp_bounding_box(out_bbox, format=format, spatial_size=spatial_size) out_bbox = out_bbox.to(dtype=in_dtype) return out_bbox - if bounding_box.ndim < 2: - bounding_box = [bounding_box] - - expected_bboxes = [transform(bbox, affine_matrix, format, spatial_size) for bbox in bounding_box] - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] - - return expected_bboxes + return torch.stack([transform(b) for b in bounding_box.reshape(-1, 4).unbind()]).reshape(bounding_box.shape) class TestResize: @@ -493,7 +484,7 @@ def test_kernel_video(self): @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize( - "input_type_and_kernel", + ("input_type", "kernel"), [ (torch.Tensor, F.resize_image_tensor), (PIL.Image.Image, F.resize_image_pil), @@ -503,8 +494,7 @@ def test_kernel_video(self): (datapoints.Video, F.resize_video), ], ) - def test_dispatcher(self, size, input_type_and_kernel): - input_type, kernel = input_type_and_kernel + def test_dispatcher(self, size, input_type, kernel): check_dispatcher( F.resize, kernel, @@ -726,3 +716,147 @@ def test_no_regression_5405(self, input_type): output = F.resize(input, size=size, max_size=max_size, antialias=True) assert max(F.get_spatial_size(output)) == max_size + + +class TestHorizontalFlip: + def _make_input(self, input_type, *, dtype=None, device="cpu", spatial_size=(17, 11), **kwargs): + if input_type in {torch.Tensor, PIL.Image.Image, datapoints.Image}: + input = make_image(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs) + if input_type is torch.Tensor: + input = input.as_subclass(torch.Tensor) + elif input_type is PIL.Image.Image: + input = F.to_image_pil(input) + elif input_type is datapoints.BoundingBox: + kwargs.setdefault("format", datapoints.BoundingBoxFormat.XYXY) + input = make_bounding_box( + dtype=dtype or torch.float32, + device=device, + spatial_size=spatial_size, + **kwargs, + ) + elif input_type is datapoints.Mask: + input = make_segmentation_mask(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs) + elif input_type is datapoints.Video: + input = make_video(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs) + + return input + + @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_image_tensor(self, dtype, device): + check_kernel(F.horizontal_flip_image_tensor, self._make_input(torch.Tensor)) + + @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_bounding_box(self, format, dtype, device): + bounding_box = self._make_input(datapoints.BoundingBox, dtype=dtype, device=device, format=format) + check_kernel( + F.horizontal_flip_bounding_box, + bounding_box, + format=format, + spatial_size=bounding_box.spatial_size, + ) + + @pytest.mark.parametrize( + "dtype_and_make_mask", [(torch.uint8, make_segmentation_mask), (torch.bool, make_detection_mask)] + ) + def test_kernel_mask(self, dtype_and_make_mask): + dtype, make_mask = dtype_and_make_mask + check_kernel(F.horizontal_flip_mask, make_mask(dtype=dtype)) + + def test_kernel_video(self): + check_kernel(F.horizontal_flip_video, self._make_input(datapoints.Video)) + + @pytest.mark.parametrize( + ("input_type", "kernel"), + [ + (torch.Tensor, F.horizontal_flip_image_tensor), + (PIL.Image.Image, F.horizontal_flip_image_pil), + (datapoints.Image, F.horizontal_flip_image_tensor), + (datapoints.BoundingBox, F.horizontal_flip_bounding_box), + (datapoints.Mask, F.horizontal_flip_mask), + (datapoints.Video, F.horizontal_flip_video), + ], + ) + def test_dispatcher(self, kernel, input_type): + check_dispatcher(F.horizontal_flip, kernel, self._make_input(input_type)) + + @pytest.mark.parametrize( + ("input_type", "kernel"), + [ + (torch.Tensor, F.resize_image_tensor), + (PIL.Image.Image, F.resize_image_pil), + (datapoints.Image, F.resize_image_tensor), + (datapoints.BoundingBox, F.resize_bounding_box), + (datapoints.Mask, F.resize_mask), + (datapoints.Video, F.resize_video), + ], + ) + def test_dispatcher_signature(self, kernel, input_type): + check_dispatcher_signatures_match(F.resize, kernel=kernel, input_type=input_type) + + @pytest.mark.parametrize( + "input_type", + [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video], + ) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_transform(self, input_type, device): + input = self._make_input(input_type, device=device) + + check_transform(transforms.RandomHorizontalFlip, input, p=1) + + @pytest.mark.parametrize( + "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)] + ) + def test_image_correctness(self, fn): + image = self._make_input(torch.Tensor, dtype=torch.uint8, device="cpu") + + actual = fn(image) + expected = F.to_image_tensor(F.horizontal_flip(F.to_image_pil(image))) + + torch.testing.assert_close(actual, expected) + + def _reference_horizontal_flip_bounding_box(self, bounding_box): + affine_matrix = np.array( + [ + [-1, 0, bounding_box.spatial_size[1]], + [0, 1, 0], + ], + dtype="float64" if bounding_box.dtype == torch.float64 else "float32", + ) + + expected_bboxes = reference_affine_bounding_box_helper( + bounding_box, + format=bounding_box.format, + spatial_size=bounding_box.spatial_size, + affine_matrix=affine_matrix, + ) + + return datapoints.BoundingBox.wrap_like(bounding_box, expected_bboxes) + + @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) + @pytest.mark.parametrize( + "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)] + ) + def test_bounding_box_correctness(self, format, fn): + bounding_box = self._make_input(datapoints.BoundingBox) + + actual = fn(bounding_box) + expected = self._reference_horizontal_flip_bounding_box(bounding_box) + + torch.testing.assert_close(actual, expected) + + @pytest.mark.parametrize( + "input_type", + [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video], + ) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_transform_noop(self, input_type, device): + input = self._make_input(input_type, device=device) + + transform = transforms.RandomHorizontalFlip(p=0) + + output = transform(input) + + assert_equal(output, input) diff --git a/test/transforms_v2_dispatcher_infos.py b/test/transforms_v2_dispatcher_infos.py index cb1bc257e50..e0f7edd7129 100644 --- a/test/transforms_v2_dispatcher_infos.py +++ b/test/transforms_v2_dispatcher_infos.py @@ -138,16 +138,6 @@ def fill_sequence_needs_broadcast(args_kwargs): DISPATCHER_INFOS = [ - DispatcherInfo( - F.horizontal_flip, - kernels={ - datapoints.Image: F.horizontal_flip_image_tensor, - datapoints.Video: F.horizontal_flip_video, - datapoints.BoundingBox: F.horizontal_flip_bounding_box, - datapoints.Mask: F.horizontal_flip_mask, - }, - pil_kernel_info=PILKernelInfo(F.horizontal_flip_image_pil, kernel_name="horizontal_flip_image_pil"), - ), DispatcherInfo( F.affine, kernels={ diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index 547e708b726..54fd3a679a5 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -156,88 +156,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None): KERNEL_INFOS = [] -def sample_inputs_horizontal_flip_image_tensor(): - for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]): - yield ArgsKwargs(image_loader) - - -def reference_inputs_horizontal_flip_image_tensor(): - for image_loader in make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]): - yield ArgsKwargs(image_loader) - - -def sample_inputs_horizontal_flip_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders( - formats=[datapoints.BoundingBoxFormat.XYXY], dtypes=[torch.float32] - ): - yield ArgsKwargs( - bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size - ) - - -def sample_inputs_horizontal_flip_mask(): - for image_loader in make_mask_loaders(sizes=["random"], dtypes=[torch.uint8]): - yield ArgsKwargs(image_loader) - - -def sample_inputs_horizontal_flip_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): - yield ArgsKwargs(video_loader) - - -def reference_horizontal_flip_bounding_box(bounding_box, *, format, spatial_size): - affine_matrix = np.array( - [ - [-1, 0, spatial_size[1]], - [0, 1, 0], - ], - dtype="float64" if bounding_box.dtype == torch.float64 else "float32", - ) - - expected_bboxes = reference_affine_bounding_box_helper( - bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix - ) - - return expected_bboxes - - -def reference_inputs_flip_bounding_box(): - for bounding_box_loader in make_bounding_box_loaders(extra_dims=[()]): - yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - spatial_size=bounding_box_loader.spatial_size, - ) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.horizontal_flip_image_tensor, - kernel_name="horizontal_flip_image_tensor", - sample_inputs_fn=sample_inputs_horizontal_flip_image_tensor, - reference_fn=pil_reference_wrapper(F.horizontal_flip_image_pil), - reference_inputs_fn=reference_inputs_horizontal_flip_image_tensor, - float32_vs_uint8=True, - ), - KernelInfo( - F.horizontal_flip_bounding_box, - sample_inputs_fn=sample_inputs_horizontal_flip_bounding_box, - reference_fn=reference_horizontal_flip_bounding_box, - reference_inputs_fn=reference_inputs_flip_bounding_box, - ), - KernelInfo( - F.horizontal_flip_mask, - sample_inputs_fn=sample_inputs_horizontal_flip_mask, - ), - KernelInfo( - F.horizontal_flip_video, - sample_inputs_fn=sample_inputs_horizontal_flip_video, - ), - ] -) - - _AFFINE_KWARGS = combinations_grid( angle=[-87, 15, 90], translate=[(5, 5), (-5, -5)], @@ -573,6 +491,15 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size): return expected_bboxes +def reference_inputs_vertical_flip_bounding_box(): + for bounding_box_loader in make_bounding_box_loaders(extra_dims=[()]): + yield ArgsKwargs( + bounding_box_loader, + format=bounding_box_loader.format, + spatial_size=bounding_box_loader.spatial_size, + ) + + KERNEL_INFOS.extend( [ KernelInfo( @@ -587,7 +514,7 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size): F.vertical_flip_bounding_box, sample_inputs_fn=sample_inputs_vertical_flip_bounding_box, reference_fn=reference_vertical_flip_bounding_box, - reference_inputs_fn=reference_inputs_flip_bounding_box, + reference_inputs_fn=reference_inputs_vertical_flip_bounding_box, ), KernelInfo( F.vertical_flip_mask, diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index aab3be24e0b..b56205e6123 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -43,7 +43,8 @@ def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: return image.flip(-1) -horizontal_flip_image_pil = _FP.hflip +def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: + return _FP.hflip(image) def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: