From 28293dc5df923907f64fe00a07d5a18c486f005c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 13 Jul 2023 15:55:15 +0000 Subject: [PATCH 1/5] Properly handle maskrcnn in detection ref --- references/detection/coco_utils.py | 17 +++++++--------- references/detection/train.py | 31 ++++++++++++++++++++++++------ 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 313faacdb7c..88b8c069e41 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -219,7 +219,7 @@ def __getitem__(self, idx): return img, target -def get_coco(root, image_set, transforms, mode="instances", use_v2=False): +def get_coco(root, image_set, transforms, args, mode="instances"): anno_file_template = "{}_{}2017.json" PATHS = { "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), @@ -231,11 +231,14 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False): img_folder = os.path.join(root, img_folder) ann_file = os.path.join(root, ann_file) - if use_v2: + if args.use_v2: dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) - # TODO: need to update target_keys to handle masks for segmentation! - dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"}) + target_keys = ["boxes", "labels", "image_id"] + if args.with_masks: + target_keys += ["masks"] + dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) else: + # TODO: handle with_masks for V1? t = [ConvertCocoPolysToMask()] if transforms is not None: t.append(transforms) @@ -249,9 +252,3 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False): # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) return dataset - - -def get_coco_kp(root, image_set, transforms, use_v2=False): - if use_v2: - raise ValueError("KeyPoints aren't supported by transforms V2 yet.") - return get_coco(root, image_set, transforms, mode="person_keypoints") diff --git a/references/detection/train.py b/references/detection/train.py index db86f33aaa9..d722c63a13a 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -28,7 +28,7 @@ import torchvision.models.detection import torchvision.models.detection.mask_rcnn import utils -from coco_utils import get_coco, get_coco_kp +from coco_utils import get_coco from engine import evaluate, train_one_epoch from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler from torchvision.transforms import InterpolationMode @@ -42,10 +42,10 @@ def copypaste_collate_fn(batch): def get_dataset(is_train, args): image_set = "train" if is_train else "val" - paths = {"coco": (args.data_path, get_coco, 91), "coco_kp": (args.data_path, get_coco_kp, 2)} - p, ds_fn, num_classes = paths[args.dataset] - - ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2) + num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset] + ds = get_coco( + root=args.data_path, image_set=image_set, transforms=get_transform(is_train, args), args=args, mode=mode + ) return ds, num_classes @@ -68,7 +68,7 @@ def get_args_parser(add_help=True): parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help) parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path") - parser.add_argument("--dataset", default="coco", type=str, help="dataset name") + parser.add_argument("--dataset", default="coco", type=str, help="dataset name. Use coco_kp for Keypoint detection") parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument( @@ -164,6 +164,14 @@ def get_args_parser(add_help=True): parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") + parser.add_argument( + "--with-masks", + action="store_true", + help=( + "Whether the dataset should return masks. Only relevant when --use-v2 is passed. " + "True by default when using mask_rcnn." + ), + ) return parser @@ -171,6 +179,17 @@ def get_args_parser(add_help=True): def main(args): if args.backend.lower() == "datapoint" and not args.use_v2: raise ValueError("Use --use-v2 if you want to use the datapoint backend.") + if args.dataset not in ("coco", "coco_kp"): + raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}") + if "keypoint" in args.model and args.dataset != "coco_kp": + raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp") + if args.dataset == "coco_kp" and args.use_v2: + raise ValueError("KeyPoint detection doesn't support V2 transforms yet") + + if "mask" in args.model: + args.with_masks = True + + print(args.model) if args.output_dir: utils.mkdir(args.output_dir) From 03767e83a611a5febac2cd228f072507824e6710 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 13 Jul 2023 20:22:08 +0000 Subject: [PATCH 2/5] Remove unused FilterAndRemapCocoCategories --- references/detection/coco_utils.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 88b8c069e41..8b657e59434 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -10,24 +10,6 @@ from torchvision.datasets import wrap_dataset_for_transforms_v2 -class FilterAndRemapCocoCategories: - def __init__(self, categories, remap=True): - self.categories = categories - self.remap = remap - - def __call__(self, image, target): - anno = target["annotations"] - anno = [obj for obj in anno if obj["category_id"] in self.categories] - if not self.remap: - target["annotations"] = anno - return image, target - anno = copy.deepcopy(anno) - for obj in anno: - obj["category_id"] = self.categories.index(obj["category_id"]) - target["annotations"] = anno - return image, target - - def convert_coco_poly_to_mask(segmentations, height, width): masks = [] for polygons in segmentations: From dc584dcb514b2fa024e99d781cfffb43e58e5dbd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 24 Jul 2023 10:47:12 +0100 Subject: [PATCH 3/5] Apply suggestions from code review Co-authored-by: Philip Meier --- references/detection/train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index d722c63a13a..c82bf55ab66 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -68,7 +68,7 @@ def get_args_parser(add_help=True): parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help) parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path") - parser.add_argument("--dataset", default="coco", type=str, help="dataset name. Use coco_kp for Keypoint detection") + parser.add_argument("--dataset", default="coco", type=str, help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection") parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument( @@ -189,8 +189,6 @@ def main(args): if "mask" in args.model: args.with_masks = True - print(args.model) - if args.output_dir: utils.mkdir(args.output_dir) From d076b4c9ae52a309b04cc45e4cf24ae6119383d2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 24 Jul 2023 14:13:04 +0000 Subject: [PATCH 4/5] Address comments --- references/detection/coco_utils.py | 6 +++--- references/detection/train.py | 26 +++++++++++++------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 8b657e59434..0f7c1de84fe 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -201,7 +201,7 @@ def __getitem__(self, idx): return img, target -def get_coco(root, image_set, transforms, args, mode="instances"): +def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False): anno_file_template = "{}_{}2017.json" PATHS = { "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), @@ -213,10 +213,10 @@ def get_coco(root, image_set, transforms, args, mode="instances"): img_folder = os.path.join(root, img_folder) ann_file = os.path.join(root, ann_file) - if args.use_v2: + if use_v2: dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) target_keys = ["boxes", "labels", "image_id"] - if args.with_masks: + if with_masks: target_keys += ["masks"] dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) else: diff --git a/references/detection/train.py b/references/detection/train.py index c82bf55ab66..892ffbbbc1c 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -43,8 +43,14 @@ def copypaste_collate_fn(batch): def get_dataset(is_train, args): image_set = "train" if is_train else "val" num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset] + with_masks = "mask" in args.model ds = get_coco( - root=args.data_path, image_set=image_set, transforms=get_transform(is_train, args), args=args, mode=mode + root=args.data_path, + image_set=image_set, + transforms=get_transform(is_train, args), + mode=mode, + use_v2=args.use_v2, + with_masks=with_masks, ) return ds, num_classes @@ -68,7 +74,12 @@ def get_args_parser(add_help=True): parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help) parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path") - parser.add_argument("--dataset", default="coco", type=str, help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection") + parser.add_argument( + "--dataset", + default="coco", + type=str, + help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection", + ) parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument( @@ -164,14 +175,6 @@ def get_args_parser(add_help=True): parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") - parser.add_argument( - "--with-masks", - action="store_true", - help=( - "Whether the dataset should return masks. Only relevant when --use-v2 is passed. " - "True by default when using mask_rcnn." - ), - ) return parser @@ -186,9 +189,6 @@ def main(args): if args.dataset == "coco_kp" and args.use_v2: raise ValueError("KeyPoint detection doesn't support V2 transforms yet") - if "mask" in args.model: - args.with_masks = True - if args.output_dir: utils.mkdir(args.output_dir) From 044336a5d8f74fad77b259d1814cafdef656609d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 27 Jul 2023 09:53:22 +0100 Subject: [PATCH 5/5] Fix flake8 --- references/detection/coco_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 0f7c1de84fe..7cf19d39dc9 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -1,4 +1,3 @@ -import copy import os import torch