Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split unlabeled data into subsets for task-specific splitters #211

Merged
merged 3 commits into from
Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 80 additions & 18 deletions datumaro/plugins/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,19 @@ def _set_parts(self, by_splits):
@staticmethod
def _get_uniq_annotations(dataset):
annotations = []
for item in dataset:
unlabeled_or_multi = []

for idx, item in enumerate(dataset):
labels = [a for a in item.annotations
if a.type == AnnotationType.label]
if len(labels) != 1:
raise Exception("Item '%s' contains %s labels, "
"but exactly one is expected" % (item.id, len(labels)))
annotations.append(labels[0])
return annotations
if len(labels) == 1:
annotations.append(labels[0])
else:
unlabeled_or_multi.append(idx)
# raise Exception("Item '%s' contains %s labels, "
# "but exactly one is expected" % (item.id, len(labels)))

return annotations, unlabeled_or_multi

@staticmethod
def _validate_splits(splits, restrict=False):
Expand Down Expand Up @@ -143,7 +148,7 @@ def _get_sections(dataset_size, ratio):
n_splits[ii] += 1
n_splits[midx] -= 1
sections = np.add.accumulate(n_splits[:-1])
return sections
return sections, n_splits

@staticmethod
def _group_by_attr(items):
Expand Down Expand Up @@ -187,7 +192,7 @@ def _split_by_attr(self, datasets, snames, ratio, out_splits,
merge_small_classes=True):

def _split_indice(indice):
sections = self._get_sections(len(indice), ratio)
sections, _ = self._get_sections(len(indice), ratio)
splits = np.array_split(indice, sections)
for subset, split in zip(snames, splits):
if 0 < len(split):
Expand Down Expand Up @@ -223,6 +228,26 @@ def _split_indice(indice):
if len(rest) > 0:
_split_indice(rest)

def _split_unlabeled(self, unlabeled, by_splits):
"""
split unlabeled data into subsets (detection, classification)
Args:
unlabeled: list of index of unlabeled or multi-labeled data
by_splits: splits up to now
Returns:
by_splits: final splits
"""
dataset_size = len(self._extractor)
_, n_splits = list(self._get_sections(dataset_size, self._sratio))
counts = [len(by_splits[sname]) for sname in self._snames]
expected = [max(0, v) for v in np.subtract(n_splits, counts)]
sections = np.add.accumulate(expected[:-1])
np.random.shuffle(unlabeled)
splits = np.array_split(unlabeled, sections)
for subset, split in zip(self._snames, splits):
if 0 < len(split):
by_splits[subset].extend(split)

def _find_split(self, index):
for subset_indices, subset in self._parts:
if index in subset_indices:
Expand All @@ -248,7 +273,8 @@ class ClassificationSplit(_TaskSpecificSplit):
distribution.|n
|n
Notes:|n
- Each image is expected to have only one Label|n
- Each image is expected to have only one Label. Unlabeled or
multi-labeled images will be split into subsets randomly. |n
- If Labels also have attributes, also splits by attribute values.|n
- If there is not enough images in some class or attributes group,
the split ratio can't be guaranteed.|n
Expand All @@ -274,7 +300,7 @@ def _split_dataset(self):
# support only single label for a DatasetItem
# 1. group by label
by_labels = dict()
annotations = self._get_uniq_annotations(self._extractor)
annotations, unlabeled = self._get_uniq_annotations(self._extractor)

for idx, ann in enumerate(annotations):
label = getattr(ann, 'label', None)
Expand All @@ -288,6 +314,12 @@ def _split_dataset(self):

# 2. group by attributes
self._split_by_attr(by_labels, self._snames, self._sratio, by_splits)

# 3. split unlabeled data
if len(unlabeled) > 0:
self._split_unlabeled(unlabeled, by_splits)

# 4. set parts
self._set_parts(by_splits)


Expand All @@ -310,7 +342,8 @@ class ReidentificationSplit(_TaskSpecificSplit):
'train', 'val', 'test-gallery' and 'test-query'. |n
|n
Notes:|n
- Each image is expected to have a single Label|n
- Each image is expected to have a single Label. Unlabeled or multi-labeled
images will be split into 'not-supported'.|n
- Object ID can be described by Label, or by attribute (--attr parameter)|n
- The splits of the test set are controlled by '--query' parameter. |n
|s|sGallery ratio would be 1.0 - query.|n
Expand Down Expand Up @@ -377,7 +410,7 @@ def _split_dataset(self):

