From 7501a22ff9f133b841457902dd4ce3b91e1cbd92 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 31 May 2021 15:12:51 +0300 Subject: [PATCH 1/4] Add tests --- tests/test_dataset.py | 50 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index cb8a776c02..3bf6ad25ff 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -608,6 +608,56 @@ def test_loader(): self.assertFalse(called) + def test_can_filter_items(self): + dataset = Dataset.from_iterable([ + DatasetItem(id=0, subset='train'), + DatasetItem(id=1, subset='test'), + ]) + + dataset.filter('/item[id > 0]') + + self.assertEqual(1, len(dataset)) + + def test_can_filter_annotations(self): + dataset = Dataset.from_iterable([ + DatasetItem(id=0, subset='train', annotations=[Label(0), Label(1)]), + DatasetItem(id=1, subset='val', annotations=[Label(2)]), + DatasetItem(id=2, subset='test', annotations=[Label(0), Label(2)]), + ], categories=['a', 'b', 'c']) + + dataset.filter('/item/annotation[label = "c"]', + filter_annotations=True, remove_empty=True) + + self.assertEqual(2, len(dataset)) + + def test_can_filter_items_in_merged_dataset(self): + dataset = Dataset.from_extractors( + Dataset.from_iterable([ DatasetItem(id=0, subset='train') ]), + Dataset.from_iterable([ DatasetItem(id=1, subset='test') ]), + ) + + dataset.filter('/item[id > 0]') + + self.assertEqual(1, len(dataset)) + + def test_can_filter_annotations_in_merged_dataset(self): + dataset = Dataset.from_extractors( + Dataset.from_iterable([ + DatasetItem(id=0, subset='train', annotations=[Label(0)]), + ], categories=['a', 'b', 'c']), + Dataset.from_iterable([ + DatasetItem(id=1, subset='val', annotations=[Label(1)]), + ], categories=['a', 'b', 'c']), + Dataset.from_iterable([ + DatasetItem(id=2, subset='test', annotations=[Label(2)]), + ], categories=['a', 'b', 'c']), + ) + + dataset.filter('/item/annotation[label = "c"]', + filter_annotations=True, remove_empty=True) + + self.assertEqual(1, len(dataset)) + class DatasetItemTest(TestCase): def test_ctor_requires_id(self): From 1d70a2682d8b886a47c2ef966f9959420720895e Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 15 Jun 2021 12:34:34 +0300 Subject: [PATCH 2/4] Fix xpathfilter transform --- datumaro/components/dataset.py | 2 +- datumaro/components/dataset_filter.py | 37 +++++++++++++++++++-------- datumaro/components/extractor.py | 8 ++++++ tests/requirements.py | 1 + tests/test_dataset.py | 7 ++++- 5 files changed, 42 insertions(+), 13 deletions(-) diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index e8e1b1d813..20e46ed8fa 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -686,4 +686,4 @@ def eager_mode(new_mode=True, dataset: Dataset = None): Dataset._global_eager = new_mode yield finally: - Dataset._global_eager = old_mode \ No newline at end of file + Dataset._global_eager = old_mode diff --git a/datumaro/components/dataset_filter.py b/datumaro/components/dataset_filter.py index 2fe1443d51..21fe84ce56 100644 --- a/datumaro/components/dataset_filter.py +++ b/datumaro/components/dataset_filter.py @@ -213,17 +213,32 @@ def encode_annotation(cls, o, categories=None): def to_string(encoded_item): return ET.tostring(encoded_item, encoding='unicode', pretty_print=True) -def XPathDatasetFilter(extractor, xpath=None): - if xpath is None: - return extractor - try: - xpath = ET.XPath(xpath) - except Exception: - log.error("Failed to create XPath from expression '%s'", xpath) - raise - f = lambda item: bool(xpath( - DatasetItemEncoder.encode(item, extractor.categories()))) - return extractor.select(f) +class XPathDatasetFilter(Transform): + def __init__(self, extractor, xpath=None): + super().__init__(extractor) + + if xpath is not None: + try: + xpath = ET.XPath(xpath) + except Exception: + log.error("Failed to create XPath from expression '%s'", xpath) + raise + + self._f = lambda item: bool(xpath( + DatasetItemEncoder.encode(item, extractor.categories()))) + else: + self._f = None + + def __iter__(self): + if self._f: + if hasattr(self._extractor, 'select'): + yield from self._extractor.select(self._f) + else: + for item in self._extractor: + if self._f(item): + yield item + else: + yield from self._extractor class XPathAnnotationsFilter(Transform): def __init__(self, extractor, xpath=None, remove_empty=False): diff --git a/datumaro/components/extractor.py b/datumaro/components/extractor.py index 708a6f544e..46474ca316 100644 --- a/datumaro/components/extractor.py +++ b/datumaro/components/extractor.py @@ -717,6 +717,7 @@ def _find_sources_recursive(cls, path: str, ext: Optional[str], break return sources + class Transform(Extractor): @staticmethod def wrap_item(item, **kwargs): @@ -748,4 +749,11 @@ def __len__(self): return super().__len__() def transform_item(self, item: DatasetItem) -> DatasetItem: + """ + Supposed to return a modified copy of the input item. + + Avoid changing and returning the input item, because it can lead to + unexpected problems. wrap_item() can be used to simplify copying. + """ + raise NotImplementedError() diff --git a/tests/requirements.py b/tests/requirements.py index 6be36ef676..37e9a9722b 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -26,6 +26,7 @@ class Requirements: # GitHub issues (bugs) # https://github.com/openvinotoolkit/datumaro/issues DATUM_BUG_219 = "Return format is not uniform" + DATUM_BUG_259 = "Dataset.filter fails on merged datasets" class SkipMessages: diff --git a/tests/test_dataset.py b/tests/test_dataset.py index f8443a7e9a..a221c01f74 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,7 +1,7 @@ -import numpy as np import os import os.path as osp +import numpy as np from unittest import TestCase from datumaro.components.dataset_filter import ( @@ -15,6 +15,7 @@ LabelCategories, AnnotationType, Transform) from datumaro.util.image import Image from datumaro.util.test_utils import TestDir, compare_datasets + from .requirements import Requirements, mark_requirement @@ -641,6 +642,7 @@ def test_loader(): self.assertFalse(called) + @mark_requirement(Requirements.DATUM_BUG_259) def test_can_filter_items(self): dataset = Dataset.from_iterable([ DatasetItem(id=0, subset='train'), @@ -651,6 +653,7 @@ def test_can_filter_items(self): self.assertEqual(1, len(dataset)) + @mark_requirement(Requirements.DATUM_BUG_259) def test_can_filter_annotations(self): dataset = Dataset.from_iterable([ DatasetItem(id=0, subset='train', annotations=[Label(0), Label(1)]), @@ -663,6 +666,7 @@ def test_can_filter_annotations(self): self.assertEqual(2, len(dataset)) + @mark_requirement(Requirements.DATUM_BUG_259) def test_can_filter_items_in_merged_dataset(self): dataset = Dataset.from_extractors( Dataset.from_iterable([ DatasetItem(id=0, subset='train') ]), @@ -673,6 +677,7 @@ def test_can_filter_items_in_merged_dataset(self): self.assertEqual(1, len(dataset)) + @mark_requirement(Requirements.DATUM_BUG_259) def test_can_filter_annotations_in_merged_dataset(self): dataset = Dataset.from_extractors( Dataset.from_iterable([ From 9ea854a2935eeb75a37796c102d29cd4a4b2b026 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 15 Jun 2021 12:37:10 +0300 Subject: [PATCH 3/4] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 56401a9ffe..2c07fa92e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Incorrect image layout on saving and a problem with ecoding on loading () +- An error when xpath fiter is applied to the dataset or its subset () ### Security - From 9bce132c5866c17fa4b177f3dab4ddadace75731 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 15 Jun 2021 12:50:05 +0300 Subject: [PATCH 4/4] Use filter function --- datumaro/components/dataset_filter.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datumaro/components/dataset_filter.py b/datumaro/components/dataset_filter.py index 21fe84ce56..68c30867ab 100644 --- a/datumaro/components/dataset_filter.py +++ b/datumaro/components/dataset_filter.py @@ -234,9 +234,7 @@ def __iter__(self): if hasattr(self._extractor, 'select'): yield from self._extractor.select(self._f) else: - for item in self._extractor: - if self._f(item): - yield item + yield from filter(self._f, self._extractor) else: yield from self._extractor