Skip to content

Commit

Permalink
[fbsync] Added crop_segmentation_mask op (#5851)
Browse files Browse the repository at this point in the history
Summary:
* Added `crop_segmentation_mask` op

* Fixed failed mypy

Reviewed By: jdsgomes, NicolasHug

Differential Revision: D36095716

fbshipit-source-id: 7e471babab53882870878d919a8c865836dff995
  • Loading branch information
YosuaMichael authored and facebook-github-bot committed May 6, 2022
1 parent 3117a34 commit a49871e
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
55 changes: 55 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,20 @@ def crop_bounding_box():
)


@register_kernel_info_from_sample_inputs_fn
def crop_segmentation_mask():
for mask, top, left, height, width in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
):
yield SampleInput(
mask,
top=top,
left=left,
height=height,
width=width,
)


@pytest.mark.parametrize(
"kernel",
[
Expand Down Expand Up @@ -860,3 +874,44 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte
)

torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"top, left, height, width",
[
[4, 6, 30, 40],
[-8, 6, 70, 40],
[-8, -6, 70, 8],
],
)
def test_correctness_crop_segmentation_mask(device, top, left, height, width):
def _compute_expected_mask(mask, top_, left_, height_, width_):
h, w = mask.shape[-2], mask.shape[-1]
if top_ >= 0 and left_ >= 0 and top_ + height_ < h and left_ + width_ < w:
expected = mask[..., top_ : top_ + height_, left_ : left_ + width_]
else:
# Create output mask
expected_shape = mask.shape[:-2] + (height_, width_)
expected = torch.zeros(expected_shape, device=mask.device, dtype=mask.dtype)

out_y1 = abs(top_) if top_ < 0 else 0
out_y2 = h - top_ if top_ + height_ >= h else height_
out_x1 = abs(left_) if left_ < 0 else 0
out_x2 = w - left_ if left_ + width_ >= w else width_

in_y1 = 0 if top_ < 0 else top_
in_y2 = h if top_ + height_ >= h else top_ + height_
in_x1 = 0 if left_ < 0 else left_
in_x2 = w if left_ + width_ >= w else left_ + width_
# Paste input mask into output
expected[..., out_y1:out_y2, out_x1:out_x2] = mask[..., in_y1:in_y2, in_x1:in_x2]

return expected

for mask in make_segmentation_masks():
if mask.device != torch.device(device):
mask = mask.to(device)
output_mask = F.crop_segmentation_mask(mask, top, left, height, width)
expected_mask = _compute_expected_mask(mask, top, left, height, width)
torch.testing.assert_close(output_mask, expected_mask)
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
crop_bounding_box,
crop_image_tensor,
crop_image_pil,
crop_segmentation_mask,
perspective_image_tensor,
perspective_image_pil,
vertical_flip_image_tensor,
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,10 @@ def crop_bounding_box(
).view(shape)


def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(img, top, left, height, width)


def perspective_image_tensor(
img: torch.Tensor,
perspective_coeffs: List[float],
Expand Down

0 comments on commit a49871e

Please sign in to comment.