# group by ID(attr_for_id)
by_id = dict()
annotations = self._get_uniq_annotations(dataset)
annotations, unlabeled = self._get_uniq_annotations(dataset)
if attr_for_id is None: # use label
for idx, ann in enumerate(annotations):
ID = getattr(ann, 'label', None)
Expand Down Expand Up @@ -408,7 +441,7 @@ def _split_dataset(self):
split_ratio = np.array([test, 1.0 - test])
IDs = list(by_id.keys())
np.random.shuffle(IDs)
sections = self._get_sections(len(IDs), split_ratio)
sections, _ = self._get_sections(len(IDs), split_ratio)
splits = np.array_split(IDs, sections)
testset = {pid: by_id[pid] for pid in splits[0]}
trval = {pid: by_id[pid] for pid in splits[1]}
Expand Down Expand Up @@ -458,6 +491,11 @@ def _split_dataset(self):
self._split_by_attr(trval, trval_snames, trval_ratio, by_splits,
merge_small_classes=False)

# split unlabeled data into 'not-supported'.
if len(unlabeled) > 0:
self._subsets.add("not-supported")
by_splits["not-supported"] = unlabeled

self._set_parts(by_splits)

@staticmethod
Expand Down Expand Up @@ -506,6 +544,20 @@ def _rebalancing(test, trval, expected_count, testset_total):
test[id_trval] = trval.pop(id_trval)
trval[id_test] = test.pop(id_test)

def get_subset(self, name):
# lazy splitting
if self._initialized is False:
self._split_dataset()
self._initialized = True
return super().get_subset(name)

def subsets(self):
# lazy splitting
if self._initialized is False:
self._split_dataset()
self._initialized = True
return super().subsets()


