Skip to content

Commit

Permalink
[fbsync] [proto] Ported SimpleCopyPaste transform (#6451)
Browse files Browse the repository at this point in the history
Summary:
* WIP

* [proto] Added SimpleCopyPaste transform

* Refactored and cleaned the implementation and added tests

* Fixing code

* Fixed code formatting issue

* Minor updates

* Fixed merge issue

Reviewed By: datumbox

Differential Revision: D39013674

fbshipit-source-id: 212f54c4fd9cbbc72011dc331de61106842d0e99

Co-authored-by: Philip Meier <github.pmeier@posteo.de>
  • Loading branch information
2 people authored and facebook-github-bot committed Aug 25, 2022
1 parent 443016d commit 58ef09e
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 2 deletions.
91 changes: 91 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,97 @@ def test__transform(self, mocker):
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)


class TestSimpleCopyPaste:
def create_fake_image(self, mocker, image_type):
if image_type == PIL.Image.Image:
return PIL.Image.new("RGB", (32, 32), 123)
return mocker.MagicMock(spec=image_type)

def test__extract_image_targets_assertion(self, mocker):
transform = transforms.SimpleCopyPaste()

flat_sample = [
# images, batch size = 2
self.create_fake_image(mocker, features.Image),
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
# labels, bboxes, masks
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
]

with pytest.raises(TypeError, match="requires input sample to contain equal-sized list of Images"):
transform._extract_image_targets(flat_sample)

@pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor])
def test__extract_image_targets(self, image_type, mocker):
transform = transforms.SimpleCopyPaste()

flat_sample = [
# images, batch size = 2
self.create_fake_image(mocker, image_type),
self.create_fake_image(mocker, image_type),
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
]

images, targets = transform._extract_image_targets(flat_sample)

assert len(images) == len(targets) == 2
if image_type == PIL.Image.Image:
torch.testing.assert_close(images[0], pil_to_tensor(flat_sample[0]))
torch.testing.assert_close(images[1], pil_to_tensor(flat_sample[1]))
else:
assert images[0] == flat_sample[0]
assert images[1] == flat_sample[1]

def test__copy_paste(self):
image = 2 * torch.ones(3, 32, 32)
masks = torch.zeros(2, 32, 32)
masks[0, 3:9, 2:8] = 1
masks[1, 20:30, 20:30] = 1
target = {
"boxes": features.BoundingBox(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32)
),
"masks": features.SegmentationMask(masks),
"labels": features.Label(torch.tensor([1, 2])),
}

paste_image = 10 * torch.ones(3, 32, 32)
paste_masks = torch.zeros(2, 32, 32)
paste_masks[0, 13:19, 12:18] = 1
paste_masks[1, 15:19, 1:8] = 1
paste_target = {
"boxes": features.BoundingBox(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32)
),
"masks": features.SegmentationMask(paste_masks),
"labels": features.Label(torch.tensor([3, 4])),
}

transform = transforms.SimpleCopyPaste()
random_selection = torch.tensor([0, 1])
output_image, output_target = transform._copy_paste(image, target, paste_image, paste_target, random_selection)

assert output_image.unique().tolist() == [2, 10]
assert output_target["boxes"].shape == (4, 4)
torch.testing.assert_close(output_target["boxes"][:2, :], target["boxes"])
torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"])
torch.testing.assert_close(output_target["labels"], features.Label(torch.tensor([1, 2, 3, 4])))
assert output_target["masks"].shape == (4, 32, 32)
torch.testing.assert_close(output_target["masks"][:2, :], target["masks"])
torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"])


class TestFixedSizeCrop:
def test__get_params(self, mocker):
crop_size = (7, 7)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ._transform import Transform # usort: skip

from ._augment import RandomCutmix, RandomErasing, RandomMixup
from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
from ._auto_augment import AugMix, AutoAugment, AutoAugmentPolicy, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Expand Down
191 changes: 190 additions & 1 deletion torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import math
import numbers
import warnings
from typing import Any, Dict, Tuple
from typing import Any, Dict, List, Tuple

