Skip to content

Commit

Permalink
Fix merged dataset item filtering (#258)
Browse files Browse the repository at this point in the history
* Add tests

* Fix xpathfilter transform

* Update changelog
  • Loading branch information
Maxim Zhiltsov authored Jun 15, 2021
1 parent b86a6eb commit 11484ed
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Incorrect image layout on saving and a problem with ecoding on loading (<https://github.com/openvinotoolkit/datumaro/pull/284>)
- An error when xpath fiter is applied to the dataset or its subset (<https://github.com/openvinotoolkit/datumaro/issues/259>)

### Security
-
Expand Down
2 changes: 1 addition & 1 deletion datumaro/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,4 +689,4 @@ def eager_mode(new_mode=True, dataset: Dataset = None):
Dataset._global_eager = new_mode
yield
finally:
Dataset._global_eager = old_mode
Dataset._global_eager = old_mode
35 changes: 24 additions & 11 deletions datumaro/components/dataset_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,30 @@ def encode_annotation(cls, o, categories=None):
def to_string(encoded_item):
return ET.tostring(encoded_item, encoding='unicode', pretty_print=True)

def XPathDatasetFilter(extractor, xpath=None):
if xpath is None:
return extractor
try:
xpath = ET.XPath(xpath)
except Exception:
log.error("Failed to create XPath from expression '%s'", xpath)
raise
f = lambda item: bool(xpath(
DatasetItemEncoder.encode(item, extractor.categories())))
return extractor.select(f)
class XPathDatasetFilter(Transform):
def __init__(self, extractor, xpath=None):
super().__init__(extractor)

if xpath is not None:
try:
xpath = ET.XPath(xpath)
except Exception:
log.error("Failed to create XPath from expression '%s'", xpath)
raise

self._f = lambda item: bool(xpath(
DatasetItemEncoder.encode(item, extractor.categories())))
else:
self._f = None

def __iter__(self):
if self._f:
if hasattr(self._extractor, 'select'):
yield from self._extractor.select(self._f)
else:
yield from filter(self._f, self._extractor)
else:
yield from self._extractor

class XPathAnnotationsFilter(Transform):
def __init__(self, extractor, xpath=None, remove_empty=False):
Expand Down
8 changes: 8 additions & 0 deletions datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ def _find_sources_recursive(cls, path: str, ext: Optional[str],
break
return sources


class Transform(Extractor):
@staticmethod
def wrap_item(item, **kwargs):
Expand Down Expand Up @@ -746,4 +747,11 @@ def __len__(self):
return super().__len__()

def transform_item(self, item: DatasetItem) -> DatasetItem:
"""
Supposed to return a modified copy of the input item.
Avoid changing and returning the input item, because it can lead to
unexpected problems. wrap_item() can be used to simplify copying.
"""

raise NotImplementedError()
1 change: 1 addition & 0 deletions tests/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Requirements:
# GitHub issues (bugs)
# https://github.com/openvinotoolkit/datumaro/issues
DATUM_BUG_219 = "Return format is not uniform"
DATUM_BUG_259 = "Dataset.filter fails on merged datasets"


class SkipMessages:
Expand Down
57 changes: 56 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import os
import os.path as osp

import numpy as np
from unittest import TestCase

from datumaro.components.dataset_filter import (
Expand All @@ -15,6 +15,7 @@
LabelCategories, AnnotationType, Transform)
from datumaro.util.image import Image
from datumaro.util.test_utils import TestDir, compare_datasets

from .requirements import Requirements, mark_requirement


Expand Down Expand Up @@ -641,6 +642,60 @@ def test_loader():

self.assertFalse(called)

@mark_requirement(Requirements.DATUM_BUG_259)
def test_can_filter_items(self):
dataset = Dataset.from_iterable([
DatasetItem(id=0, subset='train'),
DatasetItem(id=1, subset='test'),
])

dataset.filter('/item[id > 0]')

self.assertEqual(1, len(dataset))

@mark_requirement(Requirements.DATUM_BUG_259)
def test_can_filter_annotations(self):
dataset = Dataset.from_iterable([
DatasetItem(id=0, subset='train', annotations=[Label(0), Label(1)]),
DatasetItem(id=1, subset='val', annotations=[Label(2)]),
DatasetItem(id=2, subset='test', annotations=[Label(0), Label(2)]),
], categories=['a', 'b', 'c'])

dataset.filter('/item/annotation[label = "c"]',
filter_annotations=True, remove_empty=True)

self.assertEqual(2, len(dataset))

@mark_requirement(Requirements.DATUM_BUG_259)
def test_can_filter_items_in_merged_dataset(self):
dataset = Dataset.from_extractors(
Dataset.from_iterable([ DatasetItem(id=0, subset='train') ]),
Dataset.from_iterable([ DatasetItem(id=1, subset='test') ]),
)

dataset.filter('/item[id > 0]')

self.assertEqual(1, len(dataset))

@mark_requirement(Requirements.DATUM_BUG_259)
def test_can_filter_annotations_in_merged_dataset(self):
dataset = Dataset.from_extractors(
Dataset.from_iterable([
DatasetItem(id=0, subset='train', annotations=[Label(0)]),
], categories=['a', 'b', 'c']),
Dataset.from_iterable([
DatasetItem(id=1, subset='val', annotations=[Label(1)]),
], categories=['a', 'b', 'c']),
Dataset.from_iterable([
DatasetItem(id=2, subset='test', annotations=[Label(2)]),
], categories=['a', 'b', 'c']),
)

dataset.filter('/item/annotation[label = "c"]',
filter_annotations=True, remove_empty=True)

self.assertEqual(1, len(dataset))


class DatasetItemTest(TestCase):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
Expand Down

0 comments on commit 11484ed

Please sign in to comment.