class DetectionSplit(_TaskSpecificSplit):
"""
Expand Down Expand Up @@ -545,26 +597,28 @@ def __init__(self, dataset, splits, seed=None):
@staticmethod
def _group_by_bbox_labels(dataset):
by_labels = dict()
unlabeled = []
for idx, item in enumerate(dataset):
bbox_anns = [a for a in item.annotations
if a.type == AnnotationType.bbox]
assert 0 < len(bbox_anns), \
"Expected more than one bbox annotation in the dataset"
if len(bbox_anns) == 0:
unlabeled.append(idx)
continue
for ann in bbox_anns:
label = getattr(ann, 'label', None)
if label not in by_labels:
by_labels[label] = [(idx, ann)]
else:
by_labels[label].append((idx, ann))
return by_labels
return by_labels, unlabeled

def _split_dataset(self):
np.random.seed(self._seed)

subsets, sratio = self._snames, self._sratio

# 1. group by bbox label
by_labels = self._group_by_bbox_labels(self._extractor)
by_labels, unlabeled = self._group_by_bbox_labels(self._extractor)

# 2. group by attributes
required = self._get_required(sratio)
Expand Down Expand Up @@ -595,7 +649,11 @@ def _split_dataset(self):
n_combs = [len(v) for v in by_combinations]

# 3-1. initially count per-image GT samples
counts_all = {idx: dict() for idx in range(total)}
counts_all = {}
for idx in range(total):
if idx not in unlabeled:
counts_all[idx] = dict()

for idx_comb, indice in enumerate(by_combinations):
for idx in indice:
if idx_comb not in counts_all[idx]:
Expand Down Expand Up @@ -668,4 +726,8 @@ def update_nc(counts, n_combs):
by_splits[sname].append(idx)
update_nc(counts, nc)

# split unlabeled data
if len(unlabeled) > 0:
self._split_unlabeled(unlabeled, by_splits)

self._set_parts(by_splits)
91 changes: 45 additions & 46 deletions tests/test_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,29 +236,27 @@ def test_split_for_classification_zero_ratio(self):
self.assertEqual(4, len(actual.get_subset("val")))
self.assertEqual(0, len(actual.get_subset("test")))

def test_split_for_classification_gives_error(self):
def test_split_for_classification_unlabeled(self):
with self.subTest("no label"):
source = Dataset.from_iterable([
DatasetItem(1, annotations=[]),
DatasetItem(2, annotations=[]),
], categories=["a", "b", "c"])
iterable = [DatasetItem(i, annotations=[]) for i in range(10)]
source = Dataset.from_iterable(iterable, categories=["a", "b"])
splits = [("train", 0.7), ("test", 0.3)]
actual = splitter.ClassificationSplit(source, splits)

with self.assertRaisesRegex(Exception, "exactly one is expected"):
splits = [("train", 0.7), ("test", 0.3)]
actual = splitter.ClassificationSplit(source, splits)
len(actual.get_subset("train"))
self.assertEqual(7, len(actual.get_subset("train")))
self.assertEqual(3, len(actual.get_subset("test")))

with self.subTest("multi label"):
source = Dataset.from_iterable([
DatasetItem(1, annotations=[Label(0), Label(1)]),
DatasetItem(2, annotations=[Label(0), Label(2)]),
], categories=["a", "b", "c"])
anns = [Label(0), Label(1)]
iterable = [DatasetItem(i, annotations=anns) for i in range(10)]
source = Dataset.from_iterable(iterable, categories=["a", "b"])
splits = [("train", 0.7), ("test", 0.3)]
actual = splitter.ClassificationSplit(source, splits)

with self.assertRaisesRegex(Exception, "exactly one is expected"):
splits = [("train", 0.7), ("test", 0.3)]
splitter.ClassificationSplit(source, splits)
len(actual.get_subset("train"))
self.assertEqual(7, len(actual.get_subset("train")))
self.assertEqual(3, len(actual.get_subset("test")))

def test_split_for_classification_gives_error(self):
source = Dataset.from_iterable([
DatasetItem(1, annotations=[Label(0)]),
DatasetItem(2, annotations=[Label(1)]),
Expand Down Expand Up @@ -396,30 +394,27 @@ def test_split_for_reidentification_rebalance(self):
self.assertEqual(90, len(actual.get_subset("test-gallery")))
self.assertEqual(120, len(actual.get_subset("test-query")))

def test_split_for_reidentification_gives_error(self):
query = 0.4 / 0.7 # valid query ratio
def test_split_for_reidentification_unlabeled(self):
query = 0.5

with self.subTest("no label"):
source = Dataset.from_iterable([
DatasetItem(1, annotations=[]),
DatasetItem(2, annotations=[]),
], categories=["a", "b", "c"])
iterable = [DatasetItem(i, annotations=[]) for i in range(10)]
source = Dataset.from_iterable(iterable, categories=["a", "b"])
splits = [("train", 0.6), ("test", 0.4)]
actual = splitter.ReidentificationSplit(source, splits, query)
self.assertEqual(10, len(actual.get_subset("not-supported")))

with self.assertRaisesRegex(Exception, "exactly one is expected"):
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
actual = splitter.ReidentificationSplit(source, splits, query)
len(actual.get_subset("train"))
with self.subTest("multi label"):
anns = [Label(0), Label(1)]
iterable = [DatasetItem(i, annotations=anns) for i in range(10)]
source = Dataset.from_iterable(iterable, categories=["a", "b"])
splits = [("train", 0.6), ("test", 0.4)]
actual = splitter.ReidentificationSplit(source, splits, query)

with self.subTest(msg="multi label"):
source = Dataset.from_iterable([
DatasetItem(1, annotations=[Label(0), Label(1)]),
DatasetItem(2, annotations=[Label(0), Label(2)]),
], categories=["a", "b", "c"])
self.assertEqual(10, len(actual.get_subset("not-supported")))

with self.assertRaisesRegex(Exception, "exactly one is expected"):
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
actual = splitter.ReidentificationSplit(source, splits, query)
len(actual.get_subset("train"))
def test_split_for_reidentification_gives_error(self):
query = 0.4 / 0.7 # valid query ratio

counts = {i: (i % 3 + 1) * 7 for i in range(10)}
config = {"person": {"attrs": ["PID"], "counts": counts}}
Expand Down Expand Up @@ -638,18 +633,22 @@ def test_split_for_detection(self):
list(r1.get_subset("test")), list(r3.get_subset("test"))
)

def test_split_for_detection_gives_error(self):
with self.subTest(msg="bbox annotation"):
source = Dataset.from_iterable([
DatasetItem(1, annotations=[Label(0), Label(1)]),
DatasetItem(2, annotations=[Label(0), Label(2)]),
], categories=["a", "b", "c"])
def test_split_for_detection_with_unlabeled(self):
source, _ = self._generate_detection_dataset(
append_bbox=self._get_append_bbox("cvat"),
with_attr=True,
nimages=10,
)
for i in range(10):
source.put(DatasetItem(i + 10, annotations={}))

with self.assertRaisesRegex(Exception, "more than one bbox"):
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
actual = splitter.DetectionSplit(source, splits)
len(actual.get_subset("train"))
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
actual = splitter.DetectionSplit(source, splits)
self.assertEqual(10, len(actual.get_subset("train")))
self.assertEqual(4, len(actual.get_subset("val")))
self.assertEqual(6, len(actual.get_subset("test")))

def test_split_for_detection_gives_error(self):
source, _ = self._generate_detection_dataset(
append_bbox=self._get_append_bbox("cvat"),
with_attr=True,
Expand Down