From 17866a2d389356a4cf24f6ceea2f523d7b6e50e6 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 17 Aug 2022 17:34:07 +0000 Subject: [PATCH 1/7] WIP --- torchvision/prototype/transforms/__init__.py | 2 +- torchvision/prototype/transforms/_augment.py | 32 ++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index dc6476ab4b5..22181a19e59 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -2,7 +2,7 @@ from ._transform import Transform # usort: skip -from ._augment import RandomCutmix, RandomErasing, RandomMixup +from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide from ._color import ( ColorJitter, diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index fd8ae9ab378..5e2ad3956b9 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -7,6 +7,7 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform +from torchvision.transforms.functional import InterpolationMode from ._transform import _RandomApplyTransform from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_image @@ -178,3 +179,34 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._mixup_onehotlabel(inpt, lam_adjusted) else: return inpt + + +class SimpleCopyPaste(Transform): + def __init__( + self, blending: bool = True, resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR + ) -> None: + super().__init__() + self.resize_interpolation = resize_interpolation + self.blending = blending + + def _get_params(self, sample: Any) -> Dict[str, Any]: + return super()._get_params(sample) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return super()._transform(inpt, params) + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + + # TODO: Ensure that inputs is batched + + if not ( + has_all(sample, features.BoundingBox, features.SegmentationMask) + and has_any(sample, PIL.Image.Image, features.Image) + and has_any(sample, features.Label, features.OneHotLabel) + ): + raise TypeError( + f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " + "BoundingBoxes, Segmentation Masks and Labels or OneHotLabels." + ) + return super().forward(*inputs) \ No newline at end of file From dce3a29cbedf0562d6ef9a866da53ab35991860c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 18 Aug 2022 16:06:12 +0000 Subject: [PATCH 2/7] [proto] Added SimpleCopyPaste transform --- torchvision/prototype/transforms/_augment.py | 192 +++++++++++++++++-- 1 file changed, 175 insertions(+), 17 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 0a55bf0bdac..2fcb17e26bb 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -1,10 +1,12 @@ import math import numbers import warnings -from typing import Any, Dict, Tuple +from typing import Any, Dict, List, Tuple, Union import PIL.Image import torch +from torch.utils._pytree import tree_flatten, tree_unflatten +from torchvision.ops import masks_to_boxes from torchvision.prototype import features from torchvision.prototype.transforms import functional as F @@ -184,32 +186,188 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt -class SimpleCopyPaste(Transform): +class SimpleCopyPaste(_RandomApplyTransform): def __init__( - self, blending: bool = True, resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR + self, + p: float = 0.5, + blending: bool = True, + resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR, ) -> None: - super().__init__() + super().__init__(p=p) self.resize_interpolation = resize_interpolation self.blending = blending - def _get_params(self, sample: Any) -> Dict[str, Any]: - return super()._get_params(sample) + def _copy_paste( + self, + image: Any, + target: Dict[str, Any], + paste_image: Any, + paste_target: Dict[str, Any], + blending: bool = True, + resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR, + ) -> Tuple[Any, Dict[str, Any]]: + + # Random paste targets selection: + num_masks = len(paste_target["masks"]) + + if num_masks < 1: + # Such degerante case with num_masks=0 can happen with LSJ + # Let's just return (image, target) + return image, target + + # We have to please torch script by explicitly specifying dtype as torch.long + random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) + random_selection = torch.unique(random_selection).to(torch.long) + + paste_masks = paste_target["masks"][random_selection] + paste_boxes = paste_target["boxes"][random_selection] + paste_labels = paste_target["labels"][random_selection] + + masks = target["masks"] + + # We resize source and paste data if they have different sizes + # This is something different to TF implementation we introduced here as + # originally the algorithm works on equal-sized data + # (for example, coming from LSJ data augmentations) + size1 = image.shape[-2:] + size2 = paste_image.shape[-2:] + if size1 != size2: + paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation) + paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST) + # resize bboxes: + ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device) + paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape) + + paste_alpha_mask = paste_masks.sum(dim=0) > 0 + + if blending: + paste_alpha_mask = F.gaussian_blur( + paste_alpha_mask.unsqueeze(0), + kernel_size=(5, 5), + sigma=[ + 2.0, + ], + ) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return super()._transform(inpt, params) + # Copy-paste images: + image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) + + # Copy-paste masks: + masks = masks * (~paste_alpha_mask) + non_all_zero_masks = masks.sum((-1, -2)) > 0 + masks = masks[non_all_zero_masks] + + # Do a shallow copy of the target dict + out_target = {k: v for k, v in target.items()} + + out_target["masks"] = torch.cat([masks, paste_masks]) + + # Copy-paste boxes and labels + boxes = masks_to_boxes(masks) + out_target["boxes"] = torch.cat([boxes, paste_boxes]) + + labels = target["labels"][non_all_zero_masks] + out_target["labels"] = torch.cat([labels, paste_labels]) + + # Update additional optional keys: area and iscrowd if exist + if "area" in target: + out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32) + + if "iscrowd" in target and "iscrowd" in paste_target: + # target['iscrowd'] size can be differ from mask size (non_all_zero_masks) + # For example, if previous transforms geometrically modifies masks/boxes/labels but + # does not update "iscrowd" + if len(target["iscrowd"]) == len(non_all_zero_masks): + iscrowd = target["iscrowd"][non_all_zero_masks] + paste_iscrowd = paste_target["iscrowd"][random_selection] + out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd]) + + # Check for degenerated boxes and remove them + boxes = out_target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + valid_targets = ~degenerate_boxes.any(dim=1) + + out_target["boxes"] = boxes[valid_targets] + out_target["masks"] = out_target["masks"][valid_targets] + out_target["labels"] = out_target["labels"][valid_targets] + + if "area" in out_target: + out_target["area"] = out_target["area"][valid_targets] + if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets): + out_target["iscrowd"] = out_target["iscrowd"][valid_targets] + + return image, out_target def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - # TODO: Ensure that inputs is batched - - if not ( - has_all(sample, features.BoundingBox, features.SegmentationMask) - and has_any(sample, PIL.Image.Image, features.Image) - and has_any(sample, features.Label, features.OneHotLabel) - ): + flat_sample, spec = tree_flatten(sample) + + # fetch all images, bboxes, masks and labels from unstructured input + # with List[image], List[BoundingBox], List[SegmentationMask], List[Label] + images, bboxes, masks, labels = [], [], [], [] + for obj in flat_sample: + if isinstance(obj, PIL.Image.Image, features.Image) or is_simple_tensor(obj): + images.append(F.to_image_tensor(obj)) + elif isinstance(obj, features.BoundingBox): + bboxes.append(obj) + elif isinstance(obj, features.SegmentationMask): + masks.append(obj) + elif isinstance(obj, (features.Label, features.OneHotLabel)): + labels.append(obj) + + if not (len(images) == len(bboxes) == len(masks) == len(labels)): raise TypeError( - f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " + f"{type(self).__name__}() requires input sample to contain equal-sized list of Images, " "BoundingBoxes, Segmentation Masks and Labels or OneHotLabels." ) - return super().forward(*inputs) \ No newline at end of file + + targets = [] + for bbox, mask, label in zip(bboxes, masks, labels): + targets.append({"boxes": bbox, "masks": mask, "labels": label}) + + # images = [t1, t2, ..., tN] + # Let's define paste_images as shifted list of input images + # paste_images = [t2, t3, ..., tN, t1] + # FYI: in TF they mix data on the dataset level + images_rolled = images[-1:] + images[:-1] + targets_rolled = targets[-1:] + targets[:-1] + + output_images, output_targets = [], [] + + for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled): + output_image, output_data = self._copy_paste( + image, + target, + paste_image, + paste_target, + blending=self.blending, + resize_interpolation=self.resize_interpolation, + ) + output_images.append(output_image) + output_targets.append(output_data) + + # Insert updated images and targets into input flat_sample + c0, c1, c2, c3 = 0, 0, 0, 0 + for i, obj in enumerate(flat_sample): + if isinstance(obj, features.Image): + flat_sample[i] = features.Image.new_like(obj, output_images[c0]) + c0 += 1 + elif isinstance(obj, PIL.Image.Image): + flat_sample[i] = F.to_image_pil(output_images[c0]) + c0 += 1 + elif is_simple_tensor(obj): + flat_sample[i] = output_images[c0] + c0 += 1 + elif isinstance(obj, features.BoundingBox): + flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"]) + c1 += 1 + elif isinstance(obj, features.SegmentationMask): + flat_sample[i] = features.SegmentationMask.new_like(obj, output_targets[c2]["masks"]) + c2 += 1 + elif isinstance(obj, (features.Label, features.OneHotLabel)): + flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] + c3 += 1 + + return tree_unflatten(flat_sample, spec) From 9a52b8b96eb0e0e06d14872a68a6bd6f6951ae80 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 22 Aug 2022 11:02:24 +0000 Subject: [PATCH 3/7] Refactored and cleaned the implementation and added tests --- test/test_prototype_transforms.py | 79 +++++++++++ torchvision/prototype/transforms/_augment.py | 141 +++++++++---------- 2 files changed, 146 insertions(+), 74 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 8839d842b85..a5485b929e8 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1328,3 +1328,82 @@ def test__transform(self, mocker): transform(inpt_sentinel) mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel) + + +class TestSimpleCopyPaste: + def test_forward_assertions(self): + pass + + @pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor]) + def test__extract_image_targets(self, image_type, mocker): + transform = transforms.SimpleCopyPaste() + + def create_fake_image(image_type): + if image_type == PIL.Image.Image: + return PIL.Image.new("RGB", (32, 32), 123) + return mocker.MagicMock(spec=image_type) + + flat_sample = [ + # images, batch size = 2 + create_fake_image(image_type), + create_fake_image(image_type), + # labels, bboxes, masks + mocker.MagicMock(spec=features.Label), + mocker.MagicMock(spec=features.BoundingBox), + mocker.MagicMock(spec=features.SegmentationMask), + # labels, bboxes, masks + mocker.MagicMock(spec=features.Label), + mocker.MagicMock(spec=features.BoundingBox), + mocker.MagicMock(spec=features.SegmentationMask), + ] + + images, targets = transform._extract_image_targets(flat_sample) + + assert len(images) == len(targets) == 2 + if image_type == PIL.Image.Image: + torch.testing.assert_close(images[0], pil_to_tensor(flat_sample[0])) + torch.testing.assert_close(images[1], pil_to_tensor(flat_sample[1])) + else: + assert images[0] == flat_sample[0] + assert images[1] == flat_sample[1] + + def test__copy_paste(self): + image = 2 * torch.ones(3, 32, 32) + masks = torch.zeros(2, 32, 32) + masks[0, 3:9, 2:8] = 1 + masks[1, 20:30, 20:30] = 1 + target = { + "boxes": features.BoundingBox( + torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32) + ), + "masks": features.SegmentationMask(masks), + "labels": features.Label(torch.tensor([1, 2])), + } + + paste_image = 10 * torch.ones(3, 32, 32) + paste_masks = torch.zeros(2, 32, 32) + paste_masks[0, 13:19, 12:18] = 1 + paste_masks[1, 15:19, 1:8] = 1 + paste_target = { + "boxes": features.BoundingBox( + torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32) + ), + "masks": features.SegmentationMask(paste_masks), + "labels": features.Label(torch.tensor([3, 4])), + } + + transform = transforms.SimpleCopyPaste() + random_selection = torch.tensor([0, 1]) + output_image, output_target = transform._copy_paste(image, target, paste_image, paste_target, random_selection) + + assert output_image.unique().tolist() == [2, 10] + assert output_target["boxes"].shape == (4, 4) + torch.testing.assert_close(output_target["boxes"][:2, :], target["boxes"]) + torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"]) + torch.testing.assert_close(output_target["labels"], features.Label(torch.tensor([1, 2, 3, 4]))) + assert output_target["masks"].shape == (4, 32, 32) + torch.testing.assert_close(output_target["masks"][:2, :], target["masks"]) + torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"]) + + def test_forward(self): + pass diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 2fcb17e26bb..49c76b13850 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple import PIL.Image import torch @@ -10,7 +10,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F -from torchvision.transforms.functional import InterpolationMode +from torchvision.transforms.functional import InterpolationMode, pil_to_tensor from ._transform import _RandomApplyTransform from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image @@ -203,22 +203,11 @@ def _copy_paste( target: Dict[str, Any], paste_image: Any, paste_target: Dict[str, Any], + random_selection: torch.Tensor, blending: bool = True, resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR, ) -> Tuple[Any, Dict[str, Any]]: - # Random paste targets selection: - num_masks = len(paste_target["masks"]) - - if num_masks < 1: - # Such degerante case with num_masks=0 can happen with LSJ - # Let's just return (image, target) - return image, target - - # We have to please torch script by explicitly specifying dtype as torch.long - random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) - random_selection = torch.unique(random_selection).to(torch.long) - paste_masks = paste_target["masks"][random_selection] paste_boxes = paste_target["boxes"][random_selection] paste_labels = paste_target["labels"][random_selection] @@ -232,22 +221,14 @@ def _copy_paste( size1 = image.shape[-2:] size2 = paste_image.shape[-2:] if size1 != size2: - paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation) - paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST) - # resize bboxes: - ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device) - paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape) + paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation) + paste_masks = F.resize(paste_masks, size=size1) + paste_boxes = F.resize(paste_boxes, size=size1) paste_alpha_mask = paste_masks.sum(dim=0) > 0 if blending: - paste_alpha_mask = F.gaussian_blur( - paste_alpha_mask.unsqueeze(0), - kernel_size=(5, 5), - sigma=[ - 2.0, - ], - ) + paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0]) # Copy-paste images: image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) @@ -263,27 +244,19 @@ def _copy_paste( out_target["masks"] = torch.cat([masks, paste_masks]) # Copy-paste boxes and labels - boxes = masks_to_boxes(masks) + bbox_format = target["boxes"].format + xyxy_boxes = masks_to_boxes(masks) + # masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive + # we need to add +1 to x2y2 + xyxy_boxes[:, 2:] += 1 + boxes = F.convert_bounding_box_format(xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False) out_target["boxes"] = torch.cat([boxes, paste_boxes]) labels = target["labels"][non_all_zero_masks] out_target["labels"] = torch.cat([labels, paste_labels]) - # Update additional optional keys: area and iscrowd if exist - if "area" in target: - out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32) - - if "iscrowd" in target and "iscrowd" in paste_target: - # target['iscrowd'] size can be differ from mask size (non_all_zero_masks) - # For example, if previous transforms geometrically modifies masks/boxes/labels but - # does not update "iscrowd" - if len(target["iscrowd"]) == len(non_all_zero_masks): - iscrowd = target["iscrowd"][non_all_zero_masks] - paste_iscrowd = paste_target["iscrowd"][random_selection] - out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd]) - # Check for degenerated boxes and remove them - boxes = out_target["boxes"] + boxes = F.convert_bounding_box_format(out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY) degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] if degenerate_boxes.any(): valid_targets = ~degenerate_boxes.any(dim=1) @@ -292,24 +265,17 @@ def _copy_paste( out_target["masks"] = out_target["masks"][valid_targets] out_target["labels"] = out_target["labels"][valid_targets] - if "area" in out_target: - out_target["area"] = out_target["area"][valid_targets] - if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets): - out_target["iscrowd"] = out_target["iscrowd"][valid_targets] - return image, out_target - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - - flat_sample, spec = tree_flatten(sample) - + def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]: # fetch all images, bboxes, masks and labels from unstructured input # with List[image], List[BoundingBox], List[SegmentationMask], List[Label] images, bboxes, masks, labels = [], [], [], [] for obj in flat_sample: - if isinstance(obj, PIL.Image.Image, features.Image) or is_simple_tensor(obj): - images.append(F.to_image_tensor(obj)) + if isinstance(obj, features.Image) or is_simple_tensor(obj): + images.append(obj) + elif isinstance(obj, PIL.Image.Image): + images.append(pil_to_tensor(obj)) elif isinstance(obj, features.BoundingBox): bboxes.append(obj) elif isinstance(obj, features.SegmentationMask): @@ -327,28 +293,11 @@ def forward(self, *inputs: Any) -> Any: for bbox, mask, label in zip(bboxes, masks, labels): targets.append({"boxes": bbox, "masks": mask, "labels": label}) - # images = [t1, t2, ..., tN] - # Let's define paste_images as shifted list of input images - # paste_images = [t2, t3, ..., tN, t1] - # FYI: in TF they mix data on the dataset level - images_rolled = images[-1:] + images[:-1] - targets_rolled = targets[-1:] + targets[:-1] + return images, targets - output_images, output_targets = [], [] - - for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled): - output_image, output_data = self._copy_paste( - image, - target, - paste_image, - paste_target, - blending=self.blending, - resize_interpolation=self.resize_interpolation, - ) - output_images.append(output_image) - output_targets.append(output_data) - - # Insert updated images and targets into input flat_sample + def _insert_outputs( + self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]] + ) -> None: c0, c1, c2, c3 = 0, 0, 0, 0 for i, obj in enumerate(flat_sample): if isinstance(obj, features.Image): @@ -370,4 +319,48 @@ def forward(self, *inputs: Any) -> Any: flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] c3 += 1 + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + + flat_sample, spec = tree_flatten(sample) + + images, targets = self._extract_image_targets(flat_sample) + + # images = [t1, t2, ..., tN] + # Let's define paste_images as shifted list of input images + # paste_images = [t2, t3, ..., tN, t1] + # FYI: in TF they mix data on the dataset level + images_rolled = images[-1:] + images[:-1] + targets_rolled = targets[-1:] + targets[:-1] + + output_images, output_targets = [], [] + + for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled): + + # Random paste targets selection: + num_masks = len(paste_target["masks"]) + + if num_masks < 1: + # Such degerante case with num_masks=0 can happen with LSJ + # Let's just return (image, target) + output_image, output_target = image, target + else: + random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) + random_selection = torch.unique(random_selection).to(torch.long) + + output_image, output_target = self._copy_paste( + image, + target, + paste_image, + paste_target, + random_selection=random_selection, + blending=self.blending, + resize_interpolation=self.resize_interpolation, + ) + output_images.append(output_image) + output_targets.append(output_target) + + # Insert updated images and targets into input flat_sample + self._insert_outputs(flat_sample, output_images, output_targets) + return tree_unflatten(flat_sample, spec) From 541b3d3d58c94d9fc77c9738f82ce327b20a48b1 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 22 Aug 2022 14:34:06 +0000 Subject: [PATCH 4/7] Fixing code --- test/test_prototype_transforms.py | 37 +++++++++++++------- torchvision/prototype/transforms/_augment.py | 14 +++++--- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index fd6f9fc2c8f..8de32f67640 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1332,22 +1332,38 @@ def test__transform(self, mocker): class TestSimpleCopyPaste: - def test_forward_assertions(self): - pass + + def create_fake_image(self, mocker, image_type): + if image_type == PIL.Image.Image: + return PIL.Image.new("RGB", (32, 32), 123) + return mocker.MagicMock(spec=image_type) + + def test__extract_image_targets_assertion(self, mocker): + transform = transforms.SimpleCopyPaste() + + flat_sample = [ + # images, batch size = 2 + self.create_fake_image(mocker, features.Image), + # labels, bboxes, masks + mocker.MagicMock(spec=features.Label), + mocker.MagicMock(spec=features.BoundingBox), + mocker.MagicMock(spec=features.SegmentationMask), + # labels, bboxes, masks + mocker.MagicMock(spec=features.BoundingBox), + mocker.MagicMock(spec=features.SegmentationMask), + ] + + with pytest.raises(TypeError, match="requires input sample to contain equal-sized list of Images"): + transform._extract_image_targets(flat_sample) @pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor]) def test__extract_image_targets(self, image_type, mocker): transform = transforms.SimpleCopyPaste() - def create_fake_image(image_type): - if image_type == PIL.Image.Image: - return PIL.Image.new("RGB", (32, 32), 123) - return mocker.MagicMock(spec=image_type) - flat_sample = [ # images, batch size = 2 - create_fake_image(image_type), - create_fake_image(image_type), + self.create_fake_image(mocker, image_type), + self.create_fake_image(mocker, image_type), # labels, bboxes, masks mocker.MagicMock(spec=features.Label), mocker.MagicMock(spec=features.BoundingBox), @@ -1406,9 +1422,6 @@ def test__copy_paste(self): torch.testing.assert_close(output_target["masks"][:2, :], target["masks"]) torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"]) - def test_forward(self): - pass - class TestFixedSizeCrop: def test__get_params(self, mocker): diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 49c76b13850..ccb2d557db5 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -208,9 +208,9 @@ def _copy_paste( resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR, ) -> Tuple[Any, Dict[str, Any]]: - paste_masks = paste_target["masks"][random_selection] - paste_boxes = paste_target["boxes"][random_selection] - paste_labels = paste_target["labels"][random_selection] + paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection]) + paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection]) + paste_labels = paste_target["labels"].new_like(paste_target["labels"], paste_target["labels"][random_selection]) masks = target["masks"] @@ -249,14 +249,18 @@ def _copy_paste( # masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive # we need to add +1 to x2y2 xyxy_boxes[:, 2:] += 1 - boxes = F.convert_bounding_box_format(xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False) + boxes = F.convert_bounding_box_format( + xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False + ) out_target["boxes"] = torch.cat([boxes, paste_boxes]) labels = target["labels"][non_all_zero_masks] out_target["labels"] = torch.cat([labels, paste_labels]) # Check for degenerated boxes and remove them - boxes = F.convert_bounding_box_format(out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY) + boxes = F.convert_bounding_box_format( + out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY + ) degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] if degenerate_boxes.any(): valid_targets = ~degenerate_boxes.any(dim=1) From ec770d7741bf65d5474d4ebe2f2114552b709716 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 22 Aug 2022 14:40:00 +0000 Subject: [PATCH 5/7] Fixed code formatting issue --- test/test_prototype_transforms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 8de32f67640..62bf532fe46 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1332,7 +1332,6 @@ def test__transform(self, mocker): class TestSimpleCopyPaste: - def create_fake_image(self, mocker, image_type): if image_type == PIL.Image.Image: return PIL.Image.new("RGB", (32, 32), 123) From 3c7a9cd75e1ab583009a8d1e2674fa4d266460b8 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 22 Aug 2022 19:33:36 +0000 Subject: [PATCH 6/7] Minor updates --- torchvision/prototype/transforms/_augment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index ccb2d557db5..71da7464ac4 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -246,8 +246,8 @@ def _copy_paste( # Copy-paste boxes and labels bbox_format = target["boxes"].format xyxy_boxes = masks_to_boxes(masks) - # masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive - # we need to add +1 to x2y2 + # TODO: masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive + # we need to add +1 to x2y2. We need to investigate that. xyxy_boxes[:, 2:] += 1 boxes = F.convert_bounding_box_format( xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False @@ -350,7 +350,7 @@ def forward(self, *inputs: Any) -> Any: output_image, output_target = image, target else: random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) - random_selection = torch.unique(random_selection).to(torch.long) + random_selection = torch.unique(random_selection) output_image, output_target = self._copy_paste( image, From 902e0f22ada1404ab296a2f43d9e57ba22ddc8a6 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 23 Aug 2022 09:05:19 +0000 Subject: [PATCH 7/7] Fixed merge issue --- torchvision/prototype/transforms/_augment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 9bfbe718f4a..32697f32257 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -3,6 +3,7 @@ import warnings from typing import Any, Dict, List, Tuple +import PIL.Image import torch from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision.ops import masks_to_boxes