Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[proto] Ported RandomIoUCrop from detection refs #6401

Merged
merged 21 commits into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
bf21a25
[proto] Ported RandomIoUCrop from detection refs
vfdev-5 Aug 10, 2022
ed229cd
Scope acceptable data types
vfdev-5 Aug 12, 2022
19076a8
Merge branch 'main' of github.com:pytorch/vision into proto-random-io…
vfdev-5 Aug 15, 2022
5c2275e
Added get_params test
vfdev-5 Aug 15, 2022
84d2f09
Added test__transform_empty_params
vfdev-5 Aug 15, 2022
6d664e5
Merge branch 'main' of github.com:pytorch/vision into proto-random-io…
vfdev-5 Aug 17, 2022
3085dbf
Added support for OneHotLabel and tests
vfdev-5 Aug 17, 2022
1d06ec3
Merge branch 'main' into proto-random-iou-crop
vfdev-5 Aug 17, 2022
abf4381
Added tests for mask
vfdev-5 Aug 17, 2022
b4fe1a9
Merge branch 'proto-random-iou-crop' of github.com:vfdev-5/vision int…
vfdev-5 Aug 17, 2022
3eaacf2
Merge branch 'main' of github.com:pytorch/vision into proto-random-io…
vfdev-5 Aug 17, 2022
6231608
Updated error message
vfdev-5 Aug 17, 2022
0b61852
Apply suggestions from code review
vfdev-5 Aug 17, 2022
419ba8f
Added support for OHE masks and tests
vfdev-5 Aug 17, 2022
e862e89
Merge branch 'main' into proto-random-iou-crop
vfdev-5 Aug 17, 2022
f8253aa
Ignored mypy error
vfdev-5 Aug 17, 2022
4967ae9
Merge branch 'proto-random-iou-crop' of github.com:vfdev-5/vision int…
vfdev-5 Aug 17, 2022
79bbe76
Fixed forward call on sample
vfdev-5 Aug 18, 2022
b7a5591
Merge branch 'main' into proto-random-iou-crop
vfdev-5 Aug 18, 2022
1353cfe
Added a todo
vfdev-5 Aug 18, 2022
c5c46a5
Merge branch 'main' into proto-random-iou-crop
vfdev-5 Aug 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 120 additions & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest
import torch
from common_utils import assert_equal
from common_utils import assert_equal, cpu_and_gpu
from test_prototype_transforms_functional import (
make_bounding_box,
make_bounding_boxes,
Expand All @@ -15,6 +15,7 @@
make_one_hot_labels,
make_segmentation_mask,
)
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features, transforms
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image

Expand Down Expand Up @@ -1127,6 +1128,124 @@ def test_ctor(self, trfms):
assert isinstance(output, torch.Tensor)


class TestRandomIoUCrop:
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
def test__get_params(self, device, options, mocker):
image = mocker.MagicMock(spec=features.Image)
image.num_channels = 3
image.image_size = (24, 32)
bboxes = features.BoundingBox(
torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]),
format="XYXY",
image_size=image.image_size,
device=device,
)
sample = [image, bboxes]

transform = transforms.RandomIoUCrop(sampler_options=options)

n_samples = 5
for _ in range(n_samples):

params = transform._get_params(sample)

if options == [2.0]:
assert len(params) == 0
return

assert len(params["is_within_crop_area"]) > 0
assert params["is_within_crop_area"].dtype == torch.bool

orig_h = image.image_size[0]
orig_w = image.image_size[1]
assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h)
assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w)

left, top = params["left"], params["top"]
new_h, new_w = params["height"], params["width"]
ious = box_iou(
bboxes,
torch.tensor([[left, top, left + new_w, top + new_h]], dtype=bboxes.dtype, device=bboxes.device),
)
assert ious.max() >= options[0] or ious.max() >= options[1], f"{ious} vs {options}"

def test__transform_empty_params(self, mocker):
transform = transforms.RandomIoUCrop(sampler_options=[2.0])
image = features.Image(torch.rand(1, 3, 4, 4))
bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", image_size=(4, 4))
label = features.Label(torch.tensor([1]))
sample = [image, bboxes, label]
# Let's mock transform._get_params to control the output:
transform._get_params = mocker.MagicMock(return_value={})
output = transform(sample)
torch.testing.assert_close(output, sample)

def test_forward_assertion(self):
transform = transforms.RandomIoUCrop()
with pytest.raises(
TypeError,
match="requires input sample to contain Images or PIL Images, BoundingBoxes and Labels or OneHotLabels",
):
transform(torch.tensor(0))

def test__transform(self, mocker):
transform = transforms.RandomIoUCrop()

image = features.Image(torch.rand(3, 32, 24))
bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,))
label = features.Label(torch.randint(0, 10, size=(6,)))
ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1))
masks = make_segmentation_mask((32, 24))
ohe_masks = features.SegmentationMask(torch.randint(0, 2, size=(6, 32, 24)))
sample = [image, bboxes, label, ohe_label, masks, ohe_masks]

fn = mocker.patch("torchvision.prototype.transforms.functional.crop", side_effect=lambda x, **params: x)
is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool)

params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area)
transform._get_params = mocker.MagicMock(return_value=params)
output = transform(sample)

assert fn.call_count == 4

