From 11484ed3ab52c31f10e7b70ba4065c1cdf943ab0 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 15 Jun 2021 12:52:57 +0300 Subject: [PATCH] Fix merged dataset item filtering (#258) * Add tests * Fix xpathfilter transform * Update changelog --- CHANGELOG.md | 1 + datumaro/components/dataset.py | 2 +- datumaro/components/dataset_filter.py | 35 ++++++++++------ datumaro/components/extractor.py | 8 ++++ tests/requirements.py | 1 + tests/test_dataset.py | 57 ++++++++++++++++++++++++++- 6 files changed, 91 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 00fbbcee2b..32c27abf99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Incorrect image layout on saving and a problem with ecoding on loading () +- An error when xpath fiter is applied to the dataset or its subset () ### Security - diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index 2e9313200f..ea82cc6fc9 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -689,4 +689,4 @@ def eager_mode(new_mode=True, dataset: Dataset = None): Dataset._global_eager = new_mode yield finally: - Dataset._global_eager = old_mode \ No newline at end of file + Dataset._global_eager = old_mode diff --git a/datumaro/components/dataset_filter.py b/datumaro/components/dataset_filter.py index 2fe1443d51..68c30867ab 100644 --- a/datumaro/components/dataset_filter.py +++ b/datumaro/components/dataset_filter.py @@ -213,17 +213,30 @@ def encode_annotation(cls, o, categories=None): def to_string(encoded_item): return ET.tostring(encoded_item, encoding='unicode', pretty_print=True) -def XPathDatasetFilter(extractor, xpath=None): - if xpath is None: - return extractor - try: - xpath = ET.XPath(xpath) - except Exception: - log.error("Failed to create XPath from expression '%s'", xpath) - raise - f = lambda item: bool(xpath( - DatasetItemEncoder.encode(item, extractor.categories()))) - return extractor.select(f) +class XPathDatasetFilter(Transform): + def __init__(self, extractor, xpath=None): + super().__init__(extractor) + + if xpath is not None: + try: + xpath = ET.XPath(xpath) + except Exception: + log.error("Failed to create XPath from expression '%s'", xpath) + raise + + self._f = lambda item: bool(xpath( + DatasetItemEncoder.encode(item, extractor.categories()))) + else: + self._f = None + + def __iter__(self): + if self._f: + if hasattr(self._extractor, 'select'): + yield from self._extractor.select(self._f) + else: + yield from filter(self._f, self._extractor) + else: + yield from self._extractor class XPathAnnotationsFilter(Transform): def __init__(self, extractor, xpath=None, remove_empty=False): diff --git a/datumaro/components/extractor.py b/datumaro/components/extractor.py index 52c50a31d4..91a5654f0b 100644 --- a/datumaro/components/extractor.py +++ b/datumaro/components/extractor.py @@ -715,6 +715,7 @@ def _find_sources_recursive(cls, path: str, ext: Optional[str], break return sources + class Transform(Extractor): @staticmethod def wrap_item(item, **kwargs): @@ -746,4 +747,11 @@ def __len__(self): return super().__len__() def transform_item(self, item: DatasetItem) -> DatasetItem: + """ + Supposed to return a modified copy of the input item. + + Avoid changing and returning the input item, because it can lead to + unexpected problems. wrap_item() can be used to simplify copying. + """ + raise NotImplementedError() diff --git a/tests/requirements.py b/tests/requirements.py index 6be36ef676..37e9a9722b 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -26,6 +26,7 @@ class Requirements: # GitHub issues (bugs) # https://github.com/openvinotoolkit/datumaro/issues DATUM_BUG_219 = "Return format is not uniform" + DATUM_BUG_259 = "Dataset.filter fails on merged datasets" class SkipMessages: diff --git a/tests/test_dataset.py b/tests/test_dataset.py index f8f7f0a085..a221c01f74 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,7 +1,7 @@ -import numpy as np import os import os.path as osp +import numpy as np from unittest import TestCase from datumaro.components.dataset_filter import ( @@ -15,6 +15,7 @@ LabelCategories, AnnotationType, Transform) from datumaro.util.image import Image from datumaro.util.test_utils import TestDir, compare_datasets + from .requirements import Requirements, mark_requirement @@ -641,6 +642,60 @@ def test_loader(): self.assertFalse(called) + @mark_requirement(Requirements.DATUM_BUG_259) + def test_can_filter_items(self): + dataset = Dataset.from_iterable([ + DatasetItem(id=0, subset='train'), + DatasetItem(id=1, subset='test'), + ]) + + dataset.filter('/item[id > 0]') + + self.assertEqual(1, len(dataset)) + + @mark_requirement(Requirements.DATUM_BUG_259) + def test_can_filter_annotations(self): + dataset = Dataset.from_iterable([ + DatasetItem(id=0, subset='train', annotations=[Label(0), Label(1)]), + DatasetItem(id=1, subset='val', annotations=[Label(2)]), + DatasetItem(id=2, subset='test', annotations=[Label(0), Label(2)]), + ], categories=['a', 'b', 'c']) + + dataset.filter('/item/annotation[label = "c"]', + filter_annotations=True, remove_empty=True) + + self.assertEqual(2, len(dataset)) + + @mark_requirement(Requirements.DATUM_BUG_259) + def test_can_filter_items_in_merged_dataset(self): + dataset = Dataset.from_extractors( + Dataset.from_iterable([ DatasetItem(id=0, subset='train') ]), + Dataset.from_iterable([ DatasetItem(id=1, subset='test') ]), + ) + + dataset.filter('/item[id > 0]') + + self.assertEqual(1, len(dataset)) + + @mark_requirement(Requirements.DATUM_BUG_259) + def test_can_filter_annotations_in_merged_dataset(self): + dataset = Dataset.from_extractors( + Dataset.from_iterable([ + DatasetItem(id=0, subset='train', annotations=[Label(0)]), + ], categories=['a', 'b', 'c']), + Dataset.from_iterable([ + DatasetItem(id=1, subset='val', annotations=[Label(1)]), + ], categories=['a', 'b', 'c']), + Dataset.from_iterable([ + DatasetItem(id=2, subset='test', annotations=[Label(2)]), + ], categories=['a', 'b', 'c']), + ) + + dataset.filter('/item/annotation[label = "c"]', + filter_annotations=True, remove_empty=True) + + self.assertEqual(1, len(dataset)) + class DatasetItemTest(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ)