Skip to content

Commit

Permalink
[proto] Ported RandomIoUCrop from detection refs (#6401)
Browse files Browse the repository at this point in the history
* [proto] Ported RandomIoUCrop from detection refs

* Scope acceptable data types

* Added get_params test

* Added test__transform_empty_params

* Added support for OneHotLabel and tests

* Added tests for mask

* Updated error message

* Apply suggestions from code review

Co-authored-by: Philip Meier <github.pmeier@posteo.de>

* Added support for OHE masks and tests

* Ignored mypy error

* Fixed forward call on sample

* Added a todo

Co-authored-by: Philip Meier <github.pmeier@posteo.de>
  • Loading branch information
vfdev-5 and pmeier authored Aug 18, 2022
1 parent b9e9c28 commit 961d97b
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 2 deletions.
121 changes: 120 additions & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest
import torch
from common_utils import assert_equal
from common_utils import assert_equal, cpu_and_gpu
from test_prototype_transforms_functional import (
make_bounding_box,
make_bounding_boxes,
Expand All @@ -15,6 +15,7 @@
make_one_hot_labels,
make_segmentation_mask,
)
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features, transforms
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image

Expand Down Expand Up @@ -1127,6 +1128,124 @@ def test_ctor(self, trfms):
assert isinstance(output, torch.Tensor)


class TestRandomIoUCrop:
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
def test__get_params(self, device, options, mocker):
image = mocker.MagicMock(spec=features.Image)
image.num_channels = 3
image.image_size = (24, 32)
bboxes = features.BoundingBox(
torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]),
format="XYXY",
image_size=image.image_size,
device=device,
)
sample = [image, bboxes]

transform = transforms.RandomIoUCrop(sampler_options=options)

n_samples = 5
for _ in range(n_samples):

params = transform._get_params(sample)

if options == [2.0]:
assert len(params) == 0
return

assert len(params["is_within_crop_area"]) > 0
assert params["is_within_crop_area"].dtype == torch.bool

orig_h = image.image_size[0]
orig_w = image.image_size[1]
assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h)
assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w)

left, top = params["left"], params["top"]
new_h, new_w = params["height"], params["width"]
ious = box_iou(
bboxes,
torch.tensor([[left, top, left + new_w, top + new_h]], dtype=bboxes.dtype, device=bboxes.device),
)
assert ious.max() >= options[0] or ious.max() >= options[1], f"{ious} vs {options}"

def test__transform_empty_params(self, mocker):
transform = transforms.RandomIoUCrop(sampler_options=[2.0])
image = features.Image(torch.rand(1, 3, 4, 4))
bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", image_size=(4, 4))
label = features.Label(torch.tensor([1]))
sample = [image, bboxes, label]
# Let's mock transform._get_params to control the output:
transform._get_params = mocker.MagicMock(return_value={})
output = transform(sample)
torch.testing.assert_close(output, sample)

def test_forward_assertion(self):
transform = transforms.RandomIoUCrop()
with pytest.raises(
TypeError,
match="requires input sample to contain Images or PIL Images, BoundingBoxes and Labels or OneHotLabels",
):
transform(torch.tensor(0))

def test__transform(self, mocker):
transform = transforms.RandomIoUCrop()

image = features.Image(torch.rand(3, 32, 24))
bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,))
label = features.Label(torch.randint(0, 10, size=(6,)))
ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1))
masks = make_segmentation_mask((32, 24))
ohe_masks = features.SegmentationMask(torch.randint(0, 2, size=(6, 32, 24)))
sample = [image, bboxes, label, ohe_label, masks, ohe_masks]

fn = mocker.patch("torchvision.prototype.transforms.functional.crop", side_effect=lambda x, **params: x)
is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool)

params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area)
transform._get_params = mocker.MagicMock(return_value=params)
output = transform(sample)

assert fn.call_count == 4

