From f08b8cc9be17ff18385cf89e8e7f7c0b0a7fa1fa Mon Sep 17 00:00:00 2001 From: brimoor Date: Wed, 2 Oct 2024 23:28:08 +0200 Subject: [PATCH] fix COCO category IDs --- .../user_guide/dataset_creation/datasets.rst | 5 +- docs/source/user_guide/export_datasets.rst | 5 +- fiftyone/utils/coco.py | 34 +++++------ tests/unittests/import_export_tests.py | 59 +++++++++++++++++++ 4 files changed, 78 insertions(+), 25 deletions(-) diff --git a/docs/source/user_guide/dataset_creation/datasets.rst b/docs/source/user_guide/dataset_creation/datasets.rst index c25550a1911..ac4085bb3be 100644 --- a/docs/source/user_guide/dataset_creation/datasets.rst +++ b/docs/source/user_guide/dataset_creation/datasets.rst @@ -1499,9 +1499,8 @@ where `labels.json` is a JSON file in the following format: ... ], "categories": [ - ... { - "id": 2, + "id": 1, "name": "cat", "supercategory": "animal", "keypoints": ["nose", "head", ...], @@ -1524,7 +1523,7 @@ where `labels.json` is a JSON file in the following format: { "id": 1, "image_id": 1, - "category_id": 2, + "category_id": 1, "bbox": [260, 177, 231, 199], "segmentation": [...], "keypoints": [224, 226, 2, ...], diff --git a/docs/source/user_guide/export_datasets.rst b/docs/source/user_guide/export_datasets.rst index 293672544a2..810601036b4 100644 --- a/docs/source/user_guide/export_datasets.rst +++ b/docs/source/user_guide/export_datasets.rst @@ -1646,9 +1646,8 @@ where `labels.json` is a JSON file in the following format: }, "licenses": [], "categories": [ - ... { - "id": 2, + "id": 1, "name": "cat", "supercategory": "animal" }, @@ -1669,7 +1668,7 @@ where `labels.json` is a JSON file in the following format: { "id": 1, "image_id": 1, - "category_id": 2, + "category_id": 1, "bbox": [260, 177, 231, 199], "segmentation": [...], "score": 0.95, diff --git a/fiftyone/utils/coco.py b/fiftyone/utils/coco.py index 76a4fd494b0..ed6d2f6ae8e 100644 --- a/fiftyone/utils/coco.py +++ b/fiftyone/utils/coco.py @@ -563,15 +563,11 @@ def setup(self): self.labels_path, extra_attrs=self.extra_attrs ) - classes = None if classes_map is not None: - classes = _to_classes(classes_map) - - if classes is not None: - info["classes"] = classes + info["classes"] = _to_classes(classes_map) image_ids = _get_matching_image_ids( - classes, + classes_map, images, annotations, image_ids=self.image_ids, @@ -907,12 +903,11 @@ def export_sample(self, image_or_path, label, metadata=None): def close(self, *args): if self._dynamic_classes: - classes = sorted(self._classes) - labels_map_rev = _to_labels_map_rev(classes) + labels_map_rev = _to_labels_map_rev(sorted(self._classes)) for anno in self._annotations: anno["category_id"] = labels_map_rev[anno["category_id"]] - else: - classes = self.classes + elif self.categories is None: + labels_map_rev = _to_labels_map_rev(self.classes) _info = self.info or {} _date_created = datetime.now().replace(microsecond=0).isoformat() @@ -933,10 +928,10 @@ def close(self, *args): categories = [ { "id": i, - "name": l, + "name": c, "supercategory": None, } - for i, l in enumerate(classes) + for c, i in sorted(labels_map_rev.items(), key=lambda t: t[1]) ] labels = { @@ -1681,7 +1676,7 @@ def download_coco_dataset_split( if classes is not None: # Filter by specified classes all_ids, any_ids = _get_images_with_classes( - image_ids, annotations, classes, all_classes + image_ids, annotations, classes, all_classes_map ) else: all_ids = image_ids @@ -1846,7 +1841,7 @@ def _parse_include_license(include_license): def _get_matching_image_ids( - all_classes, + classes_map, images, annotations, image_ids=None, @@ -1862,7 +1857,7 @@ def _get_matching_image_ids( if classes is not None: all_ids, any_ids = _get_images_with_classes( - image_ids, annotations, classes, all_classes + image_ids, annotations, classes, classes_map ) else: all_ids = image_ids @@ -1930,7 +1925,7 @@ def _do_download(args): def _get_images_with_classes( - image_ids, annotations, target_classes, all_classes + image_ids, annotations, target_classes, classes_map ): if annotations is None: logger.warning("Dataset is unlabeled; ignoring classes requirement") @@ -1939,11 +1934,12 @@ def _get_images_with_classes( if etau.is_str(target_classes): target_classes = [target_classes] - bad_classes = [c for c in target_classes if c not in all_classes] + labels_map_rev = {c: i for i, c in classes_map.items()} + + bad_classes = [c for c in target_classes if c not in labels_map_rev] if bad_classes: raise ValueError("Unsupported classes: %s" % bad_classes) - labels_map_rev = _to_labels_map_rev(all_classes) class_ids = {labels_map_rev[c] for c in target_classes} all_ids = [] @@ -2029,7 +2025,7 @@ def _load_image_ids_json(json_path): def _to_labels_map_rev(classes): - return {c: i for i, c in enumerate(classes)} + return {c: i for i, c in enumerate(classes, 1)} def _to_classes(classes_map): diff --git a/tests/unittests/import_export_tests.py b/tests/unittests/import_export_tests.py index 54798733f58..f3e27dcb82c 100644 --- a/tests/unittests/import_export_tests.py +++ b/tests/unittests/import_export_tests.py @@ -1317,6 +1317,65 @@ def test_coco_detection_dataset(self): {c["id"] for c in categories2}, ) + # Alphabetized 1-based categories by default + + export_dir = self._new_dir() + + dataset.export( + export_dir=export_dir, + dataset_type=fo.types.COCODetectionDataset, + ) + + dataset2 = fo.Dataset.from_dir( + dataset_dir=export_dir, + dataset_type=fo.types.COCODetectionDataset, + label_types="detections", + label_field="predictions", + ) + categories2 = dataset2.info["categories"] + + self.assertListEqual([c["id"] for c in categories2], [1, 2]) + self.assertListEqual([c["name"] for c in categories2], ["cat", "dog"]) + + # Only load matching classes + + export_dir = self._new_dir() + + dataset.export( + export_dir=export_dir, + dataset_type=fo.types.COCODetectionDataset, + ) + + dataset2 = fo.Dataset.from_dir( + dataset_dir=export_dir, + dataset_type=fo.types.COCODetectionDataset, + label_types="detections", + label_field="predictions", + classes="cat", + only_matching=False, + ) + + self.assertEqual(len(dataset2), 2) + self.assertListEqual( + dataset2.distinct("predictions.detections.label"), + ["cat", "dog"], + ) + + dataset3 = fo.Dataset.from_dir( + dataset_dir=export_dir, + dataset_type=fo.types.COCODetectionDataset, + label_types="detections", + label_field="predictions", + classes="cat", + only_matching=True, + ) + + self.assertEqual(len(dataset3), 2) + self.assertListEqual( + dataset3.distinct("predictions.detections.label"), + ["cat"], + ) + @drop_datasets def test_voc_detection_dataset(self): dataset = self._make_dataset()