diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 56eb6747c6b..c83327a069e 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2259,3 +2259,99 @@ def test_image_correctness(self, permutation, batch_dims): expected = self.reference_image_correctness(image, permutation=permutation) torch.testing.assert_close(actual, expected) + + +class TestElastic: + def _make_displacement(self, inpt): + return torch.rand( + 1, + *F.get_size(inpt), + 2, + dtype=torch.float32, + device=inpt.device if isinstance(inpt, torch.Tensor) else "cpu", + ) + + @param_value_parametrization( + interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR], + fill=EXHAUSTIVE_TYPE_FILLS, + ) + @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_image_tensor(self, param, value, dtype, device): + image = make_image_tensor(dtype=dtype, device=device) + + check_kernel( + F.elastic_image_tensor, + image, + displacement=self._make_displacement(image), + **{param: value}, + check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))), + ) + + @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_bounding_boxes(self, format, dtype, device): + bounding_boxes = make_bounding_box(format=format, dtype=dtype, device=device) + + check_kernel( + F.elastic_bounding_boxes, + bounding_boxes, + format=bounding_boxes.format, + canvas_size=bounding_boxes.canvas_size, + displacement=self._make_displacement(bounding_boxes), + ) + + @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask]) + def test_kernel_mask(self, make_mask): + mask = make_mask() + check_kernel(F.elastic_mask, mask, displacement=self._make_displacement(mask)) + + def test_kernel_video(self): + video = make_video() + check_kernel(F.elastic_video, video, displacement=self._make_displacement(video)) + + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], + ) + def test_functional(self, make_input): + input = make_input() + check_functional(F.elastic, input, displacement=self._make_displacement(input)) + + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.elastic_image_tensor, torch.Tensor), + (F.elastic_image_pil, PIL.Image.Image), + (F.elastic_image_tensor, datapoints.Image), + (F.elastic_bounding_boxes, datapoints.BoundingBoxes), + (F.elastic_mask, datapoints.Mask), + (F.elastic_video, datapoints.Video), + ], + ) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.elastic, kernel=kernel, input_type=input_type) + + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], + ) + def test_displacement_error(self, make_input): + input = make_input() + + with pytest.raises(TypeError, match="displacement should be a Tensor"): + F.elastic(input, displacement=None) + + with pytest.raises(ValueError, match="displacement shape should be"): + F.elastic(input, displacement=torch.rand(F.get_size(input))) + + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], + ) + # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image + @pytest.mark.parametrize("size", [(163, 163), (72, 333), (313, 95)]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_transform(self, make_input, size, device): + check_transform(transforms.ElasticTransform, make_input(size, device=device)) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 0872d71dd8e..898e7e0c1a8 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1755,6 +1755,9 @@ def elastic_image_tensor( interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: _FillTypeJIT = None, ) -> torch.Tensor: + if not isinstance(displacement, torch.Tensor): + raise TypeError("Argument displacement should be a Tensor") + interpolation = _check_interpolation(interpolation) if image.numel() == 0: @@ -1835,6 +1838,12 @@ def elastic_bounding_boxes( canvas_size: Tuple[int, int], displacement: torch.Tensor, ) -> torch.Tensor: + expected_shape = (1, canvas_size[0], canvas_size[1], 2) + if not isinstance(displacement, torch.Tensor): + raise TypeError("Argument displacement should be a Tensor") + elif displacement.shape != expected_shape: + raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") + if bounding_boxes.numel() == 0: return bounding_boxes