Skip to content

Commit

Permalink
split unlabeled data into subsets for task-specific splitters (#211)
Browse files Browse the repository at this point in the history
* split unlabeled data into subsets for classification, detection. for re-id, 'not-supported' subsets for this data
  • Loading branch information
Jihyeon Yi authored Apr 9, 2021
1 parent 7b42340 commit 8b4a997
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 68 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- LabelMe format saves dataset items with their relative paths by subsets without changing names (<https://github.com/openvinotoolkit/datumaro/pull/200>)
- Allowed arbitrary subset count and names in classification and detection splitters (<https://github.com/openvinotoolkit/datumaro/pull/207>)
- Annotation-less dataset elements are now participate in subset splitting (<https://github.com/openvinotoolkit/datumaro/pull/211>)

### Deprecated
-
Expand Down
103 changes: 81 additions & 22 deletions datumaro/plugins/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def __init__(self, dataset, splits, seed, restrict=False):
self._seed = seed

# remove subset name restriction
# regarding https://github.com/openvinotoolkit/datumaro/issues/194
# self._subsets = {"train", "val", "test"} # output subset names
# https://github.com/openvinotoolkit/datumaro/issues/194
self._subsets = subsets
self._parts = []
self._length = "parent"
Expand All @@ -65,24 +64,27 @@ def _set_parts(self, by_splits):
@staticmethod
def _get_uniq_annotations(dataset):
annotations = []
for item in dataset:
unlabeled_or_multi = []

for idx, item in enumerate(dataset):
labels = [a for a in item.annotations
if a.type == AnnotationType.label]
if len(labels) != 1:
raise Exception("Item '%s' contains %s labels, "
"but exactly one is expected" % (item.id, len(labels)))
annotations.append(labels[0])
return annotations
if len(labels) == 1:
annotations.append(labels[0])
else:
unlabeled_or_multi.append(idx)

return annotations, unlabeled_or_multi

@staticmethod
def _validate_splits(splits, restrict=False):
snames = []
ratios = []
subsets = set()
valid = ["train", "val", "test"]
# remove subset name restriction
# regarding https://github.com/openvinotoolkit/datumaro/issues/194
for subset, ratio in splits:
# remove subset name restriction
# https://github.com/openvinotoolkit/datumaro/issues/194
if restrict:
assert subset in valid, \
"Subset name must be one of %s, got %s" % (valid, subset)
Expand Down Expand Up @@ -143,7 +145,7 @@ def _get_sections(dataset_size, ratio):
n_splits[ii] += 1
n_splits[midx] -= 1
sections = np.add.accumulate(n_splits[:-1])
return sections
return sections, n_splits

@staticmethod
def _group_by_attr(items):
Expand Down Expand Up @@ -187,7 +189,7 @@ def _split_by_attr(self, datasets, snames, ratio, out_splits,
merge_small_classes=True):

def _split_indice(indice):
sections = self._get_sections(len(indice), ratio)
sections, _ = self._get_sections(len(indice), ratio)
splits = np.array_split(indice, sections)
for subset, split in zip(snames, splits):
if 0 < len(split):
Expand Down Expand Up @@ -223,6 +225,26 @@ def _split_indice(indice):
if len(rest) > 0:
_split_indice(rest)

def _split_unlabeled(self, unlabeled, by_splits):
"""
split unlabeled data into subsets (detection, classification)
Args:
unlabeled: list of index of unlabeled or multi-labeled data
by_splits: splits up to now
Returns:
by_splits: final splits
"""
dataset_size = len(self._extractor)
_, n_splits = list(self._get_sections(dataset_size, self._sratio))
counts = [len(by_splits[sname]) for sname in self._snames]
expected = [max(0, v) for v in np.subtract(n_splits, counts)]
sections = np.add.accumulate(expected[:-1])
np.random.shuffle(unlabeled)
splits = np.array_split(unlabeled, sections)
for subset, split in zip(self._snames, splits):
if 0 < len(split):
by_splits[subset].extend(split)

def _find_split(self, index):
for subset_indices, subset in self._parts:
if index in subset_indices:
Expand All @@ -248,7 +270,8 @@ class ClassificationSplit(_TaskSpecificSplit):
distribution.|n
|n
Notes:|n
- Each image is expected to have only one Label|n
- Each image is expected to have only one Label. Unlabeled or
multi-labeled images will be split into subsets randomly. |n
- If Labels also have attributes, also splits by attribute values.|n
- If there is not enough images in some class or attributes group,
the split ratio can't be guaranteed.|n
Expand All @@ -274,7 +297,7 @@ def _split_dataset(self):
# support only single label for a DatasetItem
# 1. group by label
by_labels = dict()
annotations = self._get_uniq_annotations(self._extractor)
annotations, unlabeled = self._get_uniq_annotations(self._extractor)

for idx, ann in enumerate(annotations):
label = getattr(ann, 'label', None)
Expand All @@ -288,6 +311,12 @@ def _split_dataset(self):

# 2. group by attributes
self._split_by_attr(by_labels, self._snames, self._sratio, by_splits)

# 3. split unlabeled data
if len(unlabeled) > 0:
self._split_unlabeled(unlabeled, by_splits)

# 4. set parts
self._set_parts(by_splits)


Expand All @@ -310,7 +339,8 @@ class ReidentificationSplit(_TaskSpecificSplit):
'train', 'val', 'test-gallery' and 'test-query'. |n
|n
Notes:|n
- Each image is expected to have a single Label|n
- Each image is expected to have a single Label. Unlabeled or multi-labeled
images will be split into 'not-supported'.|n
- Object ID can be described by Label, or by attribute (--attr parameter)|n
- The splits of the test set are controlled by '--query' parameter. |n
|s|sGallery ratio would be 1.0 - query.|n
Expand Down Expand Up @@ -377,7 +407,7 @@ def _split_dataset(self):

# group by ID(attr_for_id)
by_id = dict()
annotations = self._get_uniq_annotations(dataset)
annotations, unlabeled = self._get_uniq_annotations(dataset)
if attr_for_id is None: # use label
for idx, ann in enumerate(annotations):
ID = getattr(ann, 'label', None)
Expand Down Expand Up @@ -408,7 +438,7 @@ def _split_dataset(self):
split_ratio = np.array([test, 1.0 - test])
IDs = list(by_id.keys())
np.random.shuffle(IDs)
sections = self._get_sections(len(IDs), split_ratio)
sections, _ = self._get_sections(len(IDs), split_ratio)
splits = np.array_split(IDs, sections)
testset = {pid: by_id[pid] for pid in splits[0]}
trval = {pid: by_id[pid] for pid in splits[1]}
Expand Down Expand Up @@ -458,6 +488,11 @@ def _split_dataset(self):
self._split_by_attr(trval, trval_snames, trval_ratio, by_splits,
merge_small_classes=False)

# split unlabeled data into 'not-supported'.
if len(unlabeled) > 0:
self._subsets.add("not-supported")
by_splits["not-supported"] = unlabeled

self._set_parts(by_splits)

@staticmethod
Expand Down Expand Up @@ -506,6 +541,20 @@ def _rebalancing(test, trval, expected_count, testset_total):
test[id_trval] = trval.pop(id_trval)
trval[id_test] = test.pop(id_test)

def get_subset(self, name):
# lazy splitting
if self._initialized is False:
self._split_dataset()
self._initialized = True
return super().get_subset(name)

def subsets(self):
# lazy splitting
if self._initialized is False:
self._split_dataset()
self._initialized = True
return super().subsets()


class DetectionSplit(_TaskSpecificSplit):
"""
Expand Down Expand Up @@ -545,26 +594,28 @@ def __init__(self, dataset, splits, seed=None):
@staticmethod
def _group_by_bbox_labels(dataset):
by_labels = dict()
unlabeled = []
for idx, item in enumerate(dataset):
bbox_anns = [a for a in item.annotations
if a.type == AnnotationType.bbox]
assert 0 < len(bbox_anns), \
"Expected more than one bbox annotation in the dataset"
if len(bbox_anns) == 0:
unlabeled.append(idx)
continue
for ann in bbox_anns:
label = getattr(ann, 'label', None)
if label not in by_labels:
by_labels[label] = [(idx, ann)]
else:
by_labels[label].append((idx, ann))
return by_labels
return by_labels, unlabeled

def _split_dataset(self):
np.random.seed(self._seed)

subsets, sratio = self._snames, self._sratio

# 1. group by bbox label
by_labels = self._group_by_bbox_labels(self._extractor)
by_labels, unlabeled = self._group_by_bbox_labels(self._extractor)

# 2. group by attributes
required = self._get_required(sratio)
Expand Down Expand Up @@ -595,7 +646,11 @@ def _split_dataset(self):
n_combs = [len(v) for v in by_combinations]

# 3-1. initially count per-image GT samples
counts_all = {idx: dict() for idx in range(total)}
counts_all = {}
for idx in range(total):
if idx not in unlabeled:
counts_all[idx] = dict()

for idx_comb, indice in enumerate(by_combinations):
for idx in indice:
if idx_comb not in counts_all[idx]:
Expand Down Expand Up @@ -668,4 +723,8 @@ def update_nc(counts, n_combs):
by_splits[sname].append(idx)
update_nc(counts, nc)

# split unlabeled data
if len(unlabeled) > 0:
self._split_unlabeled(unlabeled, by_splits)

self._set_parts(by_splits)
91 changes: 45 additions & 46 deletions tests/test_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,29 +236,27 @@ def test_split_for_classification_zero_ratio(self):
self.assertEqual(4, len(actual.get_subset("val")))
self.assertEqual(0, len(actual.get_subset("test")))

def test_split_for_classification_gives_error(self):
def test_split_for_classification_unlabeled(self):
with self.subTest("no label"):
source = Dataset.from_iterable([
DatasetItem(1, annotations=[]),
DatasetItem(2, annotations=[]),
], categories=["a", "b", "c"])
iterable = [DatasetItem(i, annotations=[]) for i in range(10)]
source = Dataset.from_iterable(iterable, categories=["a", "b"])
splits = [("train", 0.7), ("test", 0.3)]
actual = splitter.ClassificationSplit(source, splits)

with self.assertRaisesRegex(Exception, "exactly one is expected"):
splits = [("train", 0.7), ("test", 0.3)]
actual = splitter.ClassificationSplit(source, splits)
len(actual.get_subset("train"))
self.assertEqual(7, len(actual.get_subset("train")))
self.assertEqual(3, len(actual.get_subset("test")))

with self.subTest("multi label"):
source = Dataset.from_iterable([
DatasetItem(1, annotations=[Label(0), Label(1)]),
DatasetItem(2, annotations=[Label(0), Label(2)]),
], categories=["a", "b", "c"])
anns = [Label(0), Label(1)]
iterable = [DatasetItem(i, annotations=anns) for i in range(10)]
source = Dataset.from_iterable(iterable, categories=["a", "b"])
splits = [("train", 0.7), ("test", 0.3)]
actual = splitter.ClassificationSplit(source, splits)

with self.assertRaisesRegex(Exception, "exactly one is expected"):
splits = [("train", 0.7), ("test", 0.3)]
splitter.ClassificationSplit(source, splits)
len(actual.get_subset("train"))
self.assertEqual(7, len(actual.get_subset("train")))
self.assertEqual(3, len(actual.get_subset("test")))

def test_split_for_classification_gives_error(self):
source = Dataset.from_iterable([
DatasetItem(1, annotations=[Label(0)]),
DatasetItem(2, annotations=[Label(1)]),
Expand Down Expand Up @@ -396,30 +394,27 @@ def test_split_for_reidentification_rebalance(self):
self.assertEqual(90, len(actual.get_subset("test-gallery")))
self.assertEqual(120, len(actual.get_subset("test-query")))

def test_split_for_reidentification_gives_error(self):
query = 0.4 / 0.7 # valid query ratio
def test_split_for_reidentification_unlabeled(self):
query = 0.5

with self.subTest("no label"):
source = Dataset.from_iterable([
DatasetItem(1, annotations=[]),
DatasetItem(2, annotations=[]),
], categories=["a", "b", "c"])
iterable = [DatasetItem(i, annotations=[]) for i in range(10)]
source = Dataset.from_iterable(iterable, categories=["a", "b"])
splits = [("train", 0.6), ("test", 0.4)]
actual = splitter.ReidentificationSplit(source, splits, query)
self.assertEqual(10, len(actual.get_subset("not-supported")))

with self.assertRaisesRegex(Exception, "exactly one is expected"):
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
actual = splitter.ReidentificationSplit(source, splits, query)
len(actual.get_subset("train"))
with self.subTest("multi label"):
anns = [Label(0), Label(1)]
iterable = [DatasetItem(i, annotations=anns) for i in range(10)]
source = Dataset.from_iterable(iterable, categories=["a", "b"])
splits = [("train", 0.6), ("test", 0.4)]
actual = splitter.ReidentificationSplit(source, splits, query)

with self.subTest(msg="multi label"):
source = Dataset.from_iterable([
DatasetItem(1, annotations=[Label(0), Label(1)]),
DatasetItem(2, annotations=[Label(0), Label(2)]),
], categories=["a", "b", "c"])
self.assertEqual(10, len(actual.get_subset("not-supported")))

with self.assertRaisesRegex(Exception, "exactly one is expected"):
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
actual = splitter.ReidentificationSplit(source, splits, query)
len(actual.get_subset("train"))
def test_split_for_reidentification_gives_error(self):
query = 0.4 / 0.7 # valid query ratio

counts = {i: (i % 3 + 1) * 7 for i in range(10)}
config = {"person": {"attrs": ["PID"], "counts": counts}}
Expand Down Expand Up @@ -638,18 +633,22 @@ def test_split_for_detection(self):
list(r1.get_subset("test")), list(r3.get_subset("test"))
)

def test_split_for_detection_gives_error(self):
with self.subTest(msg="bbox annotation"):
source = Dataset.from_iterable([
DatasetItem(1, annotations=[Label(0), Label(1)]),
DatasetItem(2, annotations=[Label(0), Label(2)]),
], categories=["a", "b", "c"])
def test_split_for_detection_with_unlabeled(self):
source, _ = self._generate_detection_dataset(
append_bbox=self._get_append_bbox("cvat"),
with_attr=True,
nimages=10,
)
for i in range(10):
source.put(DatasetItem(i + 10, annotations={}))

with self.assertRaisesRegex(Exception, "more than one bbox"):
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
actual = splitter.DetectionSplit(source, splits)
len(actual.get_subset("train"))
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
actual = splitter.DetectionSplit(source, splits)
self.assertEqual(10, len(actual.get_subset("train")))
self.assertEqual(4, len(actual.get_subset("val")))
self.assertEqual(6, len(actual.get_subset("test")))

def test_split_for_detection_gives_error(self):
source, _ = self._generate_detection_dataset(
append_bbox=self._get_append_bbox("cvat"),
with_attr=True,
Expand Down

0 comments on commit 8b4a997

Please sign in to comment.