import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.ops import masks_to_boxes
from torchvision.prototype import features

from torchvision.prototype.transforms import functional as F
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor

from ._transform import _RandomApplyTransform
from ._utils import has_any, is_simple_tensor, query_chw
Expand Down Expand Up @@ -178,3 +183,187 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._mixup_onehotlabel(inpt, lam_adjusted)
else:
return inpt


class SimpleCopyPaste(_RandomApplyTransform):
def __init__(
self,
p: float = 0.5,
blending: bool = True,
resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR,
) -> None:
super().__init__(p=p)
self.resize_interpolation = resize_interpolation
self.blending = blending

def _copy_paste(
self,
image: Any,
target: Dict[str, Any],
paste_image: Any,
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool = True,
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
) -> Tuple[Any, Dict[str, Any]]:

paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection])
paste_labels = paste_target["labels"].new_like(paste_target["labels"], paste_target["labels"][random_selection])

masks = target["masks"]

# We resize source and paste data if they have different sizes
# This is something different to TF implementation we introduced here as
# originally the algorithm works on equal-sized data
# (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:]
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation)
paste_masks = F.resize(paste_masks, size=size1)
paste_boxes = F.resize(paste_boxes, size=size1)

paste_alpha_mask = paste_masks.sum(dim=0) > 0

if blending:
paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0])

# Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)

# Copy-paste masks:
masks = masks * (~paste_alpha_mask)
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]

# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}

out_target["masks"] = torch.cat([masks, paste_masks])

# Copy-paste boxes and labels
bbox_format = target["boxes"].format
xyxy_boxes = masks_to_boxes(masks)
# TODO: masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive
# we need to add +1 to x2y2. We need to investigate that.
xyxy_boxes[:, 2:] += 1
boxes = F.convert_bounding_box_format(
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False
)
out_target["boxes"] = torch.cat([boxes, paste_boxes])

labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])

# Check for degenerated boxes and remove them
boxes = F.convert_bounding_box_format(
out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY
)
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)

out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]

return image, out_target

def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[SegmentationMask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, features.Image) or is_simple_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(pil_to_tensor(obj))
elif isinstance(obj, features.BoundingBox):
bboxes.append(obj)
elif isinstance(obj, features.SegmentationMask):
masks.append(obj)
elif isinstance(obj, (features.Label, features.OneHotLabel)):
labels.append(obj)

if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal-sized list of Images, "
"BoundingBoxes, Segmentation Masks and Labels or OneHotLabels."
)

targets = []
for bbox, mask, label in zip(bboxes, masks, labels):
targets.append({"boxes": bbox, "masks": mask, "labels": label})

return images, targets

def _insert_outputs(
self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]]
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample):
if isinstance(obj, features.Image):
flat_sample[i] = features.Image.new_like(obj, output_images[c0])
c0 += 1
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0])
c0 += 1
elif is_simple_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, features.BoundingBox):
flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"])
c1 += 1
elif isinstance(obj, features.SegmentationMask):
flat_sample[i] = features.SegmentationMask.new_like(obj, output_targets[c2]["masks"])
c2 += 1
elif isinstance(obj, (features.Label, features.OneHotLabel)):
flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
c3 += 1

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

flat_sample, spec = tree_flatten(sample)

images, targets = self._extract_image_targets(flat_sample)

# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]

output_images, output_targets = [], []

for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):

# Random paste targets selection:
num_masks = len(paste_target["masks"])

if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
output_image, output_target = image, target
else:
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
random_selection = torch.unique(random_selection)

output_image, output_target = self._copy_paste(
image,
target,
paste_image,
paste_target,
random_selection=random_selection,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
)
output_images.append(output_image)
output_targets.append(output_target)

# Insert updated images and targets into input flat_sample
self._insert_outputs(flat_sample, output_images, output_targets)

return tree_unflatten(flat_sample, spec)

0 comments on commit 58ef09e

Please sign in to comment.