diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 86e4b88fe63..7057fdaa2a9 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -138,6 +138,22 @@ def make_one_hot_labels( yield make_one_hot_label(extra_dims_) +def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype=torch.long): + size = size or torch.randint(16, 33, (2,)).tolist() + shape = (*extra_dims, 1, *size) + data = make_tensor(shape, low=0, high=num_categories, dtype=dtype) + return features.SegmentationMask(data) + + +def make_segmentation_masks( + image_sizes=((16, 16), (7, 33), (31, 9)), + dtypes=(torch.long,), + extra_dims=((), (4,), (2, 3)), +): + for image_size, dtype, extra_dims_ in itertools.product(image_sizes, dtypes, extra_dims): + yield make_segmentation_mask(size=image_size, dtype=dtype, extra_dims=extra_dims_) + + class SampleInput: def __init__(self, *args, **kwargs): self.args = args @@ -212,7 +228,7 @@ def resize_bounding_box(): @register_kernel_info_from_sample_inputs_fn def affine_image_tensor(): for image, angle, translate, scale, shear in itertools.product( - make_images(extra_dims=()), + make_images(extra_dims=((), (4,))), [-87, 15, 90], # angle [5, -5], # translate [0.77, 1.27], # scale @@ -248,6 +264,24 @@ def affine_bounding_box(): ) +@register_kernel_info_from_sample_inputs_fn +def affine_segmentation_mask(): + for image, angle, translate, scale, shear in itertools.product( + make_segmentation_masks(extra_dims=((), (4,))), + [-87, 15, 90], # angle + [5, -5], # translate + [0.77, 1.27], # scale + [0, 12], # shear + ): + yield SampleInput( + image, + angle=angle, + translate=(translate, translate), + scale=scale, + shear=(shear, shear), + ) + + @register_kernel_info_from_sample_inputs_fn def rotate_bounding_box(): for bounding_box, angle, expand, center in itertools.product( @@ -444,6 +478,76 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box) +@pytest.mark.parametrize("angle", [-54, 56]) +@pytest.mark.parametrize("translate", [-7, 8]) +@pytest.mark.parametrize("scale", [0.89, 1.12]) +@pytest.mark.parametrize("shear", [4]) +@pytest.mark.parametrize("center", [None, (12, 14)]) +def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, center): + def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): + assert mask.ndim == 3 and mask.shape[0] == 1 + affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_) + inv_affine_matrix = np.linalg.inv(affine_matrix) + inv_affine_matrix = inv_affine_matrix[:2, :] + + expected_mask = torch.zeros_like(mask.cpu()) + for out_y in range(expected_mask.shape[1]): + for out_x in range(expected_mask.shape[2]): + output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0]) + input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype(np.int32) + in_x, in_y = input_pt[:2] + if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]: + expected_mask[0, out_y, out_x] = mask[0, in_y, in_x] + return expected_mask.to(mask.device) + + for mask in make_segmentation_masks(extra_dims=((), (4,))): + output_mask = F.affine_segmentation_mask( + mask, + angle=angle, + translate=(translate, translate), + scale=scale, + shear=(shear, shear), + center=center, + ) + if center is None: + center = [s // 2 for s in mask.shape[-2:][::-1]] + + if mask.ndim < 4: + masks = [mask] + else: + masks = [m for m in mask] + + expected_masks = [] + for mask in masks: + expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center) + expected_masks.append(expected_mask) + if len(expected_masks) > 1: + expected_masks = torch.stack(expected_masks) + else: + expected_masks = expected_masks[0] + torch.testing.assert_close(output_mask, expected_masks) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_affine_segmentation_mask_on_fixed_input(device): + # Check transformation against known expected output and CPU/CUDA devices + + # Create a fixed input segmentation mask with 2 square masks + # in top-left, bottom-left corners + mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device) + mask[0, 2:10, 2:10] = 1 + mask[0, 32 - 9 : 32 - 3, 3:9] = 2 + + # Rotate 90 degrees and scale + expected_mask = torch.rot90(mask, k=-1, dims=(-2, -1)) + expected_mask = torch.nn.functional.interpolate(expected_mask[None, :].float(), size=(64, 64), mode="nearest") + expected_mask = expected_mask[0, :, 16 : 64 - 16, 16 : 64 - 16].long() + + out_mask = F.affine_segmentation_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0]) + + torch.testing.assert_close(out_mask, expected_mask) + + @pytest.mark.parametrize("angle", range(-90, 90, 56)) @pytest.mark.parametrize("expand", [True, False]) @pytest.mark.parametrize("center", [None, (12, 14)]) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index ace1f585d82..51bf73a18f7 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -52,6 +52,7 @@ affine_bounding_box, affine_image_tensor, affine_image_pil, + affine_segmentation_mask, rotate_bounding_box, rotate_image_tensor, rotate_image_pil, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index c3f294a8546..71882f06270 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -294,6 +294,25 @@ def affine_bounding_box( ).view(original_shape) +def affine_segmentation_mask( + img: torch.Tensor, + angle: float, + translate: List[float], + scale: float, + shear: List[float], + center: Optional[List[float]] = None, +) -> torch.Tensor: + return affine_image_tensor( + img, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=InterpolationMode.NEAREST, + center=center, + ) + + def rotate_image_tensor( img: torch.Tensor, angle: float,