Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[proto] Ported SimpleCopyPaste transform #6451

Merged
merged 13 commits into from
Aug 23, 2022
91 changes: 91 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,97 @@ def test__transform(self, mocker):
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)


class TestSimpleCopyPaste:
def create_fake_image(self, mocker, image_type):
if image_type == PIL.Image.Image:
return PIL.Image.new("RGB", (32, 32), 123)
return mocker.MagicMock(spec=image_type)

def test__extract_image_targets_assertion(self, mocker):
transform = transforms.SimpleCopyPaste()

flat_sample = [
# images, batch size = 2
self.create_fake_image(mocker, features.Image),
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
# labels, bboxes, masks
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
]

with pytest.raises(TypeError, match="requires input sample to contain equal-sized list of Images"):
transform._extract_image_targets(flat_sample)

@pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor])
def test__extract_image_targets(self, image_type, mocker):
transform = transforms.SimpleCopyPaste()

flat_sample = [
# images, batch size = 2
self.create_fake_image(mocker, image_type),
self.create_fake_image(mocker, image_type),
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
]

images, targets = transform._extract_image_targets(flat_sample)

assert len(images) == len(targets) == 2
if image_type == PIL.Image.Image:
torch.testing.assert_close(images[0], pil_to_tensor(flat_sample[0]))
torch.testing.assert_close(images[1], pil_to_tensor(flat_sample[1]))
else:
assert images[0] == flat_sample[0]
assert images[1] == flat_sample[1]

def test__copy_paste(self):
image = 2 * torch.ones(3, 32, 32)
masks = torch.zeros(2, 32, 32)
masks[0, 3:9, 2:8] = 1
masks[1, 20:30, 20:30] = 1
target = {
"boxes": features.BoundingBox(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32)
),
"masks": features.SegmentationMask(masks),
"labels": features.Label(torch.tensor([1, 2])),
}

paste_image = 10 * torch.ones(3, 32, 32)
paste_masks = torch.zeros(2, 32, 32)
paste_masks[0, 13:19, 12:18] = 1
paste_masks[1, 15:19, 1:8] = 1
paste_target = {
"boxes": features.BoundingBox(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32)
),
"masks": features.SegmentationMask(paste_masks),
"labels": features.Label(torch.tensor([3, 4])),
}

transform = transforms.SimpleCopyPaste()
random_selection = torch.tensor([0, 1])
output_image, output_target = transform._copy_paste(image, target, paste_image, paste_target, random_selection)

assert output_image.unique().tolist() == [2, 10]
assert output_target["boxes"].shape == (4, 4)
torch.testing.assert_close(output_target["boxes"][:2, :], target["boxes"])
torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"])
torch.testing.assert_close(output_target["labels"], features.Label(torch.tensor([1, 2, 3, 4])))
assert output_target["masks"].shape == (4, 32, 32)
torch.testing.assert_close(output_target["masks"][:2, :], target["masks"])
torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"])


class TestFixedSizeCrop:
def test__get_params(self, mocker):
crop_size = (7, 7)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ._transform import Transform # usort: skip

from ._augment import RandomCutmix, RandomErasing, RandomMixup
from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
from ._auto_augment import AugMix, AutoAugment, AutoAugmentPolicy, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Expand Down
191 changes: 190 additions & 1 deletion torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import math
import numbers
import warnings
from typing import Any, Dict, Tuple
from typing import Any, Dict, List, Tuple

import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.ops import masks_to_boxes
from torchvision.prototype import features

from torchvision.prototype.transforms import functional as F
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor

from ._transform import _RandomApplyTransform
from ._utils import has_any, is_simple_tensor, query_chw
Expand Down Expand Up @@ -178,3 +183,187 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._mixup_onehotlabel(inpt, lam_adjusted)
else:
return inpt


class SimpleCopyPaste(_RandomApplyTransform):
def __init__(
self,
p: float = 0.5,
blending: bool = True,
resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR,
) -> None:
super().__init__(p=p)
self.resize_interpolation = resize_interpolation
self.blending = blending

def _copy_paste(
self,
image: Any,
target: Dict[str, Any],
paste_image: Any,
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool = True,
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
) -> Tuple[Any, Dict[str, Any]]:

paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection])
paste_labels = paste_target["labels"].new_like(paste_target["labels"], paste_target["labels"][random_selection])

masks = target["masks"]

# We resize source and paste data if they have different sizes
# This is something different to TF implementation we introduced here as
# originally the algorithm works on equal-sized data
# (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:]
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation)
paste_masks = F.resize(paste_masks, size=size1)
paste_boxes = F.resize(paste_boxes, size=size1)

paste_alpha_mask = paste_masks.sum(dim=0) > 0

if blending:
paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0])

# Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)

# Copy-paste masks:
masks = masks * (~paste_alpha_mask)
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]

# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}

out_target["masks"] = torch.cat([masks, paste_masks])

# Copy-paste boxes and labels
bbox_format = target["boxes"].format
xyxy_boxes = masks_to_boxes(masks)
# TODO: masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive
# we need to add +1 to x2y2. We need to investigate that.
xyxy_boxes[:, 2:] += 1
boxes = F.convert_bounding_box_format(
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False
)
out_target["boxes"] = torch.cat([boxes, paste_boxes])

labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])

# Check for degenerated boxes and remove them
boxes = F.convert_bounding_box_format(
out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY
)
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)

out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]

return image, out_target

def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[SegmentationMask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, features.Image) or is_simple_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(pil_to_tensor(obj))
elif isinstance(obj, features.BoundingBox):
bboxes.append(obj)
elif isinstance(obj, features.SegmentationMask):
masks.append(obj)
elif isinstance(obj, (features.Label, features.OneHotLabel)):
labels.append(obj)

if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal-sized list of Images, "
"BoundingBoxes, Segmentation Masks and Labels or OneHotLabels."
)

targets = []
for bbox, mask, label in zip(bboxes, masks, labels):
targets.append({"boxes": bbox, "masks": mask, "labels": label})

return images, targets

def _insert_outputs(
self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]]
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we save these indices the first time we iterate over the flat sample? If yes, maybe we can get away with only doing flat_sample[i] = output[i] and only looping over the output.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we can do that. I just found that passing indices everywhere would be a bit bulky...

for i, obj in enumerate(flat_sample):
if isinstance(obj, features.Image):
flat_sample[i] = features.Image.new_like(obj, output_images[c0])
c0 += 1
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0])
c0 += 1
elif is_simple_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, features.BoundingBox):
flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"])
c1 += 1
elif isinstance(obj, features.SegmentationMask):
flat_sample[i] = features.SegmentationMask.new_like(obj, output_targets[c2]["masks"])
c2 += 1
elif isinstance(obj, (features.Label, features.OneHotLabel)):
flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
c3 += 1

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

flat_sample, spec = tree_flatten(sample)

images, targets = self._extract_image_targets(flat_sample)

# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]

output_images, output_targets = [], []

for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):

# Random paste targets selection:
num_masks = len(paste_target["masks"])

if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
output_image, output_target = image, target
else:
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
random_selection = torch.unique(random_selection)

output_image, output_target = self._copy_paste(
image,
target,
paste_image,
paste_target,
random_selection=random_selection,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
)
output_images.append(output_image)
output_targets.append(output_target)

# Insert updated images and targets into input flat_sample
self._insert_outputs(flat_sample, output_images, output_targets)

return tree_unflatten(flat_sample, spec)