From 2169216e71b27ec8fd2e6eedee8802e4722a7880 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 2 Jul 2020 10:45:19 +0930 Subject: [PATCH] fix crop augmentation #127, #129 --- adet/data/augmentation.py | 123 +++++++++++++++++++++++++++++++++ adet/data/dataset_mapper.py | 128 +++++++++++++++++++++-------------- adet/data/detection_utils.py | 123 +++++++-------------------------- 3 files changed, 226 insertions(+), 148 deletions(-) create mode 100644 adet/data/augmentation.py diff --git a/adet/data/augmentation.py b/adet/data/augmentation.py new file mode 100644 index 000000000..fec98656d --- /dev/null +++ b/adet/data/augmentation.py @@ -0,0 +1,123 @@ +import numpy as np +from fvcore.transforms import transform as T + +from detectron2.data.transforms import RandomCrop, StandardAugInput +from detectron2.structures import BoxMode + + +class InstanceAugInput(StandardAugInput): + """ + Keep the old behavior of instance-aware augmentation + """ + + def __init__(self, *args, **kwargs): + instances = kwargs.pop("instances", None) + super().__init__(*args, **kwargs) + if instances is not None: + self.instances = instances + + +def gen_crop_transform_with_instance(crop_size, image_size, instances, crop_box=True): + """ + Generate a CropTransform so that the cropping region contains + the center of the given instance. + + Args: + crop_size (tuple): h, w in pixels + image_size (tuple): h, w + instance (dict): an annotation dict of one instance, in Detectron2's + dataset format. + """ + instance = (np.random.choice(instances),) + instance = instance[0] + crop_size = np.asarray(crop_size, dtype=np.int32) + bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS) + center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5 + assert ( + image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1] + ), "The annotation bounding box is outside of the image!" + assert ( + image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1] + ), "Crop size is larger than image size!" + + min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0) + max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0) + max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32)) + + y0 = np.random.randint(min_yx[0], max_yx[0] + 1) + x0 = np.random.randint(min_yx[1], max_yx[1] + 1) + + # if some instance is cropped extend the box + if not crop_box: + num_modifications = 0 + modified = True + + # convert crop_size to float + crop_size = crop_size.astype(np.float32) + while modified: + modified, x0, y0, crop_size = adjust_crop(x0, y0, crop_size, instances) + num_modifications += 1 + if num_modifications > 100: + raise ValueError( + "Cannot finished cropping adjustment within 100 tries (#instances {}).".format( + len(instances) + ) + ) + return T.CropTransform(0, 0, image_size[1], image_size[0]) + + return T.CropTransform(*map(int, (x0, y0, crop_size[1], crop_size[0]))) + + +def adjust_crop(x0, y0, crop_size, instances, eps=1e-3): + modified = False + + x1 = x0 + crop_size[1] + y1 = y0 + crop_size[0] + + for instance in instances: + bbox = BoxMode.convert( + instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS + ) + + if bbox[0] < x0 - eps and bbox[2] > x0 + eps: + crop_size[1] += x0 - bbox[0] + x0 = bbox[0] + modified = True + + if bbox[0] < x1 - eps and bbox[2] > x1 + eps: + crop_size[1] += bbox[2] - x1 + x1 = bbox[2] + modified = True + + if bbox[1] < y0 - eps and bbox[3] > y0 + eps: + crop_size[0] += y0 - bbox[1] + y0 = bbox[1] + modified = True + + if bbox[1] < y1 - eps and bbox[3] > y1 + eps: + crop_size[0] += bbox[3] - y1 + y1 = bbox[3] + modified = True + + return modified, x0, y0, crop_size + + +class RandomCropWithInstance(RandomCrop): + """ Instance-aware cropping. + """ + + def __init__(self, crop_type, crop_size, crop_instance=True): + """ + Args: + crop_instance (bool): if False, extend cropping boxes to avoid cropping instances + """ + super().__init__(crop_type, crop_size) + self.crop_instance = crop_instance + self.input_args = ("image", "instances") + + def get_transform(self, img, instances): + image_size = img.shape[:2] + crop_size = self.get_crop_size(image_size) + return gen_crop_transform_with_instance( + crop_size, image_size, instances, crop_box=self.crop_instance + ) diff --git a/adet/data/dataset_mapper.py b/adet/data/dataset_mapper.py index f101f5912..11e26544f 100755 --- a/adet/data/dataset_mapper.py +++ b/adet/data/dataset_mapper.py @@ -1,22 +1,21 @@ import copy -import numpy as np +import logging import os.path as osp + +import numpy as np import torch from fvcore.common.file_io import PathManager from PIL import Image -import logging +from pycocotools import mask as maskUtils -from detectron2.data.dataset_mapper import DatasetMapper -from detectron2.data.detection_utils import SizeMismatchError from detectron2.data import detection_utils as utils from detectron2.data import transforms as T +from detectron2.data.dataset_mapper import DatasetMapper +from detectron2.data.detection_utils import SizeMismatchError -from .detection_utils import ( - build_augmentation, - transform_instance_annotations, - annotations_to_instances, - gen_crop_transform_with_instance, -) +from .augmentation import InstanceAugInput, RandomCropWithInstance +from .detection_utils import (annotations_to_instances, build_augmentation, + transform_instance_annotations) """ This file contains the default mapping that's applied to "dataset dicts". @@ -27,6 +26,28 @@ logger = logging.getLogger(__name__) +def segmToRLE(segm, img_size): + h, w = img_size + if type(segm) == list: + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = maskUtils.frPyObjects(segm, h, w) + rle = maskUtils.merge(rles) + elif type(segm["counts"]) == list: + # uncompressed RLE + rle = maskUtils.frPyObjects(segm, h, w) + else: + # rle + rle = segm + return rle + + +def segmToMask(segm, img_size): + rle = segmToRLE(segm, img_size) + m = maskUtils.decode(rle) + return m + + class DatasetMapperWithBasis(DatasetMapper): """ This caller enables the default Detectron2 mapper to read an additional basis semantic label @@ -36,13 +57,27 @@ def __init__(self, cfg, is_train=True): super().__init__(cfg, is_train) # Rebuild augmentations - logger.info("Rebuilding the augmentations. The previous augmentations will be overridden.") + logger.info( + "Rebuilding the augmentations. The previous augmentations will be overridden." + ) self.augmentation = build_augmentation(cfg, is_train) + if cfg.INPUT.CROP.ENABLED and is_train: + self.augmentation.insert( + 0, + RandomCropWithInstance( + cfg.INPUT.CROP.TYPE, + cfg.INPUT.CROP.SIZE, + cfg.INPUT.CROP.CROP_INSTANCE, + ), + ) + logging.getLogger(__name__).info( + "Cropping used in training: " + str(self.augmentation[0]) + ) + # fmt: off - self.basis_loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON - self.ann_set = cfg.MODEL.BASIS_MODULE.ANN_SET - self.crop_box = cfg.INPUT.CROP.CROP_INSTANCE + self.basis_loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON + self.ann_set = cfg.MODEL.BASIS_MODULE.ANN_SET # fmt: on def __call__(self, dataset_dict): @@ -72,35 +107,27 @@ def __call__(self, dataset_dict): else: raise e - if "annotations" not in dataset_dict or len(dataset_dict["annotations"]) == 0: - image, transforms = T.apply_augmentations( - ([self.crop] if self.crop else []) + self.augmentation, image - ) + # USER: Remove if you don't do semantic/panoptic segmentation. + if "sem_seg_file_name" in dataset_dict: + sem_seg_gt = utils.read_image( + dataset_dict.pop("sem_seg_file_name"), "L" + ).squeeze(2) else: - # Crop around an instance if there are instances in the image. - # USER: Remove if you don't use cropping - if self.crop: - crop_tfm = gen_crop_transform_with_instance( - self.crop.get_crop_size(image.shape[:2]), - image.shape[:2], - dataset_dict["annotations"], - crop_box=self.crop_box, - ) - image = crop_tfm.apply_image(image) - try: - image, transforms = T.apply_augmentations(self.augmentation, image) - except ValueError as e: - print(dataset_dict["file_name"]) - raise e - if self.crop: - transforms = crop_tfm + transforms + sem_seg_gt = None - image_shape = image.shape[:2] # h, w + aug_input = InstanceAugInput(image, sem_seg=sem_seg_gt, instances=dataset_dict["annotations"]) + transforms = aug_input.apply_augmentations(self.augmentation) + image, sem_seg_gt = aug_input.image, aug_input.sem_seg + image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. - dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + dataset_dict["image"] = torch.as_tensor( + np.ascontiguousarray(image.transpose(2, 0, 1)) + ) + if sem_seg_gt is not None: + dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) # USER: Remove if you don't use pre-computed proposals. # Most users would not need this feature. @@ -130,7 +157,10 @@ def __call__(self, dataset_dict): # USER: Implement additional transformations if you have other types of data annos = [ transform_instance_annotations( - obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices + obj, + transforms, + image_shape, + keypoint_hflip_indices=self.keypoint_hflip_indices, ) for obj in dataset_dict.pop("annotations") if obj.get("iscrowd", 0) == 0 @@ -143,24 +173,24 @@ def __call__(self, dataset_dict): # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to - # the intersection of original bounding box and the cropping box. - if self.crop and instances.has("gt_masks"): + if self.compute_tight_boxes and instances.has("gt_masks"): instances.gt_boxes = instances.gt_masks.get_bounding_boxes() dataset_dict["instances"] = utils.filter_empty_instances(instances) - # USER: Remove if you don't do semantic/panoptic segmentation. - if "sem_seg_file_name" in dataset_dict: - sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) - sem_seg_gt = transforms.apply_segmentation(sem_seg_gt) - sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) - dataset_dict["sem_seg"] = sem_seg_gt - if self.basis_loss_on and self.is_train: # load basis supervisions if self.ann_set == "coco": - basis_sem_path = dataset_dict["file_name"].replace('train2017', 'thing_train2017').replace('image/train', 'thing_train') + basis_sem_path = ( + dataset_dict["file_name"] + .replace("train2017", "thing_train2017") + .replace("image/train", "thing_train") + ) else: - basis_sem_path = dataset_dict["file_name"].replace('coco', 'lvis').replace('train2017', 'thing_train') + basis_sem_path = ( + dataset_dict["file_name"] + .replace("coco", "lvis") + .replace("train2017", "thing_train") + ) # change extension to npz basis_sem_path = osp.splitext(basis_sem_path)[0] + ".npz" basis_sem_gt = np.load(basis_sem_path)["mask"] diff --git a/adet/data/detection_utils.py b/adet/data/detection_utils.py index 2dd8f492d..1448e4338 100644 --- a/adet/data/detection_utils.py +++ b/adet/data/detection_utils.py @@ -1,94 +1,13 @@ import logging -import numpy as np +import numpy as np import torch from detectron2.data import transforms as T -from detectron2.data.detection_utils import transform_instance_annotations as d2_transform_inst_anno -from detectron2.data.detection_utils import annotations_to_instances as d2_anno_to_inst -from detectron2.structures import BoxMode - - -logger = logging.getLogger(__name__) - - -def gen_crop_transform_with_instance(crop_size, image_size, instances, crop_box=True): - """ - Generate a CropTransform so that the cropping region contains - the center of the given instance. - - Args: - crop_size (tuple): h, w in pixels - image_size (tuple): h, w - instance (dict): an annotation dict of one instance, in Detectron2's - dataset format. - """ - instance = np.random.choice(instances), - instance = instance[0] - crop_size = np.asarray(crop_size, dtype=np.int32) - bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS) - center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5 - assert ( - image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1] - ), "The annotation bounding box is outside of the image!" - assert ( - image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1] - ), "Crop size is larger than image size!" - - min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0) - max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0) - max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32)) - - y0 = np.random.randint(min_yx[0], max_yx[0] + 1) - x0 = np.random.randint(min_yx[1], max_yx[1] + 1) - - # if some instance is cropped extend the box - if not crop_box: - num_modifications = 0 - modified = True - - # convert crop_size to float - crop_size = crop_size.astype(np.float32) - while modified: - modified, x0, y0, crop_size = adjust_crop(x0, y0, crop_size, instances) - num_modifications += 1 - if num_modifications > 100: - logger.info("Cannot finished cropping adjustment within 100 tries (#instances {}).".format(len(instances))) - return T.CropTransform(0, 0, image_size[1], image_size[0]) - - return T.CropTransform(*map(int, (x0, y0, crop_size[1], crop_size[0]))) - - -def adjust_crop(x0, y0, crop_size, instances, eps=1e-3): - modified = False - - x1 = x0 + crop_size[1] - y1 = y0 + crop_size[0] - - for instance in instances: - bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS) - - if bbox[0] < x0 - eps and bbox[2] > x0 + eps: - crop_size[1] += x0 - bbox[0] - x0 = bbox[0] - modified = True - - if bbox[0] < x1 - eps and bbox[2] > x1 + eps: - crop_size[1] += bbox[2] - x1 - x1 = bbox[2] - modified = True - - if bbox[1] < y0 - eps and bbox[3] > y0 + eps: - crop_size[0] += y0 - bbox[1] - y0 = bbox[1] - modified = True - - if bbox[1] < y1 - eps and bbox[3] > y1 + eps: - crop_size[0] += bbox[3] - y1 - y1 = bbox[3] - modified = True - - return modified, x0, y0, crop_size +from detectron2.data.detection_utils import \ + annotations_to_instances as d2_anno_to_inst +from detectron2.data.detection_utils import \ + transform_instance_annotations as d2_transform_inst_anno def transform_instance_annotations( @@ -96,13 +15,14 @@ def transform_instance_annotations( ): annotation = d2_transform_inst_anno( - annotation, transforms, image_size, - keypoint_hflip_indices=keypoint_hflip_indices) + annotation, + transforms, + image_size, + keypoint_hflip_indices=keypoint_hflip_indices, + ) if "beziers" in annotation: - beziers = transform_beziers_annotations( - annotation["beziers"], transforms - ) + beziers = transform_beziers_annotations(annotation["beziers"], transforms) annotation["beziers"] = beziers return annotation @@ -120,7 +40,9 @@ def transform_beziers_annotations(beziers, transforms): beziers = transforms.apply_coords(beziers).reshape(-1) # This assumes that HorizFlipTransform is the only one that does flip - do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1 + do_hflip = ( + sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1 + ) if do_hflip: raise ValueError("Flipping text data is not supported (also disencouraged).") @@ -130,16 +52,17 @@ def transform_beziers_annotations(beziers, transforms): def annotations_to_instances(annos, image_size, mask_format="polygon"): instance = d2_anno_to_inst(annos, image_size, mask_format) + if not annos: + return instance + # add attributes if "beziers" in annos[0]: beziers = [obj.get("beziers", []) for obj in annos] - instance.beziers = torch.as_tensor( - beziers, dtype=torch.float32) + instance.beziers = torch.as_tensor(beziers, dtype=torch.float32) if "rec" in annos[0]: text = [obj.get("rec", []) for obj in annos] - instance.text = torch.as_tensor( - text, dtype=torch.int32) + instance.text = torch.as_tensor(text, dtype=torch.int32) return instance @@ -160,9 +83,11 @@ def build_augmentation(cfg, is_train): max_size = cfg.INPUT.MAX_SIZE_TEST sample_style = "choice" if sample_style == "range": - assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format( - len(min_size) - ) + assert ( + len(min_size) == 2 + ), "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size)) + + logger = logging.getLogger(__name__) augmentation = [] augmentation.append(T.ResizeShortestEdge(min_size, max_size, sample_style))