diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a89b9b83d..3a0cdd23b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,9 +19,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - New command `describe-downloads` to print information about downloadable datasets () +- Detection for Cityscapes format + () +- Maximum recursion `--depth` parameter for `detect-dataset` CLI command + () ### Changed -- TBD +- `env.detect_dataset()` now returns a list of detected formats at all recursion levels + instead of just the lowest one + () +- Open Images: allowed to store annotations file in root path as well + () ### Deprecated - `--save-images` is replaced with `--save-media` in CLI and converter API @@ -34,7 +42,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - TBD ### Fixed -- TBD +- Detection for LFW format + () ### Security - TBD diff --git a/datumaro/cli/commands/detect_format.py b/datumaro/cli/commands/detect_format.py index c84c6a2827..4251e43270 100644 --- a/datumaro/cli/commands/detect_format.py +++ b/datumaro/cli/commands/detect_format.py @@ -8,7 +8,7 @@ from datumaro.cli.util.project import load_project from datumaro.components.environment import Environment from datumaro.components.errors import ProjectNotFoundError -from datumaro.components.format_detection import RejectionReason, detect_dataset_format +from datumaro.components.format_detection import RejectionReason from datumaro.util import dump_json_file from datumaro.util.scope import scope_add, scoped @@ -53,6 +53,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): help="Path to which to save a JSON report describing detected " "and rejected formats. By default, no report is saved.", ) + parser.add_argument("--depth", help="The maximum depth for recursive search (default: 2) ") parser.set_defaults(command=detect_format_command) return parser @@ -90,10 +91,9 @@ def rejection_callback( "message": human_message, } - detected_formats = detect_dataset_format( - ((format_name, importer.detect) for format_name, importer in env.importers.items.items()), - args.url, - rejection_callback=rejection_callback, + depth = 2 if not args.depth else int(args.depth) + detected_formats = env.detect_dataset( + args.url, rejection_callback=rejection_callback, depth=depth ) report["detected_formats"] = detected_formats diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index f49b0ab695..d5309dadf1 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -4,7 +4,6 @@ from __future__ import annotations -import glob import inspect import logging as log import os @@ -1227,7 +1226,7 @@ def import_from( return dataset @staticmethod - def detect(path: str, *, env: Optional[Environment] = None, depth: int = 1) -> str: + def detect(path: str, *, env: Optional[Environment] = None, depth: int = 2) -> str: """ Attempts to detect dataset format of a given directory. @@ -1247,21 +1246,13 @@ def detect(path: str, *, env: Optional[Environment] = None, depth: int = 1) -> s if depth < 0: raise ValueError("Depth cannot be less than zero") - for _ in range(depth + 1): - matches = env.detect_dataset(path) - if matches and len(matches) == 1: - return matches[0] - - paths = glob.glob(osp.join(path, "*")) - path = "" if len(paths) != 1 else paths[0] - ignore_dirs = {"__MSOSX", "__MACOSX"} - if not osp.isdir(path) or osp.basename(path) in ignore_dirs: - break - + matches = env.detect_dataset(path, depth=depth) if not matches: raise NoMatchingFormatsError() - if 1 < len(matches): + elif 1 < len(matches): raise MultipleFormatsMatchError(matches) + else: + return matches[0] @contextmanager diff --git a/datumaro/components/environment.py b/datumaro/components/environment.py index 0059dac05e..2f1fa83415 100644 --- a/datumaro/components/environment.py +++ b/datumaro/components/environment.py @@ -8,10 +8,10 @@ import os.path as osp from functools import partial from inspect import isclass -from typing import Callable, Dict, Generic, Iterable, Iterator, Optional, Type, TypeVar +from typing import Callable, Dict, Generic, Iterable, Iterator, List, Optional, Type, TypeVar from datumaro.components.cli_plugin import CliPlugin, plugin_types -from datumaro.components.format_detection import detect_dataset_format +from datumaro.components.format_detection import RejectionReason, detect_dataset_format from datumaro.util.os_util import import_foreign_module, split_path T = TypeVar("T") @@ -240,11 +240,32 @@ def make_transform(self, name, *args, **kwargs): def is_format_known(self, name): return name in self.importers or name in self.extractors - def detect_dataset(self, path): - return detect_dataset_format( - ( - (format_name, importer.detect) - for format_name, importer in self.importers.items.items() - ), - path, - ) + def detect_dataset( + self, + path: str, + depth: int = 1, + rejection_callback: Optional[Callable[[str, RejectionReason, str], None]] = None, + ) -> List[str]: + ignore_dirs = {"__MSOSX", "__MACOSX"} + matched_formats = set() + for _ in range(depth + 1): + detected_formats = detect_dataset_format( + ( + (format_name, importer.detect) + for format_name, importer in self.importers.items.items() + ), + path, + rejection_callback=rejection_callback, + ) + + if detected_formats and len(detected_formats) == 1: + return detected_formats + elif detected_formats: + matched_formats |= set(detected_formats) + + paths = glob.glob(osp.join(path, "*")) + path = "" if len(paths) != 1 else paths[0] + if not osp.isdir(path) or osp.basename(path) in ignore_dirs: + break + + return list(matched_formats) diff --git a/datumaro/plugins/cityscapes_format.py b/datumaro/plugins/cityscapes_format.py index b67c096d43..cf2873ab59 100644 --- a/datumaro/plugins/cityscapes_format.py +++ b/datumaro/plugins/cityscapes_format.py @@ -22,6 +22,7 @@ from datumaro.components.dataset import ItemStatus from datumaro.components.errors import MediaTypeError from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor +from datumaro.components.format_detection import FormatDetectionContext from datumaro.components.media import Image from datumaro.util import find from datumaro.util.annotation_util import make_label_id_mapping @@ -299,6 +300,19 @@ def _lazy_extract_mask(mask, c): class CityscapesImporter(Importer): + @classmethod + def detect(cls, context: FormatDetectionContext) -> None: + patterns = [ + f"{CityscapesPath.GT_FINE_DIR}/**/*{CityscapesPath.GT_INSTANCE_MASK_SUFFIX}", + f"{CityscapesPath.GT_FINE_DIR}/**/*{CityscapesPath.LABEL_TRAIN_IDS_SUFFIX}", + f"{CityscapesPath.IMGS_FINE_DIR}/{CityscapesPath.ORIGINAL_IMAGE_DIR}" + f"/**/*{CityscapesPath.ORIGINAL_IMAGE}.*", + ] + with context.require_any(): + for pattern in patterns: + with context.alternative(): + context.require_file(pattern) + @classmethod def find_sources(cls, path): sources = cls._find_sources_recursive( diff --git a/datumaro/plugins/lfw_format.py b/datumaro/plugins/lfw_format.py index 9a80ca81d3..6bf3d4a9b6 100644 --- a/datumaro/plugins/lfw_format.py +++ b/datumaro/plugins/lfw_format.py @@ -216,7 +216,7 @@ def get_image_name(person, image_id): class LfwImporter(Importer): @classmethod def detect(cls, context: FormatDetectionContext) -> None: - context.require_file(f"*/{LfwPath.ANNOTATION_DIR}/{LfwPath.PAIRS_FILE}") + context.require_file(f"{LfwPath.ANNOTATION_DIR}/{LfwPath.PAIRS_FILE}") @classmethod def find_sources(cls, path): diff --git a/datumaro/plugins/open_images_format.py b/datumaro/plugins/open_images_format.py index 12e1e7c768..f29617117f 100644 --- a/datumaro/plugins/open_images_format.py +++ b/datumaro/plugins/open_images_format.py @@ -175,7 +175,10 @@ def __init__(self, path, image_meta=None): self._dataset_dir = path - self._annotation_files = os.listdir(osp.join(path, OpenImagesPath.ANNOTATIONS_DIR)) + self._annotation_dir = osp.join(path, OpenImagesPath.ANNOTATIONS_DIR) + if not osp.exists(self._annotation_dir): + self._annotation_dir = path + self._annotation_files = os.listdir(self._annotation_dir) self._categories = {} self._items = [] @@ -188,7 +191,7 @@ def __init__(self, path, image_meta=None): elif image_meta is None: try: self._image_meta = load_image_meta_file( - osp.join(path, OpenImagesPath.ANNOTATIONS_DIR, DEFAULT_IMAGE_META_FILE_NAME) + osp.join(self._annotation_dir, DEFAULT_IMAGE_META_FILE_NAME) ) except FileNotFoundError: self._image_meta = {} @@ -209,7 +212,7 @@ def categories(self): @contextlib.contextmanager def _open_csv_annotation(self, file_name): - absolute_path = osp.join(self._dataset_dir, OpenImagesPath.ANNOTATIONS_DIR, file_name) + absolute_path = osp.join(self._annotation_dir, file_name) with open(absolute_path, "r", encoding="utf-8", newline="") as f: yield csv.DictReader(f) @@ -266,9 +269,7 @@ def _load_categories(self): def _load_label_category_parents(self): label_categories = self._categories[AnnotationType.label] - hierarchy_path = osp.join( - self._dataset_dir, OpenImagesPath.ANNOTATIONS_DIR, OpenImagesPath.HIERARCHY_FILE_NAME - ) + hierarchy_path = osp.join(self._annotation_dir, OpenImagesPath.HIERARCHY_FILE_NAME) try: root_node = parse_json_file(hierarchy_path) @@ -590,17 +591,20 @@ class OpenImagesImporter(Importer): @classmethod def detect(cls, context: FormatDetectionContext) -> None: + ann_dirs = [f"{OpenImagesPath.ANNOTATIONS_DIR}/", ""] + ann_patterns = itertools.product(ann_dirs, cls.POSSIBLE_ANNOTATION_PATTERNS) with context.require_any(): - for pattern in cls.POSSIBLE_ANNOTATION_PATTERNS: + for ann_dir, ann_file in ann_patterns: with context.alternative(): - context.require_file(f"{OpenImagesPath.ANNOTATIONS_DIR}/{pattern}") + context.require_file(ann_dir + ann_file) @classmethod def find_sources(cls, path): for pattern in cls.POSSIBLE_ANNOTATION_PATTERNS: if glob.glob(osp.join(glob.escape(path), OpenImagesPath.ANNOTATIONS_DIR, pattern)): return [{"url": path, "format": OpenImagesExtractor.NAME}] - + elif glob.glob(osp.join(glob.escape(path), pattern)): + return [{"url": path, "format": OpenImagesExtractor.NAME}] return [] diff --git a/site/content/en/docs/formats/open_images.md b/site/content/en/docs/formats/open_images.md index 2b2f3408ce..01cafd1c07 100644 --- a/site/content/en/docs/formats/open_images.md +++ b/site/content/en/docs/formats/open_images.md @@ -191,6 +191,8 @@ The mask images must be extracted from the ZIP archives linked above. To use per-subset image description files instead of `image_ids_and_rotation.csv`, place them in the `annotations` subdirectory. +The `annotations` directory is optional and you can store all annotation files +in the root of input path. To add custom classes, you can use [`dataset_meta.json`](/docs/user-manual/supported_formats/#dataset-meta-file). diff --git a/tests/cli/test_detect_format.py b/tests/cli/test_detect_format.py index cbd6b2836f..f7ed1fb1c6 100644 --- a/tests/cli/test_detect_format.py +++ b/tests/cli/test_detect_format.py @@ -9,6 +9,7 @@ from datumaro.plugins.ade20k2017_format import Ade20k2017Importer from datumaro.plugins.ade20k2020_format import Ade20k2020Importer from datumaro.plugins.image_dir_format import ImageDirImporter +from datumaro.plugins.lfw_format import LfwImporter from datumaro.util.os_util import suppress_output from datumaro.util.test_utils import TestDir from datumaro.util.test_utils import run_datum as run @@ -17,6 +18,7 @@ ADE20K2017_DIR = osp.join(osp.dirname(__file__), "../assets/ade20k2017_dataset/dataset") ADE20K2020_DIR = osp.join(osp.dirname(__file__), "../assets/ade20k2020_dataset/dataset") +LFW_DIR = osp.join(osp.dirname(__file__), "../assets/lfw_dataset") class DetectFormatTest(TestCase): @@ -32,6 +34,38 @@ def test_unambiguous(self): self.assertIn(Ade20k2017Importer.NAME, output) self.assertNotIn(Ade20k2020Importer.NAME, output) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_deep_nested_folders(self): + with TestDir() as test_dir: + output_file = io.StringIO() + + annotation_dir = osp.join(test_dir, "a", "b", "c", "annotations") + os.makedirs(annotation_dir) + shutil.copy(osp.join(LFW_DIR, "test", "annotations", "pairs.txt"), annotation_dir) + + with contextlib.redirect_stdout(output_file): + run(self, "detect-format", test_dir, "--depth", "3") + + output = output_file.getvalue() + + self.assertIn(LfwImporter.NAME, output) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_nested_folders(self): + with TestDir() as test_dir: + output_file = io.StringIO() + + annotation_dir = osp.join(test_dir, "a", "training/street") + os.makedirs(annotation_dir) + shutil.copy(osp.join(ADE20K2020_DIR, "training/street/1.json"), annotation_dir) + + with contextlib.redirect_stdout(output_file): + run(self, "detect-format", test_dir) + + output = output_file.getvalue() + + self.assertIn(Ade20k2020Importer.NAME, output) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_ambiguous(self): with TestDir() as test_dir: diff --git a/tests/test_cityscapes_format.py b/tests/test_cityscapes_format.py index ff00dbcfbe..11ac4766b5 100644 --- a/tests/test_cityscapes_format.py +++ b/tests/test_cityscapes_format.py @@ -184,7 +184,7 @@ def test_can_import_with_train_label_map(self): @mark_requirement(Requirements.DATUM_267) def test_can_detect_cityscapes(self): detected_formats = Environment().detect_dataset(DUMMY_DATASET_DIR) - self.assertIn(CityscapesImporter.NAME, detected_formats) + self.assertEquals([CityscapesImporter.NAME], detected_formats) class TestExtractorBase(Extractor):