diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 338911d1445..0f485a10550 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1313,6 +1313,97 @@ def test__transform(self, mocker): mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel) +class TestSimpleCopyPaste: + def create_fake_image(self, mocker, image_type): + if image_type == PIL.Image.Image: + return PIL.Image.new("RGB", (32, 32), 123) + return mocker.MagicMock(spec=image_type) + + def test__extract_image_targets_assertion(self, mocker): + transform = transforms.SimpleCopyPaste() + + flat_sample = [ + # images, batch size = 2 + self.create_fake_image(mocker, features.Image), + # labels, bboxes, masks + mocker.MagicMock(spec=features.Label), + mocker.MagicMock(spec=features.BoundingBox), + mocker.MagicMock(spec=features.SegmentationMask), + # labels, bboxes, masks + mocker.MagicMock(spec=features.BoundingBox), + mocker.MagicMock(spec=features.SegmentationMask), + ] + + with pytest.raises(TypeError, match="requires input sample to contain equal-sized list of Images"): + transform._extract_image_targets(flat_sample) + + @pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor]) + def test__extract_image_targets(self, image_type, mocker): + transform = transforms.SimpleCopyPaste() + + flat_sample = [ + # images, batch size = 2 + self.create_fake_image(mocker, image_type), + self.create_fake_image(mocker, image_type), + # labels, bboxes, masks + mocker.MagicMock(spec=features.Label), + mocker.MagicMock(spec=features.BoundingBox), + mocker.MagicMock(spec=features.SegmentationMask), + # labels, bboxes, masks + mocker.MagicMock(spec=features.Label), + mocker.MagicMock(spec=features.BoundingBox), + mocker.MagicMock(spec=features.SegmentationMask), + ] + + images, targets = transform._extract_image_targets(flat_sample) + + assert len(images) == len(targets) == 2 + if image_type == PIL.Image.Image: + torch.testing.assert_close(images[0], pil_to_tensor(flat_sample[0])) + torch.testing.assert_close(images[1], pil_to_tensor(flat_sample[1])) + else: + assert images[0] == flat_sample[0] + assert images[1] == flat_sample[1] + + def test__copy_paste(self): + image = 2 * torch.ones(3, 32, 32) + masks = torch.zeros(2, 32, 32) + masks[0, 3:9, 2:8] = 1 + masks[1, 20:30, 20:30] = 1 + target = { + "boxes": features.BoundingBox( + torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32) + ), + "masks": features.SegmentationMask(masks), + "labels": features.Label(torch.tensor([1, 2])), + } + + paste_image = 10 * torch.ones(3, 32, 32) + paste_masks = torch.zeros(2, 32, 32) + paste_masks[0, 13:19, 12:18] = 1 + paste_masks[1, 15:19, 1:8] = 1 + paste_target = { + "boxes": features.BoundingBox( + torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32) + ), + "masks": features.SegmentationMask(paste_masks), + "labels": features.Label(torch.tensor([3, 4])), + } + + transform = transforms.SimpleCopyPaste() + random_selection = torch.tensor([0, 1]) + output_image, output_target = transform._copy_paste(image, target, paste_image, paste_target, random_selection) + + assert output_image.unique().tolist() == [2, 10] + assert output_target["boxes"].shape == (4, 4) + torch.testing.assert_close(output_target["boxes"][:2, :], target["boxes"]) + torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"]) + torch.testing.assert_close(output_target["labels"], features.Label(torch.tensor([1, 2, 3, 4]))) + assert output_target["masks"].shape == (4, 32, 32) + torch.testing.assert_close(output_target["masks"][:2, :], target["masks"]) + torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"]) + + class TestFixedSizeCrop: def test__get_params(self, mocker): crop_size = (7, 7) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 5c81436008e..90dfd297da8 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -2,7 +2,7 @@ from ._transform import Transform # usort: skip -from ._augment import RandomCutmix, RandomErasing, RandomMixup +from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste from ._auto_augment import AugMix, AutoAugment, AutoAugmentPolicy, RandAugment, TrivialAugmentWide from ._color import ( ColorJitter, diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 3a3e0068169..32697f32257 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -1,11 +1,16 @@ import math import numbers import warnings -from typing import Any, Dict, Tuple +from typing import Any, Dict, List, Tuple +import PIL.Image import torch +from torch.utils._pytree import tree_flatten, tree_unflatten +from torchvision.ops import masks_to_boxes from torchvision.prototype import features + from torchvision.prototype.transforms import functional as F +from torchvision.transforms.functional import InterpolationMode, pil_to_tensor from ._transform import _RandomApplyTransform from ._utils import has_any, is_simple_tensor, query_chw @@ -178,3 +183,187 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._mixup_onehotlabel(inpt, lam_adjusted) else: return inpt + + +class SimpleCopyPaste(_RandomApplyTransform): + def __init__( + self, + p: float = 0.5, + blending: bool = True, + resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR, + ) -> None: + super().__init__(p=p) + self.resize_interpolation = resize_interpolation + self.blending = blending + + def _copy_paste( + self, + image: Any, + target: Dict[str, Any], + paste_image: Any, + paste_target: Dict[str, Any], + random_selection: torch.Tensor, + blending: bool = True, + resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR, + ) -> Tuple[Any, Dict[str, Any]]: + + paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection]) + paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection]) + paste_labels = paste_target["labels"].new_like(paste_target["labels"], paste_target["labels"][random_selection]) + + masks = target["masks"] + + # We resize source and paste data if they have different sizes + # This is something different to TF implementation we introduced here as + # originally the algorithm works on equal-sized data + # (for example, coming from LSJ data augmentations) + size1 = image.shape[-2:] + size2 = paste_image.shape[-2:] + if size1 != size2: + paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation) + paste_masks = F.resize(paste_masks, size=size1) + paste_boxes = F.resize(paste_boxes, size=size1) + + paste_alpha_mask = paste_masks.sum(dim=0) > 0 + + if blending: + paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0]) + + # Copy-paste images: + image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) + + # Copy-paste masks: + masks = masks * (~paste_alpha_mask) + non_all_zero_masks = masks.sum((-1, -2)) > 0 + masks = masks[non_all_zero_masks] + + # Do a shallow copy of the target dict + out_target = {k: v for k, v in target.items()} + + out_target["masks"] = torch.cat([masks, paste_masks]) + + # Copy-paste boxes and labels + bbox_format = target["boxes"].format + xyxy_boxes = masks_to_boxes(masks) + # TODO: masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive + # we need to add +1 to x2y2. We need to investigate that. + xyxy_boxes[:, 2:] += 1 + boxes = F.convert_bounding_box_format( + xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False + ) + out_target["boxes"] = torch.cat([boxes, paste_boxes]) + + labels = target["labels"][non_all_zero_masks] + out_target["labels"] = torch.cat([labels, paste_labels]) + + # Check for degenerated boxes and remove them + boxes = F.convert_bounding_box_format( + out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY + ) + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + valid_targets = ~degenerate_boxes.any(dim=1) + + out_target["boxes"] = boxes[valid_targets] + out_target["masks"] = out_target["masks"][valid_targets] + out_target["labels"] = out_target["labels"][valid_targets] + + return image, out_target + + def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]: + # fetch all images, bboxes, masks and labels from unstructured input + # with List[image], List[BoundingBox], List[SegmentationMask], List[Label] + images, bboxes, masks, labels = [], [], [], [] + for obj in flat_sample: + if isinstance(obj, features.Image) or is_simple_tensor(obj): + images.append(obj) + elif isinstance(obj, PIL.Image.Image): + images.append(pil_to_tensor(obj)) + elif isinstance(obj, features.BoundingBox): + bboxes.append(obj) + elif isinstance(obj, features.SegmentationMask): + masks.append(obj) + elif isinstance(obj, (features.Label, features.OneHotLabel)): + labels.append(obj) + + if not (len(images) == len(bboxes) == len(masks) == len(labels)): + raise TypeError( + f"{type(self).__name__}() requires input sample to contain equal-sized list of Images, " + "BoundingBoxes, Segmentation Masks and Labels or OneHotLabels." + ) + + targets = [] + for bbox, mask, label in zip(bboxes, masks, labels): + targets.append({"boxes": bbox, "masks": mask, "labels": label}) + + return images, targets + + def _insert_outputs( + self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]] + ) -> None: + c0, c1, c2, c3 = 0, 0, 0, 0 + for i, obj in enumerate(flat_sample): + if isinstance(obj, features.Image): + flat_sample[i] = features.Image.new_like(obj, output_images[c0]) + c0 += 1 + elif isinstance(obj, PIL.Image.Image): + flat_sample[i] = F.to_image_pil(output_images[c0]) + c0 += 1 + elif is_simple_tensor(obj): + flat_sample[i] = output_images[c0] + c0 += 1 + elif isinstance(obj, features.BoundingBox): + flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"]) + c1 += 1 + elif isinstance(obj, features.SegmentationMask): + flat_sample[i] = features.SegmentationMask.new_like(obj, output_targets[c2]["masks"]) + c2 += 1 + elif isinstance(obj, (features.Label, features.OneHotLabel)): + flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] + c3 += 1 + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + + flat_sample, spec = tree_flatten(sample) + + images, targets = self._extract_image_targets(flat_sample) + + # images = [t1, t2, ..., tN] + # Let's define paste_images as shifted list of input images + # paste_images = [t2, t3, ..., tN, t1] + # FYI: in TF they mix data on the dataset level + images_rolled = images[-1:] + images[:-1] + targets_rolled = targets[-1:] + targets[:-1] + + output_images, output_targets = [], [] + + for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled): + + # Random paste targets selection: + num_masks = len(paste_target["masks"]) + + if num_masks < 1: + # Such degerante case with num_masks=0 can happen with LSJ + # Let's just return (image, target) + output_image, output_target = image, target + else: + random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) + random_selection = torch.unique(random_selection) + + output_image, output_target = self._copy_paste( + image, + target, + paste_image, + paste_target, + random_selection=random_selection, + blending=self.blending, + resize_interpolation=self.resize_interpolation, + ) + output_images.append(output_image) + output_targets.append(output_target) + + # Insert updated images and targets into input flat_sample + self._insert_outputs(flat_sample, output_images, output_targets) + + return tree_unflatten(flat_sample, spec)