Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[proto] Ported SimpleCopyPaste transform #6451

Merged
merged 13 commits into from
Aug 23, 2022
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ._transform import Transform # usort: skip

from ._augment import RandomCutmix, RandomErasing, RandomMixup
from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
from ._auto_augment import AugMix, AutoAugment, AutoAugmentPolicy, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Expand Down
193 changes: 192 additions & 1 deletion torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import math
import numbers
import warnings
from typing import Any, Dict, Tuple
from typing import Any, Dict, List, Tuple, Union

import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.ops import masks_to_boxes
from torchvision.prototype import features

from torchvision.prototype.transforms import functional as F
from torchvision.transforms.functional import InterpolationMode

from ._transform import _RandomApplyTransform
from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image
Expand Down Expand Up @@ -180,3 +184,190 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._mixup_onehotlabel(inpt, lam_adjusted)
else:
return inpt


class SimpleCopyPaste(_RandomApplyTransform):
def __init__(
self,
p: float = 0.5,
blending: bool = True,
resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR,
) -> None:
super().__init__(p=p)
self.resize_interpolation = resize_interpolation
self.blending = blending

def _copy_paste(
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
self,
image: Any,
target: Dict[str, Any],
paste_image: Any,
paste_target: Dict[str, Any],
blending: bool = True,
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
) -> Tuple[Any, Dict[str, Any]]:

# Random paste targets selection:
num_masks = len(paste_target["masks"])

if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
return image, target

# We have to please torch script by explicitly specifying dtype as torch.long
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
random_selection = torch.unique(random_selection).to(torch.long)

paste_masks = paste_target["masks"][random_selection]
paste_boxes = paste_target["boxes"][random_selection]
paste_labels = paste_target["labels"][random_selection]

masks = target["masks"]

# We resize source and paste data if they have different sizes
# This is something different to TF implementation we introduced here as
# originally the algorithm works on equal-sized data
# (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:]
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation)
paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST)
# resize bboxes:
ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device)
paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape)

paste_alpha_mask = paste_masks.sum(dim=0) > 0

if blending:
paste_alpha_mask = F.gaussian_blur(
paste_alpha_mask.unsqueeze(0),
kernel_size=(5, 5),
sigma=[
2.0,
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
],
)

# Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)

# Copy-paste masks:
masks = masks * (~paste_alpha_mask)
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]

# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}

out_target["masks"] = torch.cat([masks, paste_masks])

# Copy-paste boxes and labels
boxes = masks_to_boxes(masks)
out_target["boxes"] = torch.cat([boxes, paste_boxes])

labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])

# Update additional optional keys: area and iscrowd if exist
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
if "area" in target:
out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32)

if "iscrowd" in target and "iscrowd" in paste_target:
# target['iscrowd'] size can be differ from mask size (non_all_zero_masks)
# For example, if previous transforms geometrically modifies masks/boxes/labels but
# does not update "iscrowd"
if len(target["iscrowd"]) == len(non_all_zero_masks):
iscrowd = target["iscrowd"][non_all_zero_masks]
paste_iscrowd = paste_target["iscrowd"][random_selection]
out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd])

# Check for degenerated boxes and remove them
boxes = out_target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)

out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]

if "area" in out_target:
out_target["area"] = out_target["area"][valid_targets]
if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets):
out_target["iscrowd"] = out_target["iscrowd"][valid_targets]

return image, out_target

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

flat_sample, spec = tree_flatten(sample)

# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[SegmentationMask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, PIL.Image.Image, features.Image) or is_simple_tensor(obj):
images.append(F.to_image_tensor(obj))
elif isinstance(obj, features.BoundingBox):
bboxes.append(obj)
elif isinstance(obj, features.SegmentationMask):
masks.append(obj)
elif isinstance(obj, (features.Label, features.OneHotLabel)):
labels.append(obj)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As our input is a non-collated batch of datapoint, we need to fetch images, bboxes, masks, labels.
We then transform them and finally put into their places.
I'm not a fan of this code but this helps to do what we want. Maybe we could specify a bit more the input structure and thus code forward in a more elegant way.


if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal-sized list of Images, "
"BoundingBoxes, Segmentation Masks and Labels or OneHotLabels."
)

targets = []
for bbox, mask, label in zip(bboxes, masks, labels):
targets.append({"boxes": bbox, "masks": mask, "labels": label})

# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]

output_images, output_targets = [], []

for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
output_image, output_data = self._copy_paste(
image,
target,
paste_image,
paste_target,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
)
output_images.append(output_image)
output_targets.append(output_data)

# Insert updated images and targets into input flat_sample
c0, c1, c2, c3 = 0, 0, 0, 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we save these indices the first time we iterate over the flat sample? If yes, maybe we can get away with only doing flat_sample[i] = output[i] and only looping over the output.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we can do that. I just found that passing indices everywhere would be a bit bulky...

for i, obj in enumerate(flat_sample):
if isinstance(obj, features.Image):
flat_sample[i] = features.Image.new_like(obj, output_images[c0])
c0 += 1
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0])
c0 += 1
elif is_simple_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, features.BoundingBox):
flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"])
c1 += 1
elif isinstance(obj, features.SegmentationMask):
flat_sample[i] = features.SegmentationMask.new_like(obj, output_targets[c2]["masks"])
c2 += 1
elif isinstance(obj, (features.Label, features.OneHotLabel)):
flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
c3 += 1

return tree_unflatten(flat_sample, spec)