expected_calls = [
mocker.call(image, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
mocker.call(bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
mocker.call(masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
mocker.call(
ohe_masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
),
]

fn.assert_has_calls(expected_calls)

expected_within_targets = sum(is_within_crop_area)

# check number of bboxes vs number of labels:
output_bboxes = output[1]
assert isinstance(output_bboxes, features.BoundingBox)
assert len(output_bboxes) == expected_within_targets

# check labels
output_label = output[2]
assert isinstance(output_label, features.Label)
assert len(output_label) == expected_within_targets
torch.testing.assert_close(output_label, label[is_within_crop_area])

output_ohe_label = output[3]
assert isinstance(output_ohe_label, features.OneHotLabel)
torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area])

output_masks = output[4]
assert isinstance(output_masks, features.SegmentationMask)
assert output_masks.shape[:-2] == masks.shape[:-2]

output_ohe_masks = output[5]
assert isinstance(output_ohe_masks, features.SegmentationMask)
assert len(output_ohe_masks) == expected_within_targets


class TestScaleJitter:
def test__get_params(self, mocker):
image_size = (24, 32)
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
RandomAffine,
RandomCrop,
RandomHorizontalFlip,
RandomIoUCrop,
RandomPerspective,
RandomResizedCrop,
RandomRotation,
Expand Down
114 changes: 113 additions & 1 deletion torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@

import PIL.Image
import torch
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor
from torchvision.transforms.functional_tensor import _parse_pad_padding
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size

from typing_extensions import Literal

from ._transform import _RandomApplyTransform
from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image
from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_bounding_box, query_image


class RandomHorizontalFlip(_RandomApplyTransform):
Expand Down Expand Up @@ -620,6 +622,116 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
)


class RandomIoUCrop(Transform):
def __init__(
self,
min_scale: float = 0.3,
max_scale: float = 1.0,
min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0,
sampler_options: Optional[List[float]] = None,
trials: int = 40,
):
super().__init__()
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
self.min_scale = min_scale
self.max_scale = max_scale
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
if sampler_options is None:
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
self.options = sampler_options
self.trials = trials

def _get_params(self, sample: Any) -> Dict[str, Any]:

image = query_image(sample)
_, orig_h, orig_w = get_image_dimensions(image)
bboxes = query_bounding_box(sample)

while True:
# sample an option
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
min_jaccard_overlap = self.options[idx]
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
return dict()

for _ in range(self.trials):
# check the aspect ratio limitations
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
new_w = int(orig_w * r[0])
new_h = int(orig_h * r[1])
aspect_ratio = new_w / new_h
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
continue

# check for 0 area crops
r = torch.rand(2)
left = int((orig_w - new_w) * r[0])
top = int((orig_h - new_h) * r[1])
right = left + new_w
bottom = top + new_h
if left == right or top == bottom:
continue

# check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_bounding_box_format(
bboxes, old_format=bboxes.format, new_format=features.BoundingBoxFormat.XYXY, copy=True
)
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
if not is_within_crop_area.any():
continue

# check at least 1 box with jaccard limitations
xyxy_bboxes = xyxy_bboxes[is_within_crop_area]
ious = box_iou(
xyxy_bboxes,
torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device),
)
if ious.max() < min_jaccard_overlap:
continue

return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if len(params) < 1:
return inpt

is_within_crop_area = params["is_within_crop_area"]

if isinstance(inpt, (features.Label, features.OneHotLabel)):
return inpt.new_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type]

output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])

if isinstance(output, features.BoundingBox):
bboxes = output[is_within_crop_area]
bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size)
output = features.BoundingBox.new_like(output, bboxes)
elif isinstance(output, features.SegmentationMask) and output.shape[-3] > 1:
# apply is_within_crop_area if mask is one-hot encoded
masks = output[is_within_crop_area]
output = features.SegmentationMask.new_like(output, masks)

return output

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
# TODO: Allow image to be a torch.Tensor
if not (
has_all(sample, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image)
and has_any(sample, features.Label, features.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks."
)
return super().forward(sample)


class ScaleJitter(Transform):
def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Im
raise TypeError("No image was found in the sample")


def query_bounding_box(sample: Any) -> features.BoundingBox:
flat_sample, _ = tree_flatten(sample)
for i in flat_sample:
if isinstance(i, features.BoundingBox):
return i

raise TypeError("No bounding box was found in the sample")


def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
if isinstance(image, features.Image):
channels = image.num_channels
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torchvision.transforms import InterpolationMode # usort: skip
from ._meta import (
clamp_bounding_box,
convert_bounding_box_format,
convert_color_space_image_tensor,
convert_color_space_image_pil,
Expand Down
9 changes: 9 additions & 0 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def convert_bounding_box_format(
return bounding_box


def clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
xyxy_boxes = convert_bounding_box_format(bounding_box, format, BoundingBoxFormat.XYXY)
xyxy_boxes[..., 0::2].clamp_(min=0, max=image_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=image_size[0])
return convert_bounding_box_format(xyxy_boxes, BoundingBoxFormat.XYXY, format, copy=False)


def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return image[..., :-1, :, :], image[..., -1:, :, :]

Expand Down

0 comments on commit 961d97b

Please sign in to comment.