Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix merged dataset item filtering #258

Merged
merged 5 commits into from
Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Incorrect image layout on saving and a problem with ecoding on loading (<https://github.com/openvinotoolkit/datumaro/pull/284>)
- An error when xpath fiter is applied to the dataset or its subset (<https://github.com/openvinotoolkit/datumaro/issues/259>)

### Security
-
Expand Down
2 changes: 1 addition & 1 deletion datumaro/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,4 +686,4 @@ def eager_mode(new_mode=True, dataset: Dataset = None):
Dataset._global_eager = new_mode
yield
finally:
Dataset._global_eager = old_mode
Dataset._global_eager = old_mode
35 changes: 24 additions & 11 deletions datumaro/components/dataset_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,30 @@ def encode_annotation(cls, o, categories=None):
def to_string(encoded_item):
return ET.tostring(encoded_item, encoding='unicode', pretty_print=True)

def XPathDatasetFilter(extractor, xpath=None):
if xpath is None:
return extractor
try:
xpath = ET.XPath(xpath)
except Exception:
log.error("Failed to create XPath from expression '%s'", xpath)
raise
f = lambda item: bool(xpath(
DatasetItemEncoder.encode(item, extractor.categories())))
return extractor.select(f)
class XPathDatasetFilter(Transform):
def __init__(self, extractor, xpath=None):
super().__init__(extractor)

if xpath is not None:
try:
xpath = ET.XPath(xpath)
except Exception:
log.error("Failed to create XPath from expression '%s'", xpath)
raise

self._f = lambda item: bool(xpath(
DatasetItemEncoder.encode(item, extractor.categories())))
else:
self._f = None

def __iter__(self):
if self._f:
if hasattr(self._extractor, 'select'):
yield from self._extractor.select(self._f)
else:
yield from filter(self._f, self._extractor)
else:
yield from self._extractor

class XPathAnnotationsFilter(Transform):
def __init__(self, extractor, xpath=None, remove_empty=False):
Expand Down
8 changes: 8 additions & 0 deletions datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@ def _find_sources_recursive(cls, path: str, ext: Optional[str],
break
return sources


class Transform(Extractor):
@staticmethod
def wrap_item(item, **kwargs):
Expand Down Expand Up @@ -748,4 +749,11 @@ def __len__(self):
return super().__len__()

def transform_item(self, item: DatasetItem) -> DatasetItem:
"""
Supposed to return a modified copy of the input item.

Avoid changing and returning the input item, because it can lead to
unexpected problems. wrap_item() can be used to simplify copying.
"""

raise NotImplementedError()
1 change: 1 addition & 0 deletions tests/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Requirements:
# GitHub issues (bugs)
# https://github.com/openvinotoolkit/datumaro/issues
DATUM_BUG_219 = "Return format is not uniform"
DATUM_BUG_259 = "Dataset.filter fails on merged datasets"


class SkipMessages:
Expand Down
57 changes: 56 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import os
import os.path as osp

import numpy as np
from unittest import TestCase

from datumaro.components.dataset_filter import (
Expand All @@ -15,6 +15,7 @@
LabelCategories, AnnotationType, Transform)
from datumaro.util.image import Image
from datumaro.util.test_utils import TestDir, compare_datasets

from .requirements import Requirements, mark_requirement


Expand Down Expand Up @@ -641,6 +642,60 @@ def test_loader():

self.assertFalse(called)

@mark_requirement(Requirements.DATUM_BUG_259)
def test_can_filter_items(self):
dataset = Dataset.from_iterable([
DatasetItem(id=0, subset='train'),
DatasetItem(id=1, subset='test'),
])

dataset.filter('/item[id > 0]')

self.assertEqual(1, len(dataset))

@mark_requirement(Requirements.DATUM_BUG_259)
def test_can_filter_annotations(self):
dataset = Dataset.from_iterable([
DatasetItem(id=0, subset='train', annotations=[Label(0), Label(1)]),
DatasetItem(id=1, subset='val', annotations=[Label(2)]),
DatasetItem(id=2, subset='test', annotations=[Label(0), Label(2)]),
], categories=['a', 'b', 'c'])

dataset.filter('/item/annotation[label = "c"]',
filter_annotations=True, remove_empty=True)

self.assertEqual(2, len(dataset))

@mark_requirement(Requirements.DATUM_BUG_259)
def test_can_filter_items_in_merged_dataset(self):
dataset = Dataset.from_extractors(
Dataset.from_iterable([ DatasetItem(id=0, subset='train') ]),
Dataset.from_iterable([ DatasetItem(id=1, subset='test') ]),
)

dataset.filter('/item[id > 0]')

self.assertEqual(1, len(dataset))

@mark_requirement(Requirements.DATUM_BUG_259)
def test_can_filter_annotations_in_merged_dataset(self):
dataset = Dataset.from_extractors(
Dataset.from_iterable([
DatasetItem(id=0, subset='train', annotations=[Label(0)]),
], categories=['a', 'b', 'c']),
Dataset.from_iterable([
DatasetItem(id=1, subset='val', annotations=[Label(1)]),
], categories=['a', 'b', 'c']),
Dataset.from_iterable([
DatasetItem(id=2, subset='test', annotations=[Label(2)]),
], categories=['a', 'b', 'c']),
)

dataset.filter('/item/annotation[label = "c"]',
filter_annotations=True, remove_empty=True)

self.assertEqual(1, len(dataset))


class DatasetItemTest(TestCase):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
Expand Down