Skip to content

Commit

Permalink
fix COCO category IDs
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Oct 3, 2024
1 parent 739d6b1 commit f08b8cc
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 25 deletions.
5 changes: 2 additions & 3 deletions docs/source/user_guide/dataset_creation/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1499,9 +1499,8 @@ where `labels.json` is a JSON file in the following format:
...
],
"categories": [
...
{
"id": 2,
"id": 1,
"name": "cat",
"supercategory": "animal",
"keypoints": ["nose", "head", ...],
Expand All @@ -1524,7 +1523,7 @@ where `labels.json` is a JSON file in the following format:
{
"id": 1,
"image_id": 1,
"category_id": 2,
"category_id": 1,
"bbox": [260, 177, 231, 199],
"segmentation": [...],
"keypoints": [224, 226, 2, ...],
Expand Down
5 changes: 2 additions & 3 deletions docs/source/user_guide/export_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1646,9 +1646,8 @@ where `labels.json` is a JSON file in the following format:
},
"licenses": [],
"categories": [
...
{
"id": 2,
"id": 1,
"name": "cat",
"supercategory": "animal"
},
Expand All @@ -1669,7 +1668,7 @@ where `labels.json` is a JSON file in the following format:
{
"id": 1,
"image_id": 1,
"category_id": 2,
"category_id": 1,
"bbox": [260, 177, 231, 199],
"segmentation": [...],
"score": 0.95,
Expand Down
34 changes: 15 additions & 19 deletions fiftyone/utils/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,15 +563,11 @@ def setup(self):
self.labels_path, extra_attrs=self.extra_attrs
)

classes = None
if classes_map is not None:
classes = _to_classes(classes_map)

if classes is not None:
info["classes"] = classes
info["classes"] = _to_classes(classes_map)

image_ids = _get_matching_image_ids(
classes,
classes_map,
images,
annotations,
image_ids=self.image_ids,
Expand Down Expand Up @@ -907,12 +903,11 @@ def export_sample(self, image_or_path, label, metadata=None):

def close(self, *args):
if self._dynamic_classes:
classes = sorted(self._classes)
labels_map_rev = _to_labels_map_rev(classes)
labels_map_rev = _to_labels_map_rev(sorted(self._classes))
for anno in self._annotations:
anno["category_id"] = labels_map_rev[anno["category_id"]]
else:
classes = self.classes
elif self.categories is None:
labels_map_rev = _to_labels_map_rev(self.classes)

_info = self.info or {}
_date_created = datetime.now().replace(microsecond=0).isoformat()
Expand All @@ -933,10 +928,10 @@ def close(self, *args):
categories = [
{
"id": i,
"name": l,
"name": c,
"supercategory": None,
}
for i, l in enumerate(classes)
for c, i in sorted(labels_map_rev.items(), key=lambda t: t[1])
]

labels = {
Expand Down Expand Up @@ -1681,7 +1676,7 @@ def download_coco_dataset_split(
if classes is not None:
# Filter by specified classes
all_ids, any_ids = _get_images_with_classes(
image_ids, annotations, classes, all_classes
image_ids, annotations, classes, all_classes_map
)
else:
all_ids = image_ids
Expand Down Expand Up @@ -1846,7 +1841,7 @@ def _parse_include_license(include_license):


def _get_matching_image_ids(
all_classes,
classes_map,
images,
annotations,
image_ids=None,
Expand All @@ -1862,7 +1857,7 @@ def _get_matching_image_ids(

if classes is not None:
all_ids, any_ids = _get_images_with_classes(
image_ids, annotations, classes, all_classes
image_ids, annotations, classes, classes_map
)
else:
all_ids = image_ids
Expand Down Expand Up @@ -1930,7 +1925,7 @@ def _do_download(args):


def _get_images_with_classes(
image_ids, annotations, target_classes, all_classes
image_ids, annotations, target_classes, classes_map
):
if annotations is None:
logger.warning("Dataset is unlabeled; ignoring classes requirement")
Expand All @@ -1939,11 +1934,12 @@ def _get_images_with_classes(
if etau.is_str(target_classes):
target_classes = [target_classes]

bad_classes = [c for c in target_classes if c not in all_classes]
labels_map_rev = {c: i for i, c in classes_map.items()}

bad_classes = [c for c in target_classes if c not in labels_map_rev]
if bad_classes:
raise ValueError("Unsupported classes: %s" % bad_classes)

labels_map_rev = _to_labels_map_rev(all_classes)
class_ids = {labels_map_rev[c] for c in target_classes}

all_ids = []
Expand Down Expand Up @@ -2029,7 +2025,7 @@ def _load_image_ids_json(json_path):


def _to_labels_map_rev(classes):
return {c: i for i, c in enumerate(classes)}
return {c: i for i, c in enumerate(classes, 1)}


def _to_classes(classes_map):
Expand Down
59 changes: 59 additions & 0 deletions tests/unittests/import_export_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,65 @@ def test_coco_detection_dataset(self):
{c["id"] for c in categories2},
)

# Alphabetized 1-based categories by default

export_dir = self._new_dir()

dataset.export(
export_dir=export_dir,
dataset_type=fo.types.COCODetectionDataset,
)

dataset2 = fo.Dataset.from_dir(
dataset_dir=export_dir,
dataset_type=fo.types.COCODetectionDataset,
label_types="detections",
label_field="predictions",
)
categories2 = dataset2.info["categories"]

self.assertListEqual([c["id"] for c in categories2], [1, 2])
self.assertListEqual([c["name"] for c in categories2], ["cat", "dog"])

# Only load matching classes

export_dir = self._new_dir()

dataset.export(
export_dir=export_dir,
dataset_type=fo.types.COCODetectionDataset,
)

dataset2 = fo.Dataset.from_dir(
dataset_dir=export_dir,
dataset_type=fo.types.COCODetectionDataset,
label_types="detections",
label_field="predictions",
classes="cat",
only_matching=False,
)

self.assertEqual(len(dataset2), 2)
self.assertListEqual(
dataset2.distinct("predictions.detections.label"),
["cat", "dog"],
)

dataset3 = fo.Dataset.from_dir(
dataset_dir=export_dir,
dataset_type=fo.types.COCODetectionDataset,
label_types="detections",
label_field="predictions",
classes="cat",
only_matching=True,
)

self.assertEqual(len(dataset3), 2)
self.assertListEqual(
dataset3.distinct("predictions.detections.label"),
["cat"],
)

@drop_datasets
def test_voc_detection_dataset(self):
dataset = self._make_dataset()
Expand Down

0 comments on commit f08b8cc

Please sign in to comment.