expected_calls = [
mocker.call(image, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
mocker.call(bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
mocker.call(masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
mocker.call(
ohe_masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
),
]

fn.assert_has_calls(expected_calls)

expected_within_targets = sum(is_within_crop_area)

# check number of bboxes vs number of labels:
output_bboxes = output[1]
assert isinstance(output_bboxes, features.BoundingBox)
assert len(output_bboxes) == expected_within_targets

# check labels
output_label = output[2]
assert isinstance(output_label, features.Label)
assert len(output_label) == expected_within_targets
torch.testing.assert_close(output_label, label[is_within_crop_area])

output_ohe_label = output[3]
assert isinstance(output_ohe_label, features.OneHotLabel)
torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area])

output_masks = output[4]
assert isinstance(output_masks, features.SegmentationMask)
assert output_masks.shape[:-2] == masks.shape[:-2]

output_ohe_masks = output[5]
assert isinstance(output_ohe_masks, features.SegmentationMask)
assert len(output_ohe_masks) == expected_within_targets


class TestScaleJitter:
def test__get_params(self, mocker):
image_size = (24, 32)
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
RandomAffine,
RandomCrop,
RandomHorizontalFlip,
RandomIoUCrop,
RandomPerspective,
RandomResizedCrop,
RandomRotation,
Expand Down
113 changes: 112 additions & 1 deletion torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@

import PIL.Image
import torch
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor
from torchvision.transforms.functional_tensor import _parse_pad_padding
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size

from typing_extensions import Literal

from ._transform import _RandomApplyTransform
from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image
from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_bounding_box, query_image


class RandomHorizontalFlip(_RandomApplyTransform):
Expand Down Expand Up @@ -620,6 +622,115 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
)


class RandomIoUCrop(Transform):
def __init__(
self,
min_scale: float = 0.3,
max_scale: float = 1.0,
min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0,
sampler_options: Optional[List[float]] = None,
trials: int = 40,
):
super().__init__()
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
self.min_scale = min_scale
self.max_scale = max_scale
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
if sampler_options is None:
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
self.options = sampler_options
self.trials = trials

def _get_params(self, sample: Any) -> Dict[str, Any]:

image = query_image(sample)
_, orig_h, orig_w = get_image_dimensions(image)
bboxes = query_bounding_box(sample)

while True:
# sample an option
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
min_jaccard_overlap = self.options[idx]
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
return dict()

for _ in range(self.trials):
# check the aspect ratio limitations
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
new_w = int(orig_w * r[0])
new_h = int(orig_h * r[1])
aspect_ratio = new_w / new_h
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
continue

# check for 0 area crops
r = torch.rand(2)
left = int((orig_w - new_w) * r[0])
top = int((orig_h - new_h) * r[1])
right = left + new_w
bottom = top + new_h
if left == right or top == bottom:
continue

# check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_bounding_box_format(
bboxes, old_format=bboxes.format, new_format=features.BoundingBoxFormat.XYXY, copy=True
)
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
if not is_within_crop_area.any():
continue

# check at least 1 box with jaccard limitations
xyxy_bboxes = xyxy_bboxes[is_within_crop_area]
ious = box_iou(
xyxy_bboxes,
torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device),
)
if ious.max() < min_jaccard_overlap:
continue

return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if len(params) < 1:
return inpt

is_within_crop_area = params["is_within_crop_area"]

if isinstance(inpt, (features.Label, features.OneHotLabel)):
return inpt.new_like(inpt, inpt[is_within_crop_area])

output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])

if isinstance(output, features.BoundingBox):
bboxes = output[is_within_crop_area]
bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
output = features.BoundingBox.new_like(output, bboxes)
elif isinstance(output, features.SegmentationMask) and output.shape[-3] > 1:
# apply is_within_crop_area if mask is one-hot encoded
masks = output[is_within_crop_area]
output = features.SegmentationMask.new_like(output, masks)
Comment on lines +713 to +716
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox here is a support for one-hot encoded masks meanwhile other solutions we could decide about


return output

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if not (
has_all(sample, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about plain tensors? For CutMix and MixUp we don't allow the "old" image types at all:

if not has_all(sample, features.Image, features.OneHotLabel):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")

Given that have ported this from references there is no BC constraint to allow plain tensors and PIL images. Still, it feels unnecessary restrictive.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To understand your comment, you want to add support for images as Tensors and keep Image and PIL Image ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either that or remove support for PIL here. We should either support all image types or only the "new one" like we do in CutMix and MixUp (not saying this is a good thing, but we should be consistent).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the reason why CutMix, MixUp are not supporting PIL due to lack of implementation ?
Here in RandomIoUCrop we can support everything, so I would add image as torch.Tensor support and keep PIL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should support all types. Part of the API supporting them and part not is weird and will hinter adoption.

and has_any(sample, features.Label, features.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks."
)
return super().forward(*inputs)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved


class ScaleJitter(Transform):
def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Im
raise TypeError("No image was found in the sample")


def query_bounding_box(sample: Any) -> features.BoundingBox:
flat_sample, _ = tree_flatten(sample)
for i in flat_sample:
if isinstance(i, features.BoundingBox):
return i

raise TypeError("No bounding box was found in the sample")


def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
if isinstance(image, features.Image):
channels = image.num_channels
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torchvision.transforms import InterpolationMode # usort: skip
from ._meta import (
clamp_bounding_box,
convert_bounding_box_format,
convert_color_space_image_tensor,
convert_color_space_image_pil,
Expand Down
9 changes: 9 additions & 0 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def convert_bounding_box_format(
return bounding_box


def clamp_bounding_box(
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
xyxy_boxes = convert_bounding_box_format(bounding_box, format, BoundingBoxFormat.XYXY)
xyxy_boxes[..., 0::2].clamp_(min=0, max=image_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=image_size[0])
return convert_bounding_box_format(xyxy_boxes, BoundingBoxFormat.XYXY, format, copy=False)


def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return image[..., :-1, :, :], image[..., -1:, :, :]

Expand Down