Skip to content

Commit

Permalink
[fbsync] fix elastic error (#7838)
Browse files Browse the repository at this point in the history
Summary: Co-authored-by: vfdev <vfdev.5@gmail.com>

Reviewed By: matteobettini

Differential Revision: D48642275

fbshipit-source-id: 7bc2e758f5efb3631c5460932d1a9e2fde96f36b
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Aug 25, 2023
1 parent d360f4f commit 7a3827d
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
96 changes: 96 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,3 +2259,99 @@ def test_image_correctness(self, permutation, batch_dims):
expected = self.reference_image_correctness(image, permutation=permutation)

torch.testing.assert_close(actual, expected)


class TestElastic:
def _make_displacement(self, inpt):
return torch.rand(
1,
*F.get_size(inpt),
2,
dtype=torch.float32,
device=inpt.device if isinstance(inpt, torch.Tensor) else "cpu",
)

@param_value_parametrization(
interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR],
fill=EXHAUSTIVE_TYPE_FILLS,
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor(self, param, value, dtype, device):
image = make_image_tensor(dtype=dtype, device=device)

check_kernel(
F.elastic_image_tensor,
image,
displacement=self._make_displacement(image),
**{param: value},
check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
)

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, format, dtype, device):
bounding_boxes = make_bounding_box(format=format, dtype=dtype, device=device)

check_kernel(
F.elastic_bounding_boxes,
bounding_boxes,
format=bounding_boxes.format,
canvas_size=bounding_boxes.canvas_size,
displacement=self._make_displacement(bounding_boxes),
)

@pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_mask])
def test_kernel_mask(self, make_mask):
mask = make_mask()
check_kernel(F.elastic_mask, mask, displacement=self._make_displacement(mask))

def test_kernel_video(self):
video = make_video()
check_kernel(F.elastic_video, video, displacement=self._make_displacement(video))

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
def test_functional(self, make_input):
input = make_input()
check_functional(F.elastic, input, displacement=self._make_displacement(input))

@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.elastic_image_tensor, torch.Tensor),
(F.elastic_image_pil, PIL.Image.Image),
(F.elastic_image_tensor, datapoints.Image),
(F.elastic_bounding_boxes, datapoints.BoundingBoxes),
(F.elastic_mask, datapoints.Mask),
(F.elastic_video, datapoints.Video),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.elastic, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
def test_displacement_error(self, make_input):
input = make_input()

with pytest.raises(TypeError, match="displacement should be a Tensor"):
F.elastic(input, displacement=None)

with pytest.raises(ValueError, match="displacement shape should be"):
F.elastic(input, displacement=torch.rand(F.get_size(input)))

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
@pytest.mark.parametrize("size", [(163, 163), (72, 333), (313, 95)])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, size, device):
check_transform(transforms.ElasticTransform, make_input(size, device=device))
9 changes: 9 additions & 0 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,9 @@ def elastic_image_tensor(
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None,
) -> torch.Tensor:
if not isinstance(displacement, torch.Tensor):
raise TypeError("Argument displacement should be a Tensor")

interpolation = _check_interpolation(interpolation)

if image.numel() == 0:
Expand Down Expand Up @@ -1835,6 +1838,12 @@ def elastic_bounding_boxes(
canvas_size: Tuple[int, int],
displacement: torch.Tensor,
) -> torch.Tensor:
expected_shape = (1, canvas_size[0], canvas_size[1], 2)
if not isinstance(displacement, torch.Tensor):
raise TypeError("Argument displacement should be a Tensor")
elif displacement.shape != expected_shape:
raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")

if bounding_boxes.numel() == 0:
return bounding_boxes

Expand Down

0 comments on commit 7a3827d

Please sign in to comment.