Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

port FixedSizeCrop from detection references to prototype transforms #6417

Merged
merged 13 commits into from
Aug 19, 2022
159 changes: 159 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from test_prototype_transforms_functional import (
make_bounding_box,
make_bounding_boxes,
make_image,
make_images,
make_label,
make_one_hot_labels,
Expand Down Expand Up @@ -1328,3 +1329,161 @@ def test__transform(self, mocker):
transform(inpt_sentinel)

mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)


class TestFixedSizeCrop:
def test__get_params(self, mocker):
crop_size = (7, 7)
batch_shape = (10,)
image_size = (11, 5)

transform = transforms.FixedSizeCrop(size=crop_size)

sample = dict(
image=make_image(size=image_size, color_space=features.ColorSpace.RGB),
bounding_boxes=make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=batch_shape
),
)
params = transform._get_params(sample)

assert params["needs_crop"]
assert params["height"] <= crop_size[0]
assert params["width"] <= crop_size[1]

assert (
isinstance(params["is_valid"], torch.Tensor)
and params["is_valid"].dtype is torch.bool
and params["is_valid"].shape == batch_shape
)

assert params["needs_pad"]
assert any(pad > 0 for pad in params["padding"])

@pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2)))
def test__transform(self, mocker, needs):
fill_sentinel = mocker.MagicMock()
padding_mode_sentinel = mocker.MagicMock()

transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)
transform._transformed_types = (mocker.MagicMock,)
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)

needs_crop, needs_pad = needs
top_sentinel = mocker.MagicMock()
left_sentinel = mocker.MagicMock()
height_sentinel = mocker.MagicMock()
width_sentinel = mocker.MagicMock()
padding_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=needs_crop,
top=top_sentinel,
left=left_sentinel,
height=height_sentinel,
width=width_sentinel,
padding=padding_sentinel,
needs_pad=needs_pad,
),
)

inpt_sentinel = mocker.MagicMock()

mock_crop = mocker.patch("torchvision.prototype.transforms._geometry.F.crop")
mock_pad = mocker.patch("torchvision.prototype.transforms._geometry.F.pad")
transform(inpt_sentinel)

if needs_crop:
mock_crop.assert_called_once_with(
inpt_sentinel,
top=top_sentinel,
left=left_sentinel,
height=height_sentinel,
width=width_sentinel,
)
else:
mock_crop.assert_not_called()

if needs_pad:
# If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use
# `MagicMock.assert_called_once_with` and have to perform the checks manually
mock_pad.assert_called_once()
args, kwargs = mock_pad.call_args
if not needs_crop:
assert args[0] is inpt_sentinel
assert args[1] is padding_sentinel
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
else:
mock_pad.assert_not_called()

def test__transform_culling(self, mocker):
batch_size = 10
image_size = (10, 10)

is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=True,
top=0,
left=0,
height=image_size[0],
width=image_size[1],
is_valid=is_valid,
needs_pad=False,
),
)

bounding_boxes = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
)
segmentation_masks = make_segmentation_mask(size=image_size, extra_dims=(batch_size,))
labels = make_label(size=(batch_size,))

transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)

output = transform(
dict(
bounding_boxes=bounding_boxes,
segmentation_masks=segmentation_masks,
labels=labels,
)
)

assert_equal(output["bounding_boxes"], bounding_boxes[is_valid])
assert_equal(output["segmentation_masks"], segmentation_masks[is_valid])
assert_equal(output["labels"], labels[is_valid])

def test__transform_bounding_box_clamping(self, mocker):
batch_size = 3
image_size = (10, 10)

mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=True,
top=0,
left=0,
height=image_size[0],
width=image_size[1],
is_valid=torch.full((batch_size,), fill_value=True),
needs_pad=False,
),
)

bounding_box = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
)
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box")

transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)

transform(bounding_box)

mock.assert_called_once()
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
CenterCrop,
ElasticTransform,
FiveCrop,
FixedSizeCrop,
Pad,
RandomAffine,
RandomCrop,
Expand Down
97 changes: 97 additions & 0 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,3 +783,100 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation)


class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
self.padding_mode = padding_mode

def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
_, height, width = get_image_dimensions(image)
new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width)

needs_crop = new_height != height or new_width != width

offset_height = max(height - self.crop_height, 0)
offset_width = max(width - self.crop_width, 0)

r = torch.rand(1)
top = int(offset_height * r)
left = int(offset_width * r)

if needs_crop:
bounding_boxes = query_bounding_box(sample)
bounding_boxes = cast(
features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=height, width=width)
)
bounding_boxes = features.BoundingBox.new_like(
bounding_boxes,
F.clamp_bounding_box(
bounding_boxes, format=bounding_boxes.format, image_size=bounding_boxes.image_size
),
)
height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1)
else:
is_valid = None

pad_bottom = max(self.crop_height - new_height, 0)
pad_right = max(self.crop_width - new_width, 0)

needs_pad = pad_bottom != 0 or pad_right != 0

return dict(
needs_crop=needs_crop,
top=top,
left=left,
height=new_height,
width=new_width,
is_valid=is_valid,
padding=[0, 0, pad_right, pad_bottom],
needs_pad=needs_pad,
)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_crop"]:
inpt = F.crop(
inpt,
top=params["top"],
left=params["left"],
height=params["height"],
width=params["width"],
)
if isinstance(inpt, (features.Label, features.OneHotLabel, features.SegmentationMask)):
pmeier marked this conversation as resolved.
Show resolved Hide resolved
inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
elif isinstance(inpt, features.BoundingBox):
inpt = features.BoundingBox.new_like(
inpt,
F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, image_size=inpt.image_size),
)

if params["needs_pad"]:
inpt = F.pad(inpt, params["padding"], fill=self.fill, padding_mode=self.padding_mode)

return inpt

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if not (
has_all(sample, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor)
and has_any(sample, features.Label, features.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks."
)
return super().forward(sample)